numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,2486 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Batched Dot Products for RISC-V.
|
|
3
|
+
* @file include/numkong/dots/rvv.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dots.h
|
|
8
|
+
*
|
|
9
|
+
* Custom RVV-native register-tiled GEMM implementation, analogous to how AMX
|
|
10
|
+
* (dots/sapphireamx.h) and SME (dots/sme.h) each have their own unique implementations
|
|
11
|
+
* independent of the cross-product macros.
|
|
12
|
+
*
|
|
13
|
+
* RVV's variable-length vectors and widening multiply-accumulate (`vfwmacc`) make it
|
|
14
|
+
* fundamentally different from fixed-width SIMD. Key design choices:
|
|
15
|
+
*
|
|
16
|
+
* - f32 GEMM: Uses `vfwmacc_vv_f64m4` for f64 accumulation (vector-vector widened FMA),
|
|
17
|
+
* Process 4 rows per tile (rows_per_tile=4). Narrowed to f32 on store.
|
|
18
|
+
* - f64 GEMM: Uses `vfmul`+Kahan with Kahan compensation,
|
|
19
|
+
* Process 2 rows per tile (rows_per_tile=2, tighter register budget at LMUL=4).
|
|
20
|
+
* - B packing: Column-panel layout with cache-line padding. Each depth step stores
|
|
21
|
+
* contiguous elements along depth — one `vle32`/`vle64` per vectorized chunk.
|
|
22
|
+
* - Edge handling: RVV's `vsetvl` returns actual VL for partial vectors — no separate
|
|
23
|
+
* edge kernel needed.
|
|
24
|
+
* - Vectorization axis: depth (k dimension). Each inner loop iteration loads a chunk of
|
|
25
|
+
* both A and B along depth, computing element-wise widened FMA.
|
|
26
|
+
*
|
|
27
|
+
* - e2m3 GEMM: Integer arithmetic via LUT (5-bit magnitude → i8 value×16).
|
|
28
|
+
* B is pre-packed as signed i8. A is converted on-the-fly via `vluxei8` gather.
|
|
29
|
+
* Uses `vwmul` (i8→i16) then `vwadd_wv` (i32+=i16) for K-vectorized accumulation.
|
|
30
|
+
* Final result scaled by 1/256. Process 4 rows per tile (rows_per_tile=4).
|
|
31
|
+
* - e3m2 GEMM: Integer arithmetic via LUT (5-bit magnitude → i16 value×16).
|
|
32
|
+
* B is pre-packed as signed i16. A is converted on-the-fly via `vluxei16` gather.
|
|
33
|
+
* Uses `vwmacc` (i16×i16→i32) for K-vectorized widening MAC.
|
|
34
|
+
* Final result scaled by 1/256. Process 2 rows per tile (rows_per_tile=2, wider accumulator elements).
|
|
35
|
+
* - e4m3 GEMM: f32 LUT gather (7-bit magnitude → f32 bit pattern, 128 entries).
|
|
36
|
+
* B is pre-packed as f32. A is converted on-the-fly via `vluxei32` gather with
|
|
37
|
+
* sign injection (bit 7 → bit 31). Uses `vfwmacc_vv_f64m4` for f64 accumulation.
|
|
38
|
+
* Process 2 rows per tile (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
|
|
39
|
+
* - e5m2 GEMM: Same f32 LUT gather approach as e4m3, different LUT contents.
|
|
40
|
+
* E5M2 has 5 exponent bits (wider range, lower precision than e4m3).
|
|
41
|
+
* Process 2 rows per tile (rows_per_tile=2).
|
|
42
|
+
*/
|
|
43
|
+
#ifndef NK_DOTS_RVV_H
|
|
44
|
+
#define NK_DOTS_RVV_H
|
|
45
|
+
|
|
46
|
+
#if NK_TARGET_RISCV_
|
|
47
|
+
#if NK_TARGET_RVV
|
|
48
|
+
|
|
49
|
+
#include "numkong/types.h"
|
|
50
|
+
#include "numkong/dots/serial.h"
|
|
51
|
+
#include "numkong/cast/rvv.h" // `nk_bf16m1_to_f32m2_rvv_`
|
|
52
|
+
|
|
53
|
+
#if defined(__clang__)
|
|
54
|
+
#pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
|
|
55
|
+
#elif defined(__GNUC__)
|
|
56
|
+
#pragma GCC push_options
|
|
57
|
+
#pragma GCC target("arch=+v")
|
|
58
|
+
#endif
|
|
59
|
+
|
|
60
|
+
#if defined(__cplusplus)
|
|
61
|
+
extern "C" {
|
|
62
|
+
#endif
|
|
63
|
+
|
|
64
|
+
/**
|
|
65
|
+
* @brief E2M3 magnitude LUT: 5-bit magnitude → unsigned value×16 (u8).
|
|
66
|
+
* Shared across scalar helper, packed kernel, and symmetric kernel.
|
|
67
|
+
*/
|
|
68
|
+
static nk_u8_t const nk_e2m3_magnitude_lut_rvv_[32] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20,
|
|
69
|
+
22, 24, 26, 28, 30, 32, 36, 40, 44, 48, 52,
|
|
70
|
+
56, 60, 64, 72, 80, 88, 96, 104, 112, 120};
|
|
71
|
+
|
|
72
|
+
/**
|
|
73
|
+
* @brief E3M2 magnitude LUT: 5-bit magnitude → unsigned value×16 (u16).
|
|
74
|
+
* Shared across scalar helper, packed kernel, and symmetric kernel.
|
|
75
|
+
*/
|
|
76
|
+
static nk_u16_t const nk_e3m2_magnitude_lut_rvv_[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12,
|
|
77
|
+
14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80,
|
|
78
|
+
96, 112, 128, 160, 192, 224, 256, 320, 384, 448};
|
|
79
|
+
|
|
80
|
+
#pragma region Single Precision Floats
|
|
81
|
+
|
|
82
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
83
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
84
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
85
|
+
// Break power-of-2 strides for cache associativity
|
|
86
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
87
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
88
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
89
|
+
column_count * sizeof(nk_f64_t); // per-column norms
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
NK_PUBLIC void nk_dots_pack_f32_rvv(nk_f32_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
93
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
94
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
95
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
96
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
97
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
98
|
+
|
|
99
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
100
|
+
header->column_count = (nk_u32_t)column_count;
|
|
101
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
102
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
103
|
+
|
|
104
|
+
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
105
|
+
nk_size_t total = column_count * depth_padded;
|
|
106
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
107
|
+
|
|
108
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
109
|
+
nk_f32_t const *src = (nk_f32_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
110
|
+
nk_f32_t *dst = packed + column * depth_padded;
|
|
111
|
+
for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
// Append per-column norms after packed data
|
|
115
|
+
nk_f64_t *norms = (nk_f64_t *)(packed + total);
|
|
116
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
117
|
+
nk_f32_t const *src = (nk_f32_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
118
|
+
norms[column] = nk_dots_reduce_sumsq_f32_(src, depth);
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
/**
|
|
123
|
+
* @brief f32 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
|
|
124
|
+
*
|
|
125
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
126
|
+
* acc_f64 = sum_k f64(a[row][k]) * f64(b_packed[column][k])
|
|
127
|
+
* using `vfwmacc_vv_f64m4` which widens both operands from f32m2 to f64m4.
|
|
128
|
+
*
|
|
129
|
+
* Register tile: process 4 rows per iteration (rows_per_tile=4).
|
|
130
|
+
* Each row loads its own A vector; B vector is shared across rows per depth chunk.
|
|
131
|
+
*/
|
|
132
|
+
NK_INTERNAL void nk_dots_packed_f32_rvv_aligned_(nk_f32_t const *a_matrix, void const *b_packed_buffer,
|
|
133
|
+
nk_f64_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
134
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
135
|
+
nk_size_t c_stride_in_bytes) {
|
|
136
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
137
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
138
|
+
nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
|
|
139
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
140
|
+
|
|
141
|
+
// Zero output matrix
|
|
142
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
143
|
+
nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
144
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// mr=4 register tile over rows
|
|
148
|
+
nk_size_t row = 0;
|
|
149
|
+
for (; row + 4 <= row_count; row += 4) {
|
|
150
|
+
nk_f32_t const *a_row_0 = (nk_f32_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
151
|
+
nk_f32_t const *a_row_1 = (nk_f32_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
152
|
+
nk_f32_t const *a_row_2 = (nk_f32_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
|
|
153
|
+
nk_f32_t const *a_row_3 = (nk_f32_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
|
|
154
|
+
nk_f64_t *c_row_0 = (nk_f64_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
155
|
+
nk_f64_t *c_row_1 = (nk_f64_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
156
|
+
nk_f64_t *c_row_2 = (nk_f64_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
|
|
157
|
+
nk_f64_t *c_row_3 = (nk_f64_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
|
|
158
|
+
|
|
159
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
160
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
161
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
162
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
163
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
164
|
+
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
165
|
+
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
166
|
+
|
|
167
|
+
nk_size_t remaining = depth;
|
|
168
|
+
nk_size_t k = 0;
|
|
169
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
170
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
171
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
172
|
+
vfloat32m2_t a_vector_0_f32m2 = __riscv_vle32_v_f32m2(a_row_0 + k, vector_length);
|
|
173
|
+
vfloat32m2_t a_vector_1_f32m2 = __riscv_vle32_v_f32m2(a_row_1 + k, vector_length);
|
|
174
|
+
vfloat32m2_t a_vector_2_f32m2 = __riscv_vle32_v_f32m2(a_row_2 + k, vector_length);
|
|
175
|
+
vfloat32m2_t a_vector_3_f32m2 = __riscv_vle32_v_f32m2(a_row_3 + k, vector_length);
|
|
176
|
+
accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
|
|
177
|
+
vector_length);
|
|
178
|
+
accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
|
|
179
|
+
vector_length);
|
|
180
|
+
accumulator_2_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_2_f64m4, a_vector_2_f32m2, b_vector_f32m2,
|
|
181
|
+
vector_length);
|
|
182
|
+
accumulator_3_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_3_f64m4, a_vector_3_f32m2, b_vector_f32m2,
|
|
183
|
+
vector_length);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// Horizontal reduce directly to f64
|
|
187
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
188
|
+
c_row_0[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
189
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
|
|
190
|
+
c_row_1[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
191
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
|
|
192
|
+
c_row_2[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
193
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
|
|
194
|
+
c_row_3[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
195
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
// Remainder rows (mr < 4)
|
|
199
|
+
for (; row < row_count; ++row) {
|
|
200
|
+
nk_f32_t const *a_row = (nk_f32_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
201
|
+
nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
202
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
203
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
204
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
205
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
206
|
+
nk_size_t remaining = depth;
|
|
207
|
+
nk_size_t k = 0;
|
|
208
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
209
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
210
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
211
|
+
vfloat32m2_t a_vector_f32m2 = __riscv_vle32_v_f32m2(a_row + k, vector_length);
|
|
212
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
|
|
213
|
+
vector_length);
|
|
214
|
+
}
|
|
215
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
216
|
+
c_row[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
217
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
/**
|
|
223
|
+
* @brief Public f32 packed GEMM wrapper matching the declared signature in dots.h.
|
|
224
|
+
*
|
|
225
|
+
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
226
|
+
* vectors naturally, so no separate edge kernel is needed.
|
|
227
|
+
*/
|
|
228
|
+
NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t m, nk_size_t n,
|
|
229
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
230
|
+
nk_dots_packed_f32_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
/**
|
|
234
|
+
* @brief Symmetric f32 GEMM: C = A * A^T, upper triangle + mirror.
|
|
235
|
+
*
|
|
236
|
+
* Uses f64 widened accumulation via `vfwmacc_vv_f64m4` for precision.
|
|
237
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
238
|
+
*/
|
|
239
|
+
NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
240
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
241
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
242
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
|
|
243
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
|
|
244
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
245
|
+
|
|
246
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
247
|
+
nk_f32_t const *a_i = vectors + i * stride_elements;
|
|
248
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
249
|
+
nk_f32_t const *a_j = vectors + j * stride_elements;
|
|
250
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
251
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
252
|
+
nk_size_t remaining = depth;
|
|
253
|
+
nk_size_t k = 0;
|
|
254
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
255
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
256
|
+
vfloat32m2_t a_vector_f32m2 = __riscv_vle32_v_f32m2(a_i + k, vector_length);
|
|
257
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(a_j + k, vector_length);
|
|
258
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
|
|
259
|
+
vector_length);
|
|
260
|
+
}
|
|
261
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
262
|
+
nk_f64_t dot = __riscv_vfmv_f_s_f64m1_f64(
|
|
263
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
264
|
+
result[i * result_stride_elements + j] = dot;
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
#pragma endregion // Single Precision Floats
|
|
270
|
+
|
|
271
|
+
#pragma region Double Precision Floats
|
|
272
|
+
|
|
273
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
274
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e64m4();
|
|
275
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
276
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f64_t);
|
|
277
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
278
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f64_t) +
|
|
279
|
+
column_count * sizeof(nk_f64_t); // per-column norms
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
NK_PUBLIC void nk_dots_pack_f64_rvv(nk_f64_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
283
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
284
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e64m4();
|
|
285
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
286
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f64_t);
|
|
287
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
288
|
+
|
|
289
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
290
|
+
header->column_count = (nk_u32_t)column_count;
|
|
291
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
292
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
293
|
+
|
|
294
|
+
nk_f64_t *packed = (nk_f64_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
295
|
+
nk_size_t total = column_count * depth_padded;
|
|
296
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
297
|
+
|
|
298
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
299
|
+
nk_f64_t const *src = (nk_f64_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
300
|
+
nk_f64_t *dst = packed + column * depth_padded;
|
|
301
|
+
for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// Append per-column norms after packed data
|
|
305
|
+
nk_f64_t *norms = (nk_f64_t *)(packed + total);
|
|
306
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
307
|
+
nk_f64_t const *src = (nk_f64_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
308
|
+
norms[column] = nk_dots_reduce_sumsq_f64_(src, depth);
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
/**
|
|
313
|
+
* @brief f64 packed GEMM kernel: C += A * B_packed^T with Kahan compensation.
|
|
314
|
+
*
|
|
315
|
+
* Vectorizes over depth dimension k using `vfmul`+Kahan (vector-vector multiply).
|
|
316
|
+
* Uses Kahan summation over full depth to maintain precision.
|
|
317
|
+
* Register tile: process 2 rows per iteration (rows_per_tile=2, budget: 32 regs at LMUL=4).
|
|
318
|
+
*/
|
|
319
|
+
NK_INTERNAL void nk_dots_packed_f64_rvv_aligned_(nk_f64_t const *a_matrix, void const *b_packed_buffer,
|
|
320
|
+
nk_f64_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
321
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
322
|
+
nk_size_t c_stride_in_bytes) {
|
|
323
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
324
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
325
|
+
nk_f64_t const *packed_data = (nk_f64_t const *)((char const *)b_packed_buffer +
|
|
326
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
327
|
+
|
|
328
|
+
// Zero output matrix
|
|
329
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
330
|
+
nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
331
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
// Process 2 rows per tile (rows_per_tile=2, tighter register budget for f64 at LMUL=4)
|
|
335
|
+
nk_size_t row = 0;
|
|
336
|
+
for (; row + 2 <= row_count; row += 2) {
|
|
337
|
+
nk_f64_t const *a_row_0 = (nk_f64_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
338
|
+
nk_f64_t const *a_row_1 = (nk_f64_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
339
|
+
nk_f64_t *c_row_0 = (nk_f64_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
340
|
+
nk_f64_t *c_row_1 = (nk_f64_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
341
|
+
|
|
342
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
343
|
+
nk_f64_t const *b_column = packed_data + column * depth_padded;
|
|
344
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
345
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
346
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
347
|
+
vfloat64m4_t compensation_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
348
|
+
vfloat64m4_t compensation_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
349
|
+
|
|
350
|
+
nk_size_t remaining = depth;
|
|
351
|
+
nk_size_t k = 0;
|
|
352
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
353
|
+
vector_length = __riscv_vsetvl_e64m4(remaining);
|
|
354
|
+
vfloat64m4_t b_vector_f64m4 = __riscv_vle64_v_f64m4(b_column + k, vector_length);
|
|
355
|
+
vfloat64m4_t a_vector_0_f64m4 = __riscv_vle64_v_f64m4(a_row_0 + k, vector_length);
|
|
356
|
+
vfloat64m4_t a_vector_1_f64m4 = __riscv_vle64_v_f64m4(a_row_1 + k, vector_length);
|
|
357
|
+
|
|
358
|
+
// Kahan step for row 0: product = a*b; corrected = product - comp; running = acc + corrected; comp =
|
|
359
|
+
// (running - acc) - corrected; acc = running
|
|
360
|
+
vfloat64m4_t product_0_f64m4 = __riscv_vfmul_vv_f64m4(a_vector_0_f64m4, b_vector_f64m4, vector_length);
|
|
361
|
+
vfloat64m4_t corrected_term_0_f64m4 = __riscv_vfsub_vv_f64m4(product_0_f64m4, compensation_0_f64m4,
|
|
362
|
+
vector_length);
|
|
363
|
+
vfloat64m4_t running_sum_0_f64m4 = __riscv_vfadd_vv_f64m4_tu(accumulator_0_f64m4, accumulator_0_f64m4,
|
|
364
|
+
corrected_term_0_f64m4, vector_length);
|
|
365
|
+
compensation_0_f64m4 = __riscv_vfsub_vv_f64m4_tu(
|
|
366
|
+
compensation_0_f64m4,
|
|
367
|
+
__riscv_vfsub_vv_f64m4(running_sum_0_f64m4, accumulator_0_f64m4, vector_length),
|
|
368
|
+
corrected_term_0_f64m4, vector_length);
|
|
369
|
+
accumulator_0_f64m4 = running_sum_0_f64m4;
|
|
370
|
+
|
|
371
|
+
// Kahan step for row 1
|
|
372
|
+
vfloat64m4_t product_1_f64m4 = __riscv_vfmul_vv_f64m4(a_vector_1_f64m4, b_vector_f64m4, vector_length);
|
|
373
|
+
vfloat64m4_t corrected_term_1_f64m4 = __riscv_vfsub_vv_f64m4(product_1_f64m4, compensation_1_f64m4,
|
|
374
|
+
vector_length);
|
|
375
|
+
vfloat64m4_t running_sum_1_f64m4 = __riscv_vfadd_vv_f64m4_tu(accumulator_1_f64m4, accumulator_1_f64m4,
|
|
376
|
+
corrected_term_1_f64m4, vector_length);
|
|
377
|
+
compensation_1_f64m4 = __riscv_vfsub_vv_f64m4_tu(
|
|
378
|
+
compensation_1_f64m4,
|
|
379
|
+
__riscv_vfsub_vv_f64m4(running_sum_1_f64m4, accumulator_1_f64m4, vector_length),
|
|
380
|
+
corrected_term_1_f64m4, vector_length);
|
|
381
|
+
accumulator_1_f64m4 = running_sum_1_f64m4;
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
// Horizontal reduce
|
|
385
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
386
|
+
c_row_0[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
387
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
|
|
388
|
+
c_row_1[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
389
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
// Remainder rows
|
|
393
|
+
for (; row < row_count; ++row) {
|
|
394
|
+
nk_f64_t const *a_row = (nk_f64_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
395
|
+
nk_f64_t *c_row = (nk_f64_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
396
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
397
|
+
nk_f64_t const *b_column = packed_data + column * depth_padded;
|
|
398
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
399
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
400
|
+
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
401
|
+
|
|
402
|
+
nk_size_t remaining = depth;
|
|
403
|
+
nk_size_t k = 0;
|
|
404
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
405
|
+
vector_length = __riscv_vsetvl_e64m4(remaining);
|
|
406
|
+
vfloat64m4_t b_vector_f64m4 = __riscv_vle64_v_f64m4(b_column + k, vector_length);
|
|
407
|
+
vfloat64m4_t a_vector_f64m4 = __riscv_vle64_v_f64m4(a_row + k, vector_length);
|
|
408
|
+
|
|
409
|
+
vfloat64m4_t product_f64m4 = __riscv_vfmul_vv_f64m4(a_vector_f64m4, b_vector_f64m4, vector_length);
|
|
410
|
+
vfloat64m4_t corrected_term_f64m4 = __riscv_vfsub_vv_f64m4(product_f64m4, compensation_f64m4,
|
|
411
|
+
vector_length);
|
|
412
|
+
vfloat64m4_t running_sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(accumulator_f64m4, accumulator_f64m4,
|
|
413
|
+
corrected_term_f64m4, vector_length);
|
|
414
|
+
compensation_f64m4 = __riscv_vfsub_vv_f64m4_tu(
|
|
415
|
+
compensation_f64m4, __riscv_vfsub_vv_f64m4(running_sum_f64m4, accumulator_f64m4, vector_length),
|
|
416
|
+
corrected_term_f64m4, vector_length);
|
|
417
|
+
accumulator_f64m4 = running_sum_f64m4;
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
421
|
+
c_row[column] = __riscv_vfmv_f_s_f64m1_f64(
|
|
422
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
/**
|
|
428
|
+
* @brief Public f64 packed GEMM wrapper matching the declared signature in dots.h.
|
|
429
|
+
*/
|
|
430
|
+
NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t m, nk_size_t n,
|
|
431
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
432
|
+
nk_dots_packed_f64_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
/**
|
|
436
|
+
* @brief Symmetric f64 GEMM: C = A * A^T, upper triangle + mirror.
|
|
437
|
+
*
|
|
438
|
+
* Uses Kahan compensation over full depth for precision.
|
|
439
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
440
|
+
*/
|
|
441
|
+
NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
442
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
443
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
444
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
|
|
445
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
|
|
446
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
447
|
+
|
|
448
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
449
|
+
nk_f64_t const *a_i = vectors + i * stride_elements;
|
|
450
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
451
|
+
nk_f64_t const *a_j = vectors + j * stride_elements;
|
|
452
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
453
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
454
|
+
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
455
|
+
|
|
456
|
+
nk_size_t remaining = depth;
|
|
457
|
+
nk_size_t k = 0;
|
|
458
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
459
|
+
vector_length = __riscv_vsetvl_e64m4(remaining);
|
|
460
|
+
vfloat64m4_t a_vector_f64m4 = __riscv_vle64_v_f64m4(a_i + k, vector_length);
|
|
461
|
+
vfloat64m4_t b_vector_f64m4 = __riscv_vle64_v_f64m4(a_j + k, vector_length);
|
|
462
|
+
|
|
463
|
+
vfloat64m4_t product_f64m4 = __riscv_vfmul_vv_f64m4(a_vector_f64m4, b_vector_f64m4, vector_length);
|
|
464
|
+
vfloat64m4_t corrected_term_f64m4 = __riscv_vfsub_vv_f64m4(product_f64m4, compensation_f64m4,
|
|
465
|
+
vector_length);
|
|
466
|
+
vfloat64m4_t running_sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(accumulator_f64m4, accumulator_f64m4,
|
|
467
|
+
corrected_term_f64m4, vector_length);
|
|
468
|
+
compensation_f64m4 = __riscv_vfsub_vv_f64m4_tu(
|
|
469
|
+
compensation_f64m4, __riscv_vfsub_vv_f64m4(running_sum_f64m4, accumulator_f64m4, vector_length),
|
|
470
|
+
corrected_term_f64m4, vector_length);
|
|
471
|
+
accumulator_f64m4 = running_sum_f64m4;
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
475
|
+
nk_f64_t dot = __riscv_vfmv_f_s_f64m1_f64(
|
|
476
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
477
|
+
result[i * result_stride_elements + j] = dot;
|
|
478
|
+
}
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
#pragma endregion // Double Precision Floats
|
|
483
|
+
|
|
484
|
+
#pragma region Micro Precision E2M3
|
|
485
|
+
|
|
486
|
+
/**
|
|
487
|
+
* @brief Scalar conversion helper: e2m3 byte → signed i8 (value × 16).
|
|
488
|
+
*
|
|
489
|
+
* Extracts 5-bit magnitude, looks up in LUT, applies sign from bit 5.
|
|
490
|
+
* Every e2m3 value × 16 is an exact integer in [-120, +120], fitting in i8.
|
|
491
|
+
*/
|
|
492
|
+
NK_INTERNAL nk_i8_t nk_e2m3_to_i8_rvv_(nk_u8_t raw) {
|
|
493
|
+
nk_u8_t magnitude = raw & 0x1Fu;
|
|
494
|
+
nk_i8_t val = (nk_i8_t)nk_e2m3_magnitude_lut_rvv_[magnitude];
|
|
495
|
+
return (raw & 0x20u) ? (nk_i8_t)(-val) : val;
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
499
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
|
|
500
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
501
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
|
|
502
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
503
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i8_t) +
|
|
504
|
+
column_count * sizeof(nk_f32_t); // per-column norms
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
/**
|
|
508
|
+
* @brief Pack B matrix from e2m3 to signed i8 (value × 16) for integer dot product.
|
|
509
|
+
*
|
|
510
|
+
* Each e2m3 byte is converted to a signed i8 via scalar LUT lookup.
|
|
511
|
+
* Padding values are zeroed. Column-panel layout with depth-contiguous storage.
|
|
512
|
+
*/
|
|
513
|
+
NK_PUBLIC void nk_dots_pack_e2m3_rvv(nk_e2m3_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
514
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
515
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
|
|
516
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
517
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
|
|
518
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
519
|
+
|
|
520
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
521
|
+
header->column_count = (nk_u32_t)column_count;
|
|
522
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
523
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
524
|
+
|
|
525
|
+
nk_i8_t *packed = (nk_i8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
526
|
+
nk_size_t total = column_count * depth_padded;
|
|
527
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
528
|
+
|
|
529
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
530
|
+
nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
531
|
+
nk_i8_t *dst = packed + column * depth_padded;
|
|
532
|
+
for (nk_size_t k = 0; k < depth; ++k) dst[k] = nk_e2m3_to_i8_rvv_(src[k]);
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
// Append per-column norms after packed data
|
|
536
|
+
nk_f32_t *norms = (nk_f32_t *)(packed + total);
|
|
537
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
538
|
+
nk_e2m3_t const *src = (nk_e2m3_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
539
|
+
norms[column] = nk_dots_reduce_sumsq_e2m3_(src, depth);
|
|
540
|
+
}
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
/**
|
|
544
|
+
* @brief e2m3 packed GEMM kernel: C += A * B_packed^T with integer i8 LUT arithmetic.
|
|
545
|
+
*
|
|
546
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
547
|
+
* - Load raw e2m3 bytes from A, extract magnitude via `vluxei8` gather LUT
|
|
548
|
+
* - Apply sign from bit 5 via masked negate to produce signed i8 A values
|
|
549
|
+
* - Load pre-packed signed i8 values from B
|
|
550
|
+
* - Widening multiply i8×i8 → i16, then widen-accumulate i32 += i16
|
|
551
|
+
* - Final result = i32_sum / 256.0f
|
|
552
|
+
*
|
|
553
|
+
* Register tile: process 4 rows per iteration (rows_per_tile=4).
|
|
554
|
+
* The LUT gather on A magnitudes uses `vluxei8_v_u8m1` (byte-indexed byte gather).
|
|
555
|
+
*/
|
|
556
|
+
NK_INTERNAL void nk_dots_packed_e2m3_rvv_aligned_(nk_e2m3_t const *a_matrix, void const *b_packed_buffer,
|
|
557
|
+
nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
558
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
559
|
+
nk_size_t c_stride_in_bytes) {
|
|
560
|
+
nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
|
|
561
|
+
|
|
562
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
563
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
564
|
+
nk_i8_t const *packed_data = (nk_i8_t const *)((char const *)b_packed_buffer +
|
|
565
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
566
|
+
|
|
567
|
+
// Zero output matrix
|
|
568
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
569
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
570
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
// mr=4 register tile over rows
|
|
574
|
+
nk_size_t row = 0;
|
|
575
|
+
for (; row + 4 <= row_count; row += 4) {
|
|
576
|
+
nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
577
|
+
nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
578
|
+
nk_u8_t const *a_row_2 = (nk_u8_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
|
|
579
|
+
nk_u8_t const *a_row_3 = (nk_u8_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
|
|
580
|
+
nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
581
|
+
nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
582
|
+
nk_f32_t *c_row_2 = (nk_f32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
|
|
583
|
+
nk_f32_t *c_row_3 = (nk_f32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
|
|
584
|
+
|
|
585
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
586
|
+
nk_i8_t const *b_column = packed_data + column * depth_padded;
|
|
587
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
588
|
+
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
589
|
+
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
590
|
+
vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
591
|
+
vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
592
|
+
|
|
593
|
+
nk_size_t remaining = depth;
|
|
594
|
+
nk_size_t k = 0;
|
|
595
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
596
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
597
|
+
|
|
598
|
+
// Load pre-packed i8 B values
|
|
599
|
+
vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(b_column + k, vector_length);
|
|
600
|
+
|
|
601
|
+
// Load raw e2m3 bytes from each A row and convert via LUT
|
|
602
|
+
vuint8m1_t raw0_u8m1 = __riscv_vle8_v_u8m1(a_row_0 + k, vector_length);
|
|
603
|
+
vuint8m1_t raw1_u8m1 = __riscv_vle8_v_u8m1(a_row_1 + k, vector_length);
|
|
604
|
+
vuint8m1_t raw2_u8m1 = __riscv_vle8_v_u8m1(a_row_2 + k, vector_length);
|
|
605
|
+
vuint8m1_t raw3_u8m1 = __riscv_vle8_v_u8m1(a_row_3 + k, vector_length);
|
|
606
|
+
|
|
607
|
+
// Extract magnitudes and gather from LUT
|
|
608
|
+
vuint8m1_t mag0_u8m1 = __riscv_vand_vx_u8m1(raw0_u8m1, 0x1F, vector_length);
|
|
609
|
+
vuint8m1_t mag1_u8m1 = __riscv_vand_vx_u8m1(raw1_u8m1, 0x1F, vector_length);
|
|
610
|
+
vuint8m1_t mag2_u8m1 = __riscv_vand_vx_u8m1(raw2_u8m1, 0x1F, vector_length);
|
|
611
|
+
vuint8m1_t mag3_u8m1 = __riscv_vand_vx_u8m1(raw3_u8m1, 0x1F, vector_length);
|
|
612
|
+
vuint8m1_t uval0_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag0_u8m1, vector_length);
|
|
613
|
+
vuint8m1_t uval1_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag1_u8m1, vector_length);
|
|
614
|
+
vuint8m1_t uval2_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag2_u8m1, vector_length);
|
|
615
|
+
vuint8m1_t uval3_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag3_u8m1, vector_length);
|
|
616
|
+
|
|
617
|
+
// Apply sign to A: negate where bit 5 is set.
|
|
618
|
+
// B is already signed from packing, so A sign completes the product sign.
|
|
619
|
+
vint8m1_t a_vector_0_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval0_u8m1);
|
|
620
|
+
vbool8_t negated_0_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw0_u8m1, 0x20, vector_length),
|
|
621
|
+
0, vector_length);
|
|
622
|
+
a_vector_0_i8m1 = __riscv_vneg_v_i8m1_mu(negated_0_b8, a_vector_0_i8m1, a_vector_0_i8m1, vector_length);
|
|
623
|
+
|
|
624
|
+
vint8m1_t a_vector_1_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval1_u8m1);
|
|
625
|
+
vbool8_t negated_1_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw1_u8m1, 0x20, vector_length),
|
|
626
|
+
0, vector_length);
|
|
627
|
+
a_vector_1_i8m1 = __riscv_vneg_v_i8m1_mu(negated_1_b8, a_vector_1_i8m1, a_vector_1_i8m1, vector_length);
|
|
628
|
+
|
|
629
|
+
vint8m1_t a_vector_2_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval2_u8m1);
|
|
630
|
+
vbool8_t negated_2_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw2_u8m1, 0x20, vector_length),
|
|
631
|
+
0, vector_length);
|
|
632
|
+
a_vector_2_i8m1 = __riscv_vneg_v_i8m1_mu(negated_2_b8, a_vector_2_i8m1, a_vector_2_i8m1, vector_length);
|
|
633
|
+
|
|
634
|
+
vint8m1_t a_vector_3_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval3_u8m1);
|
|
635
|
+
vbool8_t negated_3_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw3_u8m1, 0x20, vector_length),
|
|
636
|
+
0, vector_length);
|
|
637
|
+
a_vector_3_i8m1 = __riscv_vneg_v_i8m1_mu(negated_3_b8, a_vector_3_i8m1, a_vector_3_i8m1, vector_length);
|
|
638
|
+
|
|
639
|
+
// Widening multiply: i8×i8 → i16, then accumulate: i32 += i16
|
|
640
|
+
vint16m2_t product_0_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_0_i8m1, b_vector_i8m1, vector_length);
|
|
641
|
+
vint16m2_t product_1_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_1_i8m1, b_vector_i8m1, vector_length);
|
|
642
|
+
vint16m2_t product_2_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_2_i8m1, b_vector_i8m1, vector_length);
|
|
643
|
+
vint16m2_t product_3_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_3_i8m1, b_vector_i8m1, vector_length);
|
|
644
|
+
accumulator_0_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_0_i32m4, accumulator_0_i32m4,
|
|
645
|
+
product_0_i16m2, vector_length);
|
|
646
|
+
accumulator_1_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_1_i32m4, accumulator_1_i32m4,
|
|
647
|
+
product_1_i16m2, vector_length);
|
|
648
|
+
accumulator_2_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_2_i32m4, accumulator_2_i32m4,
|
|
649
|
+
product_2_i16m2, vector_length);
|
|
650
|
+
accumulator_3_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_3_i32m4, accumulator_3_i32m4,
|
|
651
|
+
product_3_i16m2, vector_length);
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
// Horizontal reduce and convert to f32 with scaling
|
|
655
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
656
|
+
c_row_0[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
657
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax)) *
|
|
658
|
+
lut_scale_reciprocal;
|
|
659
|
+
c_row_1[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
660
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax)) *
|
|
661
|
+
lut_scale_reciprocal;
|
|
662
|
+
c_row_2[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
663
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, vlmax)) *
|
|
664
|
+
lut_scale_reciprocal;
|
|
665
|
+
c_row_3[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
666
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, vlmax)) *
|
|
667
|
+
lut_scale_reciprocal;
|
|
668
|
+
}
|
|
669
|
+
}
|
|
670
|
+
// Remainder rows (mr < 4)
|
|
671
|
+
for (; row < row_count; ++row) {
|
|
672
|
+
nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
673
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
674
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
675
|
+
nk_i8_t const *b_column = packed_data + column * depth_padded;
|
|
676
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
677
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
678
|
+
nk_size_t remaining = depth;
|
|
679
|
+
nk_size_t k = 0;
|
|
680
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
681
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
682
|
+
vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(b_column + k, vector_length);
|
|
683
|
+
vuint8m1_t raw_a_u8m1 = __riscv_vle8_v_u8m1(a_row + k, vector_length);
|
|
684
|
+
vuint8m1_t mag_a_u8m1 = __riscv_vand_vx_u8m1(raw_a_u8m1, 0x1F, vector_length);
|
|
685
|
+
vuint8m1_t uval_a_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag_a_u8m1, vector_length);
|
|
686
|
+
vint8m1_t a_vector_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval_a_u8m1);
|
|
687
|
+
vbool8_t negated_a_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw_a_u8m1, 0x20, vector_length),
|
|
688
|
+
0, vector_length);
|
|
689
|
+
a_vector_i8m1 = __riscv_vneg_v_i8m1_mu(negated_a_b8, a_vector_i8m1, a_vector_i8m1, vector_length);
|
|
690
|
+
vint16m2_t product_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_i8m1, b_vector_i8m1, vector_length);
|
|
691
|
+
accumulator_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_i32m4, accumulator_i32m4, product_i16m2,
|
|
692
|
+
vector_length);
|
|
693
|
+
}
|
|
694
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
695
|
+
c_row[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
696
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
|
|
697
|
+
lut_scale_reciprocal;
|
|
698
|
+
}
|
|
699
|
+
}
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
/**
|
|
703
|
+
* @brief Public e2m3 packed GEMM wrapper matching the declared signature in dots.h.
|
|
704
|
+
*/
|
|
705
|
+
NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
|
|
706
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
707
|
+
nk_dots_packed_e2m3_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
/**
|
|
711
|
+
* @brief Symmetric e2m3 GEMM: C = A * A^T, upper triangle + mirror.
|
|
712
|
+
*
|
|
713
|
+
* Uses integer i8 LUT arithmetic with i32 accumulation, scaled by 1/256.
|
|
714
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
715
|
+
*/
|
|
716
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
717
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
718
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
719
|
+
nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
|
|
720
|
+
|
|
721
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
722
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
723
|
+
|
|
724
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
725
|
+
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
|
|
726
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
727
|
+
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
|
|
728
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
729
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
730
|
+
nk_size_t remaining = depth;
|
|
731
|
+
nk_size_t k = 0;
|
|
732
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
733
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
734
|
+
vuint8m1_t raw_i_u8m1 = __riscv_vle8_v_u8m1(a_i + k, vector_length);
|
|
735
|
+
vuint8m1_t raw_j_u8m1 = __riscv_vle8_v_u8m1(a_j + k, vector_length);
|
|
736
|
+
|
|
737
|
+
// Extract magnitudes and gather from LUT
|
|
738
|
+
vuint8m1_t mag_i_u8m1 = __riscv_vand_vx_u8m1(raw_i_u8m1, 0x1F, vector_length);
|
|
739
|
+
vuint8m1_t mag_j_u8m1 = __riscv_vand_vx_u8m1(raw_j_u8m1, 0x1F, vector_length);
|
|
740
|
+
vuint8m1_t uval_i_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag_i_u8m1, vector_length);
|
|
741
|
+
vuint8m1_t uval_j_u8m1 = __riscv_vluxei8_v_u8m1(nk_e2m3_magnitude_lut_rvv_, mag_j_u8m1, vector_length);
|
|
742
|
+
|
|
743
|
+
// Combined sign: XOR sign bits → conditional negate on B side
|
|
744
|
+
vuint8m1_t sign_xor_u8m1 = __riscv_vand_vx_u8m1(
|
|
745
|
+
__riscv_vxor_vv_u8m1(raw_i_u8m1, raw_j_u8m1, vector_length), 0x20, vector_length);
|
|
746
|
+
vbool8_t negate_b8 = __riscv_vmsne_vx_u8m1_b8(sign_xor_u8m1, 0, vector_length);
|
|
747
|
+
vint8m1_t val_i_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval_i_u8m1);
|
|
748
|
+
vint8m1_t val_j_i8m1 = __riscv_vreinterpret_v_u8m1_i8m1(uval_j_u8m1);
|
|
749
|
+
val_j_i8m1 = __riscv_vneg_v_i8m1_mu(negate_b8, val_j_i8m1, val_j_i8m1, vector_length);
|
|
750
|
+
|
|
751
|
+
// Widening multiply: i8×i8 → i16, then accumulate: i32 += i16
|
|
752
|
+
vint16m2_t product_i16m2 = __riscv_vwmul_vv_i16m2(val_i_i8m1, val_j_i8m1, vector_length);
|
|
753
|
+
accumulator_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_i32m4, accumulator_i32m4, product_i16m2,
|
|
754
|
+
vector_length);
|
|
755
|
+
}
|
|
756
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
757
|
+
nk_f32_t dot = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
758
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
|
|
759
|
+
lut_scale_reciprocal;
|
|
760
|
+
result[i * result_stride_elements + j] = dot;
|
|
761
|
+
}
|
|
762
|
+
}
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
#pragma endregion // Micro Precision E2M3
|
|
766
|
+
|
|
767
|
+
#pragma region Micro Precision E3M2
|
|
768
|
+
|
|
769
|
+
/**
|
|
770
|
+
* @brief Scalar conversion helper: e3m2 byte → signed i16 (value × 16).
|
|
771
|
+
*
|
|
772
|
+
* Extracts 5-bit magnitude, looks up in LUT, applies sign from bit 5.
|
|
773
|
+
* Every e3m2 value × 16 is an exact integer in [-448, +448], requiring i16.
|
|
774
|
+
*/
|
|
775
|
+
NK_INTERNAL nk_i16_t nk_e3m2_to_i16_rvv_(nk_u8_t raw) {
|
|
776
|
+
nk_u8_t magnitude = raw & 0x1Fu;
|
|
777
|
+
nk_i16_t val = (nk_i16_t)nk_e3m2_magnitude_lut_rvv_[magnitude];
|
|
778
|
+
return (raw & 0x20u) ? (nk_i16_t)(-val) : val;
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
782
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e16m2();
|
|
783
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
784
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_i16_t);
|
|
785
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
786
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i16_t) +
|
|
787
|
+
column_count * sizeof(nk_f32_t); // per-column norms
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
/**
|
|
791
|
+
* @brief Pack B matrix from e3m2 to signed i16 (value × 16) for integer dot product.
|
|
792
|
+
*
|
|
793
|
+
* Each e3m2 byte is converted to a signed i16 via scalar LUT lookup.
|
|
794
|
+
* Padding values are zeroed. Column-panel layout with depth-contiguous storage.
|
|
795
|
+
*/
|
|
796
|
+
NK_PUBLIC void nk_dots_pack_e3m2_rvv(nk_e3m2_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
797
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
798
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e16m2();
|
|
799
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
800
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_i16_t);
|
|
801
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
802
|
+
|
|
803
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
804
|
+
header->column_count = (nk_u32_t)column_count;
|
|
805
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
806
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
807
|
+
|
|
808
|
+
nk_i16_t *packed = (nk_i16_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
809
|
+
nk_size_t total = column_count * depth_padded;
|
|
810
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
811
|
+
|
|
812
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
813
|
+
nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
814
|
+
nk_i16_t *dst = packed + column * depth_padded;
|
|
815
|
+
for (nk_size_t k = 0; k < depth; ++k) dst[k] = nk_e3m2_to_i16_rvv_(src[k]);
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
// Append per-column norms after packed data
|
|
819
|
+
nk_f32_t *norms = (nk_f32_t *)(packed + total);
|
|
820
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
821
|
+
nk_e3m2_t const *src = (nk_e3m2_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
822
|
+
norms[column] = nk_dots_reduce_sumsq_e3m2_(src, depth);
|
|
823
|
+
}
|
|
824
|
+
}
|
|
825
|
+
|
|
826
|
+
/**
|
|
827
|
+
* @brief e3m2 packed GEMM kernel: C += A * B_packed^T with integer i16 LUT arithmetic.
|
|
828
|
+
*
|
|
829
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
830
|
+
* - Load raw e3m2 bytes from A, convert to signed i16 via `vluxei16` gather LUT
|
|
831
|
+
* - Load pre-packed i16 values from B
|
|
832
|
+
* - Widening multiply-accumulate: i16×i16 → i32 via `vwmacc`
|
|
833
|
+
* - Final result = i32_sum / 256.0f
|
|
834
|
+
*
|
|
835
|
+
* Register tile: process 2 rows per iteration (rows_per_tile=2, wider i16/i32 elements reduce VL).
|
|
836
|
+
* The LUT gather on A magnitudes uses `vluxei16_v_u16m2` (16-bit indexed 16-bit gather).
|
|
837
|
+
*/
|
|
838
|
+
NK_INTERNAL void nk_dots_packed_e3m2_rvv_aligned_(nk_e3m2_t const *a_matrix, void const *b_packed_buffer,
|
|
839
|
+
nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
840
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
841
|
+
nk_size_t c_stride_in_bytes) {
|
|
842
|
+
nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
|
|
843
|
+
|
|
844
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
845
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
846
|
+
nk_i16_t const *packed_data = (nk_i16_t const *)((char const *)b_packed_buffer +
|
|
847
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
848
|
+
|
|
849
|
+
// Zero output matrix
|
|
850
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
851
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
852
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
// mr=2 register tile (i16 at LMUL=2 and i32 at LMUL=4 leaves fewer spare registers)
|
|
856
|
+
nk_size_t row = 0;
|
|
857
|
+
for (; row + 2 <= row_count; row += 2) {
|
|
858
|
+
nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
859
|
+
nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
860
|
+
nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
861
|
+
nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
862
|
+
|
|
863
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
864
|
+
nk_i16_t const *b_column = packed_data + column * depth_padded;
|
|
865
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
866
|
+
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
867
|
+
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
868
|
+
|
|
869
|
+
nk_size_t remaining = depth;
|
|
870
|
+
nk_size_t k = 0;
|
|
871
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
872
|
+
vector_length = __riscv_vsetvl_e16m2(remaining);
|
|
873
|
+
|
|
874
|
+
// Load pre-packed i16 B values
|
|
875
|
+
vint16m2_t b_vector_i16m2 = __riscv_vle16_v_i16m2(b_column + k, vector_length);
|
|
876
|
+
|
|
877
|
+
// Load raw e3m2 bytes from each A row
|
|
878
|
+
vuint8m1_t raw0_u8m1 = __riscv_vle8_v_u8m1(a_row_0 + k, vector_length);
|
|
879
|
+
vuint8m1_t raw1_u8m1 = __riscv_vle8_v_u8m1(a_row_1 + k, vector_length);
|
|
880
|
+
|
|
881
|
+
// Extract magnitudes, zero-extend to u16, compute byte offsets for i16 LUT gather
|
|
882
|
+
vuint8m1_t mag0_u8m1 = __riscv_vand_vx_u8m1(raw0_u8m1, 0x1F, vector_length);
|
|
883
|
+
vuint8m1_t mag1_u8m1 = __riscv_vand_vx_u8m1(raw1_u8m1, 0x1F, vector_length);
|
|
884
|
+
vuint16m2_t idx0_u16m2 = __riscv_vzext_vf2_u16m2(mag0_u8m1, vector_length);
|
|
885
|
+
vuint16m2_t idx1_u16m2 = __riscv_vzext_vf2_u16m2(mag1_u8m1, vector_length);
|
|
886
|
+
vuint16m2_t off0_u16m2 = __riscv_vsll_vx_u16m2(idx0_u16m2, 1,
|
|
887
|
+
vector_length); // byte offsets = index × 2
|
|
888
|
+
vuint16m2_t off1_u16m2 = __riscv_vsll_vx_u16m2(idx1_u16m2, 1, vector_length);
|
|
889
|
+
|
|
890
|
+
// Gather unsigned magnitudes from i16 LUT
|
|
891
|
+
vuint16m2_t uval0_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off0_u16m2,
|
|
892
|
+
vector_length);
|
|
893
|
+
vuint16m2_t uval1_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off1_u16m2,
|
|
894
|
+
vector_length);
|
|
895
|
+
|
|
896
|
+
// Apply sign: negate where bit 5 is set
|
|
897
|
+
vuint8m1_t sign0_u8m1 = __riscv_vand_vx_u8m1(raw0_u8m1, 0x20, vector_length);
|
|
898
|
+
vuint8m1_t sign1_u8m1 = __riscv_vand_vx_u8m1(raw1_u8m1, 0x20, vector_length);
|
|
899
|
+
vbool8_t negated_0_b8 = __riscv_vmsne_vx_u8m1_b8(sign0_u8m1, 0, vector_length);
|
|
900
|
+
vbool8_t negated_1_b8 = __riscv_vmsne_vx_u8m1_b8(sign1_u8m1, 0, vector_length);
|
|
901
|
+
|
|
902
|
+
vint16m2_t a_vector_0_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval0_u16m2);
|
|
903
|
+
a_vector_0_i16m2 = __riscv_vneg_v_i16m2_mu(negated_0_b8, a_vector_0_i16m2, a_vector_0_i16m2,
|
|
904
|
+
vector_length);
|
|
905
|
+
vint16m2_t a_vector_1_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval1_u16m2);
|
|
906
|
+
a_vector_1_i16m2 = __riscv_vneg_v_i16m2_mu(negated_1_b8, a_vector_1_i16m2, a_vector_1_i16m2,
|
|
907
|
+
vector_length);
|
|
908
|
+
|
|
909
|
+
// Widening multiply-accumulate: i16×i16 → i32
|
|
910
|
+
accumulator_0_i32m4 = __riscv_vwmacc_vv_i32m4_tu(accumulator_0_i32m4, a_vector_0_i16m2, b_vector_i16m2,
|
|
911
|
+
vector_length);
|
|
912
|
+
accumulator_1_i32m4 = __riscv_vwmacc_vv_i32m4_tu(accumulator_1_i32m4, a_vector_1_i16m2, b_vector_i16m2,
|
|
913
|
+
vector_length);
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
// Horizontal reduce and convert to f32 with scaling
|
|
917
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
918
|
+
c_row_0[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
919
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax)) *
|
|
920
|
+
lut_scale_reciprocal;
|
|
921
|
+
c_row_1[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
922
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax)) *
|
|
923
|
+
lut_scale_reciprocal;
|
|
924
|
+
}
|
|
925
|
+
}
|
|
926
|
+
// Remainder rows
|
|
927
|
+
for (; row < row_count; ++row) {
|
|
928
|
+
nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
929
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
930
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
931
|
+
nk_i16_t const *b_column = packed_data + column * depth_padded;
|
|
932
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
933
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
934
|
+
nk_size_t remaining = depth;
|
|
935
|
+
nk_size_t k = 0;
|
|
936
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
937
|
+
vector_length = __riscv_vsetvl_e16m2(remaining);
|
|
938
|
+
vint16m2_t b_vector_i16m2 = __riscv_vle16_v_i16m2(b_column + k, vector_length);
|
|
939
|
+
vuint8m1_t raw_a_u8m1 = __riscv_vle8_v_u8m1(a_row + k, vector_length);
|
|
940
|
+
vuint8m1_t mag_a_u8m1 = __riscv_vand_vx_u8m1(raw_a_u8m1, 0x1F, vector_length);
|
|
941
|
+
vuint16m2_t idx_a_u16m2 = __riscv_vzext_vf2_u16m2(mag_a_u8m1, vector_length);
|
|
942
|
+
vuint16m2_t off_a_u16m2 = __riscv_vsll_vx_u16m2(idx_a_u16m2, 1, vector_length);
|
|
943
|
+
vuint16m2_t uval_a_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off_a_u16m2,
|
|
944
|
+
vector_length);
|
|
945
|
+
vint16m2_t a_vector_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval_a_u16m2);
|
|
946
|
+
vbool8_t negated_a_b8 = __riscv_vmsne_vx_u8m1_b8(__riscv_vand_vx_u8m1(raw_a_u8m1, 0x20, vector_length),
|
|
947
|
+
0, vector_length);
|
|
948
|
+
a_vector_i16m2 = __riscv_vneg_v_i16m2_mu(negated_a_b8, a_vector_i16m2, a_vector_i16m2, vector_length);
|
|
949
|
+
accumulator_i32m4 = __riscv_vwmacc_vv_i32m4_tu(accumulator_i32m4, a_vector_i16m2, b_vector_i16m2,
|
|
950
|
+
vector_length);
|
|
951
|
+
}
|
|
952
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
953
|
+
c_row[column] = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
954
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
|
|
955
|
+
lut_scale_reciprocal;
|
|
956
|
+
}
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
/**
|
|
961
|
+
* @brief Public e3m2 packed GEMM wrapper matching the declared signature in dots.h.
|
|
962
|
+
*/
|
|
963
|
+
NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
|
|
964
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
965
|
+
nk_dots_packed_e3m2_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
966
|
+
}
|
|
967
|
+
|
|
968
|
+
/**
|
|
969
|
+
* @brief Symmetric e3m2 GEMM: C = A * A^T, upper triangle + mirror.
|
|
970
|
+
*
|
|
971
|
+
* Uses integer i16 LUT arithmetic with i32 widening MAC, scaled by 1/256.
|
|
972
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
973
|
+
*/
|
|
974
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
975
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
976
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
977
|
+
nk_f32_t const lut_scale_reciprocal = 1.0f / 256.0f;
|
|
978
|
+
|
|
979
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
980
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
981
|
+
|
|
982
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
983
|
+
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
|
|
984
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
985
|
+
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
|
|
986
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
987
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
988
|
+
nk_size_t remaining = depth;
|
|
989
|
+
nk_size_t k = 0;
|
|
990
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
991
|
+
vector_length = __riscv_vsetvl_e16m2(remaining);
|
|
992
|
+
vuint8m1_t raw_i_u8m1 = __riscv_vle8_v_u8m1(a_i + k, vector_length);
|
|
993
|
+
vuint8m1_t raw_j_u8m1 = __riscv_vle8_v_u8m1(a_j + k, vector_length);
|
|
994
|
+
|
|
995
|
+
// Extract magnitudes, zero-extend to u16, compute byte offsets
|
|
996
|
+
vuint8m1_t mag_i_u8m1 = __riscv_vand_vx_u8m1(raw_i_u8m1, 0x1F, vector_length);
|
|
997
|
+
vuint8m1_t mag_j_u8m1 = __riscv_vand_vx_u8m1(raw_j_u8m1, 0x1F, vector_length);
|
|
998
|
+
vuint16m2_t idx_i_u16m2 = __riscv_vzext_vf2_u16m2(mag_i_u8m1, vector_length);
|
|
999
|
+
vuint16m2_t idx_j_u16m2 = __riscv_vzext_vf2_u16m2(mag_j_u8m1, vector_length);
|
|
1000
|
+
vuint16m2_t off_i_u16m2 = __riscv_vsll_vx_u16m2(idx_i_u16m2, 1, vector_length);
|
|
1001
|
+
vuint16m2_t off_j_u16m2 = __riscv_vsll_vx_u16m2(idx_j_u16m2, 1, vector_length);
|
|
1002
|
+
|
|
1003
|
+
// Gather unsigned magnitudes
|
|
1004
|
+
vuint16m2_t uval_i_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off_i_u16m2,
|
|
1005
|
+
vector_length);
|
|
1006
|
+
vuint16m2_t uval_j_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_magnitude_lut_rvv_, off_j_u16m2,
|
|
1007
|
+
vector_length);
|
|
1008
|
+
|
|
1009
|
+
// Apply individual signs
|
|
1010
|
+
vuint8m1_t sign_i_u8m1 = __riscv_vand_vx_u8m1(raw_i_u8m1, 0x20, vector_length);
|
|
1011
|
+
vuint8m1_t sign_j_u8m1 = __riscv_vand_vx_u8m1(raw_j_u8m1, 0x20, vector_length);
|
|
1012
|
+
vbool8_t negated_i_b8 = __riscv_vmsne_vx_u8m1_b8(sign_i_u8m1, 0, vector_length);
|
|
1013
|
+
vbool8_t negated_j_b8 = __riscv_vmsne_vx_u8m1_b8(sign_j_u8m1, 0, vector_length);
|
|
1014
|
+
|
|
1015
|
+
vint16m2_t val_i_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval_i_u16m2);
|
|
1016
|
+
val_i_i16m2 = __riscv_vneg_v_i16m2_mu(negated_i_b8, val_i_i16m2, val_i_i16m2, vector_length);
|
|
1017
|
+
vint16m2_t val_j_i16m2 = __riscv_vreinterpret_v_u16m2_i16m2(uval_j_u16m2);
|
|
1018
|
+
val_j_i16m2 = __riscv_vneg_v_i16m2_mu(negated_j_b8, val_j_i16m2, val_j_i16m2, vector_length);
|
|
1019
|
+
|
|
1020
|
+
// Widening multiply-accumulate: i16×i16 → i32
|
|
1021
|
+
accumulator_i32m4 = __riscv_vwmacc_vv_i32m4_tu(accumulator_i32m4, val_i_i16m2, val_j_i16m2,
|
|
1022
|
+
vector_length);
|
|
1023
|
+
}
|
|
1024
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
1025
|
+
nk_f32_t dot = (nk_f32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1026
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax)) *
|
|
1027
|
+
lut_scale_reciprocal;
|
|
1028
|
+
result[i * result_stride_elements + j] = dot;
|
|
1029
|
+
}
|
|
1030
|
+
}
|
|
1031
|
+
}
|
|
1032
|
+
|
|
1033
|
+
#pragma endregion // Micro Precision E3M2
|
|
1034
|
+
|
|
1035
|
+
#pragma region Brain Float 16
|
|
1036
|
+
|
|
1037
|
+
/**
|
|
1038
|
+
* @brief Compute the packed buffer size for bf16 GEMM (B stored as f32).
|
|
1039
|
+
*
|
|
1040
|
+
* VL is determined by `__riscv_vsetvlmax_e32m2()` since B is stored as f32.
|
|
1041
|
+
* Layout: column-panel with depth-contiguous f32 values, cache-line padding.
|
|
1042
|
+
*/
|
|
1043
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1044
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
1045
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1046
|
+
// Break power-of-2 strides for cache associativity
|
|
1047
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1048
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1049
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
1050
|
+
column_count * sizeof(nk_f32_t); // per-column norms
|
|
1051
|
+
}
|
|
1052
|
+
|
|
1053
|
+
/**
|
|
1054
|
+
* @brief Pack B matrix from bf16 to f32 for widened dot product.
|
|
1055
|
+
*
|
|
1056
|
+
* Each bf16 value is converted to f32 via bit shift (bf16 is the upper 16 bits of f32).
|
|
1057
|
+
* Padding values are zeroed. Column-panel layout with depth-contiguous storage.
|
|
1058
|
+
*/
|
|
1059
|
+
NK_PUBLIC void nk_dots_pack_bf16_rvv(nk_bf16_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1060
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1061
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
1062
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1063
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1064
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1065
|
+
|
|
1066
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1067
|
+
header->column_count = (nk_u32_t)column_count;
|
|
1068
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
1069
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
1070
|
+
|
|
1071
|
+
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1072
|
+
nk_size_t total = column_count * depth_padded;
|
|
1073
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
1074
|
+
|
|
1075
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1076
|
+
nk_u16_t const *src = (nk_u16_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1077
|
+
nk_f32_t *dst = packed + column * depth_padded;
|
|
1078
|
+
for (nk_size_t k = 0; k < depth; ++k) {
|
|
1079
|
+
union {
|
|
1080
|
+
nk_u32_t u;
|
|
1081
|
+
nk_f32_t f;
|
|
1082
|
+
} conv;
|
|
1083
|
+
conv.u = (nk_u32_t)src[k] << 16;
|
|
1084
|
+
dst[k] = conv.f;
|
|
1085
|
+
}
|
|
1086
|
+
}
|
|
1087
|
+
|
|
1088
|
+
// Append per-column norms after packed data
|
|
1089
|
+
nk_f32_t *norms = (nk_f32_t *)(packed + total);
|
|
1090
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1091
|
+
nk_bf16_t const *src = (nk_bf16_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1092
|
+
norms[column] = nk_dots_reduce_sumsq_bf16_(src, depth);
|
|
1093
|
+
}
|
|
1094
|
+
}
|
|
1095
|
+
|
|
1096
|
+
/**
|
|
1097
|
+
* @brief bf16 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
|
|
1098
|
+
*
|
|
1099
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
1100
|
+
* - Load A as u16m1 and convert to f32m2 via `nk_bf16m1_to_f32m2_rvv_`
|
|
1101
|
+
* - Load B as f32m2 directly (pre-packed)
|
|
1102
|
+
* - Accumulate via `vfwmacc_vv_f64m4` which widens both f32 operands to f64
|
|
1103
|
+
* - Horizontal reduce and narrow to f32 on store
|
|
1104
|
+
*
|
|
1105
|
+
* Register tile: process 4 rows per iteration (rows_per_tile=4).
|
|
1106
|
+
*/
|
|
1107
|
+
NK_INTERNAL void nk_dots_packed_bf16_rvv_aligned_(nk_bf16_t const *a_matrix, void const *b_packed_buffer,
|
|
1108
|
+
nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
1109
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1110
|
+
nk_size_t c_stride_in_bytes) {
|
|
1111
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
1112
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
1113
|
+
nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
|
|
1114
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
1115
|
+
|
|
1116
|
+
// Zero output matrix
|
|
1117
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
1118
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
1119
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
1120
|
+
}
|
|
1121
|
+
|
|
1122
|
+
// mr=4 register tile over rows
|
|
1123
|
+
nk_size_t row = 0;
|
|
1124
|
+
for (; row + 4 <= row_count; row += 4) {
|
|
1125
|
+
nk_u16_t const *a_row_0 = (nk_u16_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
1126
|
+
nk_u16_t const *a_row_1 = (nk_u16_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
1127
|
+
nk_u16_t const *a_row_2 = (nk_u16_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
|
|
1128
|
+
nk_u16_t const *a_row_3 = (nk_u16_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
|
|
1129
|
+
nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
1130
|
+
nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
1131
|
+
nk_f32_t *c_row_2 = (nk_f32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
|
|
1132
|
+
nk_f32_t *c_row_3 = (nk_f32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
|
|
1133
|
+
|
|
1134
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1135
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
1136
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
1137
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1138
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1139
|
+
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1140
|
+
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1141
|
+
|
|
1142
|
+
nk_size_t remaining = depth;
|
|
1143
|
+
nk_size_t k = 0;
|
|
1144
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1145
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
1146
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
1147
|
+
// Load A as u16m1 and convert to f32m2
|
|
1148
|
+
vuint16m1_t a_raw_0_u16m1 = __riscv_vle16_v_u16m1(a_row_0 + k, vector_length);
|
|
1149
|
+
vfloat32m2_t a_vector_0_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_0_u16m1, vector_length);
|
|
1150
|
+
vuint16m1_t a_raw_1_u16m1 = __riscv_vle16_v_u16m1(a_row_1 + k, vector_length);
|
|
1151
|
+
vfloat32m2_t a_vector_1_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_1_u16m1, vector_length);
|
|
1152
|
+
vuint16m1_t a_raw_2_u16m1 = __riscv_vle16_v_u16m1(a_row_2 + k, vector_length);
|
|
1153
|
+
vfloat32m2_t a_vector_2_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_2_u16m1, vector_length);
|
|
1154
|
+
vuint16m1_t a_raw_3_u16m1 = __riscv_vle16_v_u16m1(a_row_3 + k, vector_length);
|
|
1155
|
+
vfloat32m2_t a_vector_3_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_3_u16m1, vector_length);
|
|
1156
|
+
accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
|
|
1157
|
+
vector_length);
|
|
1158
|
+
accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
|
|
1159
|
+
vector_length);
|
|
1160
|
+
accumulator_2_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_2_f64m4, a_vector_2_f32m2, b_vector_f32m2,
|
|
1161
|
+
vector_length);
|
|
1162
|
+
accumulator_3_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_3_f64m4, a_vector_3_f32m2, b_vector_f32m2,
|
|
1163
|
+
vector_length);
|
|
1164
|
+
}
|
|
1165
|
+
|
|
1166
|
+
// Horizontal reduce and narrow to f32
|
|
1167
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1168
|
+
c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1169
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
|
|
1170
|
+
c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1171
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
|
|
1172
|
+
c_row_2[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1173
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
|
|
1174
|
+
c_row_3[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1175
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
|
|
1176
|
+
}
|
|
1177
|
+
}
|
|
1178
|
+
// Remainder rows (mr < 4)
|
|
1179
|
+
for (; row < row_count; ++row) {
|
|
1180
|
+
nk_u16_t const *a_row = (nk_u16_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
1181
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
1182
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1183
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
1184
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
1185
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1186
|
+
nk_size_t remaining = depth;
|
|
1187
|
+
nk_size_t k = 0;
|
|
1188
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1189
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
1190
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
1191
|
+
vuint16m1_t a_raw_u16m1 = __riscv_vle16_v_u16m1(a_row + k, vector_length);
|
|
1192
|
+
vfloat32m2_t a_vector_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_u16m1, vector_length);
|
|
1193
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
|
|
1194
|
+
vector_length);
|
|
1195
|
+
}
|
|
1196
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1197
|
+
c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1198
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
1199
|
+
}
|
|
1200
|
+
}
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
/**
|
|
1204
|
+
* @brief Public bf16 packed GEMM wrapper matching the declared signature in dots.h.
|
|
1205
|
+
*
|
|
1206
|
+
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
1207
|
+
* vectors naturally, so no separate edge kernel is needed.
|
|
1208
|
+
*/
|
|
1209
|
+
NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
|
|
1210
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
1211
|
+
nk_dots_packed_bf16_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
1212
|
+
}
|
|
1213
|
+
|
|
1214
|
+
/**
|
|
1215
|
+
* @brief Symmetric bf16 GEMM: C = A * A^T, upper triangle + mirror.
|
|
1216
|
+
*
|
|
1217
|
+
* Uses f64 widened accumulation via `vfwmacc_vv_f64m4` for precision.
|
|
1218
|
+
* Both inputs are bf16, loaded as u16 and converted to f32 via `nk_bf16m1_to_f32m2_rvv_`.
|
|
1219
|
+
* Stride is in bytes.
|
|
1220
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
1221
|
+
*/
|
|
1222
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1223
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1224
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1225
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1226
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
1227
|
+
|
|
1228
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
1229
|
+
nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride);
|
|
1230
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
1231
|
+
nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride);
|
|
1232
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
1233
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1234
|
+
nk_size_t remaining = depth;
|
|
1235
|
+
nk_size_t k = 0;
|
|
1236
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1237
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
1238
|
+
vuint16m1_t a_raw_u16m1 = __riscv_vle16_v_u16m1(a_i + k, vector_length);
|
|
1239
|
+
vfloat32m2_t a_vector_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_raw_u16m1, vector_length);
|
|
1240
|
+
vuint16m1_t b_raw_u16m1 = __riscv_vle16_v_u16m1(a_j + k, vector_length);
|
|
1241
|
+
vfloat32m2_t b_vector_f32m2 = nk_bf16m1_to_f32m2_rvv_(b_raw_u16m1, vector_length);
|
|
1242
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
|
|
1243
|
+
vector_length);
|
|
1244
|
+
}
|
|
1245
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1246
|
+
nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1247
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
1248
|
+
result[i * result_stride_elements + j] = dot;
|
|
1249
|
+
}
|
|
1250
|
+
}
|
|
1251
|
+
}
|
|
1252
|
+
|
|
1253
|
+
#pragma endregion // Brain Float 16
|
|
1254
|
+
|
|
1255
|
+
#pragma region Half Precision Floats
|
|
1256
|
+
|
|
1257
|
+
/**
|
|
1258
|
+
* @brief Compute the packed buffer size for f16 GEMM (B stored as f32).
|
|
1259
|
+
*
|
|
1260
|
+
* VL is determined by `__riscv_vsetvlmax_e32m2()` since B is stored as f32.
|
|
1261
|
+
* Layout: column-panel with depth-contiguous f32 values, cache-line padding.
|
|
1262
|
+
*/
|
|
1263
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1264
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
1265
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1266
|
+
// Break power-of-2 strides for cache associativity
|
|
1267
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1268
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1269
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
1270
|
+
column_count * sizeof(nk_f32_t); // per-column norms
|
|
1271
|
+
}
|
|
1272
|
+
|
|
1273
|
+
/**
|
|
1274
|
+
* @brief Pack B matrix from f16 to f32 for widened dot product.
|
|
1275
|
+
*
|
|
1276
|
+
* Each f16 value is converted to f32 via `nk_f16_to_f32_serial`.
|
|
1277
|
+
* Padding values are zeroed. Column-panel layout with depth-contiguous storage.
|
|
1278
|
+
*/
|
|
1279
|
+
NK_PUBLIC void nk_dots_pack_f16_rvv(nk_f16_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1280
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1281
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
1282
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1283
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1284
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1285
|
+
|
|
1286
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1287
|
+
header->column_count = (nk_u32_t)column_count;
|
|
1288
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
1289
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
1290
|
+
|
|
1291
|
+
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1292
|
+
nk_size_t total = column_count * depth_padded;
|
|
1293
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
1294
|
+
|
|
1295
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1296
|
+
nk_f16_t const *src = (nk_f16_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1297
|
+
nk_f32_t *dst = packed + column * depth_padded;
|
|
1298
|
+
for (nk_size_t k = 0; k < depth; ++k) nk_f16_to_f32_serial(&src[k], &dst[k]);
|
|
1299
|
+
}
|
|
1300
|
+
|
|
1301
|
+
// Append per-column norms after packed data
|
|
1302
|
+
nk_f32_t *norms = (nk_f32_t *)(packed + total);
|
|
1303
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1304
|
+
nk_f16_t const *src = (nk_f16_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1305
|
+
norms[column] = nk_dots_reduce_sumsq_f16_(src, depth);
|
|
1306
|
+
}
|
|
1307
|
+
}
|
|
1308
|
+
|
|
1309
|
+
/**
|
|
1310
|
+
* @brief f16 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
|
|
1311
|
+
*
|
|
1312
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
1313
|
+
* - Load A as u16m1 and convert to f32m2 via `nk_f16m1_to_f32m2_rvv_`
|
|
1314
|
+
* - Load B as f32m2 directly (pre-packed)
|
|
1315
|
+
* - Accumulate via `vfwmacc_vv_f64m4` which widens both f32 operands to f64
|
|
1316
|
+
* - Horizontal reduce and narrow to f32 on store
|
|
1317
|
+
*
|
|
1318
|
+
* Register tile: process 4 rows per iteration (rows_per_tile=4).
|
|
1319
|
+
*/
|
|
1320
|
+
NK_INTERNAL void nk_dots_packed_f16_rvv_aligned_(nk_f16_t const *a_matrix, void const *b_packed_buffer,
|
|
1321
|
+
nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
1322
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1323
|
+
nk_size_t c_stride_in_bytes) {
|
|
1324
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
1325
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
1326
|
+
nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
|
|
1327
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
1328
|
+
|
|
1329
|
+
// Zero output matrix
|
|
1330
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
1331
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
1332
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
1333
|
+
}
|
|
1334
|
+
|
|
1335
|
+
// mr=4 register tile over rows
|
|
1336
|
+
nk_size_t row = 0;
|
|
1337
|
+
for (; row + 4 <= row_count; row += 4) {
|
|
1338
|
+
nk_u16_t const *a_row_0 = (nk_u16_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
1339
|
+
nk_u16_t const *a_row_1 = (nk_u16_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
1340
|
+
nk_u16_t const *a_row_2 = (nk_u16_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
|
|
1341
|
+
nk_u16_t const *a_row_3 = (nk_u16_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
|
|
1342
|
+
nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
1343
|
+
nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
1344
|
+
nk_f32_t *c_row_2 = (nk_f32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
|
|
1345
|
+
nk_f32_t *c_row_3 = (nk_f32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
|
|
1346
|
+
|
|
1347
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1348
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
1349
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
1350
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1351
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1352
|
+
vfloat64m4_t accumulator_2_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1353
|
+
vfloat64m4_t accumulator_3_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1354
|
+
|
|
1355
|
+
nk_size_t remaining = depth;
|
|
1356
|
+
nk_size_t k = 0;
|
|
1357
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1358
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
1359
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
1360
|
+
// Load A as u16m1 and convert to f32m2
|
|
1361
|
+
vuint16m1_t a_raw_0_u16m1 = __riscv_vle16_v_u16m1(a_row_0 + k, vector_length);
|
|
1362
|
+
vfloat32m2_t a_vector_0_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_0_u16m1, vector_length);
|
|
1363
|
+
vuint16m1_t a_raw_1_u16m1 = __riscv_vle16_v_u16m1(a_row_1 + k, vector_length);
|
|
1364
|
+
vfloat32m2_t a_vector_1_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_1_u16m1, vector_length);
|
|
1365
|
+
vuint16m1_t a_raw_2_u16m1 = __riscv_vle16_v_u16m1(a_row_2 + k, vector_length);
|
|
1366
|
+
vfloat32m2_t a_vector_2_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_2_u16m1, vector_length);
|
|
1367
|
+
vuint16m1_t a_raw_3_u16m1 = __riscv_vle16_v_u16m1(a_row_3 + k, vector_length);
|
|
1368
|
+
vfloat32m2_t a_vector_3_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_3_u16m1, vector_length);
|
|
1369
|
+
accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
|
|
1370
|
+
vector_length);
|
|
1371
|
+
accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
|
|
1372
|
+
vector_length);
|
|
1373
|
+
accumulator_2_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_2_f64m4, a_vector_2_f32m2, b_vector_f32m2,
|
|
1374
|
+
vector_length);
|
|
1375
|
+
accumulator_3_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_3_f64m4, a_vector_3_f32m2, b_vector_f32m2,
|
|
1376
|
+
vector_length);
|
|
1377
|
+
}
|
|
1378
|
+
|
|
1379
|
+
// Horizontal reduce and narrow to f32
|
|
1380
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1381
|
+
c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1382
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
|
|
1383
|
+
c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1384
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
|
|
1385
|
+
c_row_2[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1386
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_2_f64m4, zero_f64m1, vlmax));
|
|
1387
|
+
c_row_3[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1388
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_3_f64m4, zero_f64m1, vlmax));
|
|
1389
|
+
}
|
|
1390
|
+
}
|
|
1391
|
+
// Remainder rows (mr < 4)
|
|
1392
|
+
for (; row < row_count; ++row) {
|
|
1393
|
+
nk_u16_t const *a_row = (nk_u16_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
1394
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
1395
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1396
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
1397
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
1398
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1399
|
+
nk_size_t remaining = depth;
|
|
1400
|
+
nk_size_t k = 0;
|
|
1401
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1402
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
1403
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
1404
|
+
vuint16m1_t a_raw_u16m1 = __riscv_vle16_v_u16m1(a_row + k, vector_length);
|
|
1405
|
+
vfloat32m2_t a_vector_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_u16m1, vector_length);
|
|
1406
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
|
|
1407
|
+
vector_length);
|
|
1408
|
+
}
|
|
1409
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1410
|
+
c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1411
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
1412
|
+
}
|
|
1413
|
+
}
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
/**
|
|
1417
|
+
* @brief Public f16 packed GEMM wrapper matching the declared signature in dots.h.
|
|
1418
|
+
*
|
|
1419
|
+
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
1420
|
+
* vectors naturally, so no separate edge kernel is needed.
|
|
1421
|
+
*/
|
|
1422
|
+
NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
|
|
1423
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
1424
|
+
nk_dots_packed_f16_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1427
|
+
/**
|
|
1428
|
+
* @brief Symmetric f16 GEMM: C = A * A^T, upper triangle + mirror.
|
|
1429
|
+
*
|
|
1430
|
+
* Uses f64 widened accumulation via `vfwmacc_vv_f64m4` for precision.
|
|
1431
|
+
* Both inputs are f16, loaded as u16 and converted to f32 via `nk_f16m1_to_f32m2_rvv_`.
|
|
1432
|
+
* Stride is in bytes.
|
|
1433
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
1434
|
+
*/
|
|
1435
|
+
NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1436
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1437
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1438
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1439
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
1440
|
+
|
|
1441
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
1442
|
+
nk_u16_t const *a_i = (nk_u16_t const *)((char const *)vectors + i * stride);
|
|
1443
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
1444
|
+
nk_u16_t const *a_j = (nk_u16_t const *)((char const *)vectors + j * stride);
|
|
1445
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
1446
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
1447
|
+
nk_size_t remaining = depth;
|
|
1448
|
+
nk_size_t k = 0;
|
|
1449
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1450
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
1451
|
+
vuint16m1_t a_raw_u16m1 = __riscv_vle16_v_u16m1(a_i + k, vector_length);
|
|
1452
|
+
vfloat32m2_t a_vector_f32m2 = nk_f16m1_to_f32m2_rvv_(a_raw_u16m1, vector_length);
|
|
1453
|
+
vuint16m1_t b_raw_u16m1 = __riscv_vle16_v_u16m1(a_j + k, vector_length);
|
|
1454
|
+
vfloat32m2_t b_vector_f32m2 = nk_f16m1_to_f32m2_rvv_(b_raw_u16m1, vector_length);
|
|
1455
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
|
|
1456
|
+
vector_length);
|
|
1457
|
+
}
|
|
1458
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
1459
|
+
nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
1460
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
1461
|
+
result[i * result_stride_elements + j] = dot;
|
|
1462
|
+
}
|
|
1463
|
+
}
|
|
1464
|
+
}
|
|
1465
|
+
|
|
1466
|
+
#pragma endregion // Half Precision Floats
|
|
1467
|
+
|
|
1468
|
+
#pragma region Signed 8-bit Integers
|
|
1469
|
+
|
|
1470
|
+
/**
|
|
1471
|
+
* @brief Compute the packed buffer size for i8 GEMM (B stored as i8).
|
|
1472
|
+
*
|
|
1473
|
+
* VL is determined by `__riscv_vsetvlmax_e8m1()` since B is stored as i8.
|
|
1474
|
+
* Layout: column-panel with depth-contiguous i8 values, cache-line padding.
|
|
1475
|
+
*/
|
|
1476
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1477
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
|
|
1478
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1479
|
+
// Break power-of-2 strides for cache associativity
|
|
1480
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
|
|
1481
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1482
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_i8_t) +
|
|
1483
|
+
column_count * sizeof(nk_u32_t); // per-column norms
|
|
1484
|
+
}
|
|
1485
|
+
|
|
1486
|
+
/**
|
|
1487
|
+
* @brief Pack B matrix from i8 to i8 (direct copy) for integer dot product.
|
|
1488
|
+
*
|
|
1489
|
+
* No conversion needed — values are copied directly.
|
|
1490
|
+
* Padding values are zeroed. Column-panel layout with depth-contiguous storage.
|
|
1491
|
+
*/
|
|
1492
|
+
NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1493
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1494
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
|
|
1495
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1496
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_i8_t);
|
|
1497
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1498
|
+
|
|
1499
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1500
|
+
header->column_count = (nk_u32_t)column_count;
|
|
1501
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
1502
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
1503
|
+
|
|
1504
|
+
nk_i8_t *packed = (nk_i8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1505
|
+
nk_size_t total = column_count * depth_padded;
|
|
1506
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
1507
|
+
|
|
1508
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1509
|
+
nk_i8_t const *src = (nk_i8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1510
|
+
nk_i8_t *dst = packed + column * depth_padded;
|
|
1511
|
+
for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
|
|
1512
|
+
}
|
|
1513
|
+
|
|
1514
|
+
// Append per-column norms after packed data
|
|
1515
|
+
nk_u32_t *norms = (nk_u32_t *)(packed + total);
|
|
1516
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1517
|
+
nk_i8_t const *src = (nk_i8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1518
|
+
norms[column] = nk_dots_reduce_sumsq_i8_(src, depth);
|
|
1519
|
+
}
|
|
1520
|
+
}
|
|
1521
|
+
|
|
1522
|
+
/**
|
|
1523
|
+
* @brief i8 packed GEMM kernel: C += A * B_packed^T with i32 accumulation.
|
|
1524
|
+
*
|
|
1525
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
1526
|
+
* - Load i8 values from A and pre-packed i8 values from B
|
|
1527
|
+
* - Widening multiply: i8 x i8 -> i16 via `vwmul`
|
|
1528
|
+
* - Widen-accumulate: i32 += i16 via `vwadd_wv`
|
|
1529
|
+
* - Horizontal reduce via `vredsum`
|
|
1530
|
+
*
|
|
1531
|
+
* Register tile: process 4 rows per iteration (rows_per_tile=4).
|
|
1532
|
+
* Output is nk_i32_t (integer result, no scaling).
|
|
1533
|
+
*/
|
|
1534
|
+
NK_INTERNAL void nk_dots_packed_i8_rvv_aligned_(nk_i8_t const *a_matrix, void const *b_packed_buffer,
|
|
1535
|
+
nk_i32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
1536
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1537
|
+
nk_size_t c_stride_in_bytes) {
|
|
1538
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
1539
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
1540
|
+
nk_i8_t const *packed_data = (nk_i8_t const *)((char const *)b_packed_buffer +
|
|
1541
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
1542
|
+
|
|
1543
|
+
// Zero output matrix
|
|
1544
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
1545
|
+
nk_i32_t *c_row = (nk_i32_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
1546
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
1547
|
+
}
|
|
1548
|
+
|
|
1549
|
+
// mr=4 register tile over rows
|
|
1550
|
+
nk_size_t row = 0;
|
|
1551
|
+
for (; row + 4 <= row_count; row += 4) {
|
|
1552
|
+
nk_i8_t const *a_row_0 = (nk_i8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
1553
|
+
nk_i8_t const *a_row_1 = (nk_i8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
1554
|
+
nk_i8_t const *a_row_2 = (nk_i8_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
|
|
1555
|
+
nk_i8_t const *a_row_3 = (nk_i8_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
|
|
1556
|
+
nk_i32_t *c_row_0 = (nk_i32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
1557
|
+
nk_i32_t *c_row_1 = (nk_i32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
1558
|
+
nk_i32_t *c_row_2 = (nk_i32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
|
|
1559
|
+
nk_i32_t *c_row_3 = (nk_i32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
|
|
1560
|
+
|
|
1561
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1562
|
+
nk_i8_t const *b_column = packed_data + column * depth_padded;
|
|
1563
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
1564
|
+
vint32m4_t accumulator_0_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
1565
|
+
vint32m4_t accumulator_1_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
1566
|
+
vint32m4_t accumulator_2_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
1567
|
+
vint32m4_t accumulator_3_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
1568
|
+
|
|
1569
|
+
nk_size_t remaining = depth;
|
|
1570
|
+
nk_size_t k = 0;
|
|
1571
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1572
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
1573
|
+
vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(b_column + k, vector_length);
|
|
1574
|
+
vint8m1_t a_vector_0_i8m1 = __riscv_vle8_v_i8m1(a_row_0 + k, vector_length);
|
|
1575
|
+
vint8m1_t a_vector_1_i8m1 = __riscv_vle8_v_i8m1(a_row_1 + k, vector_length);
|
|
1576
|
+
vint8m1_t a_vector_2_i8m1 = __riscv_vle8_v_i8m1(a_row_2 + k, vector_length);
|
|
1577
|
+
vint8m1_t a_vector_3_i8m1 = __riscv_vle8_v_i8m1(a_row_3 + k, vector_length);
|
|
1578
|
+
vint16m2_t product_0_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_0_i8m1, b_vector_i8m1, vector_length);
|
|
1579
|
+
vint16m2_t product_1_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_1_i8m1, b_vector_i8m1, vector_length);
|
|
1580
|
+
vint16m2_t product_2_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_2_i8m1, b_vector_i8m1, vector_length);
|
|
1581
|
+
vint16m2_t product_3_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_3_i8m1, b_vector_i8m1, vector_length);
|
|
1582
|
+
accumulator_0_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_0_i32m4, accumulator_0_i32m4,
|
|
1583
|
+
product_0_i16m2, vector_length);
|
|
1584
|
+
accumulator_1_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_1_i32m4, accumulator_1_i32m4,
|
|
1585
|
+
product_1_i16m2, vector_length);
|
|
1586
|
+
accumulator_2_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_2_i32m4, accumulator_2_i32m4,
|
|
1587
|
+
product_2_i16m2, vector_length);
|
|
1588
|
+
accumulator_3_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_3_i32m4, accumulator_3_i32m4,
|
|
1589
|
+
product_3_i16m2, vector_length);
|
|
1590
|
+
}
|
|
1591
|
+
|
|
1592
|
+
// Horizontal reduce
|
|
1593
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
1594
|
+
c_row_0[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1595
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_0_i32m4, zero_i32m1, vlmax));
|
|
1596
|
+
c_row_1[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1597
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_1_i32m4, zero_i32m1, vlmax));
|
|
1598
|
+
c_row_2[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1599
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_2_i32m4, zero_i32m1, vlmax));
|
|
1600
|
+
c_row_3[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1601
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_3_i32m4, zero_i32m1, vlmax));
|
|
1602
|
+
}
|
|
1603
|
+
}
|
|
1604
|
+
// Remainder rows (mr < 4)
|
|
1605
|
+
for (; row < row_count; ++row) {
|
|
1606
|
+
nk_i8_t const *a_row = (nk_i8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
1607
|
+
nk_i32_t *c_row = (nk_i32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
1608
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1609
|
+
nk_i8_t const *b_column = packed_data + column * depth_padded;
|
|
1610
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
1611
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
1612
|
+
nk_size_t remaining = depth;
|
|
1613
|
+
nk_size_t k = 0;
|
|
1614
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1615
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
1616
|
+
vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(b_column + k, vector_length);
|
|
1617
|
+
vint8m1_t a_vector_i8m1 = __riscv_vle8_v_i8m1(a_row + k, vector_length);
|
|
1618
|
+
vint16m2_t product_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_i8m1, b_vector_i8m1, vector_length);
|
|
1619
|
+
accumulator_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_i32m4, accumulator_i32m4, product_i16m2,
|
|
1620
|
+
vector_length);
|
|
1621
|
+
}
|
|
1622
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
1623
|
+
c_row[column] = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1624
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax));
|
|
1625
|
+
}
|
|
1626
|
+
}
|
|
1627
|
+
}
|
|
1628
|
+
|
|
1629
|
+
/**
|
|
1630
|
+
* @brief Public i8 packed GEMM wrapper matching the declared signature in dots.h.
|
|
1631
|
+
*
|
|
1632
|
+
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
1633
|
+
* vectors naturally, so no separate edge kernel is needed.
|
|
1634
|
+
*/
|
|
1635
|
+
NK_PUBLIC void nk_dots_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t m, nk_size_t n,
|
|
1636
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
1637
|
+
nk_dots_packed_i8_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
1638
|
+
}
|
|
1639
|
+
|
|
1640
|
+
/**
|
|
1641
|
+
* @brief Symmetric i8 GEMM: C = A * A^T, upper triangle + mirror.
|
|
1642
|
+
*
|
|
1643
|
+
* Uses integer i8 arithmetic with i32 accumulation.
|
|
1644
|
+
* Both inputs are i8, widened via i8 x i8 -> i16 -> i32 accumulation.
|
|
1645
|
+
* Stride is in bytes.
|
|
1646
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
1647
|
+
*/
|
|
1648
|
+
NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
1649
|
+
nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
1650
|
+
nk_size_t row_count) {
|
|
1651
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_i32_t);
|
|
1652
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
1653
|
+
|
|
1654
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
1655
|
+
nk_i8_t const *a_i = (nk_i8_t const *)((char const *)vectors + i * stride);
|
|
1656
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
1657
|
+
nk_i8_t const *a_j = (nk_i8_t const *)((char const *)vectors + j * stride);
|
|
1658
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
1659
|
+
vint32m4_t accumulator_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
|
|
1660
|
+
nk_size_t remaining = depth;
|
|
1661
|
+
nk_size_t k = 0;
|
|
1662
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1663
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
1664
|
+
vint8m1_t a_vector_i8m1 = __riscv_vle8_v_i8m1(a_i + k, vector_length);
|
|
1665
|
+
vint8m1_t b_vector_i8m1 = __riscv_vle8_v_i8m1(a_j + k, vector_length);
|
|
1666
|
+
vint16m2_t product_i16m2 = __riscv_vwmul_vv_i16m2(a_vector_i8m1, b_vector_i8m1, vector_length);
|
|
1667
|
+
accumulator_i32m4 = __riscv_vwadd_wv_i32m4_tu(accumulator_i32m4, accumulator_i32m4, product_i16m2,
|
|
1668
|
+
vector_length);
|
|
1669
|
+
}
|
|
1670
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
|
|
1671
|
+
nk_i32_t dot = (nk_i32_t)__riscv_vmv_x_s_i32m1_i32(
|
|
1672
|
+
__riscv_vredsum_vs_i32m4_i32m1(accumulator_i32m4, zero_i32m1, vlmax));
|
|
1673
|
+
result[i * result_stride_elements + j] = dot;
|
|
1674
|
+
}
|
|
1675
|
+
}
|
|
1676
|
+
}
|
|
1677
|
+
|
|
1678
|
+
#pragma endregion // Signed 8-bit Integers
|
|
1679
|
+
|
|
1680
|
+
#pragma region Unsigned 8-bit Integers
|
|
1681
|
+
|
|
1682
|
+
/**
|
|
1683
|
+
* @brief Compute the packed buffer size for u8 GEMM (B stored as u8).
|
|
1684
|
+
*
|
|
1685
|
+
* VL is determined by `__riscv_vsetvlmax_e8m1()` since B is stored as u8.
|
|
1686
|
+
* Layout: column-panel with depth-contiguous u8 values, cache-line padding.
|
|
1687
|
+
*/
|
|
1688
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1689
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
|
|
1690
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1691
|
+
// Break power-of-2 strides for cache associativity
|
|
1692
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_u8_t);
|
|
1693
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1694
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_u8_t) +
|
|
1695
|
+
column_count * sizeof(nk_u32_t); // per-column norms
|
|
1696
|
+
}
|
|
1697
|
+
|
|
1698
|
+
/**
|
|
1699
|
+
* @brief Pack B matrix from u8 to u8 (direct copy) for integer dot product.
|
|
1700
|
+
*
|
|
1701
|
+
* No conversion needed — values are copied directly.
|
|
1702
|
+
* Padding values are zeroed. Column-panel layout with depth-contiguous storage.
|
|
1703
|
+
*/
|
|
1704
|
+
NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1705
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1706
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e8m1();
|
|
1707
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1708
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_u8_t);
|
|
1709
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1710
|
+
|
|
1711
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1712
|
+
header->column_count = (nk_u32_t)column_count;
|
|
1713
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
1714
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
1715
|
+
|
|
1716
|
+
nk_u8_t *packed = (nk_u8_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1717
|
+
nk_size_t total = column_count * depth_padded;
|
|
1718
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
1719
|
+
|
|
1720
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1721
|
+
nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1722
|
+
nk_u8_t *dst = packed + column * depth_padded;
|
|
1723
|
+
for (nk_size_t k = 0; k < depth; ++k) dst[k] = src[k];
|
|
1724
|
+
}
|
|
1725
|
+
|
|
1726
|
+
// Append per-column norms after packed data
|
|
1727
|
+
nk_u32_t *norms = (nk_u32_t *)(packed + total);
|
|
1728
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1729
|
+
nk_u8_t const *src = (nk_u8_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1730
|
+
norms[column] = nk_dots_reduce_sumsq_u8_(src, depth);
|
|
1731
|
+
}
|
|
1732
|
+
}
|
|
1733
|
+
|
|
1734
|
+
/**
|
|
1735
|
+
* @brief u8 packed GEMM kernel: C += A * B_packed^T with u32 accumulation.
|
|
1736
|
+
*
|
|
1737
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
1738
|
+
* - Load u8 values from A and pre-packed u8 values from B
|
|
1739
|
+
* - Widening multiply: u8 x u8 -> u16 via `vwmulu`
|
|
1740
|
+
* - Widen-accumulate: u32 += u16 via `vwaddu_wv`
|
|
1741
|
+
* - Horizontal reduce via `vredsum`
|
|
1742
|
+
*
|
|
1743
|
+
* Register tile: process 4 rows per iteration (rows_per_tile=4).
|
|
1744
|
+
* Output is nk_u32_t (unsigned integer result, no scaling).
|
|
1745
|
+
*/
|
|
1746
|
+
NK_INTERNAL void nk_dots_packed_u8_rvv_aligned_(nk_u8_t const *a_matrix, void const *b_packed_buffer,
|
|
1747
|
+
nk_u32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
1748
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1749
|
+
nk_size_t c_stride_in_bytes) {
|
|
1750
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
1751
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
1752
|
+
nk_u8_t const *packed_data = (nk_u8_t const *)((char const *)b_packed_buffer +
|
|
1753
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
1754
|
+
|
|
1755
|
+
// Zero output matrix
|
|
1756
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
1757
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
1758
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
1759
|
+
}
|
|
1760
|
+
|
|
1761
|
+
// mr=4 register tile over rows
|
|
1762
|
+
nk_size_t row = 0;
|
|
1763
|
+
for (; row + 4 <= row_count; row += 4) {
|
|
1764
|
+
nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
1765
|
+
nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
1766
|
+
nk_u8_t const *a_row_2 = (nk_u8_t const *)((char const *)a_matrix + (row + 2) * a_stride_in_bytes);
|
|
1767
|
+
nk_u8_t const *a_row_3 = (nk_u8_t const *)((char const *)a_matrix + (row + 3) * a_stride_in_bytes);
|
|
1768
|
+
nk_u32_t *c_row_0 = (nk_u32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
1769
|
+
nk_u32_t *c_row_1 = (nk_u32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
1770
|
+
nk_u32_t *c_row_2 = (nk_u32_t *)((char *)c_matrix + (row + 2) * c_stride_in_bytes);
|
|
1771
|
+
nk_u32_t *c_row_3 = (nk_u32_t *)((char *)c_matrix + (row + 3) * c_stride_in_bytes);
|
|
1772
|
+
|
|
1773
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1774
|
+
nk_u8_t const *b_column = packed_data + column * depth_padded;
|
|
1775
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
1776
|
+
vuint32m4_t accumulator_0_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
|
|
1777
|
+
vuint32m4_t accumulator_1_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
|
|
1778
|
+
vuint32m4_t accumulator_2_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
|
|
1779
|
+
vuint32m4_t accumulator_3_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
|
|
1780
|
+
|
|
1781
|
+
nk_size_t remaining = depth;
|
|
1782
|
+
nk_size_t k = 0;
|
|
1783
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1784
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
1785
|
+
vuint8m1_t b_vector_u8m1 = __riscv_vle8_v_u8m1(b_column + k, vector_length);
|
|
1786
|
+
vuint8m1_t a_vector_0_u8m1 = __riscv_vle8_v_u8m1(a_row_0 + k, vector_length);
|
|
1787
|
+
vuint8m1_t a_vector_1_u8m1 = __riscv_vle8_v_u8m1(a_row_1 + k, vector_length);
|
|
1788
|
+
vuint8m1_t a_vector_2_u8m1 = __riscv_vle8_v_u8m1(a_row_2 + k, vector_length);
|
|
1789
|
+
vuint8m1_t a_vector_3_u8m1 = __riscv_vle8_v_u8m1(a_row_3 + k, vector_length);
|
|
1790
|
+
vuint16m2_t product_0_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_0_u8m1, b_vector_u8m1, vector_length);
|
|
1791
|
+
vuint16m2_t product_1_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_1_u8m1, b_vector_u8m1, vector_length);
|
|
1792
|
+
vuint16m2_t product_2_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_2_u8m1, b_vector_u8m1, vector_length);
|
|
1793
|
+
vuint16m2_t product_3_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_3_u8m1, b_vector_u8m1, vector_length);
|
|
1794
|
+
accumulator_0_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_0_u32m4, accumulator_0_u32m4,
|
|
1795
|
+
product_0_u16m2, vector_length);
|
|
1796
|
+
accumulator_1_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_1_u32m4, accumulator_1_u32m4,
|
|
1797
|
+
product_1_u16m2, vector_length);
|
|
1798
|
+
accumulator_2_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_2_u32m4, accumulator_2_u32m4,
|
|
1799
|
+
product_2_u16m2, vector_length);
|
|
1800
|
+
accumulator_3_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_3_u32m4, accumulator_3_u32m4,
|
|
1801
|
+
product_3_u16m2, vector_length);
|
|
1802
|
+
}
|
|
1803
|
+
|
|
1804
|
+
// Horizontal reduce
|
|
1805
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
1806
|
+
c_row_0[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1807
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_0_u32m4, zero_u32m1, vlmax));
|
|
1808
|
+
c_row_1[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1809
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_1_u32m4, zero_u32m1, vlmax));
|
|
1810
|
+
c_row_2[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1811
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_2_u32m4, zero_u32m1, vlmax));
|
|
1812
|
+
c_row_3[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1813
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_3_u32m4, zero_u32m1, vlmax));
|
|
1814
|
+
}
|
|
1815
|
+
}
|
|
1816
|
+
// Remainder rows (mr < 4)
|
|
1817
|
+
for (; row < row_count; ++row) {
|
|
1818
|
+
nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
1819
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
1820
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1821
|
+
nk_u8_t const *b_column = packed_data + column * depth_padded;
|
|
1822
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
1823
|
+
vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
|
|
1824
|
+
nk_size_t remaining = depth;
|
|
1825
|
+
nk_size_t k = 0;
|
|
1826
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1827
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
1828
|
+
vuint8m1_t b_vector_u8m1 = __riscv_vle8_v_u8m1(b_column + k, vector_length);
|
|
1829
|
+
vuint8m1_t a_vector_u8m1 = __riscv_vle8_v_u8m1(a_row + k, vector_length);
|
|
1830
|
+
vuint16m2_t product_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_u8m1, b_vector_u8m1, vector_length);
|
|
1831
|
+
accumulator_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_u32m4, accumulator_u32m4, product_u16m2,
|
|
1832
|
+
vector_length);
|
|
1833
|
+
}
|
|
1834
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
1835
|
+
c_row[column] = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1836
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, vlmax));
|
|
1837
|
+
}
|
|
1838
|
+
}
|
|
1839
|
+
}
|
|
1840
|
+
|
|
1841
|
+
/**
|
|
1842
|
+
* @brief Public u8 packed GEMM wrapper matching the declared signature in dots.h.
|
|
1843
|
+
*
|
|
1844
|
+
* Dispatches to the aligned kernel for all cases — RVV's `vsetvl` handles partial
|
|
1845
|
+
* vectors naturally, so no separate edge kernel is needed.
|
|
1846
|
+
*/
|
|
1847
|
+
NK_PUBLIC void nk_dots_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t m, nk_size_t n,
|
|
1848
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
1849
|
+
nk_dots_packed_u8_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
1850
|
+
}
|
|
1851
|
+
|
|
1852
|
+
/**
|
|
1853
|
+
* @brief Symmetric u8 GEMM: C = A * A^T, upper triangle + mirror.
|
|
1854
|
+
*
|
|
1855
|
+
* Uses unsigned integer u8 arithmetic with u32 accumulation.
|
|
1856
|
+
* Both inputs are u8, widened via u8 x u8 -> u16 -> u32 accumulation.
|
|
1857
|
+
* Stride is in bytes.
|
|
1858
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
1859
|
+
*/
|
|
1860
|
+
NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
1861
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
1862
|
+
nk_size_t row_count) {
|
|
1863
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_u32_t);
|
|
1864
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
1865
|
+
|
|
1866
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
1867
|
+
nk_u8_t const *a_i = (nk_u8_t const *)((char const *)vectors + i * stride);
|
|
1868
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
1869
|
+
nk_u8_t const *a_j = (nk_u8_t const *)((char const *)vectors + j * stride);
|
|
1870
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
1871
|
+
vuint32m4_t accumulator_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
|
|
1872
|
+
nk_size_t remaining = depth;
|
|
1873
|
+
nk_size_t k = 0;
|
|
1874
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
1875
|
+
vector_length = __riscv_vsetvl_e8m1(remaining);
|
|
1876
|
+
vuint8m1_t a_vector_u8m1 = __riscv_vle8_v_u8m1(a_i + k, vector_length);
|
|
1877
|
+
vuint8m1_t b_vector_u8m1 = __riscv_vle8_v_u8m1(a_j + k, vector_length);
|
|
1878
|
+
vuint16m2_t product_u16m2 = __riscv_vwmulu_vv_u16m2(a_vector_u8m1, b_vector_u8m1, vector_length);
|
|
1879
|
+
accumulator_u32m4 = __riscv_vwaddu_wv_u32m4_tu(accumulator_u32m4, accumulator_u32m4, product_u16m2,
|
|
1880
|
+
vector_length);
|
|
1881
|
+
}
|
|
1882
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
1883
|
+
nk_u32_t dot = (nk_u32_t)__riscv_vmv_x_s_u32m1_u32(
|
|
1884
|
+
__riscv_vredsum_vs_u32m4_u32m1(accumulator_u32m4, zero_u32m1, vlmax));
|
|
1885
|
+
result[i * result_stride_elements + j] = dot;
|
|
1886
|
+
}
|
|
1887
|
+
}
|
|
1888
|
+
}
|
|
1889
|
+
|
|
1890
|
+
#pragma endregion // Unsigned 8-bit Integers
|
|
1891
|
+
|
|
1892
|
+
#pragma region Quarter Precision E4M3
|
|
1893
|
+
|
|
1894
|
+
/**
|
|
1895
|
+
* @brief E4M3 magnitude LUT: 7-bit magnitude -> f32 bit pattern (u32).
|
|
1896
|
+
* nk_e4m3_magnitude_lut_rvv_[i] = float_to_bits(e4m3_to_f32(i)) for i=0..127.
|
|
1897
|
+
* E4M3FN: 4 exponent bits (bias=7), 3 mantissa bits, no infinity,
|
|
1898
|
+
* NaN = magnitude 0x7F only.
|
|
1899
|
+
*/
|
|
1900
|
+
static nk_u32_t const nk_e4m3_magnitude_lut_rvv_[128] = {
|
|
1901
|
+
0x00000000u, 0x3B000000u, 0x3B800000u, 0x3BC00000u,
|
|
1902
|
+
0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u, /* [ 0.. 7] */
|
|
1903
|
+
0x3C800000u, 0x3C900000u, 0x3CA00000u, 0x3CB00000u,
|
|
1904
|
+
0x3CC00000u, 0x3CD00000u, 0x3CE00000u, 0x3CF00000u, /* [ 8.. 15] */
|
|
1905
|
+
0x3D000000u, 0x3D100000u, 0x3D200000u, 0x3D300000u,
|
|
1906
|
+
0x3D400000u, 0x3D500000u, 0x3D600000u, 0x3D700000u, /* [ 16.. 23] */
|
|
1907
|
+
0x3D800000u, 0x3D900000u, 0x3DA00000u, 0x3DB00000u,
|
|
1908
|
+
0x3DC00000u, 0x3DD00000u, 0x3DE00000u, 0x3DF00000u, /* [ 24.. 31] */
|
|
1909
|
+
0x3E000000u, 0x3E100000u, 0x3E200000u, 0x3E300000u,
|
|
1910
|
+
0x3E400000u, 0x3E500000u, 0x3E600000u, 0x3E700000u, /* [ 32.. 39] */
|
|
1911
|
+
0x3E800000u, 0x3E900000u, 0x3EA00000u, 0x3EB00000u,
|
|
1912
|
+
0x3EC00000u, 0x3ED00000u, 0x3EE00000u, 0x3EF00000u, /* [ 40.. 47] */
|
|
1913
|
+
0x3F000000u, 0x3F100000u, 0x3F200000u, 0x3F300000u,
|
|
1914
|
+
0x3F400000u, 0x3F500000u, 0x3F600000u, 0x3F700000u, /* [ 48.. 55] */
|
|
1915
|
+
0x3F800000u, 0x3F900000u, 0x3FA00000u, 0x3FB00000u,
|
|
1916
|
+
0x3FC00000u, 0x3FD00000u, 0x3FE00000u, 0x3FF00000u, /* [ 56.. 63] */
|
|
1917
|
+
0x40000000u, 0x40100000u, 0x40200000u, 0x40300000u,
|
|
1918
|
+
0x40400000u, 0x40500000u, 0x40600000u, 0x40700000u, /* [ 64.. 71] */
|
|
1919
|
+
0x40800000u, 0x40900000u, 0x40A00000u, 0x40B00000u,
|
|
1920
|
+
0x40C00000u, 0x40D00000u, 0x40E00000u, 0x40F00000u, /* [ 72.. 79] */
|
|
1921
|
+
0x41000000u, 0x41100000u, 0x41200000u, 0x41300000u,
|
|
1922
|
+
0x41400000u, 0x41500000u, 0x41600000u, 0x41700000u, /* [ 80.. 87] */
|
|
1923
|
+
0x41800000u, 0x41900000u, 0x41A00000u, 0x41B00000u,
|
|
1924
|
+
0x41C00000u, 0x41D00000u, 0x41E00000u, 0x41F00000u, /* [ 88.. 95] */
|
|
1925
|
+
0x42000000u, 0x42100000u, 0x42200000u, 0x42300000u,
|
|
1926
|
+
0x42400000u, 0x42500000u, 0x42600000u, 0x42700000u, /* [ 96..103] */
|
|
1927
|
+
0x42800000u, 0x42900000u, 0x42A00000u, 0x42B00000u,
|
|
1928
|
+
0x42C00000u, 0x42D00000u, 0x42E00000u, 0x42F00000u, /* [104..111] */
|
|
1929
|
+
0x43000000u, 0x43100000u, 0x43200000u, 0x43300000u,
|
|
1930
|
+
0x43400000u, 0x43500000u, 0x43600000u, 0x43700000u, /* [112..119] */
|
|
1931
|
+
0x43800000u, 0x43900000u, 0x43A00000u, 0x43B00000u,
|
|
1932
|
+
0x43C00000u, 0x43D00000u, 0x43E00000u, 0x7FC00000u /* [120..127] */
|
|
1933
|
+
};
|
|
1934
|
+
|
|
1935
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
1936
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
1937
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1938
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1939
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1940
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
1941
|
+
column_count * sizeof(nk_f32_t); // per-column norms
|
|
1942
|
+
}
|
|
1943
|
+
|
|
1944
|
+
/**
|
|
1945
|
+
* @brief Pack B matrix from e4m3 to f32 for floating-point dot product.
|
|
1946
|
+
*
|
|
1947
|
+
* Each e4m3 byte is converted to f32 via `nk_e4m3_to_f32_serial`.
|
|
1948
|
+
* Padding values are zeroed. Column-panel layout with depth-contiguous storage.
|
|
1949
|
+
*/
|
|
1950
|
+
NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
1951
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1952
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
1953
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
1954
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
1955
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
1956
|
+
|
|
1957
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
1958
|
+
header->column_count = (nk_u32_t)column_count;
|
|
1959
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
1960
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
1961
|
+
|
|
1962
|
+
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
1963
|
+
nk_size_t total = column_count * depth_padded;
|
|
1964
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
1965
|
+
|
|
1966
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1967
|
+
nk_e4m3_t const *src = (nk_e4m3_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1968
|
+
nk_f32_t *dst = packed + column * depth_padded;
|
|
1969
|
+
for (nk_size_t k = 0; k < depth; ++k) nk_e4m3_to_f32_serial(&src[k], &dst[k]);
|
|
1970
|
+
}
|
|
1971
|
+
|
|
1972
|
+
// Append per-column norms after packed data
|
|
1973
|
+
nk_f32_t *norms = (nk_f32_t *)(packed + total);
|
|
1974
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
1975
|
+
nk_e4m3_t const *src = (nk_e4m3_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
1976
|
+
norms[column] = nk_dots_reduce_sumsq_e4m3_(src, depth);
|
|
1977
|
+
}
|
|
1978
|
+
}
|
|
1979
|
+
|
|
1980
|
+
/**
|
|
1981
|
+
* @brief e4m3 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
|
|
1982
|
+
*
|
|
1983
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
1984
|
+
* - Load pre-packed f32 values from B
|
|
1985
|
+
* - Load raw e4m3 bytes from A, convert on-the-fly via 128-entry f32 LUT gather:
|
|
1986
|
+
* extract 7-bit magnitude, zero-extend to u32, compute byte offsets (x4),
|
|
1987
|
+
* gather f32 bit patterns, inject sign bit from bit 7 (<<24), reinterpret as f32
|
|
1988
|
+
* - Widening FMA: f32xf32 -> f64 via `vfwmacc_vv_f64m4`
|
|
1989
|
+
*
|
|
1990
|
+
* Register tile: process 2 rows per iteration (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
|
|
1991
|
+
*/
|
|
1992
|
+
NK_INTERNAL void nk_dots_packed_e4m3_rvv_aligned_(nk_e4m3_t const *a_matrix, void const *b_packed_buffer,
|
|
1993
|
+
nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
1994
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1995
|
+
nk_size_t c_stride_in_bytes) {
|
|
1996
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
1997
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
1998
|
+
nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
|
|
1999
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
2000
|
+
|
|
2001
|
+
// Zero output matrix
|
|
2002
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
2003
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
2004
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
2005
|
+
}
|
|
2006
|
+
|
|
2007
|
+
// mr=2 register tile over rows
|
|
2008
|
+
nk_size_t row = 0;
|
|
2009
|
+
for (; row + 2 <= row_count; row += 2) {
|
|
2010
|
+
nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
2011
|
+
nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
2012
|
+
nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
2013
|
+
nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
2014
|
+
|
|
2015
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2016
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
2017
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
2018
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2019
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2020
|
+
|
|
2021
|
+
nk_size_t remaining = depth;
|
|
2022
|
+
nk_size_t k = 0;
|
|
2023
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
2024
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
2025
|
+
|
|
2026
|
+
// Load pre-packed f32 B values
|
|
2027
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
2028
|
+
|
|
2029
|
+
// Load raw e4m3 bytes from each A row
|
|
2030
|
+
vuint8mf2_t raw0_u8mf2 = __riscv_vle8_v_u8mf2(a_row_0 + k, vector_length);
|
|
2031
|
+
vuint8mf2_t raw1_u8mf2 = __riscv_vle8_v_u8mf2(a_row_1 + k, vector_length);
|
|
2032
|
+
|
|
2033
|
+
// Extract 7-bit magnitudes, zero-extend to u32, compute byte offsets for f32 LUT
|
|
2034
|
+
vuint8mf2_t mag0_u8mf2 = __riscv_vand_vx_u8mf2(raw0_u8mf2, 0x7F, vector_length);
|
|
2035
|
+
vuint8mf2_t mag1_u8mf2 = __riscv_vand_vx_u8mf2(raw1_u8mf2, 0x7F, vector_length);
|
|
2036
|
+
vuint32m2_t idx0_u32m2 = __riscv_vzext_vf4_u32m2(mag0_u8mf2, vector_length);
|
|
2037
|
+
vuint32m2_t idx1_u32m2 = __riscv_vzext_vf4_u32m2(mag1_u8mf2, vector_length);
|
|
2038
|
+
vuint32m2_t off0_u32m2 = __riscv_vsll_vx_u32m2(idx0_u32m2, 2,
|
|
2039
|
+
vector_length); // byte offsets = index * 4
|
|
2040
|
+
vuint32m2_t off1_u32m2 = __riscv_vsll_vx_u32m2(idx1_u32m2, 2, vector_length);
|
|
2041
|
+
|
|
2042
|
+
// Gather f32 bit patterns from magnitude LUT
|
|
2043
|
+
vuint32m2_t bits0_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off0_u32m2,
|
|
2044
|
+
vector_length);
|
|
2045
|
+
vuint32m2_t bits1_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off1_u32m2,
|
|
2046
|
+
vector_length);
|
|
2047
|
+
|
|
2048
|
+
// Extract sign bit 7, shift to f32 sign position (bit 31)
|
|
2049
|
+
vuint8mf2_t sign0_u8mf2 = __riscv_vand_vx_u8mf2(raw0_u8mf2, 0x80, vector_length);
|
|
2050
|
+
vuint8mf2_t sign1_u8mf2 = __riscv_vand_vx_u8mf2(raw1_u8mf2, 0x80, vector_length);
|
|
2051
|
+
vuint32m2_t sign0_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign0_u8mf2, vector_length), 24,
|
|
2052
|
+
vector_length);
|
|
2053
|
+
vuint32m2_t sign1_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign1_u8mf2, vector_length), 24,
|
|
2054
|
+
vector_length);
|
|
2055
|
+
|
|
2056
|
+
// Apply sign and reinterpret as f32
|
|
2057
|
+
vfloat32m2_t a_vector_0_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2058
|
+
__riscv_vor_vv_u32m2(bits0_u32m2, sign0_u32m2, vector_length));
|
|
2059
|
+
vfloat32m2_t a_vector_1_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2060
|
+
__riscv_vor_vv_u32m2(bits1_u32m2, sign1_u32m2, vector_length));
|
|
2061
|
+
|
|
2062
|
+
// Widening FMA: f32xf32 -> f64
|
|
2063
|
+
accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
|
|
2064
|
+
vector_length);
|
|
2065
|
+
accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
|
|
2066
|
+
vector_length);
|
|
2067
|
+
}
|
|
2068
|
+
|
|
2069
|
+
// Horizontal reduce and narrow to f32
|
|
2070
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2071
|
+
c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2072
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
|
|
2073
|
+
c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2074
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
|
|
2075
|
+
}
|
|
2076
|
+
}
|
|
2077
|
+
// Remainder rows
|
|
2078
|
+
for (; row < row_count; ++row) {
|
|
2079
|
+
nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
2080
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
2081
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2082
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
2083
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
2084
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2085
|
+
nk_size_t remaining = depth;
|
|
2086
|
+
nk_size_t k = 0;
|
|
2087
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
2088
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
2089
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
2090
|
+
vuint8mf2_t raw_a_u8mf2 = __riscv_vle8_v_u8mf2(a_row + k, vector_length);
|
|
2091
|
+
vuint8mf2_t mag_a_u8mf2 = __riscv_vand_vx_u8mf2(raw_a_u8mf2, 0x7F, vector_length);
|
|
2092
|
+
vuint32m2_t idx_a_u32m2 = __riscv_vzext_vf4_u32m2(mag_a_u8mf2, vector_length);
|
|
2093
|
+
vuint32m2_t off_a_u32m2 = __riscv_vsll_vx_u32m2(idx_a_u32m2, 2, vector_length);
|
|
2094
|
+
vuint32m2_t bits_a_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off_a_u32m2,
|
|
2095
|
+
vector_length);
|
|
2096
|
+
vuint8mf2_t sign_a_u8mf2 = __riscv_vand_vx_u8mf2(raw_a_u8mf2, 0x80, vector_length);
|
|
2097
|
+
vuint32m2_t sign_a_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_a_u8mf2, vector_length),
|
|
2098
|
+
24, vector_length);
|
|
2099
|
+
vfloat32m2_t a_vector_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2100
|
+
__riscv_vor_vv_u32m2(bits_a_u32m2, sign_a_u32m2, vector_length));
|
|
2101
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
|
|
2102
|
+
vector_length);
|
|
2103
|
+
}
|
|
2104
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2105
|
+
c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2106
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
2107
|
+
}
|
|
2108
|
+
}
|
|
2109
|
+
}
|
|
2110
|
+
|
|
2111
|
+
/**
|
|
2112
|
+
* @brief Public e4m3 packed GEMM wrapper matching the declared signature in dots.h.
|
|
2113
|
+
*/
|
|
2114
|
+
NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
|
|
2115
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2116
|
+
nk_dots_packed_e4m3_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
2117
|
+
}
|
|
2118
|
+
|
|
2119
|
+
/**
|
|
2120
|
+
* @brief Symmetric e4m3 GEMM: C = A * A^T, upper triangle + mirror.
|
|
2121
|
+
*
|
|
2122
|
+
* Uses f32 LUT gather with f64 widened accumulation for precision.
|
|
2123
|
+
* Both operands are converted from e4m3 on-the-fly via magnitude LUT.
|
|
2124
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
2125
|
+
*/
|
|
2126
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
2127
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
2128
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
2129
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
2130
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
2131
|
+
|
|
2132
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
2133
|
+
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
|
|
2134
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
2135
|
+
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
|
|
2136
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
2137
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2138
|
+
nk_size_t remaining = depth;
|
|
2139
|
+
nk_size_t k = 0;
|
|
2140
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
2141
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
2142
|
+
vuint8mf2_t raw_i_u8mf2 = __riscv_vle8_v_u8mf2(a_i + k, vector_length);
|
|
2143
|
+
vuint8mf2_t raw_j_u8mf2 = __riscv_vle8_v_u8mf2(a_j + k, vector_length);
|
|
2144
|
+
|
|
2145
|
+
// Convert i-vector via LUT gather
|
|
2146
|
+
vuint8mf2_t mag_i_u8mf2 = __riscv_vand_vx_u8mf2(raw_i_u8mf2, 0x7F, vector_length);
|
|
2147
|
+
vuint32m2_t idx_i_u32m2 = __riscv_vzext_vf4_u32m2(mag_i_u8mf2, vector_length);
|
|
2148
|
+
vuint32m2_t off_i_u32m2 = __riscv_vsll_vx_u32m2(idx_i_u32m2, 2, vector_length);
|
|
2149
|
+
vuint32m2_t bits_i_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off_i_u32m2,
|
|
2150
|
+
vector_length);
|
|
2151
|
+
vuint8mf2_t sign_i_u8mf2 = __riscv_vand_vx_u8mf2(raw_i_u8mf2, 0x80, vector_length);
|
|
2152
|
+
vuint32m2_t sign_i_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_i_u8mf2, vector_length),
|
|
2153
|
+
24, vector_length);
|
|
2154
|
+
vfloat32m2_t val_i_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2155
|
+
__riscv_vor_vv_u32m2(bits_i_u32m2, sign_i_u32m2, vector_length));
|
|
2156
|
+
|
|
2157
|
+
// Convert j-vector via LUT gather
|
|
2158
|
+
vuint8mf2_t mag_j_u8mf2 = __riscv_vand_vx_u8mf2(raw_j_u8mf2, 0x7F, vector_length);
|
|
2159
|
+
vuint32m2_t idx_j_u32m2 = __riscv_vzext_vf4_u32m2(mag_j_u8mf2, vector_length);
|
|
2160
|
+
vuint32m2_t off_j_u32m2 = __riscv_vsll_vx_u32m2(idx_j_u32m2, 2, vector_length);
|
|
2161
|
+
vuint32m2_t bits_j_u32m2 = __riscv_vluxei32_v_u32m2(nk_e4m3_magnitude_lut_rvv_, off_j_u32m2,
|
|
2162
|
+
vector_length);
|
|
2163
|
+
vuint8mf2_t sign_j_u8mf2 = __riscv_vand_vx_u8mf2(raw_j_u8mf2, 0x80, vector_length);
|
|
2164
|
+
vuint32m2_t sign_j_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_j_u8mf2, vector_length),
|
|
2165
|
+
24, vector_length);
|
|
2166
|
+
vfloat32m2_t val_j_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2167
|
+
__riscv_vor_vv_u32m2(bits_j_u32m2, sign_j_u32m2, vector_length));
|
|
2168
|
+
|
|
2169
|
+
// Widening FMA: f32xf32 -> f64
|
|
2170
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, val_i_f32m2, val_j_f32m2,
|
|
2171
|
+
vector_length);
|
|
2172
|
+
}
|
|
2173
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2174
|
+
nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2175
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
2176
|
+
result[i * result_stride_elements + j] = dot;
|
|
2177
|
+
}
|
|
2178
|
+
}
|
|
2179
|
+
}
|
|
2180
|
+
|
|
2181
|
+
#pragma endregion // Quarter Precision E4M3
|
|
2182
|
+
|
|
2183
|
+
#pragma region Quarter Precision E5M2
|
|
2184
|
+
|
|
2185
|
+
/**
|
|
2186
|
+
* @brief E5M2 magnitude LUT: 7-bit magnitude -> f32 bit pattern (u32).
|
|
2187
|
+
* nk_e5m2_magnitude_lut_rvv_[i] = float_to_bits(e5m2_to_f32(i)) for i=0..127.
|
|
2188
|
+
* E5M2: 5 exponent bits (bias=15), 2 mantissa bits, has infinity (0x7C) and
|
|
2189
|
+
* NaN (magnitudes 0x7D..0x7F).
|
|
2190
|
+
*/
|
|
2191
|
+
static nk_u32_t const nk_e5m2_magnitude_lut_rvv_[128] = {
|
|
2192
|
+
0x00000000u, 0x37800000u, 0x38000000u, 0x38400000u,
|
|
2193
|
+
0x38800000u, 0x38A00000u, 0x38C00000u, 0x38E00000u, /* [ 0.. 7] */
|
|
2194
|
+
0x39000000u, 0x39200000u, 0x39400000u, 0x39600000u,
|
|
2195
|
+
0x39800000u, 0x39A00000u, 0x39C00000u, 0x39E00000u, /* [ 8.. 15] */
|
|
2196
|
+
0x3A000000u, 0x3A200000u, 0x3A400000u, 0x3A600000u,
|
|
2197
|
+
0x3A800000u, 0x3AA00000u, 0x3AC00000u, 0x3AE00000u, /* [ 16.. 23] */
|
|
2198
|
+
0x3B000000u, 0x3B200000u, 0x3B400000u, 0x3B600000u,
|
|
2199
|
+
0x3B800000u, 0x3BA00000u, 0x3BC00000u, 0x3BE00000u, /* [ 24.. 31] */
|
|
2200
|
+
0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u,
|
|
2201
|
+
0x3C800000u, 0x3CA00000u, 0x3CC00000u, 0x3CE00000u, /* [ 32.. 39] */
|
|
2202
|
+
0x3D000000u, 0x3D200000u, 0x3D400000u, 0x3D600000u,
|
|
2203
|
+
0x3D800000u, 0x3DA00000u, 0x3DC00000u, 0x3DE00000u, /* [ 40.. 47] */
|
|
2204
|
+
0x3E000000u, 0x3E200000u, 0x3E400000u, 0x3E600000u,
|
|
2205
|
+
0x3E800000u, 0x3EA00000u, 0x3EC00000u, 0x3EE00000u, /* [ 48.. 55] */
|
|
2206
|
+
0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u,
|
|
2207
|
+
0x3F800000u, 0x3FA00000u, 0x3FC00000u, 0x3FE00000u, /* [ 56.. 63] */
|
|
2208
|
+
0x40000000u, 0x40200000u, 0x40400000u, 0x40600000u,
|
|
2209
|
+
0x40800000u, 0x40A00000u, 0x40C00000u, 0x40E00000u, /* [ 64.. 71] */
|
|
2210
|
+
0x41000000u, 0x41200000u, 0x41400000u, 0x41600000u,
|
|
2211
|
+
0x41800000u, 0x41A00000u, 0x41C00000u, 0x41E00000u, /* [ 72.. 79] */
|
|
2212
|
+
0x42000000u, 0x42200000u, 0x42400000u, 0x42600000u,
|
|
2213
|
+
0x42800000u, 0x42A00000u, 0x42C00000u, 0x42E00000u, /* [ 80.. 87] */
|
|
2214
|
+
0x43000000u, 0x43200000u, 0x43400000u, 0x43600000u,
|
|
2215
|
+
0x43800000u, 0x43A00000u, 0x43C00000u, 0x43E00000u, /* [ 88.. 95] */
|
|
2216
|
+
0x44000000u, 0x44200000u, 0x44400000u, 0x44600000u,
|
|
2217
|
+
0x44800000u, 0x44A00000u, 0x44C00000u, 0x44E00000u, /* [ 96..103] */
|
|
2218
|
+
0x45000000u, 0x45200000u, 0x45400000u, 0x45600000u,
|
|
2219
|
+
0x45800000u, 0x45A00000u, 0x45C00000u, 0x45E00000u, /* [104..111] */
|
|
2220
|
+
0x46000000u, 0x46200000u, 0x46400000u, 0x46600000u,
|
|
2221
|
+
0x46800000u, 0x46A00000u, 0x46C00000u, 0x46E00000u, /* [112..119] */
|
|
2222
|
+
0x47000000u, 0x47200000u, 0x47400000u, 0x47600000u,
|
|
2223
|
+
0x7F800000u, 0x7FC00000u, 0x7FC00000u, 0x7FC00000u /* [120..127] */
|
|
2224
|
+
};
|
|
2225
|
+
|
|
2226
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_rvv(nk_size_t column_count, nk_size_t depth) {
|
|
2227
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
2228
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
2229
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
2230
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
2231
|
+
return sizeof(nk_cross_packed_buffer_header_t) + column_count * depth_padded * sizeof(nk_f32_t) +
|
|
2232
|
+
column_count * sizeof(nk_f32_t); // per-column norms
|
|
2233
|
+
}
|
|
2234
|
+
|
|
2235
|
+
/**
|
|
2236
|
+
* @brief Pack B matrix from e5m2 to f32 for floating-point dot product.
|
|
2237
|
+
*
|
|
2238
|
+
* Each e5m2 byte is converted to f32 via `nk_e5m2_to_f32_serial`.
|
|
2239
|
+
* Padding values are zeroed. Column-panel layout with depth-contiguous storage.
|
|
2240
|
+
*/
|
|
2241
|
+
NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth,
|
|
2242
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
2243
|
+
nk_size_t vector_length = __riscv_vsetvlmax_e32m2();
|
|
2244
|
+
nk_size_t depth_padded = nk_size_round_up_to_multiple_(depth, vector_length);
|
|
2245
|
+
nk_size_t stride_bytes = depth_padded * sizeof(nk_f32_t);
|
|
2246
|
+
if (stride_bytes > 0 && (stride_bytes & (stride_bytes - 1)) == 0) depth_padded += vector_length;
|
|
2247
|
+
|
|
2248
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed;
|
|
2249
|
+
header->column_count = (nk_u32_t)column_count;
|
|
2250
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
2251
|
+
header->depth_padded_values = (nk_u32_t)depth_padded;
|
|
2252
|
+
|
|
2253
|
+
nk_f32_t *packed = (nk_f32_t *)((char *)b_packed + sizeof(nk_cross_packed_buffer_header_t));
|
|
2254
|
+
nk_size_t total = column_count * depth_padded;
|
|
2255
|
+
for (nk_size_t i = 0; i < total; ++i) packed[i] = 0;
|
|
2256
|
+
|
|
2257
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2258
|
+
nk_e5m2_t const *src = (nk_e5m2_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
2259
|
+
nk_f32_t *dst = packed + column * depth_padded;
|
|
2260
|
+
for (nk_size_t k = 0; k < depth; ++k) nk_e5m2_to_f32_serial(&src[k], &dst[k]);
|
|
2261
|
+
}
|
|
2262
|
+
|
|
2263
|
+
// Append per-column norms after packed data
|
|
2264
|
+
nk_f32_t *norms = (nk_f32_t *)(packed + total);
|
|
2265
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2266
|
+
nk_e5m2_t const *src = (nk_e5m2_t const *)((char const *)b + column * b_stride_in_bytes);
|
|
2267
|
+
norms[column] = nk_dots_reduce_sumsq_e5m2_(src, depth);
|
|
2268
|
+
}
|
|
2269
|
+
}
|
|
2270
|
+
|
|
2271
|
+
/**
|
|
2272
|
+
* @brief e5m2 packed GEMM kernel: C += A * B_packed^T with f64 widened accumulation.
|
|
2273
|
+
*
|
|
2274
|
+
* Vectorizes over the depth dimension (k). For each (row, column) pair:
|
|
2275
|
+
* - Load pre-packed f32 values from B
|
|
2276
|
+
* - Load raw e5m2 bytes from A, convert on-the-fly via 128-entry f32 LUT gather:
|
|
2277
|
+
* extract 7-bit magnitude, zero-extend to u32, compute byte offsets (x4),
|
|
2278
|
+
* gather f32 bit patterns, inject sign bit from bit 7 (<<24), reinterpret as f32
|
|
2279
|
+
* - Widening FMA: f32xf32 -> f64 via `vfwmacc_vv_f64m4`
|
|
2280
|
+
*
|
|
2281
|
+
* Register tile: process 2 rows per iteration (rows_per_tile=2, u32m2 gather + f64m4 accumulator is register-heavy).
|
|
2282
|
+
*/
|
|
2283
|
+
NK_INTERNAL void nk_dots_packed_e5m2_rvv_aligned_(nk_e5m2_t const *a_matrix, void const *b_packed_buffer,
|
|
2284
|
+
nk_f32_t *c_matrix, nk_size_t row_count, nk_size_t column_count,
|
|
2285
|
+
nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
2286
|
+
nk_size_t c_stride_in_bytes) {
|
|
2287
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer;
|
|
2288
|
+
nk_size_t const depth_padded = header->depth_padded_values;
|
|
2289
|
+
nk_f32_t const *packed_data = (nk_f32_t const *)((char const *)b_packed_buffer +
|
|
2290
|
+
sizeof(nk_cross_packed_buffer_header_t));
|
|
2291
|
+
|
|
2292
|
+
// Zero output matrix
|
|
2293
|
+
for (nk_size_t i = 0; i < row_count; ++i) {
|
|
2294
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + i * c_stride_in_bytes);
|
|
2295
|
+
for (nk_size_t j = 0; j < column_count; ++j) c_row[j] = 0;
|
|
2296
|
+
}
|
|
2297
|
+
|
|
2298
|
+
// mr=2 register tile over rows
|
|
2299
|
+
nk_size_t row = 0;
|
|
2300
|
+
for (; row + 2 <= row_count; row += 2) {
|
|
2301
|
+
nk_u8_t const *a_row_0 = (nk_u8_t const *)((char const *)a_matrix + (row + 0) * a_stride_in_bytes);
|
|
2302
|
+
nk_u8_t const *a_row_1 = (nk_u8_t const *)((char const *)a_matrix + (row + 1) * a_stride_in_bytes);
|
|
2303
|
+
nk_f32_t *c_row_0 = (nk_f32_t *)((char *)c_matrix + (row + 0) * c_stride_in_bytes);
|
|
2304
|
+
nk_f32_t *c_row_1 = (nk_f32_t *)((char *)c_matrix + (row + 1) * c_stride_in_bytes);
|
|
2305
|
+
|
|
2306
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2307
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
2308
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
2309
|
+
vfloat64m4_t accumulator_0_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2310
|
+
vfloat64m4_t accumulator_1_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2311
|
+
|
|
2312
|
+
nk_size_t remaining = depth;
|
|
2313
|
+
nk_size_t k = 0;
|
|
2314
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
2315
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
2316
|
+
|
|
2317
|
+
// Load pre-packed f32 B values
|
|
2318
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
2319
|
+
|
|
2320
|
+
// Load raw e5m2 bytes from each A row
|
|
2321
|
+
vuint8mf2_t raw0_u8mf2 = __riscv_vle8_v_u8mf2(a_row_0 + k, vector_length);
|
|
2322
|
+
vuint8mf2_t raw1_u8mf2 = __riscv_vle8_v_u8mf2(a_row_1 + k, vector_length);
|
|
2323
|
+
|
|
2324
|
+
// Extract 7-bit magnitudes, zero-extend to u32, compute byte offsets for f32 LUT
|
|
2325
|
+
vuint8mf2_t mag0_u8mf2 = __riscv_vand_vx_u8mf2(raw0_u8mf2, 0x7F, vector_length);
|
|
2326
|
+
vuint8mf2_t mag1_u8mf2 = __riscv_vand_vx_u8mf2(raw1_u8mf2, 0x7F, vector_length);
|
|
2327
|
+
vuint32m2_t idx0_u32m2 = __riscv_vzext_vf4_u32m2(mag0_u8mf2, vector_length);
|
|
2328
|
+
vuint32m2_t idx1_u32m2 = __riscv_vzext_vf4_u32m2(mag1_u8mf2, vector_length);
|
|
2329
|
+
vuint32m2_t off0_u32m2 = __riscv_vsll_vx_u32m2(idx0_u32m2, 2,
|
|
2330
|
+
vector_length); // byte offsets = index * 4
|
|
2331
|
+
vuint32m2_t off1_u32m2 = __riscv_vsll_vx_u32m2(idx1_u32m2, 2, vector_length);
|
|
2332
|
+
|
|
2333
|
+
// Gather f32 bit patterns from magnitude LUT
|
|
2334
|
+
vuint32m2_t bits0_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off0_u32m2,
|
|
2335
|
+
vector_length);
|
|
2336
|
+
vuint32m2_t bits1_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off1_u32m2,
|
|
2337
|
+
vector_length);
|
|
2338
|
+
|
|
2339
|
+
// Extract sign bit 7, shift to f32 sign position (bit 31)
|
|
2340
|
+
vuint8mf2_t sign0_u8mf2 = __riscv_vand_vx_u8mf2(raw0_u8mf2, 0x80, vector_length);
|
|
2341
|
+
vuint8mf2_t sign1_u8mf2 = __riscv_vand_vx_u8mf2(raw1_u8mf2, 0x80, vector_length);
|
|
2342
|
+
vuint32m2_t sign0_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign0_u8mf2, vector_length), 24,
|
|
2343
|
+
vector_length);
|
|
2344
|
+
vuint32m2_t sign1_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign1_u8mf2, vector_length), 24,
|
|
2345
|
+
vector_length);
|
|
2346
|
+
|
|
2347
|
+
// Apply sign and reinterpret as f32
|
|
2348
|
+
vfloat32m2_t a_vector_0_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2349
|
+
__riscv_vor_vv_u32m2(bits0_u32m2, sign0_u32m2, vector_length));
|
|
2350
|
+
vfloat32m2_t a_vector_1_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2351
|
+
__riscv_vor_vv_u32m2(bits1_u32m2, sign1_u32m2, vector_length));
|
|
2352
|
+
|
|
2353
|
+
// Widening FMA: f32xf32 -> f64
|
|
2354
|
+
accumulator_0_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_0_f64m4, a_vector_0_f32m2, b_vector_f32m2,
|
|
2355
|
+
vector_length);
|
|
2356
|
+
accumulator_1_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_1_f64m4, a_vector_1_f32m2, b_vector_f32m2,
|
|
2357
|
+
vector_length);
|
|
2358
|
+
}
|
|
2359
|
+
|
|
2360
|
+
// Horizontal reduce and narrow to f32
|
|
2361
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2362
|
+
c_row_0[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2363
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_0_f64m4, zero_f64m1, vlmax));
|
|
2364
|
+
c_row_1[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2365
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_1_f64m4, zero_f64m1, vlmax));
|
|
2366
|
+
}
|
|
2367
|
+
}
|
|
2368
|
+
// Remainder rows
|
|
2369
|
+
for (; row < row_count; ++row) {
|
|
2370
|
+
nk_u8_t const *a_row = (nk_u8_t const *)((char const *)a_matrix + row * a_stride_in_bytes);
|
|
2371
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c_matrix + row * c_stride_in_bytes);
|
|
2372
|
+
for (nk_size_t column = 0; column < column_count; ++column) {
|
|
2373
|
+
nk_f32_t const *b_column = packed_data + column * depth_padded;
|
|
2374
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
2375
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2376
|
+
nk_size_t remaining = depth;
|
|
2377
|
+
nk_size_t k = 0;
|
|
2378
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
2379
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
2380
|
+
vfloat32m2_t b_vector_f32m2 = __riscv_vle32_v_f32m2(b_column + k, vector_length);
|
|
2381
|
+
vuint8mf2_t raw_a_u8mf2 = __riscv_vle8_v_u8mf2(a_row + k, vector_length);
|
|
2382
|
+
vuint8mf2_t mag_a_u8mf2 = __riscv_vand_vx_u8mf2(raw_a_u8mf2, 0x7F, vector_length);
|
|
2383
|
+
vuint32m2_t idx_a_u32m2 = __riscv_vzext_vf4_u32m2(mag_a_u8mf2, vector_length);
|
|
2384
|
+
vuint32m2_t off_a_u32m2 = __riscv_vsll_vx_u32m2(idx_a_u32m2, 2, vector_length);
|
|
2385
|
+
vuint32m2_t bits_a_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off_a_u32m2,
|
|
2386
|
+
vector_length);
|
|
2387
|
+
vuint8mf2_t sign_a_u8mf2 = __riscv_vand_vx_u8mf2(raw_a_u8mf2, 0x80, vector_length);
|
|
2388
|
+
vuint32m2_t sign_a_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_a_u8mf2, vector_length),
|
|
2389
|
+
24, vector_length);
|
|
2390
|
+
vfloat32m2_t a_vector_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2391
|
+
__riscv_vor_vv_u32m2(bits_a_u32m2, sign_a_u32m2, vector_length));
|
|
2392
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, a_vector_f32m2, b_vector_f32m2,
|
|
2393
|
+
vector_length);
|
|
2394
|
+
}
|
|
2395
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2396
|
+
c_row[column] = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2397
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
2398
|
+
}
|
|
2399
|
+
}
|
|
2400
|
+
}
|
|
2401
|
+
|
|
2402
|
+
/**
|
|
2403
|
+
* @brief Public e5m2 packed GEMM wrapper matching the declared signature in dots.h.
|
|
2404
|
+
*/
|
|
2405
|
+
NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t m, nk_size_t n,
|
|
2406
|
+
nk_size_t k, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2407
|
+
nk_dots_packed_e5m2_rvv_aligned_(a, b_packed, c, m, n, k, a_stride, c_stride);
|
|
2408
|
+
}
|
|
2409
|
+
|
|
2410
|
+
/**
|
|
2411
|
+
* @brief Symmetric e5m2 GEMM: C = A * A^T, upper triangle + mirror.
|
|
2412
|
+
*
|
|
2413
|
+
* Uses f32 LUT gather with f64 widened accumulation for precision.
|
|
2414
|
+
* Both operands are converted from e5m2 on-the-fly via magnitude LUT.
|
|
2415
|
+
* Processes only the rows in [row_start, row_start + row_count) for parallelism.
|
|
2416
|
+
*/
|
|
2417
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
2418
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
2419
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
2420
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
2421
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors;
|
|
2422
|
+
|
|
2423
|
+
for (nk_size_t i = row_start; i < row_end; ++i) {
|
|
2424
|
+
nk_u8_t const *a_i = (nk_u8_t const *)vectors + i * stride;
|
|
2425
|
+
for (nk_size_t j = i; j < n_vectors; ++j) {
|
|
2426
|
+
nk_u8_t const *a_j = (nk_u8_t const *)vectors + j * stride;
|
|
2427
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
2428
|
+
vfloat64m4_t accumulator_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
2429
|
+
nk_size_t remaining = depth;
|
|
2430
|
+
nk_size_t k = 0;
|
|
2431
|
+
for (nk_size_t vector_length = 0; remaining > 0; remaining -= vector_length, k += vector_length) {
|
|
2432
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
2433
|
+
vuint8mf2_t raw_i_u8mf2 = __riscv_vle8_v_u8mf2(a_i + k, vector_length);
|
|
2434
|
+
vuint8mf2_t raw_j_u8mf2 = __riscv_vle8_v_u8mf2(a_j + k, vector_length);
|
|
2435
|
+
|
|
2436
|
+
// Convert i-vector via LUT gather
|
|
2437
|
+
vuint8mf2_t mag_i_u8mf2 = __riscv_vand_vx_u8mf2(raw_i_u8mf2, 0x7F, vector_length);
|
|
2438
|
+
vuint32m2_t idx_i_u32m2 = __riscv_vzext_vf4_u32m2(mag_i_u8mf2, vector_length);
|
|
2439
|
+
vuint32m2_t off_i_u32m2 = __riscv_vsll_vx_u32m2(idx_i_u32m2, 2, vector_length);
|
|
2440
|
+
vuint32m2_t bits_i_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off_i_u32m2,
|
|
2441
|
+
vector_length);
|
|
2442
|
+
vuint8mf2_t sign_i_u8mf2 = __riscv_vand_vx_u8mf2(raw_i_u8mf2, 0x80, vector_length);
|
|
2443
|
+
vuint32m2_t sign_i_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_i_u8mf2, vector_length),
|
|
2444
|
+
24, vector_length);
|
|
2445
|
+
vfloat32m2_t val_i_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2446
|
+
__riscv_vor_vv_u32m2(bits_i_u32m2, sign_i_u32m2, vector_length));
|
|
2447
|
+
|
|
2448
|
+
// Convert j-vector via LUT gather
|
|
2449
|
+
vuint8mf2_t mag_j_u8mf2 = __riscv_vand_vx_u8mf2(raw_j_u8mf2, 0x7F, vector_length);
|
|
2450
|
+
vuint32m2_t idx_j_u32m2 = __riscv_vzext_vf4_u32m2(mag_j_u8mf2, vector_length);
|
|
2451
|
+
vuint32m2_t off_j_u32m2 = __riscv_vsll_vx_u32m2(idx_j_u32m2, 2, vector_length);
|
|
2452
|
+
vuint32m2_t bits_j_u32m2 = __riscv_vluxei32_v_u32m2(nk_e5m2_magnitude_lut_rvv_, off_j_u32m2,
|
|
2453
|
+
vector_length);
|
|
2454
|
+
vuint8mf2_t sign_j_u8mf2 = __riscv_vand_vx_u8mf2(raw_j_u8mf2, 0x80, vector_length);
|
|
2455
|
+
vuint32m2_t sign_j_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vzext_vf4_u32m2(sign_j_u8mf2, vector_length),
|
|
2456
|
+
24, vector_length);
|
|
2457
|
+
vfloat32m2_t val_j_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
2458
|
+
__riscv_vor_vv_u32m2(bits_j_u32m2, sign_j_u32m2, vector_length));
|
|
2459
|
+
|
|
2460
|
+
// Widening FMA: f32xf32 -> f64
|
|
2461
|
+
accumulator_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(accumulator_f64m4, val_i_f32m2, val_j_f32m2,
|
|
2462
|
+
vector_length);
|
|
2463
|
+
}
|
|
2464
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
2465
|
+
nk_f32_t dot = (nk_f32_t)__riscv_vfmv_f_s_f64m1_f64(
|
|
2466
|
+
__riscv_vfredusum_vs_f64m4_f64m1(accumulator_f64m4, zero_f64m1, vlmax));
|
|
2467
|
+
result[i * result_stride_elements + j] = dot;
|
|
2468
|
+
}
|
|
2469
|
+
}
|
|
2470
|
+
}
|
|
2471
|
+
|
|
2472
|
+
#pragma endregion // Quarter Precision E5M2
|
|
2473
|
+
|
|
2474
|
+
#if defined(__cplusplus)
|
|
2475
|
+
} // extern "C"
|
|
2476
|
+
#endif
|
|
2477
|
+
|
|
2478
|
+
#if defined(__clang__)
|
|
2479
|
+
#pragma clang attribute pop
|
|
2480
|
+
#elif defined(__GNUC__)
|
|
2481
|
+
#pragma GCC pop_options
|
|
2482
|
+
#endif
|
|
2483
|
+
|
|
2484
|
+
#endif // NK_TARGET_RVV
|
|
2485
|
+
#endif // NK_TARGET_RISCV_
|
|
2486
|
+
#endif // NK_DOTS_RVV_H
|