numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -43,41 +43,45 @@ extern "C" {
|
|
|
43
43
|
/** @brief Compensated horizontal sum of RVV f64m1 lanes via TwoSum tree reduction.
|
|
44
44
|
*
|
|
45
45
|
* Uses vslidedown to extract the upper half at each tree level (same pattern as
|
|
46
|
-
* nk_reduce_vsaddu_u64m1_rvv_ in reduce/rvv.h). Tail lanes beyond
|
|
46
|
+
* nk_reduce_vsaddu_u64m1_rvv_ in reduce/rvv.h). Tail lanes beyond vector_length are zero
|
|
47
47
|
* from the initial vfmv_v_f, so they are harmless in the reduction.
|
|
48
48
|
*/
|
|
49
49
|
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64m1_rvv_(vfloat64m1_t sum_f64m1, vfloat64m1_t compensation_f64m1) {
|
|
50
|
-
nk_size_t
|
|
50
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
51
51
|
// Stage 0: TwoSum merge of sum + compensation
|
|
52
|
-
vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_f64m1, compensation_f64m1,
|
|
53
|
-
vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_f64m1,
|
|
52
|
+
vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_f64m1, compensation_f64m1, max_vector_length);
|
|
53
|
+
vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_f64m1, max_vector_length);
|
|
54
54
|
vfloat64m1_t accumulated_error_f64m1 = __riscv_vfadd_vv_f64m1(
|
|
55
|
-
__riscv_vfsub_vv_f64m1(sum_f64m1,
|
|
56
|
-
|
|
57
|
-
|
|
55
|
+
__riscv_vfsub_vv_f64m1(sum_f64m1,
|
|
56
|
+
__riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, max_vector_length),
|
|
57
|
+
max_vector_length),
|
|
58
|
+
__riscv_vfsub_vv_f64m1(compensation_f64m1, virtual_addend_f64m1, max_vector_length), max_vector_length);
|
|
58
59
|
// Tree reduction: TwoSum halving at each level
|
|
59
|
-
for (nk_size_t half =
|
|
60
|
-
vfloat64m1_t upper_sum_f64m1 = __riscv_vslidedown_vx_f64m1(tentative_sum_f64m1, half,
|
|
61
|
-
vfloat64m1_t upper_error_f64m1 = __riscv_vslidedown_vx_f64m1(accumulated_error_f64m1, half,
|
|
62
|
-
vfloat64m1_t halved_tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(tentative_sum_f64m1, upper_sum_f64m1,
|
|
60
|
+
for (nk_size_t half = max_vector_length / 2; half > 0; half >>= 1) {
|
|
61
|
+
vfloat64m1_t upper_sum_f64m1 = __riscv_vslidedown_vx_f64m1(tentative_sum_f64m1, half, max_vector_length);
|
|
62
|
+
vfloat64m1_t upper_error_f64m1 = __riscv_vslidedown_vx_f64m1(accumulated_error_f64m1, half, max_vector_length);
|
|
63
|
+
vfloat64m1_t halved_tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(tentative_sum_f64m1, upper_sum_f64m1,
|
|
64
|
+
max_vector_length);
|
|
63
65
|
vfloat64m1_t halved_virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(halved_tentative_sum_f64m1,
|
|
64
|
-
tentative_sum_f64m1,
|
|
66
|
+
tentative_sum_f64m1, max_vector_length);
|
|
65
67
|
vfloat64m1_t rounding_error_f64m1 = __riscv_vfadd_vv_f64m1(
|
|
66
68
|
__riscv_vfsub_vv_f64m1(
|
|
67
69
|
tentative_sum_f64m1,
|
|
68
|
-
__riscv_vfsub_vv_f64m1(halved_tentative_sum_f64m1, halved_virtual_addend_f64m1,
|
|
69
|
-
|
|
70
|
+
__riscv_vfsub_vv_f64m1(halved_tentative_sum_f64m1, halved_virtual_addend_f64m1, max_vector_length),
|
|
71
|
+
max_vector_length),
|
|
72
|
+
__riscv_vfsub_vv_f64m1(upper_sum_f64m1, halved_virtual_addend_f64m1, max_vector_length), max_vector_length);
|
|
70
73
|
tentative_sum_f64m1 = halved_tentative_sum_f64m1;
|
|
71
74
|
accumulated_error_f64m1 = __riscv_vfadd_vv_f64m1(
|
|
72
|
-
__riscv_vfadd_vv_f64m1(accumulated_error_f64m1, upper_error_f64m1,
|
|
75
|
+
__riscv_vfadd_vv_f64m1(accumulated_error_f64m1, upper_error_f64m1, max_vector_length), rounding_error_f64m1,
|
|
76
|
+
max_vector_length);
|
|
73
77
|
}
|
|
74
78
|
return __riscv_vfmv_f_s_f64m1_f64(tentative_sum_f64m1) + __riscv_vfmv_f_s_f64m1_f64(accumulated_error_f64m1);
|
|
75
79
|
}
|
|
76
80
|
|
|
77
81
|
NK_PUBLIC void nk_dot_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
78
82
|
nk_i32_t *result) {
|
|
79
|
-
nk_size_t
|
|
80
|
-
vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
83
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
84
|
+
vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
81
85
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
82
86
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
83
87
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -89,14 +93,14 @@ NK_PUBLIC void nk_dot_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars,
|
|
|
89
93
|
sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, ab_i16m2, vector_length);
|
|
90
94
|
}
|
|
91
95
|
// Single horizontal reduction at the end
|
|
92
|
-
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0,
|
|
93
|
-
*result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1,
|
|
96
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
|
|
97
|
+
*result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, max_vector_length));
|
|
94
98
|
}
|
|
95
99
|
|
|
96
100
|
NK_PUBLIC void nk_dot_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
|
|
97
101
|
nk_u32_t *result) {
|
|
98
|
-
nk_size_t
|
|
99
|
-
vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
102
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
103
|
+
vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
100
104
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
101
105
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
102
106
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -108,14 +112,14 @@ NK_PUBLIC void nk_dot_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars,
|
|
|
108
112
|
sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, ab_u16m2, vector_length);
|
|
109
113
|
}
|
|
110
114
|
// Single horizontal reduction at the end
|
|
111
|
-
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0,
|
|
112
|
-
*result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1,
|
|
115
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
|
|
116
|
+
*result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, max_vector_length));
|
|
113
117
|
}
|
|
114
118
|
|
|
115
119
|
NK_PUBLIC void nk_dot_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
116
120
|
nk_f64_t *result) {
|
|
117
|
-
nk_size_t
|
|
118
|
-
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
121
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
122
|
+
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
119
123
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
120
124
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
121
125
|
vector_length = __riscv_vsetvl_e32m1(count_scalars);
|
|
@@ -125,16 +129,16 @@ NK_PUBLIC void nk_dot_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scala
|
|
|
125
129
|
sum_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_f64m2, a_f32m1, b_f32m1, vector_length);
|
|
126
130
|
}
|
|
127
131
|
// Single horizontal reduction at the end
|
|
128
|
-
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
129
|
-
*result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1,
|
|
132
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
133
|
+
*result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1, max_vector_length));
|
|
130
134
|
}
|
|
131
135
|
|
|
132
136
|
NK_PUBLIC void nk_dot_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
|
|
133
137
|
nk_f64_t *result) {
|
|
134
138
|
// Dot2 (Ogita-Rump-Oishi) compensated accumulation via TwoProd + TwoSum
|
|
135
|
-
nk_size_t
|
|
136
|
-
vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
137
|
-
vfloat64m1_t compensation_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
139
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
140
|
+
vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
141
|
+
vfloat64m1_t compensation_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
138
142
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
139
143
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
140
144
|
vector_length = __riscv_vsetvl_e64m1(count_scalars);
|
|
@@ -163,8 +167,8 @@ NK_PUBLIC void nk_dot_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_scala
|
|
|
163
167
|
|
|
164
168
|
NK_PUBLIC void nk_dot_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
165
169
|
nk_f32_t *result) {
|
|
166
|
-
nk_size_t
|
|
167
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
170
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
171
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
168
172
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
169
173
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
170
174
|
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
@@ -179,14 +183,14 @@ NK_PUBLIC void nk_dot_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scala
|
|
|
179
183
|
sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, a_f32m2, b_f32m2, vector_length);
|
|
180
184
|
}
|
|
181
185
|
// Single horizontal reduction at the end
|
|
182
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
183
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
186
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
187
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
|
|
184
188
|
}
|
|
185
189
|
|
|
186
190
|
NK_PUBLIC void nk_dot_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
187
191
|
nk_f32_t *result) {
|
|
188
|
-
nk_size_t
|
|
189
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
192
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
193
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
190
194
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
191
195
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
192
196
|
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
@@ -201,14 +205,14 @@ NK_PUBLIC void nk_dot_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_sc
|
|
|
201
205
|
sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, a_f32m2, b_f32m2, vector_length);
|
|
202
206
|
}
|
|
203
207
|
// Single horizontal reduction at the end
|
|
204
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
205
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
208
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
209
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
|
|
206
210
|
}
|
|
207
211
|
|
|
208
212
|
NK_PUBLIC void nk_dot_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
209
213
|
nk_f32_t *result) {
|
|
210
|
-
nk_size_t
|
|
211
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
214
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
215
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
212
216
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
213
217
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
214
218
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -223,14 +227,14 @@ NK_PUBLIC void nk_dot_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_sc
|
|
|
223
227
|
sum_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sum_f32m4, a_f32m4, b_f32m4, vector_length);
|
|
224
228
|
}
|
|
225
229
|
// Single horizontal reduction at the end
|
|
226
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
227
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
230
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
231
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
|
|
228
232
|
}
|
|
229
233
|
|
|
230
234
|
NK_PUBLIC void nk_dot_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
231
235
|
nk_f32_t *result) {
|
|
232
|
-
nk_size_t
|
|
233
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
236
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
237
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
234
238
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
235
239
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
236
240
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -245,8 +249,8 @@ NK_PUBLIC void nk_dot_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_sc
|
|
|
245
249
|
sum_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sum_f32m4, a_f32m4, b_f32m4, vector_length);
|
|
246
250
|
}
|
|
247
251
|
// Single horizontal reduction at the end
|
|
248
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
249
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
252
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
253
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
|
|
250
254
|
}
|
|
251
255
|
|
|
252
256
|
NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -257,8 +261,8 @@ NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_sc
|
|
|
257
261
|
static nk_u8_t const lut_magnitude[32] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
|
|
258
262
|
32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120};
|
|
259
263
|
|
|
260
|
-
nk_size_t
|
|
261
|
-
vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
264
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
265
|
+
vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
262
266
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
263
267
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
264
268
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -285,8 +289,8 @@ NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_sc
|
|
|
285
289
|
vint16m2_t products_i16m2 = __riscv_vwmul_vv_i16m2(a_signed_i8m1, b_signed_i8m1, vector_length);
|
|
286
290
|
sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, products_i16m2, vector_length);
|
|
287
291
|
}
|
|
288
|
-
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0,
|
|
289
|
-
nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1,
|
|
292
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
|
|
293
|
+
nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, max_vector_length));
|
|
290
294
|
*result = (nk_f32_t)sum / 256.0f;
|
|
291
295
|
}
|
|
292
296
|
|
|
@@ -298,8 +302,8 @@ NK_PUBLIC void nk_dot_e3m2_rvv(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_sc
|
|
|
298
302
|
static nk_u16_t const lut_magnitude[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28,
|
|
299
303
|
32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384, 448};
|
|
300
304
|
|
|
301
|
-
nk_size_t
|
|
302
|
-
vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
305
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
306
|
+
vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
303
307
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
304
308
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
305
309
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -333,8 +337,8 @@ NK_PUBLIC void nk_dot_e3m2_rvv(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_sc
|
|
|
333
337
|
// Widening multiply-accumulate: i16×i16 → i32
|
|
334
338
|
sum_i32m4 = __riscv_vwmacc_vv_i32m4_tu(sum_i32m4, a_signed_i16m2, b_signed_i16m2, vector_length);
|
|
335
339
|
}
|
|
336
|
-
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0,
|
|
337
|
-
nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1,
|
|
340
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
|
|
341
|
+
nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, max_vector_length));
|
|
338
342
|
*result = (nk_f32_t)sum / 256.0f;
|
|
339
343
|
}
|
|
340
344
|
|
|
@@ -344,8 +348,8 @@ NK_PUBLIC void nk_dot_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scal
|
|
|
344
348
|
count_dimensions = nk_size_round_up_to_multiple_(count_dimensions, 2);
|
|
345
349
|
nk_size_t n_full_bytes = count_dimensions / 2;
|
|
346
350
|
|
|
347
|
-
nk_size_t
|
|
348
|
-
vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
351
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
352
|
+
vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
349
353
|
for (nk_size_t vector_length; n_full_bytes > 0;
|
|
350
354
|
n_full_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
351
355
|
vector_length = __riscv_vsetvl_e8m1(n_full_bytes);
|
|
@@ -377,8 +381,8 @@ NK_PUBLIC void nk_dot_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scal
|
|
|
377
381
|
sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, ab_low_i16m2, vector_length);
|
|
378
382
|
}
|
|
379
383
|
// Single horizontal reduction at the end
|
|
380
|
-
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0,
|
|
381
|
-
*result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1,
|
|
384
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
|
|
385
|
+
*result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, max_vector_length));
|
|
382
386
|
}
|
|
383
387
|
|
|
384
388
|
NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scalars, nk_size_t count_dimensions,
|
|
@@ -387,8 +391,8 @@ NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scal
|
|
|
387
391
|
count_dimensions = nk_size_round_up_to_multiple_(count_dimensions, 2);
|
|
388
392
|
nk_size_t n_full_bytes = count_dimensions / 2;
|
|
389
393
|
|
|
390
|
-
nk_size_t
|
|
391
|
-
vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
394
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
395
|
+
vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
392
396
|
for (nk_size_t vector_length; n_full_bytes > 0;
|
|
393
397
|
n_full_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
394
398
|
vector_length = __riscv_vsetvl_e8m1(n_full_bytes);
|
|
@@ -410,8 +414,8 @@ NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scal
|
|
|
410
414
|
sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, ab_low_u16m2, vector_length);
|
|
411
415
|
}
|
|
412
416
|
// Single horizontal reduction at the end
|
|
413
|
-
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0,
|
|
414
|
-
*result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1,
|
|
417
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
|
|
418
|
+
*result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, max_vector_length));
|
|
415
419
|
}
|
|
416
420
|
|
|
417
421
|
NK_PUBLIC void nk_dot_u1_rvv(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
@@ -443,9 +447,9 @@ NK_PUBLIC void nk_dot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pair
|
|
|
443
447
|
nk_f64c_t *results) {
|
|
444
448
|
nk_f32_t const *a_f32 = (nk_f32_t const *)a_pairs;
|
|
445
449
|
nk_f32_t const *b_f32 = (nk_f32_t const *)b_pairs;
|
|
446
|
-
nk_size_t
|
|
447
|
-
vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
448
|
-
vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
450
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
451
|
+
vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
452
|
+
vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
449
453
|
for (nk_size_t vector_length; count_pairs > 0;
|
|
450
454
|
count_pairs -= vector_length, a_f32 += vector_length * 2, b_f32 += vector_length * 2) {
|
|
451
455
|
vector_length = __riscv_vsetvl_e32m1(count_pairs);
|
|
@@ -462,18 +466,20 @@ NK_PUBLIC void nk_dot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pair
|
|
|
462
466
|
sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_real_f32m1, b_imag_f32m1, vector_length);
|
|
463
467
|
sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_imag_f32m1, b_real_f32m1, vector_length);
|
|
464
468
|
}
|
|
465
|
-
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
466
|
-
results->real = __riscv_vfmv_f_s_f64m1_f64(
|
|
467
|
-
|
|
469
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
470
|
+
results->real = __riscv_vfmv_f_s_f64m1_f64(
|
|
471
|
+
__riscv_vfredusum_vs_f64m2_f64m1(sum_real_f64m2, zero_f64m1, max_vector_length));
|
|
472
|
+
results->imag = __riscv_vfmv_f_s_f64m1_f64(
|
|
473
|
+
__riscv_vfredusum_vs_f64m2_f64m1(sum_imag_f64m2, zero_f64m1, max_vector_length));
|
|
468
474
|
}
|
|
469
475
|
|
|
470
476
|
NK_PUBLIC void nk_vdot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
471
477
|
nk_f64c_t *results) {
|
|
472
478
|
nk_f32_t const *a_f32 = (nk_f32_t const *)a_pairs;
|
|
473
479
|
nk_f32_t const *b_f32 = (nk_f32_t const *)b_pairs;
|
|
474
|
-
nk_size_t
|
|
475
|
-
vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
476
|
-
vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
480
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
481
|
+
vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
482
|
+
vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
477
483
|
for (nk_size_t vector_length; count_pairs > 0;
|
|
478
484
|
count_pairs -= vector_length, a_f32 += vector_length * 2, b_f32 += vector_length * 2) {
|
|
479
485
|
vector_length = __riscv_vsetvl_e32m1(count_pairs);
|
|
@@ -490,9 +496,11 @@ NK_PUBLIC void nk_vdot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pai
|
|
|
490
496
|
sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_real_f32m1, b_imag_f32m1, vector_length);
|
|
491
497
|
sum_imag_f64m2 = __riscv_vfwnmsac_vv_f64m2_tu(sum_imag_f64m2, a_imag_f32m1, b_real_f32m1, vector_length);
|
|
492
498
|
}
|
|
493
|
-
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
494
|
-
results->real = __riscv_vfmv_f_s_f64m1_f64(
|
|
495
|
-
|
|
499
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
500
|
+
results->real = __riscv_vfmv_f_s_f64m1_f64(
|
|
501
|
+
__riscv_vfredusum_vs_f64m2_f64m1(sum_real_f64m2, zero_f64m1, max_vector_length));
|
|
502
|
+
results->imag = __riscv_vfmv_f_s_f64m1_f64(
|
|
503
|
+
__riscv_vfredusum_vs_f64m2_f64m1(sum_imag_f64m2, zero_f64m1, max_vector_length));
|
|
496
504
|
}
|
|
497
505
|
|
|
498
506
|
NK_PUBLIC void nk_dot_f64c_rvv(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
@@ -500,11 +508,11 @@ NK_PUBLIC void nk_dot_f64c_rvv(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pair
|
|
|
500
508
|
// Dot2 (Ogita-Rump-Oishi) compensated complex dot product
|
|
501
509
|
nk_f64_t const *a_f64 = (nk_f64_t const *)a_pairs;
|
|
502
510
|
nk_f64_t const *b_f64 = (nk_f64_t const *)b_pairs;
|
|
503
|
-
nk_size_t
|
|
504
|
-
vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
505
|
-
vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
506
|
-
vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
507
|
-
vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
511
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
512
|
+
vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
513
|
+
vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
514
|
+
vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
515
|
+
vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
508
516
|
for (nk_size_t vector_length; count_pairs > 0;
|
|
509
517
|
count_pairs -= vector_length, a_f64 += vector_length * 2, b_f64 += vector_length * 2) {
|
|
510
518
|
vector_length = __riscv_vsetvl_e64m1(count_pairs);
|
|
@@ -602,11 +610,11 @@ NK_PUBLIC void nk_vdot_f64c_rvv(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pai
|
|
|
602
610
|
// Dot2 (Ogita-Rump-Oishi) compensated conjugate complex dot product
|
|
603
611
|
nk_f64_t const *a_f64 = (nk_f64_t const *)a_pairs;
|
|
604
612
|
nk_f64_t const *b_f64 = (nk_f64_t const *)b_pairs;
|
|
605
|
-
nk_size_t
|
|
606
|
-
vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
607
|
-
vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
608
|
-
vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
609
|
-
vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0,
|
|
613
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
|
|
614
|
+
vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
615
|
+
vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
616
|
+
vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
617
|
+
vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
|
|
610
618
|
for (nk_size_t vector_length; count_pairs > 0;
|
|
611
619
|
count_pairs -= vector_length, a_f64 += vector_length * 2, b_f64 += vector_length * 2) {
|
|
612
620
|
vector_length = __riscv_vsetvl_e64m1(count_pairs);
|
|
@@ -37,8 +37,8 @@ extern "C" {
|
|
|
37
37
|
|
|
38
38
|
NK_PUBLIC void nk_dot_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
39
39
|
nk_f32_t *result) {
|
|
40
|
-
nk_size_t
|
|
41
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
40
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
41
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
42
42
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
43
43
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
44
44
|
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
@@ -50,8 +50,8 @@ NK_PUBLIC void nk_dot_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *
|
|
|
50
50
|
sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(sum_f32m2, a_bf16m1, b_bf16m1, vector_length);
|
|
51
51
|
}
|
|
52
52
|
// Single horizontal reduction at the end
|
|
53
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
54
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
53
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
54
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
|
|
55
55
|
}
|
|
56
56
|
|
|
57
57
|
/** @brief Convert e2m3 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
@@ -76,8 +76,8 @@ NK_INTERNAL vbfloat16m2_t nk_e5m2m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_s
|
|
|
76
76
|
|
|
77
77
|
NK_PUBLIC void nk_dot_e4m3_rvvbf16(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
78
78
|
nk_f32_t *result) {
|
|
79
|
-
nk_size_t
|
|
80
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
79
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
80
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
81
81
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
82
82
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
83
83
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -87,14 +87,14 @@ NK_PUBLIC void nk_dot_e4m3_rvvbf16(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
|
|
|
87
87
|
vbfloat16m2_t b_bf16m2 = nk_e4m3m1_to_bf16m2_rvvbf16_(b_u8m1, vector_length);
|
|
88
88
|
sum_f32m4 = __riscv_vfwmaccbf16_vv_f32m4_tu(sum_f32m4, a_bf16m2, b_bf16m2, vector_length);
|
|
89
89
|
}
|
|
90
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
91
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
90
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
91
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
|
|
92
92
|
}
|
|
93
93
|
|
|
94
94
|
NK_PUBLIC void nk_dot_e5m2_rvvbf16(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
95
95
|
nk_f32_t *result) {
|
|
96
|
-
nk_size_t
|
|
97
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
96
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
97
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
98
98
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
99
99
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
100
100
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -104,8 +104,8 @@ NK_PUBLIC void nk_dot_e5m2_rvvbf16(nk_e5m2_t const *a_scalars, nk_e5m2_t const *
|
|
|
104
104
|
vbfloat16m2_t b_bf16m2 = nk_e5m2m1_to_bf16m2_rvvbf16_(b_u8m1, vector_length);
|
|
105
105
|
sum_f32m4 = __riscv_vfwmaccbf16_vv_f32m4_tu(sum_f32m4, a_bf16m2, b_bf16m2, vector_length);
|
|
106
106
|
}
|
|
107
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
108
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
107
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
108
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
|
|
109
109
|
}
|
|
110
110
|
|
|
111
111
|
#if defined(__cplusplus)
|
|
@@ -38,8 +38,8 @@ extern "C" {
|
|
|
38
38
|
|
|
39
39
|
NK_PUBLIC void nk_dot_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
40
40
|
nk_f32_t *result) {
|
|
41
|
-
nk_size_t
|
|
42
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
41
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
42
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
43
43
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
44
44
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
45
45
|
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
@@ -51,8 +51,8 @@ NK_PUBLIC void nk_dot_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_s
|
|
|
51
51
|
sum_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(sum_f32m2, a_f16m1, b_f16m1, vector_length);
|
|
52
52
|
}
|
|
53
53
|
// Single horizontal reduction at the end
|
|
54
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
55
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
54
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
55
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
|
|
56
56
|
}
|
|
57
57
|
|
|
58
58
|
/** @brief Convert e2m3 to f16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
@@ -82,8 +82,8 @@ NK_INTERNAL vfloat16m2_t nk_e5m2m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_siz
|
|
|
82
82
|
|
|
83
83
|
NK_PUBLIC void nk_dot_e4m3_rvvhalf(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
84
84
|
nk_f32_t *result) {
|
|
85
|
-
nk_size_t
|
|
86
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
85
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
86
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
87
87
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
88
88
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
89
89
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -93,14 +93,14 @@ NK_PUBLIC void nk_dot_e4m3_rvvhalf(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
|
|
|
93
93
|
vfloat16m2_t b_f16m2 = nk_e4m3m1_to_f16m2_rvvhalf_(b_u8m1, vector_length);
|
|
94
94
|
sum_f32m4 = __riscv_vfwmacc_vv_f32m4_tu(sum_f32m4, a_f16m2, b_f16m2, vector_length);
|
|
95
95
|
}
|
|
96
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
97
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
96
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
97
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
|
|
98
98
|
}
|
|
99
99
|
|
|
100
100
|
NK_PUBLIC void nk_dot_e5m2_rvvhalf(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
101
101
|
nk_f32_t *result) {
|
|
102
|
-
nk_size_t
|
|
103
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
102
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
103
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
104
104
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
105
105
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
106
106
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -110,8 +110,8 @@ NK_PUBLIC void nk_dot_e5m2_rvvhalf(nk_e5m2_t const *a_scalars, nk_e5m2_t const *
|
|
|
110
110
|
vfloat16m2_t b_f16m2 = nk_e5m2m1_to_f16m2_rvvhalf_(b_u8m1, vector_length);
|
|
111
111
|
sum_f32m4 = __riscv_vfwmacc_vv_f32m4_tu(sum_f32m4, a_f16m2, b_f16m2, vector_length);
|
|
112
112
|
}
|
|
113
|
-
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f,
|
|
114
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
113
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
|
|
114
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
|
|
115
115
|
}
|
|
116
116
|
|
|
117
117
|
#if defined(__cplusplus)
|
|
@@ -8,10 +8,10 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_sapphire_instructions Key AVX-512 FP16 Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_fmadd_ph
|
|
13
|
-
* _mm512_fmadd_ps
|
|
14
|
-
* _mm512_cvtph_ps
|
|
11
|
+
* Intrinsic Instruction Sapphire Rapids
|
|
12
|
+
* _mm512_fmadd_ph VFMADDPH (ZMM, ZMM, ZMM) 4cy @ p01
|
|
13
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p01
|
|
14
|
+
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 7cy @ p01
|
|
15
15
|
*
|
|
16
16
|
* Sapphire Rapids introduces native AVX-512 FP16 support, enabling 32 FP16 FMAs per instruction at the same
|
|
17
17
|
* throughput as 16 FP32 FMAs — effectively 2x compute density. For FP6 types (E2M3 and E3M2) whose products
|