numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for RISC-V with Zvbb.
|
|
3
|
+
* @file include/numkong/dot/rvvbb.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 22, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* Zvbb (Vector Basic Bit-manipulation) provides native per-element popcount via `vcpop.v`,
|
|
10
|
+
* replacing the 11-instruction SWAR approach with a single instruction for u1 dot products.
|
|
11
|
+
*
|
|
12
|
+
* Only `nk_dot_u1` benefits from Zvbb (it needs byte-level popcount of AND results).
|
|
13
|
+
* Requires: RVV 1.0 + Zvbb extension (GCC 14+ or Clang 18+)
|
|
14
|
+
*/
|
|
15
|
+
#ifndef NK_DOT_RVVBB_H
|
|
16
|
+
#define NK_DOT_RVVBB_H
|
|
17
|
+
|
|
18
|
+
#if NK_TARGET_RISCV_
|
|
19
|
+
#if NK_TARGET_RVVBB
|
|
20
|
+
|
|
21
|
+
#include "numkong/types.h"
|
|
22
|
+
#include "numkong/set/rvvbb.h" // `nk_popcount_u8m4_rvvbb_`
|
|
23
|
+
|
|
24
|
+
#if defined(__clang__)
|
|
25
|
+
#pragma clang attribute push(__attribute__((target("arch=+v,+zvbb"))), apply_to = function)
|
|
26
|
+
#elif defined(__GNUC__)
|
|
27
|
+
#pragma GCC push_options
|
|
28
|
+
#pragma GCC target("arch=+v,+zvbb")
|
|
29
|
+
#endif
|
|
30
|
+
|
|
31
|
+
#if defined(__cplusplus)
|
|
32
|
+
extern "C" {
|
|
33
|
+
#endif
|
|
34
|
+
|
|
35
|
+
NK_PUBLIC void nk_dot_u1_rvvbb(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
36
|
+
nk_size_t count_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
|
|
37
|
+
|
|
38
|
+
vuint32m1_t sum_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
39
|
+
|
|
40
|
+
nk_size_t i = 0;
|
|
41
|
+
for (nk_size_t vector_length; i + 1 <= count_bytes; i += vector_length) {
|
|
42
|
+
vector_length = __riscv_vsetvl_e8m4(count_bytes - i);
|
|
43
|
+
|
|
44
|
+
// Load and AND to find shared bits (dot product of binary vectors)
|
|
45
|
+
vuint8m4_t a_u8m4 = __riscv_vle8_v_u8m4(a + i, vector_length);
|
|
46
|
+
vuint8m4_t b_u8m4 = __riscv_vle8_v_u8m4(b + i, vector_length);
|
|
47
|
+
vuint8m4_t and_u8m4 = __riscv_vand_vv_u8m4(a_u8m4, b_u8m4, vector_length);
|
|
48
|
+
|
|
49
|
+
// Native per-element popcount via Zvbb (1 instruction vs 11 SWAR)
|
|
50
|
+
vuint8m4_t popcount_u8m4 = nk_popcount_u8m4_rvvbb_(and_u8m4);
|
|
51
|
+
|
|
52
|
+
// Widen to u16 and accumulate via widening reduction sum
|
|
53
|
+
vuint16m8_t popcount_u16m8 = __riscv_vwaddu_vx_u16m8(popcount_u8m4, 0, vector_length);
|
|
54
|
+
sum_u32m1 = __riscv_vwredsumu_vs_u16m8_u32m1(popcount_u16m8, sum_u32m1, vector_length);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
*result = __riscv_vmv_x_s_u32m1_u32(sum_u32m1);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
#if defined(__cplusplus)
|
|
61
|
+
} // extern "C"
|
|
62
|
+
#endif
|
|
63
|
+
|
|
64
|
+
#if defined(__clang__)
|
|
65
|
+
#pragma clang attribute pop
|
|
66
|
+
#elif defined(__GNUC__)
|
|
67
|
+
#pragma GCC pop_options
|
|
68
|
+
#endif
|
|
69
|
+
|
|
70
|
+
#endif // NK_TARGET_RVVBB
|
|
71
|
+
#endif // NK_TARGET_RISCV_
|
|
72
|
+
#endif // NK_DOT_RVVBB_H
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for RISC-V BF16.
|
|
3
|
+
* @file include/numkong/dot/rvvbf16.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 5, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* Alibaba XuanTie C930 and similar chips implement RVV 1.0 with Zvfbfwma extension.
|
|
10
|
+
* Zvfbfwma provides widening bf16 fused multiply-accumulate to f32:
|
|
11
|
+
* vfwmaccbf16: f32 ← bf16 ⨯ bf16
|
|
12
|
+
*
|
|
13
|
+
* All mini-float types use 256-entry VLUXEI16 LUT gathers from cast/rvv.h (3 instructions each).
|
|
14
|
+
* All variants then use vfwmaccbf16_vv for fused bf16 ⨯ bf16 → f32 multiply-accumulate.
|
|
15
|
+
*
|
|
16
|
+
* Requires: RVV 1.0 + Zvfbfwma extension (GCC 14+ or Clang 18+)
|
|
17
|
+
*/
|
|
18
|
+
#ifndef NK_DOT_RVVBF16_H
|
|
19
|
+
#define NK_DOT_RVVBF16_H
|
|
20
|
+
|
|
21
|
+
#if NK_TARGET_RISCV_
|
|
22
|
+
#if NK_TARGET_RVVBF16
|
|
23
|
+
|
|
24
|
+
#include "numkong/types.h"
|
|
25
|
+
#include "numkong/cast/rvv.h" // `nk_e4m3m1_to_bf16m2_rvv_`, `nk_e5m2m1_to_bf16m2_rvv_`, etc.
|
|
26
|
+
|
|
27
|
+
#if defined(__clang__)
|
|
28
|
+
#pragma clang attribute push(__attribute__((target("arch=+v,+zvfbfwma"))), apply_to = function)
|
|
29
|
+
#elif defined(__GNUC__)
|
|
30
|
+
#pragma GCC push_options
|
|
31
|
+
#pragma GCC target("arch=+v,+zvfbfwma")
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
#if defined(__cplusplus)
|
|
35
|
+
extern "C" {
|
|
36
|
+
#endif
|
|
37
|
+
|
|
38
|
+
NK_PUBLIC void nk_dot_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
39
|
+
nk_f32_t *result) {
|
|
40
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
41
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
42
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
43
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
44
|
+
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
45
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
|
|
46
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
|
|
47
|
+
vbfloat16m1_t a_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(a_u16m1);
|
|
48
|
+
vbfloat16m1_t b_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(b_u16m1);
|
|
49
|
+
// Widening bf16 FMA: f32 ← bf16 ⨯ bf16, per-lane accumulation
|
|
50
|
+
sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(sum_f32m2, a_bf16m1, b_bf16m1, vector_length);
|
|
51
|
+
}
|
|
52
|
+
// Single horizontal reduction at the end
|
|
53
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
|
|
54
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
/** @brief Convert e2m3 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
58
|
+
NK_INTERNAL vbfloat16m2_t nk_e2m3m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
59
|
+
return __riscv_vreinterpret_v_u16m2_bf16m2(nk_e2m3m1_to_bf16m2_rvv_(raw_u8m1, vector_length));
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
/** @brief Convert e3m2 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
63
|
+
NK_INTERNAL vbfloat16m2_t nk_e3m2m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
64
|
+
return __riscv_vreinterpret_v_u16m2_bf16m2(nk_e3m2m1_to_bf16m2_rvv_(raw_u8m1, vector_length));
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
/** @brief Convert e4m3 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
68
|
+
NK_INTERNAL vbfloat16m2_t nk_e4m3m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
69
|
+
return __riscv_vreinterpret_v_u16m2_bf16m2(nk_e4m3m1_to_bf16m2_rvv_(raw_u8m1, vector_length));
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
/** @brief Convert e5m2 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
73
|
+
NK_INTERNAL vbfloat16m2_t nk_e5m2m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
74
|
+
return __riscv_vreinterpret_v_u16m2_bf16m2(nk_e5m2m1_to_bf16m2_rvv_(raw_u8m1, vector_length));
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
NK_PUBLIC void nk_dot_e4m3_rvvbf16(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
78
|
+
nk_f32_t *result) {
|
|
79
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
80
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
81
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
82
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
83
|
+
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
84
|
+
vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
|
|
85
|
+
vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
|
|
86
|
+
vbfloat16m2_t a_bf16m2 = nk_e4m3m1_to_bf16m2_rvvbf16_(a_u8m1, vector_length);
|
|
87
|
+
vbfloat16m2_t b_bf16m2 = nk_e4m3m1_to_bf16m2_rvvbf16_(b_u8m1, vector_length);
|
|
88
|
+
sum_f32m4 = __riscv_vfwmaccbf16_vv_f32m4_tu(sum_f32m4, a_bf16m2, b_bf16m2, vector_length);
|
|
89
|
+
}
|
|
90
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
|
|
91
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
NK_PUBLIC void nk_dot_e5m2_rvvbf16(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
95
|
+
nk_f32_t *result) {
|
|
96
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
97
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
98
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
99
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
100
|
+
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
101
|
+
vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
|
|
102
|
+
vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
|
|
103
|
+
vbfloat16m2_t a_bf16m2 = nk_e5m2m1_to_bf16m2_rvvbf16_(a_u8m1, vector_length);
|
|
104
|
+
vbfloat16m2_t b_bf16m2 = nk_e5m2m1_to_bf16m2_rvvbf16_(b_u8m1, vector_length);
|
|
105
|
+
sum_f32m4 = __riscv_vfwmaccbf16_vv_f32m4_tu(sum_f32m4, a_bf16m2, b_bf16m2, vector_length);
|
|
106
|
+
}
|
|
107
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
|
|
108
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
#if defined(__cplusplus)
|
|
112
|
+
} // extern "C"
|
|
113
|
+
#endif
|
|
114
|
+
|
|
115
|
+
#if defined(__clang__)
|
|
116
|
+
#pragma clang attribute pop
|
|
117
|
+
#elif defined(__GNUC__)
|
|
118
|
+
#pragma GCC pop_options
|
|
119
|
+
#endif
|
|
120
|
+
|
|
121
|
+
#endif // NK_TARGET_RVVBF16
|
|
122
|
+
#endif // NK_TARGET_RISCV_
|
|
123
|
+
#endif // NK_DOT_RVVBF16_H
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for RISC-V FP16.
|
|
3
|
+
* @file include/numkong/dot/rvvhalf.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 5, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* SiFive P670/X280 and similar chips implement RVV 1.0 with Zvfh extension.
|
|
10
|
+
* Zvfh provides native half-precision (f16) vector operations.
|
|
11
|
+
* Uses widening multiply (f16 ⨯ f16 → f32) for precision, then reduces to f32.
|
|
12
|
+
*
|
|
13
|
+
* For e2m3, e3m2, e4m3: conversion uses 256-entry VLUXEI16 LUT gathers from cast/rvv.h (3 instructions each).
|
|
14
|
+
* For e5m2: conversion uses pure shift (vzext + vsll) since e5m2 and f16 share the same exponent bias.
|
|
15
|
+
* All variants then use vfwmacc_vv for widening fused f16 ⨯ f16 → f32 multiply-accumulate.
|
|
16
|
+
*
|
|
17
|
+
* Requires: RVV 1.0 + Zvfh extension (GCC 14+ or Clang 18+)
|
|
18
|
+
*/
|
|
19
|
+
#ifndef NK_DOT_RVVHALF_H
|
|
20
|
+
#define NK_DOT_RVVHALF_H
|
|
21
|
+
|
|
22
|
+
#if NK_TARGET_RISCV_
|
|
23
|
+
#if NK_TARGET_RVVHALF
|
|
24
|
+
|
|
25
|
+
#include "numkong/types.h"
|
|
26
|
+
#include "numkong/cast/rvv.h" // `nk_e4m3m1_to_f16m2_rvv_`, `nk_e2m3m1_to_f16m2_rvv_`, etc.
|
|
27
|
+
|
|
28
|
+
#if defined(__clang__)
|
|
29
|
+
#pragma clang attribute push(__attribute__((target("arch=+v,+zvfh"))), apply_to = function)
|
|
30
|
+
#elif defined(__GNUC__)
|
|
31
|
+
#pragma GCC push_options
|
|
32
|
+
#pragma GCC target("arch=+v,+zvfh")
|
|
33
|
+
#endif
|
|
34
|
+
|
|
35
|
+
#if defined(__cplusplus)
|
|
36
|
+
extern "C" {
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
NK_PUBLIC void nk_dot_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
40
|
+
nk_f32_t *result) {
|
|
41
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
42
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
43
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
44
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
45
|
+
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
46
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
|
|
47
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
|
|
48
|
+
vfloat16m1_t a_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(a_u16m1);
|
|
49
|
+
vfloat16m1_t b_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(b_u16m1);
|
|
50
|
+
// Widening FMA: f32 += f16 ⨯ f16, per-lane accumulation
|
|
51
|
+
sum_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(sum_f32m2, a_f16m1, b_f16m1, vector_length);
|
|
52
|
+
}
|
|
53
|
+
// Single horizontal reduction at the end
|
|
54
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
|
|
55
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/** @brief Convert e2m3 to f16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
59
|
+
NK_INTERNAL vfloat16m2_t nk_e2m3m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
60
|
+
return __riscv_vreinterpret_v_u16m2_f16m2(nk_e2m3m1_to_f16m2_rvv_(raw_u8m1, vector_length));
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/** @brief Convert e3m2 to f16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
64
|
+
NK_INTERNAL vfloat16m2_t nk_e3m2m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
65
|
+
return __riscv_vreinterpret_v_u16m2_f16m2(nk_e3m2m1_to_f16m2_rvv_(raw_u8m1, vector_length));
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
/** @brief Convert e4m3 to f16 via 256-entry LUT in cast/rvv.h + reinterpret. */
|
|
69
|
+
NK_INTERNAL vfloat16m2_t nk_e4m3m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
70
|
+
return __riscv_vreinterpret_v_u16m2_f16m2(nk_e4m3m1_to_f16m2_rvv_(raw_u8m1, vector_length));
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/**
|
|
74
|
+
* @brief Convert e5m2 (1-5-2 sign-exp-mantissa, 8-bit) to f16 via pure shift (no LUT).
|
|
75
|
+
* Same exponent bias (15) means f16 = (lower7 << 8) | (sign << 15). Handles all cases.
|
|
76
|
+
*/
|
|
77
|
+
NK_INTERNAL vfloat16m2_t nk_e5m2m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
|
|
78
|
+
vuint16m2_t wide_u16m2 = __riscv_vzext_vf2_u16m2(raw_u8m1, vector_length);
|
|
79
|
+
vuint16m2_t result_u16m2 = __riscv_vsll_vx_u16m2(wide_u16m2, 8, vector_length);
|
|
80
|
+
return __riscv_vreinterpret_v_u16m2_f16m2(result_u16m2);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
NK_PUBLIC void nk_dot_e4m3_rvvhalf(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
84
|
+
nk_f32_t *result) {
|
|
85
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
86
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
87
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
88
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
89
|
+
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
90
|
+
vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
|
|
91
|
+
vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
|
|
92
|
+
vfloat16m2_t a_f16m2 = nk_e4m3m1_to_f16m2_rvvhalf_(a_u8m1, vector_length);
|
|
93
|
+
vfloat16m2_t b_f16m2 = nk_e4m3m1_to_f16m2_rvvhalf_(b_u8m1, vector_length);
|
|
94
|
+
sum_f32m4 = __riscv_vfwmacc_vv_f32m4_tu(sum_f32m4, a_f16m2, b_f16m2, vector_length);
|
|
95
|
+
}
|
|
96
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
|
|
97
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
NK_PUBLIC void nk_dot_e5m2_rvvhalf(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
101
|
+
nk_f32_t *result) {
|
|
102
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
|
|
103
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
|
|
104
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
105
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
106
|
+
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
107
|
+
vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
|
|
108
|
+
vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
|
|
109
|
+
vfloat16m2_t a_f16m2 = nk_e5m2m1_to_f16m2_rvvhalf_(a_u8m1, vector_length);
|
|
110
|
+
vfloat16m2_t b_f16m2 = nk_e5m2m1_to_f16m2_rvvhalf_(b_u8m1, vector_length);
|
|
111
|
+
sum_f32m4 = __riscv_vfwmacc_vv_f32m4_tu(sum_f32m4, a_f16m2, b_f16m2, vector_length);
|
|
112
|
+
}
|
|
113
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
|
|
114
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
#if defined(__cplusplus)
|
|
118
|
+
} // extern "C"
|
|
119
|
+
#endif
|
|
120
|
+
|
|
121
|
+
#if defined(__clang__)
|
|
122
|
+
#pragma clang attribute pop
|
|
123
|
+
#elif defined(__GNUC__)
|
|
124
|
+
#pragma GCC pop_options
|
|
125
|
+
#endif
|
|
126
|
+
|
|
127
|
+
#endif // NK_TARGET_RVVHALF
|
|
128
|
+
#endif // NK_TARGET_RISCV_
|
|
129
|
+
#endif // NK_DOT_RVVHALF_H
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for Sapphire Rapids.
|
|
3
|
+
* @file include/numkong/dot/sapphire.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 7, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_sapphire_instructions Key AVX-512 FP16 Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput Ports
|
|
12
|
+
* _mm512_fmadd_ph VFMADDPH (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
|
|
13
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
|
|
14
|
+
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 7cy 1/cy p01
|
|
15
|
+
*
|
|
16
|
+
* Sapphire Rapids introduces native AVX-512 FP16 support, enabling 32 FP16 FMAs per instruction at the same
|
|
17
|
+
* throughput as 16 FP32 FMAs — effectively 2x compute density. For FP6 types (E2M3 and E3M2) whose products
|
|
18
|
+
* are small enough to accumulate safely in FP16, this provides near-2x speedup over the Genoa BF16 path.
|
|
19
|
+
*
|
|
20
|
+
* @section dot_sapphire_accumulation Safe FP16 Accumulation
|
|
21
|
+
*
|
|
22
|
+
* E2M3 max product: 7.5² = 56.25; flush every 4 iterations → max lane sum ~225, FP16 ULP ~0.125.
|
|
23
|
+
* E3M2 max product: 28² = 784; flush every 4 iterations → max lane sum ~3136, FP16 ULP ~2.0.
|
|
24
|
+
* After the flush window, we widen the FP16 accumulator to FP32 and reset.
|
|
25
|
+
*
|
|
26
|
+
* @section dot_sapphire_stateful Stateful Streaming Logic
|
|
27
|
+
*
|
|
28
|
+
* Typed wrappers control the flush cadence:
|
|
29
|
+
* - nk_dot_e2m3x32_state_sapphire_t flushes every 4 iterations (128 elements)
|
|
30
|
+
* - nk_dot_e3m2x32_state_sapphire_t flushes every 4 iterations (128 elements)
|
|
31
|
+
*/
|
|
32
|
+
#ifndef NK_DOT_SAPPHIRE_H
|
|
33
|
+
#define NK_DOT_SAPPHIRE_H
|
|
34
|
+
|
|
35
|
+
#if NK_TARGET_X86_
|
|
36
|
+
#if NK_TARGET_SAPPHIRE
|
|
37
|
+
|
|
38
|
+
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
|
|
40
|
+
#include "numkong/dot/skylake.h" // `nk_dot_through_f32_finalize_skylake_`
|
|
41
|
+
|
|
42
|
+
#if defined(__cplusplus)
|
|
43
|
+
extern "C" {
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
#if defined(__clang__)
|
|
47
|
+
#pragma clang attribute push( \
|
|
48
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,f16c,fma,bmi,bmi2"))), \
|
|
49
|
+
apply_to = function)
|
|
50
|
+
#elif defined(__GNUC__)
|
|
51
|
+
#pragma GCC push_options
|
|
52
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
|
|
53
|
+
#endif
|
|
54
|
+
|
|
55
|
+
/** @brief Convert 32x e2m3 → 32x f16 via 64-entry signed LUT lookup (AVX-512BW).
|
|
56
|
+
* E2M3 format: S EE MMM (bias=1, 6 bits total: sign at bit 5, magnitude bits 4-0).
|
|
57
|
+
* F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
58
|
+
*
|
|
59
|
+
* Uses permutex2var with two 32-entry LUTs (positive and negative F16 values).
|
|
60
|
+
* The E2M3 sign bit (bit 5) naturally becomes the source-select bit of the 6-bit index,
|
|
61
|
+
* so no separate sign extraction, shift, or OR is needed. After cvtepu8_epi16, bits 15:6
|
|
62
|
+
* are zero and permutex2var only reads bits 5:0, so no AND mask is required either. */
|
|
63
|
+
NK_INTERNAL __m512h nk_e2m3x32_to_f16x32_sapphire_(__m256i e2m3x32) {
|
|
64
|
+
__m512i idx_i16x32 = _mm512_cvtepu8_epi16(e2m3x32);
|
|
65
|
+
|
|
66
|
+
// 32-entry LUT for positive E2M3 magnitudes → F16
|
|
67
|
+
__m512i const lut_pos_i16x32 = _mm512_set_epi16( //
|
|
68
|
+
0x4780, 0x4700, 0x4680, 0x4600, 0x4580, 0x4500, 0x4480, 0x4400, // [31-24] exp=3: f16_exp=17
|
|
69
|
+
0x4380, 0x4300, 0x4280, 0x4200, 0x4180, 0x4100, 0x4080, 0x4000, // [23-16] exp=2: f16_exp=16
|
|
70
|
+
0x3F80, 0x3F00, 0x3E80, 0x3E00, 0x3D80, 0x3D00, 0x3C80, 0x3C00, // [15-8] exp=1: f16_exp=15
|
|
71
|
+
0x3B00, 0x3A00, 0x3900, 0x3800, 0x3600, 0x3400, 0x3000, 0x0000); // [7-0] exp=0: subnormals (0, 1/8..7/8)
|
|
72
|
+
|
|
73
|
+
// 32-entry LUT for negative E2M3 magnitudes → F16 (= positive | 0x8000)
|
|
74
|
+
__m512i const lut_neg_i16x32 = _mm512_set_epi16( //
|
|
75
|
+
(short)0xC780, (short)0xC700, (short)0xC680, (short)0xC600, //
|
|
76
|
+
(short)0xC580, (short)0xC500, (short)0xC480, (short)0xC400, // [31-24] exp=3
|
|
77
|
+
(short)0xC380, (short)0xC300, (short)0xC280, (short)0xC200, //
|
|
78
|
+
(short)0xC180, (short)0xC100, (short)0xC080, (short)0xC000, // [23-16] exp=2
|
|
79
|
+
(short)0xBF80, (short)0xBF00, (short)0xBE80, (short)0xBE00, //
|
|
80
|
+
(short)0xBD80, (short)0xBD00, (short)0xBC80, (short)0xBC00, // [15-8] exp=1
|
|
81
|
+
(short)0xBB00, (short)0xBA00, (short)0xB900, (short)0xB800, //
|
|
82
|
+
(short)0xB600, (short)0xB400, (short)0xB000, (short)0x8000); // [7-0] exp=0
|
|
83
|
+
|
|
84
|
+
return nk_m512h_from_m512i_(_mm512_permutex2var_epi16(lut_pos_i16x32, idx_i16x32, lut_neg_i16x32));
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
/** @brief Convert 32x e3m2 → 32x f16 via 64-entry signed LUT lookup (AVX-512BW).
|
|
88
|
+
* E3M2 format: S EEE MM (bias=3, 6 bits total: sign at bit 5, magnitude bits 4-0).
|
|
89
|
+
* F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
90
|
+
*
|
|
91
|
+
* Same permutex2var technique as E2M3 — sign bit 5 selects the LUT source. */
|
|
92
|
+
NK_INTERNAL __m512h nk_e3m2x32_to_f16x32_sapphire_(__m256i e3m2x32) {
|
|
93
|
+
__m512i idx_i16x32 = _mm512_cvtepu8_epi16(e3m2x32);
|
|
94
|
+
|
|
95
|
+
// 32-entry LUT for positive E3M2 magnitudes → F16
|
|
96
|
+
__m512i const lut_pos_i16x32 = _mm512_set_epi16( //
|
|
97
|
+
0x4F00, 0x4E00, 0x4D00, 0x4C00, // [31-28] exp=7: f16_exp=19
|
|
98
|
+
0x4B00, 0x4A00, 0x4900, 0x4800, // [27-24] exp=6: f16_exp=18
|
|
99
|
+
0x4700, 0x4600, 0x4500, 0x4400, // [23-20] exp=5: f16_exp=17
|
|
100
|
+
0x4300, 0x4200, 0x4100, 0x4000, // [19-16] exp=4: f16_exp=16
|
|
101
|
+
0x3F00, 0x3E00, 0x3D00, 0x3C00, // [15-12] exp=3: f16_exp=15
|
|
102
|
+
0x3B00, 0x3A00, 0x3900, 0x3800, // [11-8] exp=2: f16_exp=14
|
|
103
|
+
0x3700, 0x3600, 0x3500, 0x3400, // [7-4] exp=1: f16_exp=13
|
|
104
|
+
0x3200, 0x3000, 0x2C00, 0x0000); // [3-0] exp=0: subnormals
|
|
105
|
+
|
|
106
|
+
// 32-entry LUT for negative E3M2 magnitudes → F16 (= positive | 0x8000)
|
|
107
|
+
__m512i const lut_neg_i16x32 = _mm512_set_epi16( //
|
|
108
|
+
(short)0xCF00, (short)0xCE00, (short)0xCD00, (short)0xCC00, // [31-28] exp=7
|
|
109
|
+
(short)0xCB00, (short)0xCA00, (short)0xC900, (short)0xC800, // [27-24] exp=6
|
|
110
|
+
(short)0xC700, (short)0xC600, (short)0xC500, (short)0xC400, // [23-20] exp=5
|
|
111
|
+
(short)0xC300, (short)0xC200, (short)0xC100, (short)0xC000, // [19-16] exp=4
|
|
112
|
+
(short)0xBF00, (short)0xBE00, (short)0xBD00, (short)0xBC00, // [15-12] exp=3
|
|
113
|
+
(short)0xBB00, (short)0xBA00, (short)0xB900, (short)0xB800, // [11-8] exp=2
|
|
114
|
+
(short)0xB700, (short)0xB600, (short)0xB500, (short)0xB400, // [7-4] exp=1
|
|
115
|
+
(short)0xB200, (short)0xB000, (short)0xAC00, (short)0x8000); // [3-0] exp=0
|
|
116
|
+
|
|
117
|
+
return nk_m512h_from_m512i_(_mm512_permutex2var_epi16(lut_pos_i16x32, idx_i16x32, lut_neg_i16x32));
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/** @brief Flush 32 FP16 values to FP32 accumulator by splitting into 2x16 halves. */
|
|
121
|
+
NK_INTERNAL __m512 nk_flush_f16_to_f32_sapphire_(__m512h acc_f16x32, __m512 sum_f32x16) {
|
|
122
|
+
__m256i low_f16x16 = _mm512_castsi512_si256(nk_m512i_from_m512h_(acc_f16x32));
|
|
123
|
+
__m256i high_f16x16 = _mm512_extracti64x4_epi64(nk_m512i_from_m512h_(acc_f16x32), 1);
|
|
124
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, _mm512_cvtph_ps(low_f16x16));
|
|
125
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, _mm512_cvtph_ps(high_f16x16));
|
|
126
|
+
return sum_f32x16;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
#if defined(__clang__)
|
|
130
|
+
#pragma clang attribute pop
|
|
131
|
+
#elif defined(__GNUC__)
|
|
132
|
+
#pragma GCC pop_options
|
|
133
|
+
#endif
|
|
134
|
+
|
|
135
|
+
#if defined(__cplusplus)
|
|
136
|
+
} // extern "C"
|
|
137
|
+
#endif
|
|
138
|
+
|
|
139
|
+
#endif // NK_TARGET_SAPPHIRE
|
|
140
|
+
#endif // NK_TARGET_X86_
|
|
141
|
+
#endif // NK_DOT_SAPPHIRE_H
|