numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -54,7 +54,7 @@
|
|
|
54
54
|
extern "C" {
|
|
55
55
|
#endif
|
|
56
56
|
|
|
57
|
-
#pragma region
|
|
57
|
+
#pragma region Register to Register Helpers
|
|
58
58
|
|
|
59
59
|
/**
|
|
60
60
|
* @brief Convert bf16 (m1) to f32 (m2) register-to-register.
|
|
@@ -90,7 +90,11 @@ NK_INTERNAL vuint16m1_t nk_f32m2_to_bf16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_
|
|
|
90
90
|
* F16 format: S EEEEE MMMMMMMMMM (1 sign, 5 exponent bits with bias=15, 10 mantissa bits)
|
|
91
91
|
* F32 format: S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM (1 sign, 8 exponent bits with bias=127, 23 mantissa bits)
|
|
92
92
|
*
|
|
93
|
-
*
|
|
93
|
+
* Uses the Giesen magic-multiply trick: treat the magnitude bits as a denormal f32 and
|
|
94
|
+
* multiply by 2^112 to rebias the exponent. This correctly handles ±zero, denormals,
|
|
95
|
+
* and normals in a single FP multiply; only inf/NaN needs a fixup compare+merge.
|
|
96
|
+
*
|
|
97
|
+
* https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
|
|
94
98
|
*/
|
|
95
99
|
NK_INTERNAL vfloat32m2_t nk_f16m1_to_f32m2_rvv_(vuint16m1_t f16_u16m1, nk_size_t vector_length) {
|
|
96
100
|
// Widen to 32-bit for manipulation
|
|
@@ -98,45 +102,31 @@ NK_INTERNAL vfloat32m2_t nk_f16m1_to_f32m2_rvv_(vuint16m1_t f16_u16m1, nk_size_t
|
|
|
98
102
|
// Extract sign: (raw >> 15) << 31
|
|
99
103
|
vuint32m2_t sign_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 15, vector_length), 31,
|
|
100
104
|
vector_length);
|
|
101
|
-
//
|
|
102
|
-
|
|
105
|
+
// Strip sign, shift magnitude into f32 mantissa position.
|
|
106
|
+
// For a normal f16 with exp E, this places E into the f32 exponent field,
|
|
107
|
+
// creating a tiny f32 whose value is proportional to the f16 magnitude.
|
|
108
|
+
vuint32m2_t nonsign_u32m2 = __riscv_vand_vx_u32m2(bits_u32m2, 0x7FFF, vector_length);
|
|
109
|
+
vuint32m2_t shifted_u32m2 = __riscv_vsll_vx_u32m2(nonsign_u32m2, 13, vector_length);
|
|
110
|
+
// Multiply by 2^112 (= magic 0x77800000 as f32) to rebias the exponent.
|
|
111
|
+
// This single multiply correctly handles zero, denormals, and normals:
|
|
112
|
+
// zero: 0.0 × 2^112 = 0.0
|
|
113
|
+
// denormal: (M × 2^-136) × 2^112 = M × 2^-24 (correct f16 denormal value)
|
|
114
|
+
// normal: (2^(E-127) × …) × 2^112 = 2^(E-15) × … (correct rebiased value)
|
|
115
|
+
vfloat32m2_t magic_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
116
|
+
__riscv_vmv_v_x_u32m2(((nk_u32_t)(254 - 15) << 23), vector_length));
|
|
117
|
+
vfloat32m2_t result_f32m2 = __riscv_vfmul_vv_f32m2(__riscv_vreinterpret_v_u32m2_f32m2(shifted_u32m2), magic_f32m2,
|
|
103
118
|
vector_length);
|
|
104
|
-
//
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
// Special case: exponent == 0 (zero or denormal)
|
|
116
|
-
// Zero: sign | 0. Denormal: mantissa × 2^(-24), handled via FPU normalization trick.
|
|
117
|
-
// For denormals, convert mantissa to float and subtract 0x0C000000 (24 from exponent),
|
|
118
|
-
// matching the serial implementation. For zeros (mantissa==0), (float)0 - bias = 0.
|
|
119
|
-
vbool16_t is_exp_zero = __riscv_vmseq_vx_u32m2_b16(exponent_u32m2, 0, vector_length);
|
|
120
|
-
vfloat32m2_t mantissa_f32m2 = __riscv_vfcvt_f_xu_v_f32m2(mantissa_u32m2, vector_length);
|
|
121
|
-
vuint32m2_t denorm_bits_u32m2 = __riscv_vsub_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(mantissa_f32m2),
|
|
122
|
-
0x0C000000, vector_length);
|
|
123
|
-
vuint32m2_t zero_or_denorm_u32m2 = __riscv_vor_vv_u32m2(sign_u32m2, denorm_bits_u32m2, vector_length);
|
|
124
|
-
// For true zeros (mantissa==0), the FPU converts 0 to 0x00000000, minus bias wraps,
|
|
125
|
-
// so force to sign-only.
|
|
126
|
-
vbool16_t is_true_zero = __riscv_vmand_mm_b16(
|
|
127
|
-
is_exp_zero, __riscv_vmseq_vx_u32m2_b16(mantissa_u32m2, 0, vector_length), vector_length);
|
|
128
|
-
zero_or_denorm_u32m2 = __riscv_vmerge_vvm_u32m2(zero_or_denorm_u32m2, sign_u32m2, is_true_zero, vector_length);
|
|
129
|
-
|
|
130
|
-
// Special case: exponent == 31 (infinity or NaN)
|
|
131
|
-
// sign | 0x7F800000 | (mantissa << 13)
|
|
132
|
-
vbool16_t is_exp_max = __riscv_vmseq_vx_u32m2_b16(exponent_u32m2, 31, vector_length);
|
|
133
|
-
vuint32m2_t inf_nan_u32m2 = __riscv_vor_vv_u32m2(__riscv_vor_vx_u32m2(sign_u32m2, 0x7F800000, vector_length),
|
|
134
|
-
__riscv_vsll_vx_u32m2(mantissa_u32m2, 13, vector_length),
|
|
135
|
-
vector_length);
|
|
136
|
-
|
|
137
|
-
// Select: exp==0 → zero_or_denorm, exp==31 → inf_nan, else → normal
|
|
138
|
-
vuint32m2_t result_u32m2 = __riscv_vmerge_vvm_u32m2(normal_u32m2, zero_or_denorm_u32m2, is_exp_zero, vector_length);
|
|
139
|
-
result_u32m2 = __riscv_vmerge_vvm_u32m2(result_u32m2, inf_nan_u32m2, is_exp_max, vector_length);
|
|
119
|
+
// Inf/NaN fixup: the multiply maps f16 exp=31 to a large finite f32.
|
|
120
|
+
// Detect those lanes and force the f32 exponent to 255 (inf/NaN).
|
|
121
|
+
// Threshold 0x47800000 = 2^16; any f16 with exp=31 exceeds it after scaling.
|
|
122
|
+
vfloat32m2_t infnan_threshold_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
|
|
123
|
+
__riscv_vmv_v_x_u32m2(((nk_u32_t)(127 + 16) << 23), vector_length));
|
|
124
|
+
vbool16_t is_infnan = __riscv_vmfge_vv_f32m2_b16(result_f32m2, infnan_threshold_f32m2, vector_length);
|
|
125
|
+
vuint32m2_t result_u32m2 = __riscv_vreinterpret_v_f32m2_u32m2(result_f32m2);
|
|
126
|
+
vuint32m2_t fixed_u32m2 = __riscv_vor_vx_u32m2(result_u32m2, 0x7F800000, vector_length);
|
|
127
|
+
result_u32m2 = __riscv_vmerge_vvm_u32m2(result_u32m2, fixed_u32m2, is_infnan, vector_length);
|
|
128
|
+
// Restore sign
|
|
129
|
+
result_u32m2 = __riscv_vor_vv_u32m2(result_u32m2, sign_u32m2, vector_length);
|
|
140
130
|
return __riscv_vreinterpret_v_u32m2_f32m2(result_u32m2);
|
|
141
131
|
}
|
|
142
132
|
|
|
@@ -162,13 +152,8 @@ NK_INTERNAL vuint16m1_t nk_f32m2_to_f16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_t
|
|
|
162
152
|
exponent_i32m2 = __riscv_vmax_vx_i32m2(exponent_i32m2, 0, vector_length);
|
|
163
153
|
vuint32m2_t f16_exponent_u32m2 = __riscv_vreinterpret_v_i32m2_u32m2(
|
|
164
154
|
__riscv_vmin_vx_i32m2(exponent_i32m2, 31, vector_length));
|
|
165
|
-
// Round mantissa: add 0x1000 (half of truncated bits) then shift
|
|
166
|
-
// If rounding overflows the mantissa (bit 23 set), carry into exponent.
|
|
155
|
+
// Round mantissa: add 0x1000 (half of truncated bits) then shift
|
|
167
156
|
vuint32m2_t rounded_mantissa_u32m2 = __riscv_vadd_vx_u32m2(mantissa_u32m2, 0x1000, vector_length);
|
|
168
|
-
vbool16_t mantissa_overflow_b16 = __riscv_vmsne_vx_u32m2_b16(
|
|
169
|
-
__riscv_vand_vx_u32m2(rounded_mantissa_u32m2, 0x800000, vector_length), 0, vector_length);
|
|
170
|
-
f16_exponent_u32m2 = __riscv_vadd_vx_u32m2_mu(mantissa_overflow_b16, f16_exponent_u32m2, f16_exponent_u32m2, 1,
|
|
171
|
-
vector_length);
|
|
172
157
|
vuint32m2_t f16_mantissa_u32m2 = __riscv_vsrl_vx_u32m2(rounded_mantissa_u32m2, 13, vector_length);
|
|
173
158
|
f16_mantissa_u32m2 = __riscv_vand_vx_u32m2(f16_mantissa_u32m2, 0x3FF, vector_length);
|
|
174
159
|
// Combine: sign | (exponent << 10) | mantissa
|
|
@@ -181,242 +166,206 @@ NK_INTERNAL vuint16m1_t nk_f32m2_to_f16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_t
|
|
|
181
166
|
}
|
|
182
167
|
|
|
183
168
|
/**
|
|
184
|
-
* @brief Convert e4m3 (m1) to f32 (m4) via
|
|
185
|
-
*
|
|
169
|
+
* @brief Convert e4m3 (m1) to f32 (m4) via Giesen magic-multiply.
|
|
170
|
+
* Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
|
|
171
|
+
* Handles zero, subnormals, and normals in a single vfmul. NaN fixup for magnitude 0x7F.
|
|
172
|
+
* https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
|
|
186
173
|
*/
|
|
187
174
|
NK_INTERNAL vfloat32m4_t nk_e4m3m1_to_f32m4_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
0x41800000u, 0x41900000u, 0x41A00000u, 0x41B00000u,
|
|
212
|
-
0x41C00000u, 0x41D00000u, 0x41E00000u, 0x41F00000u, /* [ 88.. 95] */
|
|
213
|
-
0x42000000u, 0x42100000u, 0x42200000u, 0x42300000u,
|
|
214
|
-
0x42400000u, 0x42500000u, 0x42600000u, 0x42700000u, /* [ 96..103] */
|
|
215
|
-
0x42800000u, 0x42900000u, 0x42A00000u, 0x42B00000u,
|
|
216
|
-
0x42C00000u, 0x42D00000u, 0x42E00000u, 0x42F00000u, /* [104..111] */
|
|
217
|
-
0x43000000u, 0x43100000u, 0x43200000u, 0x43300000u,
|
|
218
|
-
0x43400000u, 0x43500000u, 0x43600000u, 0x43700000u, /* [112..119] */
|
|
219
|
-
0x43800000u, 0x43900000u, 0x43A00000u, 0x43B00000u,
|
|
220
|
-
0x43C00000u, 0x43D00000u, 0x43E00000u, 0x7FC00000u /* [120..127] */
|
|
221
|
-
};
|
|
222
|
-
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
|
|
223
|
-
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
|
|
224
|
-
vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
|
|
225
|
-
vector_length);
|
|
226
|
-
vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e4m3_mag_to_f32_lut_, offsets_u32m4, vector_length);
|
|
227
|
-
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 24,
|
|
228
|
-
vector_length);
|
|
229
|
-
return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
|
|
175
|
+
// Extract sign: (raw & 0x80) → bit 7, shift to bit 31
|
|
176
|
+
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(
|
|
177
|
+
__riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length), vector_length), 24,
|
|
178
|
+
vector_length);
|
|
179
|
+
// Strip sign to get 7-bit magnitude, widen to u32, shift left by 20
|
|
180
|
+
vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
|
|
181
|
+
vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
|
|
182
|
+
vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 20, vector_length);
|
|
183
|
+
|
|
184
|
+
// Magic multiply: reinterpret as f32 × 2^120 rebiases from E4M3 (bias=7) to f32 (bias=127).
|
|
185
|
+
vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(
|
|
186
|
+
__riscv_vmv_v_x_u32m4(0x7B800000, vector_length)); // 2^120 = (254-7)<<23
|
|
187
|
+
vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
|
|
188
|
+
vector_length);
|
|
189
|
+
|
|
190
|
+
// NaN fixup: masked OR writes sign|0x7FC00000 only into NaN lanes
|
|
191
|
+
vbool8_t is_nan = __riscv_vmseq_vx_u8m1_b8(nonsign_u8m1, 0x7F, vector_length);
|
|
192
|
+
vuint32m4_t result_u32m4 = __riscv_vor_vx_u32m4_mu(is_nan, __riscv_vreinterpret_v_f32m4_u32m4(result_f32m4),
|
|
193
|
+
sign_u32m4, 0x7FC00000, vector_length);
|
|
194
|
+
|
|
195
|
+
// Restore sign
|
|
196
|
+
result_u32m4 = __riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length);
|
|
197
|
+
return __riscv_vreinterpret_v_u32m4_f32m4(result_u32m4);
|
|
230
198
|
}
|
|
231
199
|
|
|
232
200
|
/**
|
|
233
|
-
* @brief Convert e5m2 (m1) to f32 (m4) via
|
|
234
|
-
*
|
|
201
|
+
* @brief Convert e5m2 (m1) to f32 (m4) via Giesen magic-multiply.
|
|
202
|
+
* Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
|
|
203
|
+
* Handles zero, subnormals, and normals in a single vfmul. Inf/NaN fixup for exp=31.
|
|
204
|
+
* https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
|
|
235
205
|
*/
|
|
236
206
|
NK_INTERNAL vfloat32m4_t nk_e5m2m1_to_f32m4_rvv_(vuint8m1_t e5m2_u8m1, nk_size_t vector_length) {
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
0x43800000u, 0x43A00000u, 0x43C00000u, 0x43E00000u, /* [ 88.. 95] */
|
|
262
|
-
0x44000000u, 0x44200000u, 0x44400000u, 0x44600000u,
|
|
263
|
-
0x44800000u, 0x44A00000u, 0x44C00000u, 0x44E00000u, /* [ 96..103] */
|
|
264
|
-
0x45000000u, 0x45200000u, 0x45400000u, 0x45600000u,
|
|
265
|
-
0x45800000u, 0x45A00000u, 0x45C00000u, 0x45E00000u, /* [104..111] */
|
|
266
|
-
0x46000000u, 0x46200000u, 0x46400000u, 0x46600000u,
|
|
267
|
-
0x46800000u, 0x46A00000u, 0x46C00000u, 0x46E00000u, /* [112..119] */
|
|
268
|
-
0x47000000u, 0x47200000u, 0x47400000u, 0x47600000u,
|
|
269
|
-
0x7F800000u, 0x7FC00000u, 0x7FC00000u, 0x7FC00000u /* [120..127] */
|
|
270
|
-
};
|
|
271
|
-
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length);
|
|
272
|
-
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
|
|
273
|
-
vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
|
|
274
|
-
vector_length);
|
|
275
|
-
vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e5m2_mag_to_f32_lut_, offsets_u32m4, vector_length);
|
|
276
|
-
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 24,
|
|
277
|
-
vector_length);
|
|
278
|
-
return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
|
|
207
|
+
// Extract sign: (raw & 0x80) → bit 7, shift to bit 31
|
|
208
|
+
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(
|
|
209
|
+
__riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length), vector_length), 24,
|
|
210
|
+
vector_length);
|
|
211
|
+
// Strip sign to get 7-bit magnitude, widen to u32, shift left by 21
|
|
212
|
+
vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length),
|
|
213
|
+
vector_length);
|
|
214
|
+
vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 21, vector_length);
|
|
215
|
+
|
|
216
|
+
// Magic multiply: reinterpret as f32 × 2^112 rebiases from E5M2 (bias=15) to f32 (bias=127).
|
|
217
|
+
vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(
|
|
218
|
+
__riscv_vmv_v_x_u32m4(0x77800000, vector_length)); // 2^112 = (254-15)<<23
|
|
219
|
+
vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
|
|
220
|
+
vector_length);
|
|
221
|
+
|
|
222
|
+
// Inf/NaN fixup: masked OR writes 0x7F800000 only into inf/NaN lanes (nonsign > 123)
|
|
223
|
+
vbool8_t is_infnan = __riscv_vmsgtu_vx_u32m4_b8(nonsign_u32m4, 123, vector_length);
|
|
224
|
+
vuint32m4_t result_u32m4 = __riscv_vor_vx_u32m4_mu(is_infnan, __riscv_vreinterpret_v_f32m4_u32m4(result_f32m4),
|
|
225
|
+
__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 0x7F800000,
|
|
226
|
+
vector_length);
|
|
227
|
+
|
|
228
|
+
// Restore sign
|
|
229
|
+
result_u32m4 = __riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length);
|
|
230
|
+
return __riscv_vreinterpret_v_u32m4_f32m4(result_u32m4);
|
|
279
231
|
}
|
|
280
232
|
|
|
281
233
|
/**
|
|
282
|
-
* @brief Convert e2m3 (m1) to f32 (m4) via
|
|
283
|
-
*
|
|
234
|
+
* @brief Convert e2m3 (m1) to f32 (m4) via Giesen magic-multiply.
|
|
235
|
+
* Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
|
|
236
|
+
* Handles zero, subnormals, and normals in a single vfmul. No inf/NaN in E2M3FN.
|
|
237
|
+
* https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
|
|
284
238
|
*/
|
|
285
239
|
NK_INTERNAL vfloat32m4_t nk_e2m3m1_to_f32m4_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
240
|
+
// Extract sign: bit 5 → bit 31
|
|
241
|
+
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(
|
|
242
|
+
__riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length), vector_length), 26,
|
|
243
|
+
vector_length);
|
|
244
|
+
// Strip sign to get 5-bit magnitude, widen to u32, shift left by 20
|
|
245
|
+
vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length),
|
|
246
|
+
vector_length);
|
|
247
|
+
vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 20, vector_length);
|
|
248
|
+
|
|
249
|
+
// Magic multiply: reinterpret as f32 × 2^126 rebiases from E2M3 (bias=1) to f32 (bias=127).
|
|
250
|
+
vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(
|
|
251
|
+
__riscv_vmv_v_x_u32m4(0x7E800000, vector_length)); // 2^126 = (254-1)<<23
|
|
252
|
+
vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
|
|
253
|
+
vector_length);
|
|
254
|
+
|
|
255
|
+
// Restore sign (no inf/NaN fixup needed for E2M3FN)
|
|
256
|
+
vuint32m4_t result_u32m4 = __riscv_vor_vv_u32m4(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), sign_u32m4,
|
|
257
|
+
vector_length);
|
|
258
|
+
return __riscv_vreinterpret_v_u32m4_f32m4(result_u32m4);
|
|
304
259
|
}
|
|
305
260
|
|
|
306
261
|
/**
|
|
307
|
-
* @brief Convert e3m2 (m1) to f32 (m4) via
|
|
308
|
-
*
|
|
262
|
+
* @brief Convert e3m2 (m1) to f32 (m4) via Giesen magic-multiply.
|
|
263
|
+
* Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
|
|
264
|
+
* Handles zero, subnormals, and normals in a single vfmul. No inf/NaN in E3M2FN.
|
|
265
|
+
* https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
|
|
309
266
|
*/
|
|
310
267
|
NK_INTERNAL vfloat32m4_t nk_e3m2m1_to_f32m4_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
268
|
+
// Extract sign: bit 5 → bit 31
|
|
269
|
+
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(
|
|
270
|
+
__riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length), vector_length), 26,
|
|
271
|
+
vector_length);
|
|
272
|
+
// Strip sign to get 5-bit magnitude, widen to u32, shift left by 21
|
|
273
|
+
vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length),
|
|
274
|
+
vector_length);
|
|
275
|
+
vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 21, vector_length);
|
|
276
|
+
|
|
277
|
+
// Magic multiply: reinterpret as f32 × 2^124 rebiases from E3M2 (bias=3) to f32 (bias=127).
|
|
278
|
+
vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(
|
|
279
|
+
__riscv_vmv_v_x_u32m4(0x7D800000, vector_length)); // 2^124 = (254-3)<<23
|
|
280
|
+
vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
|
|
281
|
+
vector_length);
|
|
282
|
+
|
|
283
|
+
// Restore sign (no inf/NaN fixup needed for E3M2FN)
|
|
284
|
+
vuint32m4_t result_u32m4 = __riscv_vor_vv_u32m4(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), sign_u32m4,
|
|
285
|
+
vector_length);
|
|
286
|
+
return __riscv_vreinterpret_v_u32m4_f32m4(result_u32m4);
|
|
329
287
|
}
|
|
330
288
|
|
|
331
|
-
/** @brief Convert e4m3 (m1) to bf16 (m2) via
|
|
289
|
+
/** @brief Convert e4m3 (m1) to bf16 (m2) via Giesen magic-multiply.
|
|
290
|
+
* Magic-multiply to f32, truncate upper 16 bits to bf16. NaN fixup for magnitude 0x7F. */
|
|
332
291
|
NK_INTERNAL vuint16m2_t nk_e4m3m1_to_bf16m2_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
|
|
333
|
-
static nk_u16_t const nk_e4m3_mag_to_bf16_lut_[128] = {
|
|
334
|
-
0x0000u, 0x3B00u, 0x3B80u, 0x3BC0u, 0x3C00u, 0x3C20u, 0x3C40u, 0x3C60u, /* [ 0.. 7] */
|
|
335
|
-
0x3C80u, 0x3C90u, 0x3CA0u, 0x3CB0u, 0x3CC0u, 0x3CD0u, 0x3CE0u, 0x3CF0u, /* [ 8.. 15] */
|
|
336
|
-
0x3D00u, 0x3D10u, 0x3D20u, 0x3D30u, 0x3D40u, 0x3D50u, 0x3D60u, 0x3D70u, /* [ 16.. 23] */
|
|
337
|
-
0x3D80u, 0x3D90u, 0x3DA0u, 0x3DB0u, 0x3DC0u, 0x3DD0u, 0x3DE0u, 0x3DF0u, /* [ 24.. 31] */
|
|
338
|
-
0x3E00u, 0x3E10u, 0x3E20u, 0x3E30u, 0x3E40u, 0x3E50u, 0x3E60u, 0x3E70u, /* [ 32.. 39] */
|
|
339
|
-
0x3E80u, 0x3E90u, 0x3EA0u, 0x3EB0u, 0x3EC0u, 0x3ED0u, 0x3EE0u, 0x3EF0u, /* [ 40.. 47] */
|
|
340
|
-
0x3F00u, 0x3F10u, 0x3F20u, 0x3F30u, 0x3F40u, 0x3F50u, 0x3F60u, 0x3F70u, /* [ 48.. 55] */
|
|
341
|
-
0x3F80u, 0x3F90u, 0x3FA0u, 0x3FB0u, 0x3FC0u, 0x3FD0u, 0x3FE0u, 0x3FF0u, /* [ 56.. 63] */
|
|
342
|
-
0x4000u, 0x4010u, 0x4020u, 0x4030u, 0x4040u, 0x4050u, 0x4060u, 0x4070u, /* [ 64.. 71] */
|
|
343
|
-
0x4080u, 0x4090u, 0x40A0u, 0x40B0u, 0x40C0u, 0x40D0u, 0x40E0u, 0x40F0u, /* [ 72.. 79] */
|
|
344
|
-
0x4100u, 0x4110u, 0x4120u, 0x4130u, 0x4140u, 0x4150u, 0x4160u, 0x4170u, /* [ 80.. 87] */
|
|
345
|
-
0x4180u, 0x4190u, 0x41A0u, 0x41B0u, 0x41C0u, 0x41D0u, 0x41E0u, 0x41F0u, /* [ 88.. 95] */
|
|
346
|
-
0x4200u, 0x4210u, 0x4220u, 0x4230u, 0x4240u, 0x4250u, 0x4260u, 0x4270u, /* [ 96..103] */
|
|
347
|
-
0x4280u, 0x4290u, 0x42A0u, 0x42B0u, 0x42C0u, 0x42D0u, 0x42E0u, 0x42F0u, /* [104..111] */
|
|
348
|
-
0x4300u, 0x4310u, 0x4320u, 0x4330u, 0x4340u, 0x4350u, 0x4360u, 0x4370u, /* [112..119] */
|
|
349
|
-
0x4380u, 0x4390u, 0x43A0u, 0x43B0u, 0x43C0u, 0x43D0u, 0x43E0u, 0x7FC0u /* [120..127] */
|
|
350
|
-
};
|
|
351
292
|
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
|
|
352
|
-
vuint8m1_t
|
|
353
|
-
|
|
293
|
+
vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
|
|
294
|
+
vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
|
|
295
|
+
vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 20, vector_length);
|
|
296
|
+
// Magic multiply: reinterpret as f32 × 2^120
|
|
297
|
+
vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vmv_v_x_u32m4(0x7B800000, vector_length));
|
|
298
|
+
vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
|
|
299
|
+
vector_length);
|
|
300
|
+
// Truncate f32 → bf16 (right shift 16, exact for all e4m3 values)
|
|
301
|
+
vuint16m2_t result_u16m2 = __riscv_vnsrl_wx_u16m2(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 16,
|
|
354
302
|
vector_length);
|
|
355
|
-
|
|
303
|
+
// NaN fixup: magnitude 0x7F → bf16 quiet NaN 0x7FC0
|
|
304
|
+
vbool8_t is_nan = __riscv_vmseq_vx_u8m1_b8(nonsign_u8m1, 0x7F, vector_length);
|
|
305
|
+
result_u16m2 = __riscv_vmerge_vxm_u16m2(result_u16m2, 0x7FC0, is_nan, vector_length);
|
|
306
|
+
// Restore sign: bit 7 → bf16 bit 15 (<<8)
|
|
356
307
|
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
|
|
357
308
|
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
358
309
|
}
|
|
359
310
|
|
|
360
|
-
/** @brief Convert e5m2 (m1) to bf16 (m2) via
|
|
311
|
+
/** @brief Convert e5m2 (m1) to bf16 (m2) via Giesen magic-multiply.
|
|
312
|
+
* Magic-multiply to f32, inf/NaN fixup, truncate upper 16 bits to bf16. */
|
|
361
313
|
NK_INTERNAL vuint16m2_t nk_e5m2m1_to_bf16m2_rvv_(vuint8m1_t e5m2_u8m1, nk_size_t vector_length) {
|
|
362
|
-
static nk_u16_t const nk_e5m2_mag_to_bf16_lut_[128] = {
|
|
363
|
-
0x0000u, 0x3780u, 0x3800u, 0x3840u, 0x3880u, 0x38A0u, 0x38C0u, 0x38E0u, /* [ 0.. 7] */
|
|
364
|
-
0x3900u, 0x3920u, 0x3940u, 0x3960u, 0x3980u, 0x39A0u, 0x39C0u, 0x39E0u, /* [ 8.. 15] */
|
|
365
|
-
0x3A00u, 0x3A20u, 0x3A40u, 0x3A60u, 0x3A80u, 0x3AA0u, 0x3AC0u, 0x3AE0u, /* [ 16.. 23] */
|
|
366
|
-
0x3B00u, 0x3B20u, 0x3B40u, 0x3B60u, 0x3B80u, 0x3BA0u, 0x3BC0u, 0x3BE0u, /* [ 24.. 31] */
|
|
367
|
-
0x3C00u, 0x3C20u, 0x3C40u, 0x3C60u, 0x3C80u, 0x3CA0u, 0x3CC0u, 0x3CE0u, /* [ 32.. 39] */
|
|
368
|
-
0x3D00u, 0x3D20u, 0x3D40u, 0x3D60u, 0x3D80u, 0x3DA0u, 0x3DC0u, 0x3DE0u, /* [ 40.. 47] */
|
|
369
|
-
0x3E00u, 0x3E20u, 0x3E40u, 0x3E60u, 0x3E80u, 0x3EA0u, 0x3EC0u, 0x3EE0u, /* [ 48.. 55] */
|
|
370
|
-
0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, 0x3F80u, 0x3FA0u, 0x3FC0u, 0x3FE0u, /* [ 56.. 63] */
|
|
371
|
-
0x4000u, 0x4020u, 0x4040u, 0x4060u, 0x4080u, 0x40A0u, 0x40C0u, 0x40E0u, /* [ 64.. 71] */
|
|
372
|
-
0x4100u, 0x4120u, 0x4140u, 0x4160u, 0x4180u, 0x41A0u, 0x41C0u, 0x41E0u, /* [ 72.. 79] */
|
|
373
|
-
0x4200u, 0x4220u, 0x4240u, 0x4260u, 0x4280u, 0x42A0u, 0x42C0u, 0x42E0u, /* [ 80.. 87] */
|
|
374
|
-
0x4300u, 0x4320u, 0x4340u, 0x4360u, 0x4380u, 0x43A0u, 0x43C0u, 0x43E0u, /* [ 88.. 95] */
|
|
375
|
-
0x4400u, 0x4420u, 0x4440u, 0x4460u, 0x4480u, 0x44A0u, 0x44C0u, 0x44E0u, /* [ 96..103] */
|
|
376
|
-
0x4500u, 0x4520u, 0x4540u, 0x4560u, 0x4580u, 0x45A0u, 0x45C0u, 0x45E0u, /* [104..111] */
|
|
377
|
-
0x4600u, 0x4620u, 0x4640u, 0x4660u, 0x4680u, 0x46A0u, 0x46C0u, 0x46E0u, /* [112..119] */
|
|
378
|
-
0x4700u, 0x4720u, 0x4740u, 0x4760u, 0x7F80u, 0x7FC0u, 0x7FC0u, 0x7FC0u /* [120..127] */
|
|
379
|
-
};
|
|
380
314
|
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length);
|
|
381
|
-
vuint8m1_t
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
315
|
+
vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
|
|
316
|
+
vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
|
|
317
|
+
vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 21, vector_length);
|
|
318
|
+
// Magic multiply: reinterpret as f32 × 2^112
|
|
319
|
+
vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vmv_v_x_u32m4(0x77800000, vector_length));
|
|
320
|
+
vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
|
|
321
|
+
vector_length);
|
|
322
|
+
// Inf/NaN fixup: masked OR writes 0x7F800000 only into inf/NaN lanes (nonsign > 123)
|
|
323
|
+
vbool8_t is_infnan = __riscv_vmsgtu_vx_u32m4_b8(nonsign_u32m4, 123, vector_length);
|
|
324
|
+
vuint32m4_t f32_bits = __riscv_vor_vx_u32m4_mu(is_infnan, __riscv_vreinterpret_v_f32m4_u32m4(result_f32m4),
|
|
325
|
+
__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 0x7F800000,
|
|
326
|
+
vector_length);
|
|
327
|
+
// Truncate f32 → bf16 (right shift 16, exact for all e5m2 values)
|
|
328
|
+
vuint16m2_t result_u16m2 = __riscv_vnsrl_wx_u16m2(f32_bits, 16, vector_length);
|
|
329
|
+
// Restore sign: bit 7 → bf16 bit 15 (<<8)
|
|
385
330
|
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
|
|
386
331
|
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
387
332
|
}
|
|
388
333
|
|
|
389
|
-
/** @brief Convert e2m3 (m1) to bf16 (m2) via
|
|
334
|
+
/** @brief Convert e2m3 (m1) to bf16 (m2) via Giesen magic-multiply.
|
|
335
|
+
* Magic-multiply to f32, truncate upper 16 bits to bf16. No inf/NaN in E2M3FN. */
|
|
390
336
|
NK_INTERNAL vuint16m2_t nk_e2m3m1_to_bf16m2_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
|
|
391
|
-
static nk_u16_t const nk_e2m3_mag_to_bf16_lut_[32] = {
|
|
392
|
-
0x0000u, 0x3E00u, 0x3E80u, 0x3EC0u, 0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, /* [ 0.. 7] */
|
|
393
|
-
0x3F80u, 0x3F90u, 0x3FA0u, 0x3FB0u, 0x3FC0u, 0x3FD0u, 0x3FE0u, 0x3FF0u, /* [ 8.. 15] */
|
|
394
|
-
0x4000u, 0x4010u, 0x4020u, 0x4030u, 0x4040u, 0x4050u, 0x4060u, 0x4070u, /* [ 16.. 23] */
|
|
395
|
-
0x4080u, 0x4090u, 0x40A0u, 0x40B0u, 0x40C0u, 0x40D0u, 0x40E0u, 0x40F0u /* [ 24.. 31] */
|
|
396
|
-
};
|
|
397
337
|
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
|
|
398
|
-
vuint8m1_t
|
|
399
|
-
|
|
338
|
+
vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
|
|
339
|
+
vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
|
|
340
|
+
vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 20, vector_length);
|
|
341
|
+
// Magic multiply: reinterpret as f32 × 2^126
|
|
342
|
+
vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vmv_v_x_u32m4(0x7E800000, vector_length));
|
|
343
|
+
vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
|
|
344
|
+
vector_length);
|
|
345
|
+
// Truncate f32 → bf16 (right shift 16, exact for all e2m3 values)
|
|
346
|
+
vuint16m2_t result_u16m2 = __riscv_vnsrl_wx_u16m2(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 16,
|
|
400
347
|
vector_length);
|
|
401
|
-
|
|
348
|
+
// Restore sign: bit 5 → bf16 bit 15 (<<10)
|
|
402
349
|
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
|
|
403
350
|
vector_length);
|
|
404
351
|
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
405
352
|
}
|
|
406
353
|
|
|
407
|
-
/** @brief Convert e3m2 (m1) to bf16 (m2) via
|
|
354
|
+
/** @brief Convert e3m2 (m1) to bf16 (m2) via Giesen magic-multiply.
|
|
355
|
+
* Magic-multiply to f32, truncate upper 16 bits to bf16. No inf/NaN in E3M2FN. */
|
|
408
356
|
NK_INTERNAL vuint16m2_t nk_e3m2m1_to_bf16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
|
|
409
|
-
static nk_u16_t const nk_e3m2_mag_to_bf16_lut_[32] = {
|
|
410
|
-
0x0000u, 0x3D80u, 0x3E00u, 0x3E40u, 0x3E80u, 0x3EA0u, 0x3EC0u, 0x3EE0u, /* [ 0.. 7] */
|
|
411
|
-
0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, 0x3F80u, 0x3FA0u, 0x3FC0u, 0x3FE0u, /* [ 8.. 15] */
|
|
412
|
-
0x4000u, 0x4020u, 0x4040u, 0x4060u, 0x4080u, 0x40A0u, 0x40C0u, 0x40E0u, /* [ 16.. 23] */
|
|
413
|
-
0x4100u, 0x4120u, 0x4140u, 0x4160u, 0x4180u, 0x41A0u, 0x41C0u, 0x41E0u /* [ 24.. 31] */
|
|
414
|
-
};
|
|
415
357
|
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
|
|
416
|
-
vuint8m1_t
|
|
417
|
-
|
|
358
|
+
vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
|
|
359
|
+
vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
|
|
360
|
+
vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 21, vector_length);
|
|
361
|
+
// Magic multiply: reinterpret as f32 × 2^124
|
|
362
|
+
vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vmv_v_x_u32m4(0x7D800000, vector_length));
|
|
363
|
+
vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
|
|
364
|
+
vector_length);
|
|
365
|
+
// Truncate f32 → bf16 (right shift 16, exact for all e3m2 values)
|
|
366
|
+
vuint16m2_t result_u16m2 = __riscv_vnsrl_wx_u16m2(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 16,
|
|
418
367
|
vector_length);
|
|
419
|
-
|
|
368
|
+
// Restore sign: bit 5 → bf16 bit 15 (<<10)
|
|
420
369
|
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
|
|
421
370
|
vector_length);
|
|
422
371
|
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
@@ -443,8 +392,8 @@ NK_INTERNAL vuint16m2_t nk_e4m3m1_to_f16m2_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t
|
|
|
443
392
|
0x5C00u, 0x5C80u, 0x5D00u, 0x5D80u, 0x5E00u, 0x5E80u, 0x5F00u, 0x7E00u /* [120..127] */
|
|
444
393
|
};
|
|
445
394
|
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
|
|
446
|
-
vuint8m1_t
|
|
447
|
-
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(
|
|
395
|
+
vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
|
|
396
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(nonsign_u8m1, vector_length), 1,
|
|
448
397
|
vector_length);
|
|
449
398
|
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e4m3_mag_to_f16_lut_, offsets_u16m2, vector_length);
|
|
450
399
|
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
|
|
@@ -460,8 +409,8 @@ NK_INTERNAL vuint16m2_t nk_e2m3m1_to_f16m2_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t
|
|
|
460
409
|
0x4400u, 0x4480u, 0x4500u, 0x4580u, 0x4600u, 0x4680u, 0x4700u, 0x4780u /* [ 24.. 31] */
|
|
461
410
|
};
|
|
462
411
|
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
|
|
463
|
-
vuint8m1_t
|
|
464
|
-
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(
|
|
412
|
+
vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
|
|
413
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(nonsign_u8m1, vector_length), 1,
|
|
465
414
|
vector_length);
|
|
466
415
|
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e2m3_mag_to_f16_lut_, offsets_u16m2, vector_length);
|
|
467
416
|
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
|
|
@@ -478,8 +427,8 @@ NK_INTERNAL vuint16m2_t nk_e3m2m1_to_f16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t
|
|
|
478
427
|
0x4800u, 0x4900u, 0x4A00u, 0x4B00u, 0x4C00u, 0x4D00u, 0x4E00u, 0x4F00u /* [ 24.. 31] */
|
|
479
428
|
};
|
|
480
429
|
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
|
|
481
|
-
vuint8m1_t
|
|
482
|
-
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(
|
|
430
|
+
vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
|
|
431
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(nonsign_u8m1, vector_length), 1,
|
|
483
432
|
vector_length);
|
|
484
433
|
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_mag_to_f16_lut_, offsets_u16m2, vector_length);
|
|
485
434
|
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
|
|
@@ -501,18 +450,18 @@ NK_INTERNAL vuint16m2_t nk_e3m2m1_to_f16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t
|
|
|
501
450
|
*/
|
|
502
451
|
NK_INTERNAL vint8m1x2_t nk_i4m1_to_i8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t vector_length) {
|
|
503
452
|
// Extract high nibble (even indices in output)
|
|
504
|
-
vuint8m1_t
|
|
453
|
+
vuint8m1_t high_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
|
|
505
454
|
// Sign extend: (x ^ 8) - 8
|
|
506
|
-
vint8m1_t
|
|
507
|
-
__riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(
|
|
455
|
+
vint8m1_t high_i8m1 = __riscv_vsub_vx_i8m1(
|
|
456
|
+
__riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(high_u8m1), 8, vector_length), 8, vector_length);
|
|
508
457
|
|
|
509
458
|
// Extract low nibble (odd indices in output)
|
|
510
|
-
vuint8m1_t
|
|
459
|
+
vuint8m1_t low_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
|
|
511
460
|
// Sign extend: (x ^ 8) - 8
|
|
512
|
-
vint8m1_t
|
|
513
|
-
__riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(
|
|
461
|
+
vint8m1_t low_i8m1 = __riscv_vsub_vx_i8m1(
|
|
462
|
+
__riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(low_u8m1), 8, vector_length), 8, vector_length);
|
|
514
463
|
|
|
515
|
-
return __riscv_vcreate_v_i8m1x2(
|
|
464
|
+
return __riscv_vcreate_v_i8m1x2(high_i8m1, low_i8m1);
|
|
516
465
|
}
|
|
517
466
|
|
|
518
467
|
/**
|
|
@@ -522,12 +471,12 @@ NK_INTERNAL vint8m1x2_t nk_i4m1_to_i8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t v
|
|
|
522
471
|
*/
|
|
523
472
|
NK_INTERNAL vuint8m1x2_t nk_u4m1_to_u8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t vector_length) {
|
|
524
473
|
// Extract high nibble (even indices in output)
|
|
525
|
-
vuint8m1_t
|
|
474
|
+
vuint8m1_t high_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
|
|
526
475
|
|
|
527
476
|
// Extract low nibble (odd indices in output)
|
|
528
|
-
vuint8m1_t
|
|
477
|
+
vuint8m1_t low_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
|
|
529
478
|
|
|
530
|
-
return __riscv_vcreate_v_u8m1x2(
|
|
479
|
+
return __riscv_vcreate_v_u8m1x2(high_u8m1, low_u8m1);
|
|
531
480
|
}
|
|
532
481
|
|
|
533
482
|
/**
|
|
@@ -536,17 +485,17 @@ NK_INTERNAL vuint8m1x2_t nk_u4m1_to_u8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t
|
|
|
536
485
|
* Takes a tuple of two m1 vectors (high nibbles, low nibbles from segment load).
|
|
537
486
|
* Values are clamped to [-8, 7] before packing.
|
|
538
487
|
*/
|
|
539
|
-
NK_INTERNAL vuint8m1_t nk_i8m2_to_i4m1_rvv_(vint8m1_t
|
|
488
|
+
NK_INTERNAL vuint8m1_t nk_i8m2_to_i4m1_rvv_(vint8m1_t high_i8m1, vint8m1_t low_i8m1, nk_size_t vector_length) {
|
|
540
489
|
// Clamp to [-8, 7]
|
|
541
|
-
|
|
542
|
-
|
|
490
|
+
high_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(high_i8m1, 7, vector_length), -8, vector_length);
|
|
491
|
+
low_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(low_i8m1, 7, vector_length), -8, vector_length);
|
|
543
492
|
|
|
544
493
|
// Convert to unsigned nibbles: value & 0x0F
|
|
545
|
-
vuint8m1_t
|
|
546
|
-
vuint8m1_t
|
|
494
|
+
vuint8m1_t high_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(high_i8m1), 0x0F, vector_length);
|
|
495
|
+
vuint8m1_t low_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(low_i8m1), 0x0F, vector_length);
|
|
547
496
|
|
|
548
497
|
// Pack: (hi << 4) | lo
|
|
549
|
-
return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(
|
|
498
|
+
return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(high_u4m1, 4, vector_length), low_u4m1, vector_length);
|
|
550
499
|
}
|
|
551
500
|
|
|
552
501
|
/**
|
|
@@ -555,13 +504,13 @@ NK_INTERNAL vuint8m1_t nk_i8m2_to_i4m1_rvv_(vint8m1_t hi_i8m1, vint8m1_t lo_i8m1
|
|
|
555
504
|
* Takes a tuple of two m1 vectors (high nibbles, low nibbles from segment load).
|
|
556
505
|
* Values are clamped to [0, 15] before packing.
|
|
557
506
|
*/
|
|
558
|
-
NK_INTERNAL vuint8m1_t nk_u8m2_to_u4m1_rvv_(vuint8m1_t
|
|
507
|
+
NK_INTERNAL vuint8m1_t nk_u8m2_to_u4m1_rvv_(vuint8m1_t high_u8m1, vuint8m1_t low_u8m1, nk_size_t vector_length) {
|
|
559
508
|
// Clamp to [0, 15]
|
|
560
|
-
|
|
561
|
-
|
|
509
|
+
high_u8m1 = __riscv_vminu_vx_u8m1(high_u8m1, 15, vector_length);
|
|
510
|
+
low_u8m1 = __riscv_vminu_vx_u8m1(low_u8m1, 15, vector_length);
|
|
562
511
|
|
|
563
512
|
// Pack: (hi << 4) | lo
|
|
564
|
-
return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(
|
|
513
|
+
return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(high_u8m1, 4, vector_length), low_u8m1, vector_length);
|
|
565
514
|
}
|
|
566
515
|
|
|
567
516
|
/**
|
|
@@ -721,9 +670,9 @@ NK_INTERNAL vuint8m1_t nk_f32m4_to_e5m2m1_rvv_(vfloat32m4_t f32_f32m4, nk_size_t
|
|
|
721
670
|
return __riscv_vncvt_x_x_w_u8m1(result_u16m2, vector_length);
|
|
722
671
|
}
|
|
723
672
|
|
|
724
|
-
#pragma endregion
|
|
673
|
+
#pragma endregion Register - to - Register Helpers
|
|
725
674
|
|
|
726
|
-
#pragma region
|
|
675
|
+
#pragma region Unified Cast Dispatcher
|
|
727
676
|
|
|
728
677
|
NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t count, void *to, nk_dtype_t to_type) {
|
|
729
678
|
// bf16 → f32
|
|
@@ -975,9 +924,9 @@ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t cou
|
|
|
975
924
|
n_bytes -= vector_length, source += vector_length * 2, destination += vector_length) {
|
|
976
925
|
vector_length = __riscv_vsetvl_e8m1(n_bytes);
|
|
977
926
|
vint8m1x2_t loaded_i8m1x2 = __riscv_vlseg2e8_v_i8m1x2(source, vector_length);
|
|
978
|
-
vint8m1_t
|
|
979
|
-
vint8m1_t
|
|
980
|
-
vuint8m1_t packed_u8m1 = nk_i8m2_to_i4m1_rvv_(
|
|
927
|
+
vint8m1_t high_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 0);
|
|
928
|
+
vint8m1_t low_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 1);
|
|
929
|
+
vuint8m1_t packed_u8m1 = nk_i8m2_to_i4m1_rvv_(high_i8m1, low_i8m1, vector_length);
|
|
981
930
|
__riscv_vse8_v_u8m1((nk_u8_t *)destination, packed_u8m1, vector_length);
|
|
982
931
|
}
|
|
983
932
|
return;
|
|
@@ -992,9 +941,9 @@ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t cou
|
|
|
992
941
|
n_bytes -= vector_length, source += vector_length * 2, destination += vector_length) {
|
|
993
942
|
vector_length = __riscv_vsetvl_e8m1(n_bytes);
|
|
994
943
|
vuint8m1x2_t loaded_u8m1x2 = __riscv_vlseg2e8_v_u8m1x2(source, vector_length);
|
|
995
|
-
vuint8m1_t
|
|
996
|
-
vuint8m1_t
|
|
997
|
-
vuint8m1_t packed_u8m1 = nk_u8m2_to_u4m1_rvv_(
|
|
944
|
+
vuint8m1_t high_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 0);
|
|
945
|
+
vuint8m1_t low_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 1);
|
|
946
|
+
vuint8m1_t packed_u8m1 = nk_u8m2_to_u4m1_rvv_(high_u8m1, low_u8m1, vector_length);
|
|
998
947
|
__riscv_vse8_v_u8m1((nk_u8_t *)destination, packed_u8m1, vector_length);
|
|
999
948
|
}
|
|
1000
949
|
return;
|
|
@@ -1004,7 +953,7 @@ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t cou
|
|
|
1004
953
|
nk_cast_serial(from, from_type, count, to, to_type);
|
|
1005
954
|
}
|
|
1006
955
|
|
|
1007
|
-
#pragma endregion
|
|
956
|
+
#pragma endregion Unified Cast Dispatcher
|
|
1008
957
|
|
|
1009
958
|
#if defined(__cplusplus)
|
|
1010
959
|
} // extern "C"
|