numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1021 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Type Conversions for RISC-V.
|
|
3
|
+
* @file include/numkong/cast/rvv.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 13, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/cast.h
|
|
8
|
+
*
|
|
9
|
+
* SpacemiT K1 and similar chips implement RVA22 profile with base RVV 1.0.
|
|
10
|
+
* This file provides vectorized type conversions for:
|
|
11
|
+
* - BF16 ↔ F32 (bit manipulation, no hardware support)
|
|
12
|
+
* - F16 ↔ F32 (bit manipulation, no hardware support)
|
|
13
|
+
* - E4M3 ↔ F32 (FP8 format for ML inference)
|
|
14
|
+
* - E5M2 ↔ F32 (FP8 format for ML training)
|
|
15
|
+
* - i4/u4 unpacking to i8/u8
|
|
16
|
+
*
|
|
17
|
+
* Mini-float conversions use sign-symmetric magnitude LUTs: every mini-float
|
|
18
|
+
* format is sign|magnitude, so we store only the positive-half (magnitude)
|
|
19
|
+
* entries and extract the sign bit separately. This cuts LUT memory by 50-87%
|
|
20
|
+
* and fixes the E2M3FN NaN bug (E2M3FN has NO NaN; index 31 is +7.5, not NaN).
|
|
21
|
+
*
|
|
22
|
+
* 8-bit formats (e4m3, e5m2): sign = bit 7, magnitude = bits 6:0 (128 entries)
|
|
23
|
+
* 6-bit formats (e2m3, e3m2): sign = bit 5, magnitude = bits 4:0 (32 entries)
|
|
24
|
+
*
|
|
25
|
+
* @section rvv_cast_instructions Key RVV Cast Instructions
|
|
26
|
+
*
|
|
27
|
+
* Intrinsic Purpose
|
|
28
|
+
* vzext_vf4_u32m4 Zero-extend u8 → u32 (4x widening)
|
|
29
|
+
* vsext_vf4_i32m4 Sign-extend i8 → i32 (4x widening)
|
|
30
|
+
* vsll_vx / vsrl_vx Bit shifts for field extraction
|
|
31
|
+
* vand_vx Bit masking
|
|
32
|
+
* vor_vv Combining bit fields
|
|
33
|
+
* vfcvt_f_xu_v Unsigned int → float
|
|
34
|
+
* vmseq_vx Compare for conditional selection
|
|
35
|
+
* vmerge_vvm Conditional select (blend)
|
|
36
|
+
*/
|
|
37
|
+
#ifndef NK_CAST_RVV_H
|
|
38
|
+
#define NK_CAST_RVV_H
|
|
39
|
+
|
|
40
|
+
#if NK_TARGET_RISCV_
|
|
41
|
+
#if NK_TARGET_RVV
|
|
42
|
+
|
|
43
|
+
#include "numkong/types.h"
|
|
44
|
+
#include "numkong/cast/serial.h" // `nk_cast_serial`
|
|
45
|
+
|
|
46
|
+
#if defined(__clang__)
|
|
47
|
+
#pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
|
|
48
|
+
#elif defined(__GNUC__)
|
|
49
|
+
#pragma GCC push_options
|
|
50
|
+
#pragma GCC target("arch=+v")
|
|
51
|
+
#endif
|
|
52
|
+
|
|
53
|
+
#if defined(__cplusplus)
|
|
54
|
+
extern "C" {
|
|
55
|
+
#endif
|
|
56
|
+
|
|
57
|
+
#pragma region - Register-to-Register Helpers
|
|
58
|
+
|
|
59
|
+
/**
|
|
60
|
+
* @brief Convert bf16 (m1) to f32 (m2) register-to-register.
|
|
61
|
+
*
|
|
62
|
+
* BF16 is the upper 16 bits of F32 (same sign + exponent + top 7 mantissa bits).
|
|
63
|
+
* Conversion is simply: f32_bits = bf16_bits << 16.
|
|
64
|
+
*/
|
|
65
|
+
NK_INTERNAL vfloat32m2_t nk_bf16m1_to_f32m2_rvv_(vuint16m1_t bf16_u16m1, nk_size_t vector_length) {
|
|
66
|
+
vuint32m2_t bits_u32m2 = __riscv_vzext_vf2_u32m2(bf16_u16m1, vector_length);
|
|
67
|
+
bits_u32m2 = __riscv_vsll_vx_u32m2(bits_u32m2, 16, vector_length);
|
|
68
|
+
return __riscv_vreinterpret_v_u32m2_f32m2(bits_u32m2);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
/**
|
|
72
|
+
* @brief Convert f32 (m2) to bf16 (m1) register-to-register.
|
|
73
|
+
*
|
|
74
|
+
* Conversion with round-to-nearest-even (RNE): add (0x7FFF + lsb) to match hardware BF16 behavior.
|
|
75
|
+
*/
|
|
76
|
+
NK_INTERNAL vuint16m1_t nk_f32m2_to_bf16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_t vector_length) {
|
|
77
|
+
vuint32m2_t bits_u32m2 = __riscv_vreinterpret_v_f32m2_u32m2(f32_f32m2);
|
|
78
|
+
// Extract LSB of result (bit 16) for round-to-nearest-even
|
|
79
|
+
vuint32m2_t lsb_u32m2 = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 16, vector_length), 1,
|
|
80
|
+
vector_length);
|
|
81
|
+
vuint32m2_t rounding_u32m2 = __riscv_vadd_vx_u32m2(lsb_u32m2, 0x7FFF, vector_length);
|
|
82
|
+
vuint32m2_t rounded_u32m2 = __riscv_vadd_vv_u32m2(bits_u32m2, rounding_u32m2, vector_length);
|
|
83
|
+
vuint32m2_t shifted_u32m2 = __riscv_vsrl_vx_u32m2(rounded_u32m2, 16, vector_length);
|
|
84
|
+
return __riscv_vncvt_x_x_w_u16m1(shifted_u32m2, vector_length);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
/**
|
|
88
|
+
* @brief Convert f16 (m1) to f32 (m2) register-to-register.
|
|
89
|
+
*
|
|
90
|
+
* F16 format: S EEEEE MMMMMMMMMM (1 sign, 5 exponent bits with bias=15, 10 mantissa bits)
|
|
91
|
+
* F32 format: S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM (1 sign, 8 exponent bits with bias=127, 23 mantissa bits)
|
|
92
|
+
*
|
|
93
|
+
* Handles all IEEE-754 edge cases: ±zero, denormals, normals, ±inf, NaN.
|
|
94
|
+
*/
|
|
95
|
+
NK_INTERNAL vfloat32m2_t nk_f16m1_to_f32m2_rvv_(vuint16m1_t f16_u16m1, nk_size_t vector_length) {
|
|
96
|
+
// Widen to 32-bit for manipulation
|
|
97
|
+
vuint32m2_t bits_u32m2 = __riscv_vzext_vf2_u32m2(f16_u16m1, vector_length);
|
|
98
|
+
// Extract sign: (raw >> 15) << 31
|
|
99
|
+
vuint32m2_t sign_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 15, vector_length), 31,
|
|
100
|
+
vector_length);
|
|
101
|
+
// Extract exponent: (raw >> 10) & 0x1F
|
|
102
|
+
vuint32m2_t exponent_u32m2 = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 10, vector_length), 0x1F,
|
|
103
|
+
vector_length);
|
|
104
|
+
// Extract mantissa: raw & 0x3FF
|
|
105
|
+
vuint32m2_t mantissa_u32m2 = __riscv_vand_vx_u32m2(bits_u32m2, 0x3FF, vector_length);
|
|
106
|
+
|
|
107
|
+
// Normal path: rebias exponent (15 → 127): add 112, combine
|
|
108
|
+
vuint32m2_t f32_exponent_u32m2 = __riscv_vadd_vx_u32m2(exponent_u32m2, 112, vector_length);
|
|
109
|
+
vuint32m2_t normal_u32m2 = __riscv_vor_vv_u32m2(
|
|
110
|
+
sign_u32m2,
|
|
111
|
+
__riscv_vor_vv_u32m2(__riscv_vsll_vx_u32m2(f32_exponent_u32m2, 23, vector_length),
|
|
112
|
+
__riscv_vsll_vx_u32m2(mantissa_u32m2, 13, vector_length), vector_length),
|
|
113
|
+
vector_length);
|
|
114
|
+
|
|
115
|
+
// Special case: exponent == 0 (zero or denormal)
|
|
116
|
+
// Zero: sign | 0. Denormal: mantissa × 2^(-24), handled via FPU normalization trick.
|
|
117
|
+
// For denormals, convert mantissa to float and subtract 0x0C000000 (24 from exponent),
|
|
118
|
+
// matching the serial implementation. For zeros (mantissa==0), (float)0 - bias = 0.
|
|
119
|
+
vbool16_t is_exp_zero = __riscv_vmseq_vx_u32m2_b16(exponent_u32m2, 0, vector_length);
|
|
120
|
+
vfloat32m2_t mantissa_f32m2 = __riscv_vfcvt_f_xu_v_f32m2(mantissa_u32m2, vector_length);
|
|
121
|
+
vuint32m2_t denorm_bits_u32m2 = __riscv_vsub_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(mantissa_f32m2),
|
|
122
|
+
0x0C000000, vector_length);
|
|
123
|
+
vuint32m2_t zero_or_denorm_u32m2 = __riscv_vor_vv_u32m2(sign_u32m2, denorm_bits_u32m2, vector_length);
|
|
124
|
+
// For true zeros (mantissa==0), the FPU converts 0 to 0x00000000, minus bias wraps,
|
|
125
|
+
// so force to sign-only.
|
|
126
|
+
vbool16_t is_true_zero = __riscv_vmand_mm_b16(
|
|
127
|
+
is_exp_zero, __riscv_vmseq_vx_u32m2_b16(mantissa_u32m2, 0, vector_length), vector_length);
|
|
128
|
+
zero_or_denorm_u32m2 = __riscv_vmerge_vvm_u32m2(zero_or_denorm_u32m2, sign_u32m2, is_true_zero, vector_length);
|
|
129
|
+
|
|
130
|
+
// Special case: exponent == 31 (infinity or NaN)
|
|
131
|
+
// sign | 0x7F800000 | (mantissa << 13)
|
|
132
|
+
vbool16_t is_exp_max = __riscv_vmseq_vx_u32m2_b16(exponent_u32m2, 31, vector_length);
|
|
133
|
+
vuint32m2_t inf_nan_u32m2 = __riscv_vor_vv_u32m2(__riscv_vor_vx_u32m2(sign_u32m2, 0x7F800000, vector_length),
|
|
134
|
+
__riscv_vsll_vx_u32m2(mantissa_u32m2, 13, vector_length),
|
|
135
|
+
vector_length);
|
|
136
|
+
|
|
137
|
+
// Select: exp==0 → zero_or_denorm, exp==31 → inf_nan, else → normal
|
|
138
|
+
vuint32m2_t result_u32m2 = __riscv_vmerge_vvm_u32m2(normal_u32m2, zero_or_denorm_u32m2, is_exp_zero, vector_length);
|
|
139
|
+
result_u32m2 = __riscv_vmerge_vvm_u32m2(result_u32m2, inf_nan_u32m2, is_exp_max, vector_length);
|
|
140
|
+
return __riscv_vreinterpret_v_u32m2_f32m2(result_u32m2);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
/**
|
|
144
|
+
* @brief Convert f32 (m2) to f16 (m1) register-to-register.
|
|
145
|
+
*
|
|
146
|
+
* Conversion: Rebias exponent from 127 to 15, truncate mantissa from 23 to 10 bits with rounding.
|
|
147
|
+
*/
|
|
148
|
+
NK_INTERNAL vuint16m1_t nk_f32m2_to_f16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_t vector_length) {
|
|
149
|
+
vuint32m2_t bits_u32m2 = __riscv_vreinterpret_v_f32m2_u32m2(f32_f32m2);
|
|
150
|
+
// Extract sign: (raw >> 31) << 15
|
|
151
|
+
vuint32m2_t sign_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 31, vector_length), 15,
|
|
152
|
+
vector_length);
|
|
153
|
+
// Extract exponent: (raw >> 23) & 0xFF
|
|
154
|
+
vuint32m2_t exponent_u32m2 = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 23, vector_length), 0xFF,
|
|
155
|
+
vector_length);
|
|
156
|
+
// Extract mantissa: raw & 0x7FFFFF
|
|
157
|
+
vuint32m2_t mantissa_u32m2 = __riscv_vand_vx_u32m2(bits_u32m2, 0x7FFFFF, vector_length);
|
|
158
|
+
// Rebias exponent (127 → 15): subtract 112, clamp to [0, 31]
|
|
159
|
+
// Note: This is a simplified conversion that doesn't handle subnormals or overflow properly
|
|
160
|
+
vint32m2_t exponent_i32m2 = __riscv_vsub_vx_i32m2(__riscv_vreinterpret_v_u32m2_i32m2(exponent_u32m2), 112,
|
|
161
|
+
vector_length);
|
|
162
|
+
exponent_i32m2 = __riscv_vmax_vx_i32m2(exponent_i32m2, 0, vector_length);
|
|
163
|
+
vuint32m2_t f16_exponent_u32m2 = __riscv_vreinterpret_v_i32m2_u32m2(
|
|
164
|
+
__riscv_vmin_vx_i32m2(exponent_i32m2, 31, vector_length));
|
|
165
|
+
// Round mantissa: add 0x1000 (half of truncated bits) then shift.
|
|
166
|
+
// If rounding overflows the mantissa (bit 23 set), carry into exponent.
|
|
167
|
+
vuint32m2_t rounded_mantissa_u32m2 = __riscv_vadd_vx_u32m2(mantissa_u32m2, 0x1000, vector_length);
|
|
168
|
+
vbool16_t mantissa_overflow_b16 = __riscv_vmsne_vx_u32m2_b16(
|
|
169
|
+
__riscv_vand_vx_u32m2(rounded_mantissa_u32m2, 0x800000, vector_length), 0, vector_length);
|
|
170
|
+
f16_exponent_u32m2 = __riscv_vadd_vx_u32m2_mu(mantissa_overflow_b16, f16_exponent_u32m2, f16_exponent_u32m2, 1,
|
|
171
|
+
vector_length);
|
|
172
|
+
vuint32m2_t f16_mantissa_u32m2 = __riscv_vsrl_vx_u32m2(rounded_mantissa_u32m2, 13, vector_length);
|
|
173
|
+
f16_mantissa_u32m2 = __riscv_vand_vx_u32m2(f16_mantissa_u32m2, 0x3FF, vector_length);
|
|
174
|
+
// Combine: sign | (exponent << 10) | mantissa
|
|
175
|
+
vuint32m2_t result_u32m2 = __riscv_vor_vv_u32m2(
|
|
176
|
+
sign_u32m2,
|
|
177
|
+
__riscv_vor_vv_u32m2(__riscv_vsll_vx_u32m2(f16_exponent_u32m2, 10, vector_length), f16_mantissa_u32m2,
|
|
178
|
+
vector_length),
|
|
179
|
+
vector_length);
|
|
180
|
+
return __riscv_vncvt_x_x_w_u16m1(result_u32m2, vector_length);
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
/**
|
|
184
|
+
* @brief Convert e4m3 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
|
|
185
|
+
* E4M3FN: sign = bit 7, magnitude = bits 6:0 (128 entries). Sign bit 7 → f32 bit 31 (<<24).
|
|
186
|
+
*/
|
|
187
|
+
NK_INTERNAL vfloat32m4_t nk_e4m3m1_to_f32m4_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
|
|
188
|
+
static nk_u32_t const nk_e4m3_mag_to_f32_lut_[128] = {
|
|
189
|
+
0x00000000u, 0x3B000000u, 0x3B800000u, 0x3BC00000u,
|
|
190
|
+
0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u, /* [ 0.. 7] */
|
|
191
|
+
0x3C800000u, 0x3C900000u, 0x3CA00000u, 0x3CB00000u,
|
|
192
|
+
0x3CC00000u, 0x3CD00000u, 0x3CE00000u, 0x3CF00000u, /* [ 8.. 15] */
|
|
193
|
+
0x3D000000u, 0x3D100000u, 0x3D200000u, 0x3D300000u,
|
|
194
|
+
0x3D400000u, 0x3D500000u, 0x3D600000u, 0x3D700000u, /* [ 16.. 23] */
|
|
195
|
+
0x3D800000u, 0x3D900000u, 0x3DA00000u, 0x3DB00000u,
|
|
196
|
+
0x3DC00000u, 0x3DD00000u, 0x3DE00000u, 0x3DF00000u, /* [ 24.. 31] */
|
|
197
|
+
0x3E000000u, 0x3E100000u, 0x3E200000u, 0x3E300000u,
|
|
198
|
+
0x3E400000u, 0x3E500000u, 0x3E600000u, 0x3E700000u, /* [ 32.. 39] */
|
|
199
|
+
0x3E800000u, 0x3E900000u, 0x3EA00000u, 0x3EB00000u,
|
|
200
|
+
0x3EC00000u, 0x3ED00000u, 0x3EE00000u, 0x3EF00000u, /* [ 40.. 47] */
|
|
201
|
+
0x3F000000u, 0x3F100000u, 0x3F200000u, 0x3F300000u,
|
|
202
|
+
0x3F400000u, 0x3F500000u, 0x3F600000u, 0x3F700000u, /* [ 48.. 55] */
|
|
203
|
+
0x3F800000u, 0x3F900000u, 0x3FA00000u, 0x3FB00000u,
|
|
204
|
+
0x3FC00000u, 0x3FD00000u, 0x3FE00000u, 0x3FF00000u, /* [ 56.. 63] */
|
|
205
|
+
0x40000000u, 0x40100000u, 0x40200000u, 0x40300000u,
|
|
206
|
+
0x40400000u, 0x40500000u, 0x40600000u, 0x40700000u, /* [ 64.. 71] */
|
|
207
|
+
0x40800000u, 0x40900000u, 0x40A00000u, 0x40B00000u,
|
|
208
|
+
0x40C00000u, 0x40D00000u, 0x40E00000u, 0x40F00000u, /* [ 72.. 79] */
|
|
209
|
+
0x41000000u, 0x41100000u, 0x41200000u, 0x41300000u,
|
|
210
|
+
0x41400000u, 0x41500000u, 0x41600000u, 0x41700000u, /* [ 80.. 87] */
|
|
211
|
+
0x41800000u, 0x41900000u, 0x41A00000u, 0x41B00000u,
|
|
212
|
+
0x41C00000u, 0x41D00000u, 0x41E00000u, 0x41F00000u, /* [ 88.. 95] */
|
|
213
|
+
0x42000000u, 0x42100000u, 0x42200000u, 0x42300000u,
|
|
214
|
+
0x42400000u, 0x42500000u, 0x42600000u, 0x42700000u, /* [ 96..103] */
|
|
215
|
+
0x42800000u, 0x42900000u, 0x42A00000u, 0x42B00000u,
|
|
216
|
+
0x42C00000u, 0x42D00000u, 0x42E00000u, 0x42F00000u, /* [104..111] */
|
|
217
|
+
0x43000000u, 0x43100000u, 0x43200000u, 0x43300000u,
|
|
218
|
+
0x43400000u, 0x43500000u, 0x43600000u, 0x43700000u, /* [112..119] */
|
|
219
|
+
0x43800000u, 0x43900000u, 0x43A00000u, 0x43B00000u,
|
|
220
|
+
0x43C00000u, 0x43D00000u, 0x43E00000u, 0x7FC00000u /* [120..127] */
|
|
221
|
+
};
|
|
222
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
|
|
223
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
|
|
224
|
+
vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
|
|
225
|
+
vector_length);
|
|
226
|
+
vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e4m3_mag_to_f32_lut_, offsets_u32m4, vector_length);
|
|
227
|
+
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 24,
|
|
228
|
+
vector_length);
|
|
229
|
+
return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
/**
|
|
233
|
+
* @brief Convert e5m2 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
|
|
234
|
+
* E5M2: sign = bit 7, magnitude = bits 6:0 (128 entries). Sign bit 7 → f32 bit 31 (<<24).
|
|
235
|
+
*/
|
|
236
|
+
NK_INTERNAL vfloat32m4_t nk_e5m2m1_to_f32m4_rvv_(vuint8m1_t e5m2_u8m1, nk_size_t vector_length) {
|
|
237
|
+
static nk_u32_t const nk_e5m2_mag_to_f32_lut_[128] = {
|
|
238
|
+
0x00000000u, 0x37800000u, 0x38000000u, 0x38400000u,
|
|
239
|
+
0x38800000u, 0x38A00000u, 0x38C00000u, 0x38E00000u, /* [ 0.. 7] */
|
|
240
|
+
0x39000000u, 0x39200000u, 0x39400000u, 0x39600000u,
|
|
241
|
+
0x39800000u, 0x39A00000u, 0x39C00000u, 0x39E00000u, /* [ 8.. 15] */
|
|
242
|
+
0x3A000000u, 0x3A200000u, 0x3A400000u, 0x3A600000u,
|
|
243
|
+
0x3A800000u, 0x3AA00000u, 0x3AC00000u, 0x3AE00000u, /* [ 16.. 23] */
|
|
244
|
+
0x3B000000u, 0x3B200000u, 0x3B400000u, 0x3B600000u,
|
|
245
|
+
0x3B800000u, 0x3BA00000u, 0x3BC00000u, 0x3BE00000u, /* [ 24.. 31] */
|
|
246
|
+
0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u,
|
|
247
|
+
0x3C800000u, 0x3CA00000u, 0x3CC00000u, 0x3CE00000u, /* [ 32.. 39] */
|
|
248
|
+
0x3D000000u, 0x3D200000u, 0x3D400000u, 0x3D600000u,
|
|
249
|
+
0x3D800000u, 0x3DA00000u, 0x3DC00000u, 0x3DE00000u, /* [ 40.. 47] */
|
|
250
|
+
0x3E000000u, 0x3E200000u, 0x3E400000u, 0x3E600000u,
|
|
251
|
+
0x3E800000u, 0x3EA00000u, 0x3EC00000u, 0x3EE00000u, /* [ 48.. 55] */
|
|
252
|
+
0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u,
|
|
253
|
+
0x3F800000u, 0x3FA00000u, 0x3FC00000u, 0x3FE00000u, /* [ 56.. 63] */
|
|
254
|
+
0x40000000u, 0x40200000u, 0x40400000u, 0x40600000u,
|
|
255
|
+
0x40800000u, 0x40A00000u, 0x40C00000u, 0x40E00000u, /* [ 64.. 71] */
|
|
256
|
+
0x41000000u, 0x41200000u, 0x41400000u, 0x41600000u,
|
|
257
|
+
0x41800000u, 0x41A00000u, 0x41C00000u, 0x41E00000u, /* [ 72.. 79] */
|
|
258
|
+
0x42000000u, 0x42200000u, 0x42400000u, 0x42600000u,
|
|
259
|
+
0x42800000u, 0x42A00000u, 0x42C00000u, 0x42E00000u, /* [ 80.. 87] */
|
|
260
|
+
0x43000000u, 0x43200000u, 0x43400000u, 0x43600000u,
|
|
261
|
+
0x43800000u, 0x43A00000u, 0x43C00000u, 0x43E00000u, /* [ 88.. 95] */
|
|
262
|
+
0x44000000u, 0x44200000u, 0x44400000u, 0x44600000u,
|
|
263
|
+
0x44800000u, 0x44A00000u, 0x44C00000u, 0x44E00000u, /* [ 96..103] */
|
|
264
|
+
0x45000000u, 0x45200000u, 0x45400000u, 0x45600000u,
|
|
265
|
+
0x45800000u, 0x45A00000u, 0x45C00000u, 0x45E00000u, /* [104..111] */
|
|
266
|
+
0x46000000u, 0x46200000u, 0x46400000u, 0x46600000u,
|
|
267
|
+
0x46800000u, 0x46A00000u, 0x46C00000u, 0x46E00000u, /* [112..119] */
|
|
268
|
+
0x47000000u, 0x47200000u, 0x47400000u, 0x47600000u,
|
|
269
|
+
0x7F800000u, 0x7FC00000u, 0x7FC00000u, 0x7FC00000u /* [120..127] */
|
|
270
|
+
};
|
|
271
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length);
|
|
272
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
|
|
273
|
+
vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
|
|
274
|
+
vector_length);
|
|
275
|
+
vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e5m2_mag_to_f32_lut_, offsets_u32m4, vector_length);
|
|
276
|
+
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 24,
|
|
277
|
+
vector_length);
|
|
278
|
+
return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
/**
|
|
282
|
+
* @brief Convert e2m3 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
|
|
283
|
+
* E2M3FN: sign = bit 5, magnitude = bits 4:0 (32 entries). Sign bit 5 → f32 bit 31 (<<26).
|
|
284
|
+
*/
|
|
285
|
+
NK_INTERNAL vfloat32m4_t nk_e2m3m1_to_f32m4_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
|
|
286
|
+
static nk_u32_t const nk_e2m3_mag_to_f32_lut_[32] = {
|
|
287
|
+
0x00000000u, 0x3E000000u, 0x3E800000u, 0x3EC00000u,
|
|
288
|
+
0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u, /* [ 0.. 7] */
|
|
289
|
+
0x3F800000u, 0x3F900000u, 0x3FA00000u, 0x3FB00000u,
|
|
290
|
+
0x3FC00000u, 0x3FD00000u, 0x3FE00000u, 0x3FF00000u, /* [ 8.. 15] */
|
|
291
|
+
0x40000000u, 0x40100000u, 0x40200000u, 0x40300000u,
|
|
292
|
+
0x40400000u, 0x40500000u, 0x40600000u, 0x40700000u, /* [ 16.. 23] */
|
|
293
|
+
0x40800000u, 0x40900000u, 0x40A00000u, 0x40B00000u,
|
|
294
|
+
0x40C00000u, 0x40D00000u, 0x40E00000u, 0x40F00000u /* [ 24.. 31] */
|
|
295
|
+
};
|
|
296
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
|
|
297
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
|
|
298
|
+
vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
|
|
299
|
+
vector_length);
|
|
300
|
+
vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e2m3_mag_to_f32_lut_, offsets_u32m4, vector_length);
|
|
301
|
+
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 26,
|
|
302
|
+
vector_length);
|
|
303
|
+
return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
/**
|
|
307
|
+
* @brief Convert e3m2 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
|
|
308
|
+
* E3M2FN: sign = bit 5, magnitude = bits 4:0 (32 entries). Sign bit 5 → f32 bit 31 (<<26).
|
|
309
|
+
*/
|
|
310
|
+
NK_INTERNAL vfloat32m4_t nk_e3m2m1_to_f32m4_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
|
|
311
|
+
static nk_u32_t const nk_e3m2_mag_to_f32_lut_[32] = {
|
|
312
|
+
0x00000000u, 0x3D800000u, 0x3E000000u, 0x3E400000u,
|
|
313
|
+
0x3E800000u, 0x3EA00000u, 0x3EC00000u, 0x3EE00000u, /* [ 0.. 7] */
|
|
314
|
+
0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u,
|
|
315
|
+
0x3F800000u, 0x3FA00000u, 0x3FC00000u, 0x3FE00000u, /* [ 8.. 15] */
|
|
316
|
+
0x40000000u, 0x40200000u, 0x40400000u, 0x40600000u,
|
|
317
|
+
0x40800000u, 0x40A00000u, 0x40C00000u, 0x40E00000u, /* [ 16.. 23] */
|
|
318
|
+
0x41000000u, 0x41200000u, 0x41400000u, 0x41600000u,
|
|
319
|
+
0x41800000u, 0x41A00000u, 0x41C00000u, 0x41E00000u /* [ 24.. 31] */
|
|
320
|
+
};
|
|
321
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
|
|
322
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
|
|
323
|
+
vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
|
|
324
|
+
vector_length);
|
|
325
|
+
vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e3m2_mag_to_f32_lut_, offsets_u32m4, vector_length);
|
|
326
|
+
vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 26,
|
|
327
|
+
vector_length);
|
|
328
|
+
return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
/** @brief Convert e4m3 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 7 → bf16 bit 15 (<<8). */
|
|
332
|
+
NK_INTERNAL vuint16m2_t nk_e4m3m1_to_bf16m2_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
|
|
333
|
+
static nk_u16_t const nk_e4m3_mag_to_bf16_lut_[128] = {
|
|
334
|
+
0x0000u, 0x3B00u, 0x3B80u, 0x3BC0u, 0x3C00u, 0x3C20u, 0x3C40u, 0x3C60u, /* [ 0.. 7] */
|
|
335
|
+
0x3C80u, 0x3C90u, 0x3CA0u, 0x3CB0u, 0x3CC0u, 0x3CD0u, 0x3CE0u, 0x3CF0u, /* [ 8.. 15] */
|
|
336
|
+
0x3D00u, 0x3D10u, 0x3D20u, 0x3D30u, 0x3D40u, 0x3D50u, 0x3D60u, 0x3D70u, /* [ 16.. 23] */
|
|
337
|
+
0x3D80u, 0x3D90u, 0x3DA0u, 0x3DB0u, 0x3DC0u, 0x3DD0u, 0x3DE0u, 0x3DF0u, /* [ 24.. 31] */
|
|
338
|
+
0x3E00u, 0x3E10u, 0x3E20u, 0x3E30u, 0x3E40u, 0x3E50u, 0x3E60u, 0x3E70u, /* [ 32.. 39] */
|
|
339
|
+
0x3E80u, 0x3E90u, 0x3EA0u, 0x3EB0u, 0x3EC0u, 0x3ED0u, 0x3EE0u, 0x3EF0u, /* [ 40.. 47] */
|
|
340
|
+
0x3F00u, 0x3F10u, 0x3F20u, 0x3F30u, 0x3F40u, 0x3F50u, 0x3F60u, 0x3F70u, /* [ 48.. 55] */
|
|
341
|
+
0x3F80u, 0x3F90u, 0x3FA0u, 0x3FB0u, 0x3FC0u, 0x3FD0u, 0x3FE0u, 0x3FF0u, /* [ 56.. 63] */
|
|
342
|
+
0x4000u, 0x4010u, 0x4020u, 0x4030u, 0x4040u, 0x4050u, 0x4060u, 0x4070u, /* [ 64.. 71] */
|
|
343
|
+
0x4080u, 0x4090u, 0x40A0u, 0x40B0u, 0x40C0u, 0x40D0u, 0x40E0u, 0x40F0u, /* [ 72.. 79] */
|
|
344
|
+
0x4100u, 0x4110u, 0x4120u, 0x4130u, 0x4140u, 0x4150u, 0x4160u, 0x4170u, /* [ 80.. 87] */
|
|
345
|
+
0x4180u, 0x4190u, 0x41A0u, 0x41B0u, 0x41C0u, 0x41D0u, 0x41E0u, 0x41F0u, /* [ 88.. 95] */
|
|
346
|
+
0x4200u, 0x4210u, 0x4220u, 0x4230u, 0x4240u, 0x4250u, 0x4260u, 0x4270u, /* [ 96..103] */
|
|
347
|
+
0x4280u, 0x4290u, 0x42A0u, 0x42B0u, 0x42C0u, 0x42D0u, 0x42E0u, 0x42F0u, /* [104..111] */
|
|
348
|
+
0x4300u, 0x4310u, 0x4320u, 0x4330u, 0x4340u, 0x4350u, 0x4360u, 0x4370u, /* [112..119] */
|
|
349
|
+
0x4380u, 0x4390u, 0x43A0u, 0x43B0u, 0x43C0u, 0x43D0u, 0x43E0u, 0x7FC0u /* [120..127] */
|
|
350
|
+
};
|
|
351
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
|
|
352
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
|
|
353
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
|
|
354
|
+
vector_length);
|
|
355
|
+
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e4m3_mag_to_bf16_lut_, offsets_u16m2, vector_length);
|
|
356
|
+
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
|
|
357
|
+
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
/** @brief Convert e5m2 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 7 → bf16 bit 15 (<<8). */
|
|
361
|
+
NK_INTERNAL vuint16m2_t nk_e5m2m1_to_bf16m2_rvv_(vuint8m1_t e5m2_u8m1, nk_size_t vector_length) {
|
|
362
|
+
static nk_u16_t const nk_e5m2_mag_to_bf16_lut_[128] = {
|
|
363
|
+
0x0000u, 0x3780u, 0x3800u, 0x3840u, 0x3880u, 0x38A0u, 0x38C0u, 0x38E0u, /* [ 0.. 7] */
|
|
364
|
+
0x3900u, 0x3920u, 0x3940u, 0x3960u, 0x3980u, 0x39A0u, 0x39C0u, 0x39E0u, /* [ 8.. 15] */
|
|
365
|
+
0x3A00u, 0x3A20u, 0x3A40u, 0x3A60u, 0x3A80u, 0x3AA0u, 0x3AC0u, 0x3AE0u, /* [ 16.. 23] */
|
|
366
|
+
0x3B00u, 0x3B20u, 0x3B40u, 0x3B60u, 0x3B80u, 0x3BA0u, 0x3BC0u, 0x3BE0u, /* [ 24.. 31] */
|
|
367
|
+
0x3C00u, 0x3C20u, 0x3C40u, 0x3C60u, 0x3C80u, 0x3CA0u, 0x3CC0u, 0x3CE0u, /* [ 32.. 39] */
|
|
368
|
+
0x3D00u, 0x3D20u, 0x3D40u, 0x3D60u, 0x3D80u, 0x3DA0u, 0x3DC0u, 0x3DE0u, /* [ 40.. 47] */
|
|
369
|
+
0x3E00u, 0x3E20u, 0x3E40u, 0x3E60u, 0x3E80u, 0x3EA0u, 0x3EC0u, 0x3EE0u, /* [ 48.. 55] */
|
|
370
|
+
0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, 0x3F80u, 0x3FA0u, 0x3FC0u, 0x3FE0u, /* [ 56.. 63] */
|
|
371
|
+
0x4000u, 0x4020u, 0x4040u, 0x4060u, 0x4080u, 0x40A0u, 0x40C0u, 0x40E0u, /* [ 64.. 71] */
|
|
372
|
+
0x4100u, 0x4120u, 0x4140u, 0x4160u, 0x4180u, 0x41A0u, 0x41C0u, 0x41E0u, /* [ 72.. 79] */
|
|
373
|
+
0x4200u, 0x4220u, 0x4240u, 0x4260u, 0x4280u, 0x42A0u, 0x42C0u, 0x42E0u, /* [ 80.. 87] */
|
|
374
|
+
0x4300u, 0x4320u, 0x4340u, 0x4360u, 0x4380u, 0x43A0u, 0x43C0u, 0x43E0u, /* [ 88.. 95] */
|
|
375
|
+
0x4400u, 0x4420u, 0x4440u, 0x4460u, 0x4480u, 0x44A0u, 0x44C0u, 0x44E0u, /* [ 96..103] */
|
|
376
|
+
0x4500u, 0x4520u, 0x4540u, 0x4560u, 0x4580u, 0x45A0u, 0x45C0u, 0x45E0u, /* [104..111] */
|
|
377
|
+
0x4600u, 0x4620u, 0x4640u, 0x4660u, 0x4680u, 0x46A0u, 0x46C0u, 0x46E0u, /* [112..119] */
|
|
378
|
+
0x4700u, 0x4720u, 0x4740u, 0x4760u, 0x7F80u, 0x7FC0u, 0x7FC0u, 0x7FC0u /* [120..127] */
|
|
379
|
+
};
|
|
380
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length);
|
|
381
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
|
|
382
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
|
|
383
|
+
vector_length);
|
|
384
|
+
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e5m2_mag_to_bf16_lut_, offsets_u16m2, vector_length);
|
|
385
|
+
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
|
|
386
|
+
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
/** @brief Convert e2m3 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → bf16 bit 15 (<<10). */
|
|
390
|
+
NK_INTERNAL vuint16m2_t nk_e2m3m1_to_bf16m2_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
|
|
391
|
+
static nk_u16_t const nk_e2m3_mag_to_bf16_lut_[32] = {
|
|
392
|
+
0x0000u, 0x3E00u, 0x3E80u, 0x3EC0u, 0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, /* [ 0.. 7] */
|
|
393
|
+
0x3F80u, 0x3F90u, 0x3FA0u, 0x3FB0u, 0x3FC0u, 0x3FD0u, 0x3FE0u, 0x3FF0u, /* [ 8.. 15] */
|
|
394
|
+
0x4000u, 0x4010u, 0x4020u, 0x4030u, 0x4040u, 0x4050u, 0x4060u, 0x4070u, /* [ 16.. 23] */
|
|
395
|
+
0x4080u, 0x4090u, 0x40A0u, 0x40B0u, 0x40C0u, 0x40D0u, 0x40E0u, 0x40F0u /* [ 24.. 31] */
|
|
396
|
+
};
|
|
397
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
|
|
398
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
|
|
399
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
|
|
400
|
+
vector_length);
|
|
401
|
+
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e2m3_mag_to_bf16_lut_, offsets_u16m2, vector_length);
|
|
402
|
+
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
|
|
403
|
+
vector_length);
|
|
404
|
+
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
/** @brief Convert e3m2 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → bf16 bit 15 (<<10). */
|
|
408
|
+
NK_INTERNAL vuint16m2_t nk_e3m2m1_to_bf16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
|
|
409
|
+
static nk_u16_t const nk_e3m2_mag_to_bf16_lut_[32] = {
|
|
410
|
+
0x0000u, 0x3D80u, 0x3E00u, 0x3E40u, 0x3E80u, 0x3EA0u, 0x3EC0u, 0x3EE0u, /* [ 0.. 7] */
|
|
411
|
+
0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, 0x3F80u, 0x3FA0u, 0x3FC0u, 0x3FE0u, /* [ 8.. 15] */
|
|
412
|
+
0x4000u, 0x4020u, 0x4040u, 0x4060u, 0x4080u, 0x40A0u, 0x40C0u, 0x40E0u, /* [ 16.. 23] */
|
|
413
|
+
0x4100u, 0x4120u, 0x4140u, 0x4160u, 0x4180u, 0x41A0u, 0x41C0u, 0x41E0u /* [ 24.. 31] */
|
|
414
|
+
};
|
|
415
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
|
|
416
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
|
|
417
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
|
|
418
|
+
vector_length);
|
|
419
|
+
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_mag_to_bf16_lut_, offsets_u16m2, vector_length);
|
|
420
|
+
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
|
|
421
|
+
vector_length);
|
|
422
|
+
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
/** @brief Convert e4m3 (m1) to f16 (m2) via sign-symmetric magnitude LUT. Sign bit 7 → f16 bit 15 (<<8). */
|
|
426
|
+
NK_INTERNAL vuint16m2_t nk_e4m3m1_to_f16m2_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
|
|
427
|
+
static nk_u16_t const nk_e4m3_mag_to_f16_lut_[128] = {
|
|
428
|
+
0x0000u, 0x1800u, 0x1C00u, 0x1E00u, 0x2000u, 0x2100u, 0x2200u, 0x2300u, /* [ 0.. 7] */
|
|
429
|
+
0x2400u, 0x2480u, 0x2500u, 0x2580u, 0x2600u, 0x2680u, 0x2700u, 0x2780u, /* [ 8.. 15] */
|
|
430
|
+
0x2800u, 0x2880u, 0x2900u, 0x2980u, 0x2A00u, 0x2A80u, 0x2B00u, 0x2B80u, /* [ 16.. 23] */
|
|
431
|
+
0x2C00u, 0x2C80u, 0x2D00u, 0x2D80u, 0x2E00u, 0x2E80u, 0x2F00u, 0x2F80u, /* [ 24.. 31] */
|
|
432
|
+
0x3000u, 0x3080u, 0x3100u, 0x3180u, 0x3200u, 0x3280u, 0x3300u, 0x3380u, /* [ 32.. 39] */
|
|
433
|
+
0x3400u, 0x3480u, 0x3500u, 0x3580u, 0x3600u, 0x3680u, 0x3700u, 0x3780u, /* [ 40.. 47] */
|
|
434
|
+
0x3800u, 0x3880u, 0x3900u, 0x3980u, 0x3A00u, 0x3A80u, 0x3B00u, 0x3B80u, /* [ 48.. 55] */
|
|
435
|
+
0x3C00u, 0x3C80u, 0x3D00u, 0x3D80u, 0x3E00u, 0x3E80u, 0x3F00u, 0x3F80u, /* [ 56.. 63] */
|
|
436
|
+
0x4000u, 0x4080u, 0x4100u, 0x4180u, 0x4200u, 0x4280u, 0x4300u, 0x4380u, /* [ 64.. 71] */
|
|
437
|
+
0x4400u, 0x4480u, 0x4500u, 0x4580u, 0x4600u, 0x4680u, 0x4700u, 0x4780u, /* [ 72.. 79] */
|
|
438
|
+
0x4800u, 0x4880u, 0x4900u, 0x4980u, 0x4A00u, 0x4A80u, 0x4B00u, 0x4B80u, /* [ 80.. 87] */
|
|
439
|
+
0x4C00u, 0x4C80u, 0x4D00u, 0x4D80u, 0x4E00u, 0x4E80u, 0x4F00u, 0x4F80u, /* [ 88.. 95] */
|
|
440
|
+
0x5000u, 0x5080u, 0x5100u, 0x5180u, 0x5200u, 0x5280u, 0x5300u, 0x5380u, /* [ 96..103] */
|
|
441
|
+
0x5400u, 0x5480u, 0x5500u, 0x5580u, 0x5600u, 0x5680u, 0x5700u, 0x5780u, /* [104..111] */
|
|
442
|
+
0x5800u, 0x5880u, 0x5900u, 0x5980u, 0x5A00u, 0x5A80u, 0x5B00u, 0x5B80u, /* [112..119] */
|
|
443
|
+
0x5C00u, 0x5C80u, 0x5D00u, 0x5D80u, 0x5E00u, 0x5E80u, 0x5F00u, 0x7E00u /* [120..127] */
|
|
444
|
+
};
|
|
445
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
|
|
446
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
|
|
447
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
|
|
448
|
+
vector_length);
|
|
449
|
+
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e4m3_mag_to_f16_lut_, offsets_u16m2, vector_length);
|
|
450
|
+
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
|
|
451
|
+
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
/** @brief Convert e2m3 (m1) to f16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → f16 bit 15 (<<10). */
|
|
455
|
+
NK_INTERNAL vuint16m2_t nk_e2m3m1_to_f16m2_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
|
|
456
|
+
static nk_u16_t const nk_e2m3_mag_to_f16_lut_[32] = {
|
|
457
|
+
0x0000u, 0x3000u, 0x3400u, 0x3600u, 0x3800u, 0x3900u, 0x3A00u, 0x3B00u, /* [ 0.. 7] */
|
|
458
|
+
0x3C00u, 0x3C80u, 0x3D00u, 0x3D80u, 0x3E00u, 0x3E80u, 0x3F00u, 0x3F80u, /* [ 8.. 15] */
|
|
459
|
+
0x4000u, 0x4080u, 0x4100u, 0x4180u, 0x4200u, 0x4280u, 0x4300u, 0x4380u, /* [ 16.. 23] */
|
|
460
|
+
0x4400u, 0x4480u, 0x4500u, 0x4580u, 0x4600u, 0x4680u, 0x4700u, 0x4780u /* [ 24.. 31] */
|
|
461
|
+
};
|
|
462
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
|
|
463
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
|
|
464
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
|
|
465
|
+
vector_length);
|
|
466
|
+
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e2m3_mag_to_f16_lut_, offsets_u16m2, vector_length);
|
|
467
|
+
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
|
|
468
|
+
vector_length);
|
|
469
|
+
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
/** @brief Convert e3m2 (m1) to f16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → f16 bit 15 (<<10). */
|
|
473
|
+
NK_INTERNAL vuint16m2_t nk_e3m2m1_to_f16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
|
|
474
|
+
static nk_u16_t const nk_e3m2_mag_to_f16_lut_[32] = {
|
|
475
|
+
0x0000u, 0x2C00u, 0x3000u, 0x3200u, 0x3400u, 0x3500u, 0x3600u, 0x3700u, /* [ 0.. 7] */
|
|
476
|
+
0x3800u, 0x3900u, 0x3A00u, 0x3B00u, 0x3C00u, 0x3D00u, 0x3E00u, 0x3F00u, /* [ 8.. 15] */
|
|
477
|
+
0x4000u, 0x4100u, 0x4200u, 0x4300u, 0x4400u, 0x4500u, 0x4600u, 0x4700u, /* [ 16.. 23] */
|
|
478
|
+
0x4800u, 0x4900u, 0x4A00u, 0x4B00u, 0x4C00u, 0x4D00u, 0x4E00u, 0x4F00u /* [ 24.. 31] */
|
|
479
|
+
};
|
|
480
|
+
vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
|
|
481
|
+
vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
|
|
482
|
+
vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
|
|
483
|
+
vector_length);
|
|
484
|
+
vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_mag_to_f16_lut_, offsets_u16m2, vector_length);
|
|
485
|
+
vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
|
|
486
|
+
vector_length);
|
|
487
|
+
return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
/**
|
|
491
|
+
* @brief Unpack i4 (m1) nibbles to i8 (m2) register-to-register.
|
|
492
|
+
*
|
|
493
|
+
* Packed format: byte[i] contains two nibbles:
|
|
494
|
+
* - High nibble (bits [7:4]) → output[i*2]
|
|
495
|
+
* - Low nibble (bits [3:0]) → output[i*2+1]
|
|
496
|
+
*
|
|
497
|
+
* Sign extension: 4-bit signed value [-8,7] extended to 8-bit.
|
|
498
|
+
* Trick: (x ^ 8) - 8 sign-extends a 4-bit value to larger type.
|
|
499
|
+
*
|
|
500
|
+
* Returns a tuple of two m1 vectors (high nibbles, low nibbles) for segment store.
|
|
501
|
+
*/
|
|
502
|
+
NK_INTERNAL vint8m1x2_t nk_i4m1_to_i8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t vector_length) {
|
|
503
|
+
// Extract high nibble (even indices in output)
|
|
504
|
+
vuint8m1_t hi_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
|
|
505
|
+
// Sign extend: (x ^ 8) - 8
|
|
506
|
+
vint8m1_t hi_i8m1 = __riscv_vsub_vx_i8m1(
|
|
507
|
+
__riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(hi_u8m1), 8, vector_length), 8, vector_length);
|
|
508
|
+
|
|
509
|
+
// Extract low nibble (odd indices in output)
|
|
510
|
+
vuint8m1_t lo_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
|
|
511
|
+
// Sign extend: (x ^ 8) - 8
|
|
512
|
+
vint8m1_t lo_i8m1 = __riscv_vsub_vx_i8m1(
|
|
513
|
+
__riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(lo_u8m1), 8, vector_length), 8, vector_length);
|
|
514
|
+
|
|
515
|
+
return __riscv_vcreate_v_i8m1x2(hi_i8m1, lo_i8m1);
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
/**
|
|
519
|
+
* @brief Unpack u4 (m1) nibbles to u8 (m2) register-to-register.
|
|
520
|
+
*
|
|
521
|
+
* Returns a tuple of two m1 vectors (high nibbles, low nibbles) for segment store.
|
|
522
|
+
*/
|
|
523
|
+
NK_INTERNAL vuint8m1x2_t nk_u4m1_to_u8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t vector_length) {
|
|
524
|
+
// Extract high nibble (even indices in output)
|
|
525
|
+
vuint8m1_t hi_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
|
|
526
|
+
|
|
527
|
+
// Extract low nibble (odd indices in output)
|
|
528
|
+
vuint8m1_t lo_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
|
|
529
|
+
|
|
530
|
+
return __riscv_vcreate_v_u8m1x2(hi_u8m1, lo_u8m1);
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
/**
|
|
534
|
+
* @brief Pack i8 (m2) to i4 (m1) nibbles register-to-register.
|
|
535
|
+
*
|
|
536
|
+
* Takes a tuple of two m1 vectors (high nibbles, low nibbles from segment load).
|
|
537
|
+
* Values are clamped to [-8, 7] before packing.
|
|
538
|
+
*/
|
|
539
|
+
NK_INTERNAL vuint8m1_t nk_i8m2_to_i4m1_rvv_(vint8m1_t hi_i8m1, vint8m1_t lo_i8m1, nk_size_t vector_length) {
|
|
540
|
+
// Clamp to [-8, 7]
|
|
541
|
+
hi_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(hi_i8m1, 7, vector_length), -8, vector_length);
|
|
542
|
+
lo_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(lo_i8m1, 7, vector_length), -8, vector_length);
|
|
543
|
+
|
|
544
|
+
// Convert to unsigned nibbles: value & 0x0F
|
|
545
|
+
vuint8m1_t hi_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(hi_i8m1), 0x0F, vector_length);
|
|
546
|
+
vuint8m1_t lo_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(lo_i8m1), 0x0F, vector_length);
|
|
547
|
+
|
|
548
|
+
// Pack: (hi << 4) | lo
|
|
549
|
+
return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(hi_u4m1, 4, vector_length), lo_u4m1, vector_length);
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
/**
|
|
553
|
+
* @brief Pack u8 (m2) to u4 (m1) nibbles register-to-register.
|
|
554
|
+
*
|
|
555
|
+
* Takes a tuple of two m1 vectors (high nibbles, low nibbles from segment load).
|
|
556
|
+
* Values are clamped to [0, 15] before packing.
|
|
557
|
+
*/
|
|
558
|
+
NK_INTERNAL vuint8m1_t nk_u8m2_to_u4m1_rvv_(vuint8m1_t hi_u8m1, vuint8m1_t lo_u8m1, nk_size_t vector_length) {
|
|
559
|
+
// Clamp to [0, 15]
|
|
560
|
+
hi_u8m1 = __riscv_vminu_vx_u8m1(hi_u8m1, 15, vector_length);
|
|
561
|
+
lo_u8m1 = __riscv_vminu_vx_u8m1(lo_u8m1, 15, vector_length);
|
|
562
|
+
|
|
563
|
+
// Pack: (hi << 4) | lo
|
|
564
|
+
return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(hi_u8m1, 4, vector_length), lo_u8m1, vector_length);
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
/**
|
|
568
|
+
* @brief Convert f32 (m4) to e4m3 (m1) register-to-register.
|
|
569
|
+
*
|
|
570
|
+
* E4M3FN format: S EEEE MMM (1 sign, 4 exponent bits with bias=7, 3 mantissa bits)
|
|
571
|
+
* Handles normal, subnormal, overflow, and NaN. Uses RNE mantissa rounding.
|
|
572
|
+
* E4M3FN quirk: exp=15 with mant=7 is NaN (0x7F), so max finite is 0x7E (exp=15, mant=6).
|
|
573
|
+
*/
|
|
574
|
+
NK_INTERNAL vuint8m1_t nk_f32m4_to_e4m3m1_rvv_(vfloat32m4_t f32_f32m4, nk_size_t vector_length) {
|
|
575
|
+
vuint32m4_t bits_u32m4 = __riscv_vreinterpret_v_f32m4_u32m4(f32_f32m4);
|
|
576
|
+
vuint32m4_t sign_u32m4 = __riscv_vsrl_vx_u32m4(bits_u32m4, 31, vector_length);
|
|
577
|
+
vuint32m4_t abs_bits_u32m4 = __riscv_vand_vx_u32m4(bits_u32m4, 0x7FFFFFFF, vector_length);
|
|
578
|
+
vuint32m4_t f32_exp_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(bits_u32m4, 23, vector_length), 0xFF,
|
|
579
|
+
vector_length);
|
|
580
|
+
|
|
581
|
+
// Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
|
|
582
|
+
vuint32m4_t significand_u32m4 = __riscv_vor_vx_u32m4(__riscv_vand_vx_u32m4(bits_u32m4, 0x007FFFFF, vector_length),
|
|
583
|
+
0x00800000, vector_length);
|
|
584
|
+
vuint32m4_t lsb_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(significand_u32m4, 20, vector_length), 1,
|
|
585
|
+
vector_length);
|
|
586
|
+
vuint32m4_t rounding_bias_u32m4 = __riscv_vadd_vx_u32m4(lsb_u32m4, 0x0007FFFF, vector_length);
|
|
587
|
+
vuint32m4_t rounded_sig_u32m4 = __riscv_vadd_vv_u32m4(significand_u32m4, rounding_bias_u32m4, vector_length);
|
|
588
|
+
vuint32m4_t carry_u32m4 = __riscv_vsrl_vx_u32m4(rounded_sig_u32m4, 24, vector_length);
|
|
589
|
+
vuint32m4_t f32_mantissa_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(rounded_sig_u32m4, 20, vector_length),
|
|
590
|
+
0x07, vector_length);
|
|
591
|
+
// If carry, mantissa becomes 0 (rounded up to next power of 2)
|
|
592
|
+
vbool8_t has_carry_b8 = __riscv_vmsne_vx_u32m4_b8(carry_u32m4, 0, vector_length);
|
|
593
|
+
f32_mantissa_u32m4 = __riscv_vmerge_vxm_u32m4(f32_mantissa_u32m4, 0, has_carry_b8, vector_length);
|
|
594
|
+
|
|
595
|
+
// e4m3_exp = f32_exp + carry - 120
|
|
596
|
+
vint32m4_t e4m3_exp_i32m4 = __riscv_vsub_vx_i32m4(
|
|
597
|
+
__riscv_vreinterpret_v_u32m4_i32m4(__riscv_vadd_vv_u32m4(f32_exp_u32m4, carry_u32m4, vector_length)), 120,
|
|
598
|
+
vector_length);
|
|
599
|
+
|
|
600
|
+
// Detect subnormal (exp <= 0) and overflow (exp > 15)
|
|
601
|
+
vbool8_t is_subnormal_b8 = __riscv_vmsle_vx_i32m4_b8(e4m3_exp_i32m4, 0, vector_length);
|
|
602
|
+
vbool8_t is_overflow_b8 = __riscv_vmsgt_vx_i32m4_b8(e4m3_exp_i32m4, 15, vector_length);
|
|
603
|
+
|
|
604
|
+
// Normal path: clamp exp to [1,15]
|
|
605
|
+
vint32m4_t clamped_exp_i32m4 = __riscv_vmax_vx_i32m4(e4m3_exp_i32m4, 1, vector_length);
|
|
606
|
+
clamped_exp_i32m4 = __riscv_vmin_vx_i32m4(clamped_exp_i32m4, 15, vector_length);
|
|
607
|
+
// E4M3FN quirk: exp=15 with mant=7 is NaN, so cap mantissa to 6 when exp=15
|
|
608
|
+
vbool8_t is_max_exp_b8 = __riscv_vmseq_vx_i32m4_b8(clamped_exp_i32m4, 15, vector_length);
|
|
609
|
+
vuint32m4_t max_mant_u32m4 = __riscv_vmerge_vxm_u32m4(__riscv_vmv_v_x_u32m4(7, vector_length), 6, is_max_exp_b8,
|
|
610
|
+
vector_length);
|
|
611
|
+
vuint32m4_t normal_mant_u32m4 = __riscv_vminu_vv_u32m4(f32_mantissa_u32m4, max_mant_u32m4, vector_length);
|
|
612
|
+
// On overflow, saturate to max finite (exp=15, mant=6 = 0x7E with sign)
|
|
613
|
+
normal_mant_u32m4 = __riscv_vmerge_vxm_u32m4(normal_mant_u32m4, 0x06, is_overflow_b8, vector_length);
|
|
614
|
+
vuint32m4_t normal_u32m4 = __riscv_vor_vv_u32m4(
|
|
615
|
+
__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length),
|
|
616
|
+
__riscv_vor_vv_u32m4(
|
|
617
|
+
__riscv_vsll_vx_u32m4(__riscv_vreinterpret_v_i32m4_u32m4(clamped_exp_i32m4), 3, vector_length),
|
|
618
|
+
normal_mant_u32m4, vector_length),
|
|
619
|
+
vector_length);
|
|
620
|
+
|
|
621
|
+
// Subnormal path: mantissa = round(|f32| * 512)
|
|
622
|
+
vfloat32m4_t abs_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(abs_bits_u32m4);
|
|
623
|
+
vfloat32m4_t scaled_f32m4 = __riscv_vfmul_vf_f32m4(abs_f32m4, 512.0f, vector_length);
|
|
624
|
+
vint32m4_t subnorm_mant_i32m4 = __riscv_vfcvt_x_f_v_i32m4(scaled_f32m4, vector_length); // RNE rounding
|
|
625
|
+
// If rounds to 8+, promote to first normal (exp=1, mant=0 = 0x08)
|
|
626
|
+
vbool8_t promotes_b8 = __riscv_vmsgt_vx_i32m4_b8(subnorm_mant_i32m4, 7, vector_length);
|
|
627
|
+
subnorm_mant_i32m4 = __riscv_vmin_vx_i32m4(subnorm_mant_i32m4, 7, vector_length);
|
|
628
|
+
subnorm_mant_i32m4 = __riscv_vmax_vx_i32m4(subnorm_mant_i32m4, 0, vector_length);
|
|
629
|
+
vuint32m4_t subnorm_u32m4 = __riscv_vor_vv_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length),
|
|
630
|
+
__riscv_vreinterpret_v_i32m4_u32m4(subnorm_mant_i32m4),
|
|
631
|
+
vector_length);
|
|
632
|
+
vuint32m4_t first_normal_u32m4 = __riscv_vor_vx_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length), 0x08,
|
|
633
|
+
vector_length);
|
|
634
|
+
subnorm_u32m4 = __riscv_vmerge_vvm_u32m4(subnorm_u32m4, first_normal_u32m4, promotes_b8, vector_length);
|
|
635
|
+
|
|
636
|
+
// Select: subnormal when exp <= 0, else normal
|
|
637
|
+
vuint32m4_t result_u32m4 = __riscv_vmerge_vvm_u32m4(normal_u32m4, subnorm_u32m4, is_subnormal_b8, vector_length);
|
|
638
|
+
|
|
639
|
+
// Handle NaN: f32 NaN (abs_bits > 0x7F800000) → e4m3 NaN (sign | 0x7F)
|
|
640
|
+
vbool8_t is_nan_b8 = __riscv_vmsgtu_vx_u32m4_b8(abs_bits_u32m4, 0x7F800000, vector_length);
|
|
641
|
+
vuint32m4_t nan_u32m4 = __riscv_vor_vx_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length), 0x7F,
|
|
642
|
+
vector_length);
|
|
643
|
+
result_u32m4 = __riscv_vmerge_vvm_u32m4(result_u32m4, nan_u32m4, is_nan_b8, vector_length);
|
|
644
|
+
|
|
645
|
+
// Narrow u32m4 → u16m2 → u8m1
|
|
646
|
+
vuint16m2_t result_u16m2 = __riscv_vncvt_x_x_w_u16m2(result_u32m4, vector_length);
|
|
647
|
+
return __riscv_vncvt_x_x_w_u8m1(result_u16m2, vector_length);
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
/**
|
|
651
|
+
* @brief Convert f32 (m4) to e5m2 (m1) register-to-register.
|
|
652
|
+
*
|
|
653
|
+
* E5M2 format: S EEEEE MM (1 sign, 5 exponent bits with bias=15, 2 mantissa bits)
|
|
654
|
+
* Handles normal, subnormal, overflow (→ infinity), and NaN. Uses RNE mantissa rounding.
|
|
655
|
+
*/
|
|
656
|
+
NK_INTERNAL vuint8m1_t nk_f32m4_to_e5m2m1_rvv_(vfloat32m4_t f32_f32m4, nk_size_t vector_length) {
|
|
657
|
+
vuint32m4_t bits_u32m4 = __riscv_vreinterpret_v_f32m4_u32m4(f32_f32m4);
|
|
658
|
+
vuint32m4_t sign_u32m4 = __riscv_vsrl_vx_u32m4(bits_u32m4, 31, vector_length);
|
|
659
|
+
vuint32m4_t abs_bits_u32m4 = __riscv_vand_vx_u32m4(bits_u32m4, 0x7FFFFFFF, vector_length);
|
|
660
|
+
vuint32m4_t f32_exp_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(bits_u32m4, 23, vector_length), 0xFF,
|
|
661
|
+
vector_length);
|
|
662
|
+
|
|
663
|
+
// Round mantissa from 23 to 2 bits using RNE
|
|
664
|
+
vuint32m4_t significand_u32m4 = __riscv_vor_vx_u32m4(__riscv_vand_vx_u32m4(bits_u32m4, 0x007FFFFF, vector_length),
|
|
665
|
+
0x00800000, vector_length);
|
|
666
|
+
vuint32m4_t lsb_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(significand_u32m4, 21, vector_length), 1,
|
|
667
|
+
vector_length);
|
|
668
|
+
vuint32m4_t rounding_bias_u32m4 = __riscv_vadd_vx_u32m4(lsb_u32m4, 0x000FFFFF, vector_length);
|
|
669
|
+
vuint32m4_t rounded_sig_u32m4 = __riscv_vadd_vv_u32m4(significand_u32m4, rounding_bias_u32m4, vector_length);
|
|
670
|
+
vuint32m4_t carry_u32m4 = __riscv_vsrl_vx_u32m4(rounded_sig_u32m4, 24, vector_length);
|
|
671
|
+
vuint32m4_t f32_mantissa_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(rounded_sig_u32m4, 21, vector_length),
|
|
672
|
+
0x03, vector_length);
|
|
673
|
+
vbool8_t has_carry_b8 = __riscv_vmsne_vx_u32m4_b8(carry_u32m4, 0, vector_length);
|
|
674
|
+
f32_mantissa_u32m4 = __riscv_vmerge_vxm_u32m4(f32_mantissa_u32m4, 0, has_carry_b8, vector_length);
|
|
675
|
+
|
|
676
|
+
// e5m2_exp = f32_exp + carry - 112
|
|
677
|
+
vint32m4_t e5m2_exp_i32m4 = __riscv_vsub_vx_i32m4(
|
|
678
|
+
__riscv_vreinterpret_v_u32m4_i32m4(__riscv_vadd_vv_u32m4(f32_exp_u32m4, carry_u32m4, vector_length)), 112,
|
|
679
|
+
vector_length);
|
|
680
|
+
|
|
681
|
+
// Detect subnormal (exp <= 0) and overflow (exp > 31)
|
|
682
|
+
vbool8_t is_subnormal_b8 = __riscv_vmsle_vx_i32m4_b8(e5m2_exp_i32m4, 0, vector_length);
|
|
683
|
+
vbool8_t is_overflow_b8 = __riscv_vmsgt_vx_i32m4_b8(e5m2_exp_i32m4, 31, vector_length);
|
|
684
|
+
|
|
685
|
+
// Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mant=0)
|
|
686
|
+
vint32m4_t clamped_exp_i32m4 = __riscv_vmax_vx_i32m4(e5m2_exp_i32m4, 1, vector_length);
|
|
687
|
+
clamped_exp_i32m4 = __riscv_vmin_vx_i32m4(clamped_exp_i32m4, 31, vector_length);
|
|
688
|
+
vuint32m4_t normal_mant_u32m4 = __riscv_vmerge_vxm_u32m4(f32_mantissa_u32m4, 0, is_overflow_b8, vector_length);
|
|
689
|
+
vuint32m4_t normal_u32m4 = __riscv_vor_vv_u32m4(
|
|
690
|
+
__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length),
|
|
691
|
+
__riscv_vor_vv_u32m4(
|
|
692
|
+
__riscv_vsll_vx_u32m4(__riscv_vreinterpret_v_i32m4_u32m4(clamped_exp_i32m4), 2, vector_length),
|
|
693
|
+
normal_mant_u32m4, vector_length),
|
|
694
|
+
vector_length);
|
|
695
|
+
|
|
696
|
+
// Subnormal path: mantissa = round(|f32| * 65536)
|
|
697
|
+
vfloat32m4_t abs_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(abs_bits_u32m4);
|
|
698
|
+
vfloat32m4_t scaled_f32m4 = __riscv_vfmul_vf_f32m4(abs_f32m4, 65536.0f, vector_length);
|
|
699
|
+
vint32m4_t subnorm_mant_i32m4 = __riscv_vfcvt_x_f_v_i32m4(scaled_f32m4, vector_length);
|
|
700
|
+
vbool8_t promotes_b8 = __riscv_vmsgt_vx_i32m4_b8(subnorm_mant_i32m4, 3, vector_length);
|
|
701
|
+
subnorm_mant_i32m4 = __riscv_vmin_vx_i32m4(subnorm_mant_i32m4, 3, vector_length);
|
|
702
|
+
subnorm_mant_i32m4 = __riscv_vmax_vx_i32m4(subnorm_mant_i32m4, 0, vector_length);
|
|
703
|
+
vuint32m4_t subnorm_u32m4 = __riscv_vor_vv_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length),
|
|
704
|
+
__riscv_vreinterpret_v_i32m4_u32m4(subnorm_mant_i32m4),
|
|
705
|
+
vector_length);
|
|
706
|
+
vuint32m4_t first_normal_u32m4 = __riscv_vor_vx_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length), 0x04,
|
|
707
|
+
vector_length);
|
|
708
|
+
subnorm_u32m4 = __riscv_vmerge_vvm_u32m4(subnorm_u32m4, first_normal_u32m4, promotes_b8, vector_length);
|
|
709
|
+
|
|
710
|
+
// Select: subnormal when exp <= 0, else normal
|
|
711
|
+
vuint32m4_t result_u32m4 = __riscv_vmerge_vvm_u32m4(normal_u32m4, subnorm_u32m4, is_subnormal_b8, vector_length);
|
|
712
|
+
|
|
713
|
+
// Handle NaN: f32 NaN (abs_bits > 0x7F800000) → e5m2 NaN (sign | 0x7D)
|
|
714
|
+
vbool8_t is_nan_b8 = __riscv_vmsgtu_vx_u32m4_b8(abs_bits_u32m4, 0x7F800000, vector_length);
|
|
715
|
+
vuint32m4_t nan_u32m4 = __riscv_vor_vx_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length), 0x7D,
|
|
716
|
+
vector_length);
|
|
717
|
+
result_u32m4 = __riscv_vmerge_vvm_u32m4(result_u32m4, nan_u32m4, is_nan_b8, vector_length);
|
|
718
|
+
|
|
719
|
+
// Narrow u32m4 → u16m2 → u8m1
|
|
720
|
+
vuint16m2_t result_u16m2 = __riscv_vncvt_x_x_w_u16m2(result_u32m4, vector_length);
|
|
721
|
+
return __riscv_vncvt_x_x_w_u8m1(result_u16m2, vector_length);
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
#pragma endregion - Register - to - Register Helpers
|
|
725
|
+
|
|
726
|
+
#pragma region - Unified Cast Dispatcher
|
|
727
|
+
|
|
728
|
+
NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t count, void *to, nk_dtype_t to_type) {
|
|
729
|
+
// bf16 → f32
|
|
730
|
+
if (from_type == nk_bf16_k && to_type == nk_f32_k) {
|
|
731
|
+
nk_bf16_t const *source = (nk_bf16_t const *)from;
|
|
732
|
+
nk_f32_t *destination = (nk_f32_t *)to;
|
|
733
|
+
for (nk_size_t vector_length; count > 0;
|
|
734
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
735
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
736
|
+
vuint16m1_t bf16_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)source, vector_length);
|
|
737
|
+
vfloat32m2_t f32_f32m2 = nk_bf16m1_to_f32m2_rvv_(bf16_u16m1, vector_length);
|
|
738
|
+
__riscv_vse32_v_f32m2(destination, f32_f32m2, vector_length);
|
|
739
|
+
}
|
|
740
|
+
return;
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
// f32 → bf16
|
|
744
|
+
if (from_type == nk_f32_k && to_type == nk_bf16_k) {
|
|
745
|
+
nk_f32_t const *source = (nk_f32_t const *)from;
|
|
746
|
+
nk_bf16_t *destination = (nk_bf16_t *)to;
|
|
747
|
+
for (nk_size_t vector_length; count > 0;
|
|
748
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
749
|
+
vector_length = __riscv_vsetvl_e32m2(count);
|
|
750
|
+
vfloat32m2_t f32_f32m2 = __riscv_vle32_v_f32m2(source, vector_length);
|
|
751
|
+
vuint16m1_t bf16_u16m1 = nk_f32m2_to_bf16m1_rvv_(f32_f32m2, vector_length);
|
|
752
|
+
__riscv_vse16_v_u16m1((nk_u16_t *)destination, bf16_u16m1, vector_length);
|
|
753
|
+
}
|
|
754
|
+
return;
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
// f16 → f32
|
|
758
|
+
if (from_type == nk_f16_k && to_type == nk_f32_k) {
|
|
759
|
+
nk_f16_t const *source = (nk_f16_t const *)from;
|
|
760
|
+
nk_f32_t *destination = (nk_f32_t *)to;
|
|
761
|
+
for (nk_size_t vector_length; count > 0;
|
|
762
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
763
|
+
vector_length = __riscv_vsetvl_e16m1(count);
|
|
764
|
+
vuint16m1_t f16_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)source, vector_length);
|
|
765
|
+
vfloat32m2_t f32_f32m2 = nk_f16m1_to_f32m2_rvv_(f16_u16m1, vector_length);
|
|
766
|
+
__riscv_vse32_v_f32m2(destination, f32_f32m2, vector_length);
|
|
767
|
+
}
|
|
768
|
+
return;
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
// f32 → f16
|
|
772
|
+
if (from_type == nk_f32_k && to_type == nk_f16_k) {
|
|
773
|
+
nk_f32_t const *source = (nk_f32_t const *)from;
|
|
774
|
+
nk_f16_t *destination = (nk_f16_t *)to;
|
|
775
|
+
for (nk_size_t vector_length; count > 0;
|
|
776
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
777
|
+
vector_length = __riscv_vsetvl_e32m2(count);
|
|
778
|
+
vfloat32m2_t f32_f32m2 = __riscv_vle32_v_f32m2(source, vector_length);
|
|
779
|
+
vuint16m1_t f16_u16m1 = nk_f32m2_to_f16m1_rvv_(f32_f32m2, vector_length);
|
|
780
|
+
__riscv_vse16_v_u16m1((nk_u16_t *)destination, f16_u16m1, vector_length);
|
|
781
|
+
}
|
|
782
|
+
return;
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
// e4m3 → f32
|
|
786
|
+
if (from_type == nk_e4m3_k && to_type == nk_f32_k) {
|
|
787
|
+
nk_e4m3_t const *source = (nk_e4m3_t const *)from;
|
|
788
|
+
nk_f32_t *destination = (nk_f32_t *)to;
|
|
789
|
+
for (nk_size_t vector_length; count > 0;
|
|
790
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
791
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
792
|
+
vuint8m1_t e4m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
793
|
+
vfloat32m4_t f32_f32m4 = nk_e4m3m1_to_f32m4_rvv_(e4m3_u8m1, vector_length);
|
|
794
|
+
__riscv_vse32_v_f32m4(destination, f32_f32m4, vector_length);
|
|
795
|
+
}
|
|
796
|
+
return;
|
|
797
|
+
}
|
|
798
|
+
|
|
799
|
+
// e5m2 → f32
|
|
800
|
+
if (from_type == nk_e5m2_k && to_type == nk_f32_k) {
|
|
801
|
+
nk_e5m2_t const *source = (nk_e5m2_t const *)from;
|
|
802
|
+
nk_f32_t *destination = (nk_f32_t *)to;
|
|
803
|
+
for (nk_size_t vector_length; count > 0;
|
|
804
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
805
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
806
|
+
vuint8m1_t e5m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
807
|
+
vfloat32m4_t f32_f32m4 = nk_e5m2m1_to_f32m4_rvv_(e5m2_u8m1, vector_length);
|
|
808
|
+
__riscv_vse32_v_f32m4(destination, f32_f32m4, vector_length);
|
|
809
|
+
}
|
|
810
|
+
return;
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
// e2m3 → f32
|
|
814
|
+
if (from_type == nk_e2m3_k && to_type == nk_f32_k) {
|
|
815
|
+
nk_e2m3_t const *source = (nk_e2m3_t const *)from;
|
|
816
|
+
nk_f32_t *destination = (nk_f32_t *)to;
|
|
817
|
+
for (nk_size_t vector_length; count > 0;
|
|
818
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
819
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
820
|
+
vuint8m1_t e2m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
821
|
+
vfloat32m4_t f32_f32m4 = nk_e2m3m1_to_f32m4_rvv_(e2m3_u8m1, vector_length);
|
|
822
|
+
__riscv_vse32_v_f32m4(destination, f32_f32m4, vector_length);
|
|
823
|
+
}
|
|
824
|
+
return;
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
// e3m2 → f32
|
|
828
|
+
if (from_type == nk_e3m2_k && to_type == nk_f32_k) {
|
|
829
|
+
nk_e3m2_t const *source = (nk_e3m2_t const *)from;
|
|
830
|
+
nk_f32_t *destination = (nk_f32_t *)to;
|
|
831
|
+
for (nk_size_t vector_length; count > 0;
|
|
832
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
833
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
834
|
+
vuint8m1_t e3m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
835
|
+
vfloat32m4_t f32_f32m4 = nk_e3m2m1_to_f32m4_rvv_(e3m2_u8m1, vector_length);
|
|
836
|
+
__riscv_vse32_v_f32m4(destination, f32_f32m4, vector_length);
|
|
837
|
+
}
|
|
838
|
+
return;
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
// e4m3 → bf16
|
|
842
|
+
if (from_type == nk_e4m3_k && to_type == nk_bf16_k) {
|
|
843
|
+
nk_e4m3_t const *source = (nk_e4m3_t const *)from;
|
|
844
|
+
nk_bf16_t *destination = (nk_bf16_t *)to;
|
|
845
|
+
for (nk_size_t vector_length; count > 0;
|
|
846
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
847
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
848
|
+
vuint8m1_t e4m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
849
|
+
vuint16m2_t bf16_u16m2 = nk_e4m3m1_to_bf16m2_rvv_(e4m3_u8m1, vector_length);
|
|
850
|
+
__riscv_vse16_v_u16m2((nk_u16_t *)destination, bf16_u16m2, vector_length);
|
|
851
|
+
}
|
|
852
|
+
return;
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
// e5m2 → bf16
|
|
856
|
+
if (from_type == nk_e5m2_k && to_type == nk_bf16_k) {
|
|
857
|
+
nk_e5m2_t const *source = (nk_e5m2_t const *)from;
|
|
858
|
+
nk_bf16_t *destination = (nk_bf16_t *)to;
|
|
859
|
+
for (nk_size_t vector_length; count > 0;
|
|
860
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
861
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
862
|
+
vuint8m1_t e5m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
863
|
+
vuint16m2_t bf16_u16m2 = nk_e5m2m1_to_bf16m2_rvv_(e5m2_u8m1, vector_length);
|
|
864
|
+
__riscv_vse16_v_u16m2((nk_u16_t *)destination, bf16_u16m2, vector_length);
|
|
865
|
+
}
|
|
866
|
+
return;
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
// e2m3 → bf16
|
|
870
|
+
if (from_type == nk_e2m3_k && to_type == nk_bf16_k) {
|
|
871
|
+
nk_e2m3_t const *source = (nk_e2m3_t const *)from;
|
|
872
|
+
nk_bf16_t *destination = (nk_bf16_t *)to;
|
|
873
|
+
for (nk_size_t vector_length; count > 0;
|
|
874
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
875
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
876
|
+
vuint8m1_t e2m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
877
|
+
vuint16m2_t bf16_u16m2 = nk_e2m3m1_to_bf16m2_rvv_(e2m3_u8m1, vector_length);
|
|
878
|
+
__riscv_vse16_v_u16m2((nk_u16_t *)destination, bf16_u16m2, vector_length);
|
|
879
|
+
}
|
|
880
|
+
return;
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
// e3m2 → bf16
|
|
884
|
+
if (from_type == nk_e3m2_k && to_type == nk_bf16_k) {
|
|
885
|
+
nk_e3m2_t const *source = (nk_e3m2_t const *)from;
|
|
886
|
+
nk_bf16_t *destination = (nk_bf16_t *)to;
|
|
887
|
+
for (nk_size_t vector_length; count > 0;
|
|
888
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
889
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
890
|
+
vuint8m1_t e3m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
891
|
+
vuint16m2_t bf16_u16m2 = nk_e3m2m1_to_bf16m2_rvv_(e3m2_u8m1, vector_length);
|
|
892
|
+
__riscv_vse16_v_u16m2((nk_u16_t *)destination, bf16_u16m2, vector_length);
|
|
893
|
+
}
|
|
894
|
+
return;
|
|
895
|
+
}
|
|
896
|
+
|
|
897
|
+
// e4m3 → f16
|
|
898
|
+
if (from_type == nk_e4m3_k && to_type == nk_f16_k) {
|
|
899
|
+
nk_e4m3_t const *source = (nk_e4m3_t const *)from;
|
|
900
|
+
nk_f16_t *destination = (nk_f16_t *)to;
|
|
901
|
+
for (nk_size_t vector_length; count > 0;
|
|
902
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
903
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
904
|
+
vuint8m1_t e4m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
905
|
+
vuint16m2_t f16_u16m2 = nk_e4m3m1_to_f16m2_rvv_(e4m3_u8m1, vector_length);
|
|
906
|
+
__riscv_vse16_v_u16m2((nk_u16_t *)destination, f16_u16m2, vector_length);
|
|
907
|
+
}
|
|
908
|
+
return;
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
// e2m3 → f16
|
|
912
|
+
if (from_type == nk_e2m3_k && to_type == nk_f16_k) {
|
|
913
|
+
nk_e2m3_t const *source = (nk_e2m3_t const *)from;
|
|
914
|
+
nk_f16_t *destination = (nk_f16_t *)to;
|
|
915
|
+
for (nk_size_t vector_length; count > 0;
|
|
916
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
917
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
918
|
+
vuint8m1_t e2m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
919
|
+
vuint16m2_t f16_u16m2 = nk_e2m3m1_to_f16m2_rvv_(e2m3_u8m1, vector_length);
|
|
920
|
+
__riscv_vse16_v_u16m2((nk_u16_t *)destination, f16_u16m2, vector_length);
|
|
921
|
+
}
|
|
922
|
+
return;
|
|
923
|
+
}
|
|
924
|
+
|
|
925
|
+
// e3m2 → f16
|
|
926
|
+
if (from_type == nk_e3m2_k && to_type == nk_f16_k) {
|
|
927
|
+
nk_e3m2_t const *source = (nk_e3m2_t const *)from;
|
|
928
|
+
nk_f16_t *destination = (nk_f16_t *)to;
|
|
929
|
+
for (nk_size_t vector_length; count > 0;
|
|
930
|
+
count -= vector_length, source += vector_length, destination += vector_length) {
|
|
931
|
+
vector_length = __riscv_vsetvl_e8m1(count);
|
|
932
|
+
vuint8m1_t e3m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
933
|
+
vuint16m2_t f16_u16m2 = nk_e3m2m1_to_f16m2_rvv_(e3m2_u8m1, vector_length);
|
|
934
|
+
__riscv_vse16_v_u16m2((nk_u16_t *)destination, f16_u16m2, vector_length);
|
|
935
|
+
}
|
|
936
|
+
return;
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
// i4 → i8
|
|
940
|
+
if (from_type == nk_i4_k && to_type == nk_i8_k) {
|
|
941
|
+
nk_i4x2_t const *source = (nk_i4x2_t const *)from;
|
|
942
|
+
nk_i8_t *destination = (nk_i8_t *)to;
|
|
943
|
+
nk_size_t n_bytes = count / 2;
|
|
944
|
+
for (nk_size_t vector_length; n_bytes > 0;
|
|
945
|
+
n_bytes -= vector_length, source += vector_length, destination += vector_length * 2) {
|
|
946
|
+
vector_length = __riscv_vsetvl_e8m1(n_bytes);
|
|
947
|
+
vuint8m1_t packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
948
|
+
vint8m1x2_t unpacked_i8m1x2 = nk_i4m1_to_i8m2_rvv_(packed_u8m1, vector_length);
|
|
949
|
+
__riscv_vsseg2e8_v_i8m1x2(destination, unpacked_i8m1x2, vector_length);
|
|
950
|
+
}
|
|
951
|
+
return;
|
|
952
|
+
}
|
|
953
|
+
|
|
954
|
+
// u4 → u8
|
|
955
|
+
if (from_type == nk_u4_k && to_type == nk_u8_k) {
|
|
956
|
+
nk_u4x2_t const *source = (nk_u4x2_t const *)from;
|
|
957
|
+
nk_u8_t *destination = (nk_u8_t *)to;
|
|
958
|
+
nk_size_t n_bytes = count / 2;
|
|
959
|
+
for (nk_size_t vector_length; n_bytes > 0;
|
|
960
|
+
n_bytes -= vector_length, source += vector_length, destination += vector_length * 2) {
|
|
961
|
+
vector_length = __riscv_vsetvl_e8m1(n_bytes);
|
|
962
|
+
vuint8m1_t packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
|
|
963
|
+
vuint8m1x2_t unpacked_u8m1x2 = nk_u4m1_to_u8m2_rvv_(packed_u8m1, vector_length);
|
|
964
|
+
__riscv_vsseg2e8_v_u8m1x2(destination, unpacked_u8m1x2, vector_length);
|
|
965
|
+
}
|
|
966
|
+
return;
|
|
967
|
+
}
|
|
968
|
+
|
|
969
|
+
// i8 → i4
|
|
970
|
+
if (from_type == nk_i8_k && to_type == nk_i4_k) {
|
|
971
|
+
nk_i8_t const *source = (nk_i8_t const *)from;
|
|
972
|
+
nk_i4x2_t *destination = (nk_i4x2_t *)to;
|
|
973
|
+
nk_size_t n_bytes = count / 2;
|
|
974
|
+
for (nk_size_t vector_length; n_bytes > 0;
|
|
975
|
+
n_bytes -= vector_length, source += vector_length * 2, destination += vector_length) {
|
|
976
|
+
vector_length = __riscv_vsetvl_e8m1(n_bytes);
|
|
977
|
+
vint8m1x2_t loaded_i8m1x2 = __riscv_vlseg2e8_v_i8m1x2(source, vector_length);
|
|
978
|
+
vint8m1_t hi_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 0);
|
|
979
|
+
vint8m1_t lo_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 1);
|
|
980
|
+
vuint8m1_t packed_u8m1 = nk_i8m2_to_i4m1_rvv_(hi_i8m1, lo_i8m1, vector_length);
|
|
981
|
+
__riscv_vse8_v_u8m1((nk_u8_t *)destination, packed_u8m1, vector_length);
|
|
982
|
+
}
|
|
983
|
+
return;
|
|
984
|
+
}
|
|
985
|
+
|
|
986
|
+
// u8 → u4
|
|
987
|
+
if (from_type == nk_u8_k && to_type == nk_u4_k) {
|
|
988
|
+
nk_u8_t const *source = (nk_u8_t const *)from;
|
|
989
|
+
nk_u4x2_t *destination = (nk_u4x2_t *)to;
|
|
990
|
+
nk_size_t n_bytes = count / 2;
|
|
991
|
+
for (nk_size_t vector_length; n_bytes > 0;
|
|
992
|
+
n_bytes -= vector_length, source += vector_length * 2, destination += vector_length) {
|
|
993
|
+
vector_length = __riscv_vsetvl_e8m1(n_bytes);
|
|
994
|
+
vuint8m1x2_t loaded_u8m1x2 = __riscv_vlseg2e8_v_u8m1x2(source, vector_length);
|
|
995
|
+
vuint8m1_t hi_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 0);
|
|
996
|
+
vuint8m1_t lo_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 1);
|
|
997
|
+
vuint8m1_t packed_u8m1 = nk_u8m2_to_u4m1_rvv_(hi_u8m1, lo_u8m1, vector_length);
|
|
998
|
+
__riscv_vse8_v_u8m1((nk_u8_t *)destination, packed_u8m1, vector_length);
|
|
999
|
+
}
|
|
1000
|
+
return;
|
|
1001
|
+
}
|
|
1002
|
+
|
|
1003
|
+
// Fallback to serial for unimplemented conversions
|
|
1004
|
+
nk_cast_serial(from, from_type, count, to, to_type);
|
|
1005
|
+
}
|
|
1006
|
+
|
|
1007
|
+
#pragma endregion - Unified Cast Dispatcher
|
|
1008
|
+
|
|
1009
|
+
#if defined(__cplusplus)
|
|
1010
|
+
} // extern "C"
|
|
1011
|
+
#endif
|
|
1012
|
+
|
|
1013
|
+
#if defined(__clang__)
|
|
1014
|
+
#pragma clang attribute pop
|
|
1015
|
+
#elif defined(__GNUC__)
|
|
1016
|
+
#pragma GCC pop_options
|
|
1017
|
+
#endif
|
|
1018
|
+
|
|
1019
|
+
#endif // NK_TARGET_RVV
|
|
1020
|
+
#endif // NK_TARGET_RISCV_
|
|
1021
|
+
#endif // NK_CAST_RVV_H
|