numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Type Conversions for Ice Lake.
|
|
3
|
+
* @file include/numkong/cast/icelake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 2, 2026
|
|
6
|
+
*
|
|
7
|
+
* @section ice_cast_instructions AVX-512 VBMI2 Instructions
|
|
8
|
+
*
|
|
9
|
+
* Intrinsic Instruction Ice Genoa
|
|
10
|
+
* _mm512_permutex2var_epi16 VPERMI2W (ZMM, ZMM, ZMM) 3cy @ p5 2cy @ p12
|
|
11
|
+
* _mm512_test_epi16_mask VPTESTMW (k, ZMM, ZMM) 3cy @ p5 2cy @ p01
|
|
12
|
+
* _mm512_mask_mov_epi16 VMOVDQU16 (ZMM{k}, ZMM) 1cy @ p05 1cy @ p05
|
|
13
|
+
* _mm512_cvtepi16_epi8 VPMOVWB (YMM, ZMM) 3cy @ p5 2cy @ p12
|
|
14
|
+
*
|
|
15
|
+
* Ice Lake's AVX-512 VBMI2 enables efficient 128-entry LUT lookups via dual VPERMI2W operations.
|
|
16
|
+
* FP8-to-BF16/F16 conversions use 4 ZMM LUT registers with VPTESTMW for range selection, achieving
|
|
17
|
+
* ~6 cycles for 32 FP8 conversions. E5M2-to-F16 simplifies to VPSLLW due to matching exponent bias.
|
|
18
|
+
*/
|
|
19
|
+
#ifndef NK_CAST_ICELAKE_H
|
|
20
|
+
#define NK_CAST_ICELAKE_H
|
|
21
|
+
|
|
22
|
+
#if NK_TARGET_X86_
|
|
23
|
+
#if NK_TARGET_ICELAKE
|
|
24
|
+
|
|
25
|
+
#include "numkong/types.h"
|
|
26
|
+
#include "numkong/cast/skylake.h"
|
|
27
|
+
|
|
28
|
+
#if defined(__cplusplus)
|
|
29
|
+
extern "C" {
|
|
30
|
+
#endif
|
|
31
|
+
|
|
32
|
+
#if defined(__clang__)
|
|
33
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
|
|
34
|
+
apply_to = function)
|
|
35
|
+
#elif defined(__GNUC__)
|
|
36
|
+
#pragma GCC push_options
|
|
37
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
38
|
+
#endif
|
|
39
|
+
|
|
40
|
+
#pragma region - Vectorized Conversions
|
|
41
|
+
|
|
42
|
+
/** @brief Convert 32x e4m3 → 32x bf16 via arithmetic + 8-entry subnormal LUT (AVX-512BW).
|
|
43
|
+
* E4M3 format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
|
|
44
|
+
* Normal values (exp != 0): BF16 = sign | ((lower7 << 4) + 0x3C00).
|
|
45
|
+
* Subnormals (exp == 0, 8 values): looked up from 8-entry LUT via permutexvar.
|
|
46
|
+
* Memory: 16 bytes (8 × 16-bit entries) vs 256 bytes (128-entry LUT). OCP FP8 v1.0. */
|
|
47
|
+
NK_INTERNAL __m512i nk_e4m3x32_to_bf16x32_icelake_(__m256i e4m3x32) {
|
|
48
|
+
__m512i e4m3_i16x32 = _mm512_cvtepu8_epi16(e4m3x32);
|
|
49
|
+
__m512i sign_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16((short)0x80));
|
|
50
|
+
__m512i lower7_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16(0x7F));
|
|
51
|
+
|
|
52
|
+
// Normal path: BF16 = ((lower7 << 4) + 0x3C00) | (sign << 8)
|
|
53
|
+
// Formula: E4M3 exp=e, mant=m → BF16 exp = e+120 (bias 7→127), mant = m<<4
|
|
54
|
+
__m512i normal_abs_i16x32 = _mm512_add_epi16(_mm512_slli_epi16(lower7_i16x32, 4), _mm512_set1_epi16(0x3C00));
|
|
55
|
+
|
|
56
|
+
// Subnormal LUT (8 entries, repeated 4x for all lanes): E4M3 subnormals are mant × 2^(-9)
|
|
57
|
+
// Values: 0, 1/512, 2/512, 3/512, 4/512, 5/512, 6/512, 7/512
|
|
58
|
+
__m512i subn_lut_i16x32 = _mm512_set_epi16( //
|
|
59
|
+
0x3C60, 0x3C40, 0x3C20, 0x3C00, 0x3BC0, 0x3B80, 0x3B00, 0x0000, // lane 3
|
|
60
|
+
0x3C60, 0x3C40, 0x3C20, 0x3C00, 0x3BC0, 0x3B80, 0x3B00, 0x0000, // lane 2
|
|
61
|
+
0x3C60, 0x3C40, 0x3C20, 0x3C00, 0x3BC0, 0x3B80, 0x3B00, 0x0000, // lane 1
|
|
62
|
+
0x3C60, 0x3C40, 0x3C20, 0x3C00, 0x3BC0, 0x3B80, 0x3B00, 0x0000); // lane 0
|
|
63
|
+
|
|
64
|
+
// Lookup subnormals via permutexvar (use lower 3 bits of mantissa as index)
|
|
65
|
+
__m512i mant_idx_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16(0x07));
|
|
66
|
+
__m512i subnorm_abs_i16x32 = _mm512_permutexvar_epi16(mant_idx_i16x32, subn_lut_i16x32);
|
|
67
|
+
|
|
68
|
+
// Blend: if exponent == 0, use subnormal; else use normal
|
|
69
|
+
__m512i exp_bits_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16(0x78));
|
|
70
|
+
__mmask32 is_subnormal = _mm512_cmpeq_epi16_mask(exp_bits_i16x32, _mm512_setzero_si512());
|
|
71
|
+
__m512i result_abs_i16x32 = _mm512_mask_blend_epi16(is_subnormal, normal_abs_i16x32, subnorm_abs_i16x32);
|
|
72
|
+
|
|
73
|
+
// Apply sign: shift E4M3 bit 7 to BF16 bit 15
|
|
74
|
+
sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
|
|
75
|
+
return _mm512_or_si512(result_abs_i16x32, sign_i16x32);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
/** @brief Convert 32x e5m2 → 32x bf16 via arithmetic + 4-entry subnormal LUT (AVX-512BW).
|
|
79
|
+
* E5M2 format: S EEEEE MM (bias=15). BF16: S EEEEEEEE MMMMMMM (bias=127).
|
|
80
|
+
* Normal values (exp != 0): BF16 = sign | ((lower7 << 5) + 0x3800).
|
|
81
|
+
* Subnormals (exp == 0, 4 values): looked up from 4-entry LUT via permutexvar.
|
|
82
|
+
* Memory: 8 bytes (4 × 16-bit entries) vs 256 bytes (128-entry LUT). OCP FP8 v1.0. */
|
|
83
|
+
NK_INTERNAL __m512i nk_e5m2x32_to_bf16x32_icelake_(__m256i e5m2x32) {
|
|
84
|
+
__m512i e5m2_i16x32 = _mm512_cvtepu8_epi16(e5m2x32);
|
|
85
|
+
__m512i sign_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16((short)0x80));
|
|
86
|
+
__m512i lower7_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16(0x7F));
|
|
87
|
+
|
|
88
|
+
// Normal path: BF16 = ((lower7 << 5) + 0x3800) | (sign << 8)
|
|
89
|
+
// Formula: E5M2 exp=e, mant=m → BF16 exp = e+112 (bias 15→127), mant = m<<5
|
|
90
|
+
__m512i normal_abs_i16x32 = _mm512_add_epi16(_mm512_slli_epi16(lower7_i16x32, 5), _mm512_set1_epi16(0x3800));
|
|
91
|
+
|
|
92
|
+
// Subnormal LUT (4 entries, repeated 8x for all lanes): E5M2 subnormals are mant × 2^(-16)
|
|
93
|
+
// Values: 0, 1/65536, 2/65536, 3/65536 (4 entries, then zeros for padding to 8)
|
|
94
|
+
__m512i subn_lut_i16x32 = _mm512_set_epi16( //
|
|
95
|
+
0x0000, 0x0000, 0x0000, 0x0000, 0x3840, 0x3800, 0x3780, 0x0000, // lanes 3-2 (16 entries)
|
|
96
|
+
0x0000, 0x0000, 0x0000, 0x0000, 0x3840, 0x3800, 0x3780, 0x0000, // lanes 1-0 (16 entries)
|
|
97
|
+
0x0000, 0x0000, 0x0000, 0x0000, 0x3840, 0x3800, 0x3780, 0x0000, // repeat for remaining
|
|
98
|
+
0x0000, 0x0000, 0x0000, 0x0000, 0x3840, 0x3800, 0x3780, 0x0000); // all 32 entries
|
|
99
|
+
|
|
100
|
+
// Lookup subnormals via permutexvar (use lower 2 bits of mantissa as index)
|
|
101
|
+
__m512i mant_idx_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16(0x03));
|
|
102
|
+
__m512i subnorm_abs_i16x32 = _mm512_permutexvar_epi16(mant_idx_i16x32, subn_lut_i16x32);
|
|
103
|
+
|
|
104
|
+
// Blend: if exponent == 0, use subnormal; else use normal
|
|
105
|
+
__m512i exp_bits_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16(0x7C));
|
|
106
|
+
__mmask32 is_subnormal = _mm512_cmpeq_epi16_mask(exp_bits_i16x32, _mm512_setzero_si512());
|
|
107
|
+
__m512i result_abs_i16x32 = _mm512_mask_blend_epi16(is_subnormal, normal_abs_i16x32, subnorm_abs_i16x32);
|
|
108
|
+
|
|
109
|
+
// Apply sign: shift E5M2 bit 7 to BF16 bit 15
|
|
110
|
+
sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
|
|
111
|
+
return _mm512_or_si512(result_abs_i16x32, sign_i16x32);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
/** @brief Convert 32x e2m3 → 32x bf16 via 32-entry LUT lookup (AVX-512BW).
|
|
115
|
+
* E2M3 format: S EE MMM (bias=1, 6 bits total: sign at bit 5, magnitude bits 4-0).
|
|
116
|
+
* BF16: S EEEEEEEE MMMMMMM (bias=127). Uses single permutexvar; sign handled separately.
|
|
117
|
+
* Subnormals (exp=0): value = mant/8. OCP Microscaling Formats v1.0. */
|
|
118
|
+
NK_INTERNAL __m512i nk_e2m3x32_to_bf16x32_icelake_(__m256i e2m3x32) {
|
|
119
|
+
__m512i e2m3_i16x32 = _mm512_cvtepu8_epi16(e2m3x32);
|
|
120
|
+
__m512i sign_i16x32 = _mm512_and_si512(e2m3_i16x32, _mm512_set1_epi16(0x20)); // E2M3 sign at bit 5
|
|
121
|
+
__m512i idx_i16x32 = _mm512_and_si512(e2m3_i16x32, _mm512_set1_epi16(0x1F));
|
|
122
|
+
|
|
123
|
+
// 32-entry LUT for E2M3 magnitude (5 bits: bits [4:3]=exp, bits [2:0]=mant)
|
|
124
|
+
// E2M3: bias=1, range [0, 7.5] for positive, subnormals = mant/8 (OCP MX v1.0)
|
|
125
|
+
// BF16 = (bf16_exp << 7) | (bf16_mant), where bf16_exp = e2m3_exp + 126, bf16_mant = e2m3_mant << 4
|
|
126
|
+
__m512i const lut_i16x32 = _mm512_set_epi16( //
|
|
127
|
+
0x40F0, 0x40E0, 0x40D0, 0x40C0, 0x40B0, 0x40A0, 0x4090, 0x4080, // [31-24] exp=3: bf16_exp=129
|
|
128
|
+
0x4070, 0x4060, 0x4050, 0x4040, 0x4030, 0x4020, 0x4010, 0x4000, // [23-16] exp=2: bf16_exp=128
|
|
129
|
+
0x3FF0, 0x3FE0, 0x3FD0, 0x3FC0, 0x3FB0, 0x3FA0, 0x3F90, 0x3F80, // [15-8] exp=1: bf16_exp=127
|
|
130
|
+
0x3F60, 0x3F40, 0x3F20, 0x3F00, 0x3EC0, 0x3E80, 0x3E00, 0x0000); // [7-0] exp=0: subnormals 7/8..1/8, 0
|
|
131
|
+
|
|
132
|
+
// Single permutexvar for 32-entry lookup
|
|
133
|
+
__m512i result_i16x32 = _mm512_permutexvar_epi16(idx_i16x32, lut_i16x32);
|
|
134
|
+
|
|
135
|
+
// Apply sign: shift E2M3 bit 5 to BF16 bit 15, then OR
|
|
136
|
+
sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 10);
|
|
137
|
+
return _mm512_or_si512(result_i16x32, sign_i16x32);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
/** @brief Convert 32x e3m2 → 32x bf16 via 32-entry LUT lookup (AVX-512BW).
|
|
141
|
+
* E3M2 format: S EEE MM (bias=3, 6 bits total: sign at bit 7, magnitude bits 4-0).
|
|
142
|
+
* BF16: S EEEEEEEE MMMMMMM (bias=127). Uses single permutexvar; sign handled separately. */
|
|
143
|
+
NK_INTERNAL __m512i nk_e3m2x32_to_bf16x32_icelake_(__m256i e3m2x32) {
|
|
144
|
+
__m512i e3m2_i16x32 = _mm512_cvtepu8_epi16(e3m2x32);
|
|
145
|
+
__m512i sign_i16x32 = _mm512_and_si512(e3m2_i16x32, _mm512_set1_epi16(0x20)); // E3M2 sign at bit 5
|
|
146
|
+
__m512i idx_i16x32 = _mm512_and_si512(e3m2_i16x32, _mm512_set1_epi16(0x1F));
|
|
147
|
+
|
|
148
|
+
// 32-entry LUT for E3M2 magnitude (5 bits: bits [4:2]=exp, bits [1:0]=mant)
|
|
149
|
+
// E3M2: bias=3, range [0, 28] for positive, subnormals = mant/16 (OCP Microscaling v1.0)
|
|
150
|
+
// BF16 = (bf16_exp << 7) | (bf16_mant), where bf16_exp = e3m2_exp + 124, bf16_mant = e3m2_mant << 5
|
|
151
|
+
__m512i const lut_i16x32 = _mm512_set_epi16( //
|
|
152
|
+
0x41E0, 0x41C0, 0x41A0, 0x4180, // [31-28] exp=7, mant=3-0: bf16_exp=131
|
|
153
|
+
0x4160, 0x4140, 0x4120, 0x4100, // [27-24] exp=6, mant=3-0: bf16_exp=130
|
|
154
|
+
0x40E0, 0x40C0, 0x40A0, 0x4080, // [23-20] exp=5, mant=3-0: bf16_exp=129
|
|
155
|
+
0x4060, 0x4040, 0x4020, 0x4000, // [19-16] exp=4, mant=3-0: bf16_exp=128
|
|
156
|
+
0x3FE0, 0x3FC0, 0x3FA0, 0x3F80, // [15-12] exp=3, mant=3-0: bf16_exp=127
|
|
157
|
+
0x3F60, 0x3F40, 0x3F20, 0x3F00, // [11-8] exp=2, mant=3-0: bf16_exp=126
|
|
158
|
+
0x3EE0, 0x3EC0, 0x3EA0, 0x3E80, // [7-4] exp=1, mant=3-0: bf16_exp=125
|
|
159
|
+
0x3E40, 0x3E00, 0x3D80, 0x0000); // [3-0] exp=0: subnormals 3/16, 2/16, 1/16, 0
|
|
160
|
+
|
|
161
|
+
// Single permutexvar for 32-entry lookup
|
|
162
|
+
__m512i result_i16x32 = _mm512_permutexvar_epi16(idx_i16x32, lut_i16x32);
|
|
163
|
+
|
|
164
|
+
// Apply sign: shift E3M2 bit 5 to BF16 bit 15, then OR
|
|
165
|
+
sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 10);
|
|
166
|
+
return _mm512_or_si512(result_i16x32, sign_i16x32);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
/** @brief Convert 32x e4m3 → 32x f16 via 128-entry LUT lookup (AVX-512BW).
|
|
170
|
+
* E4M3 format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
171
|
+
* Uses permutex2var for fast LUT lookup; sign handled separately via shift+OR.
|
|
172
|
+
* Handles all corner cases: zero, subnormals, normals, and NaN. */
|
|
173
|
+
NK_INTERNAL __m512i nk_e4m3x32_to_f16x32_icelake_(__m256i e4m3x32) {
|
|
174
|
+
__m512i e4m3_i16x32 = _mm512_cvtepu8_epi16(e4m3x32);
|
|
175
|
+
__m512i sign_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16((short)0x80));
|
|
176
|
+
__m512i idx_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16(0x7F));
|
|
177
|
+
|
|
178
|
+
// 128-entry LUT for E4M3 absolute values to F16, split into 4x32 chunks
|
|
179
|
+
// Subnormals (idx 0-7): 0, 1/512, ..., 7/512 mapped to F16
|
|
180
|
+
// Normals (idx 8-126): F16 = (lower7 << 7) + 0x2000
|
|
181
|
+
// NaN (idx 127): 0x7E00
|
|
182
|
+
__m512i const lut0_i16x32 = _mm512_set_epi16( // indices 0-31
|
|
183
|
+
0x2F80, 0x2F00, 0x2E80, 0x2E00, 0x2D80, 0x2D00, 0x2C80, 0x2C00, // idx 31-24
|
|
184
|
+
0x2B80, 0x2B00, 0x2A80, 0x2A00, 0x2980, 0x2900, 0x2880, 0x2800, // idx 23-16
|
|
185
|
+
0x2780, 0x2700, 0x2680, 0x2600, 0x2580, 0x2500, 0x2480, 0x2400, // idx 15-8
|
|
186
|
+
0x2300, 0x2200, 0x2100, 0x2000, 0x1E00, 0x1C00, 0x1800, 0x0000); // idx 7-0
|
|
187
|
+
__m512i const lut1_i16x32 = _mm512_set_epi16( // indices 32-63
|
|
188
|
+
0x3F80, 0x3F00, 0x3E80, 0x3E00, 0x3D80, 0x3D00, 0x3C80, 0x3C00, // idx 63-56
|
|
189
|
+
0x3B80, 0x3B00, 0x3A80, 0x3A00, 0x3980, 0x3900, 0x3880, 0x3800, // idx 55-48
|
|
190
|
+
0x3780, 0x3700, 0x3680, 0x3600, 0x3580, 0x3500, 0x3480, 0x3400, // idx 47-40
|
|
191
|
+
0x3380, 0x3300, 0x3280, 0x3200, 0x3180, 0x3100, 0x3080, 0x3000); // idx 39-32
|
|
192
|
+
__m512i const lut2_i16x32 = _mm512_set_epi16( // indices 64-95
|
|
193
|
+
0x4F80, 0x4F00, 0x4E80, 0x4E00, 0x4D80, 0x4D00, 0x4C80, 0x4C00, // idx 95-88
|
|
194
|
+
0x4B80, 0x4B00, 0x4A80, 0x4A00, 0x4980, 0x4900, 0x4880, 0x4800, // idx 87-80
|
|
195
|
+
0x4780, 0x4700, 0x4680, 0x4600, 0x4580, 0x4500, 0x4480, 0x4400, // idx 79-72
|
|
196
|
+
0x4380, 0x4300, 0x4280, 0x4200, 0x4180, 0x4100, 0x4080, 0x4000); // idx 71-64
|
|
197
|
+
__m512i const lut3_i16x32 = _mm512_set_epi16( // indices 96-127
|
|
198
|
+
0x7E00, 0x5F00, 0x5E80, 0x5E00, 0x5D80, 0x5D00, 0x5C80, 0x5C00, // idx 127-120
|
|
199
|
+
0x5B80, 0x5B00, 0x5A80, 0x5A00, 0x5980, 0x5900, 0x5880, 0x5800, // idx 119-112
|
|
200
|
+
0x5780, 0x5700, 0x5680, 0x5600, 0x5580, 0x5500, 0x5480, 0x5400, // idx 111-104
|
|
201
|
+
0x5380, 0x5300, 0x5280, 0x5200, 0x5180, 0x5100, 0x5080, 0x5000); // idx 103-96
|
|
202
|
+
|
|
203
|
+
// 2x permutex2var for 64-entry lookup each, then select based on bit 6
|
|
204
|
+
__m512i result_low_i16x32 = _mm512_permutex2var_epi16(lut0_i16x32, idx_i16x32, lut1_i16x32);
|
|
205
|
+
__m512i result_high_i16x32 = _mm512_permutex2var_epi16(lut2_i16x32, idx_i16x32, lut3_i16x32);
|
|
206
|
+
|
|
207
|
+
// Select between low (idx 0-63) and high (idx 64-127) based on bit 6
|
|
208
|
+
__mmask32 use_high_mask = _mm512_test_epi16_mask(idx_i16x32, _mm512_set1_epi16(0x40));
|
|
209
|
+
__m512i result_i16x32 = _mm512_mask_mov_epi16(result_low_i16x32, use_high_mask, result_high_i16x32);
|
|
210
|
+
|
|
211
|
+
// Apply sign: shift sign bit to bit 15, then OR
|
|
212
|
+
sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
|
|
213
|
+
return _mm512_or_si512(result_i16x32, sign_i16x32);
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
/** @brief Convert 32x e5m2 → 32x f16 via simple bit shift (AVX-512BW).
|
|
217
|
+
* E5M2 format: S EEEEE MM (bias=15). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
218
|
+
* Same exponent bias means F16 = (lower7 << 8) | (sign << 15).
|
|
219
|
+
* Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
|
|
220
|
+
NK_INTERNAL __m512i nk_e5m2x32_to_f16x32_icelake_(__m256i e5m2x32) {
|
|
221
|
+
__m512i e5m2_i16x32 = _mm512_cvtepu8_epi16(e5m2x32);
|
|
222
|
+
__m512i sign_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16((short)0x80));
|
|
223
|
+
__m512i lower7_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16(0x7F));
|
|
224
|
+
|
|
225
|
+
// F16 = (lower7 << 8) | (sign << 15)
|
|
226
|
+
// Works for all cases: subnormals, normals, infinity, and NaN
|
|
227
|
+
__m512i result_i16x32 = _mm512_slli_epi16(lower7_i16x32, 8);
|
|
228
|
+
sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
|
|
229
|
+
return _mm512_or_si512(result_i16x32, sign_i16x32);
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
/** @brief Convert 32x bf16 → 32x e4m3 via bit manipulation (AVX-512BW).
|
|
233
|
+
* BF16: S EEEEEEEE MMMMMMM (bias=127). E4M3: S EEEE MMM (bias=7).
|
|
234
|
+
* Handles normal, subnormal, and overflow cases with RNE rounding. */
|
|
235
|
+
NK_INTERNAL __m256i nk_bf16x32_to_e4m3x32_icelake_(__m512i bf16x32) {
|
|
236
|
+
__m512i sign_i16x32 = _mm512_srli_epi16(bf16x32, 15);
|
|
237
|
+
__m512i bf16_exp_i16x32 = _mm512_and_si512(_mm512_srli_epi16(bf16x32, 7), _mm512_set1_epi16(0xFF));
|
|
238
|
+
|
|
239
|
+
// Round mantissa from 7 to 3 bits using RNE (round to nearest, ties to even)
|
|
240
|
+
__m512i significand_i16x32 = _mm512_or_si512(_mm512_and_si512(bf16x32, _mm512_set1_epi16(0x7F)),
|
|
241
|
+
_mm512_set1_epi16(0x80)); // Add implicit 1 bit
|
|
242
|
+
__m512i lsb_i16x32 = _mm512_and_si512(_mm512_srli_epi16(significand_i16x32, 4), _mm512_set1_epi16(1));
|
|
243
|
+
__m512i rounding_bias_i16x32 = _mm512_add_epi16(_mm512_set1_epi16(0x07), lsb_i16x32);
|
|
244
|
+
__m512i rounded_sig_i16x32 = _mm512_add_epi16(significand_i16x32, rounding_bias_i16x32);
|
|
245
|
+
__m512i carry_i16x32 = _mm512_srli_epi16(rounded_sig_i16x32, 8); // Carry into exponent if bit 8 set
|
|
246
|
+
__m512i bf16_mantissa_i16x32 = _mm512_and_si512(_mm512_srli_epi16(rounded_sig_i16x32, 4), _mm512_set1_epi16(0x07));
|
|
247
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
248
|
+
bf16_mantissa_i16x32 = _mm512_andnot_si512(_mm512_slli_epi16(carry_i16x32, 15), bf16_mantissa_i16x32);
|
|
249
|
+
__m512i e4m3_exp_i16x32 = _mm512_sub_epi16(_mm512_add_epi16(bf16_exp_i16x32, carry_i16x32), _mm512_set1_epi16(120));
|
|
250
|
+
|
|
251
|
+
// Detect underflow (exp <= 0) and overflow (exp > 15)
|
|
252
|
+
__mmask32 is_subnormal = _mm512_cmpgt_epi16_mask(_mm512_set1_epi16(1), e4m3_exp_i16x32);
|
|
253
|
+
__mmask32 overflow = _mm512_cmpgt_epi16_mask(e4m3_exp_i16x32, _mm512_set1_epi16(15));
|
|
254
|
+
|
|
255
|
+
// Normal path: clamp exp to [1,15]
|
|
256
|
+
// e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
|
|
257
|
+
__m512i clamped_exp_i16x32 = _mm512_max_epi16(e4m3_exp_i16x32, _mm512_set1_epi16(1));
|
|
258
|
+
clamped_exp_i16x32 = _mm512_min_epi16(clamped_exp_i16x32, _mm512_set1_epi16(15));
|
|
259
|
+
__mmask32 is_max_exp = _mm512_cmpeq_epi16_mask(clamped_exp_i16x32, _mm512_set1_epi16(15));
|
|
260
|
+
__m512i max_mantissa_i16x32 = _mm512_mask_blend_epi16(is_max_exp, _mm512_set1_epi16(7), _mm512_set1_epi16(6));
|
|
261
|
+
__m512i normal_mantissa_i16x32 = _mm512_min_epi16(bf16_mantissa_i16x32, max_mantissa_i16x32);
|
|
262
|
+
normal_mantissa_i16x32 = _mm512_mask_blend_epi16(overflow, normal_mantissa_i16x32, _mm512_set1_epi16(0x06));
|
|
263
|
+
__m512i normal_e4m3_i16x32 = _mm512_or_si512(
|
|
264
|
+
_mm512_slli_epi16(sign_i16x32, 7),
|
|
265
|
+
_mm512_or_si512(_mm512_slli_epi16(clamped_exp_i16x32, 3), normal_mantissa_i16x32));
|
|
266
|
+
|
|
267
|
+
// Subnormal path: compute via f32 to get correct rounding
|
|
268
|
+
// bf16 to f32 is just left shift by 16
|
|
269
|
+
__m512i bf16_low_i32x16 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(bf16x32));
|
|
270
|
+
__m512i bf16_high_i32x16 = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(bf16x32, 1));
|
|
271
|
+
__m512 f32_low = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
|
|
272
|
+
__m512 f32_high = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
|
|
273
|
+
__m512 abs_f32_low = _mm512_and_ps(f32_low, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
274
|
+
__m512 abs_f32_high = _mm512_and_ps(f32_high, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
275
|
+
__m512 scaled_low = _mm512_mul_ps(abs_f32_low, _mm512_set1_ps(512.0f));
|
|
276
|
+
__m512 scaled_high = _mm512_mul_ps(abs_f32_high, _mm512_set1_ps(512.0f));
|
|
277
|
+
__m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low);
|
|
278
|
+
__m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high);
|
|
279
|
+
__m256i subnorm_mant_low_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_low_i32x16);
|
|
280
|
+
__m256i subnorm_mant_high_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_high_i32x16);
|
|
281
|
+
__m512i subnorm_mantissa_i16x32 = _mm512_inserti64x4(_mm512_castsi256_si512(subnorm_mant_low_i16x16),
|
|
282
|
+
subnorm_mant_high_i16x16, 1);
|
|
283
|
+
__mmask32 promotes_to_normal = _mm512_cmpgt_epi16_mask(subnorm_mantissa_i16x32, _mm512_set1_epi16(7));
|
|
284
|
+
subnorm_mantissa_i16x32 = _mm512_min_epi16(subnorm_mantissa_i16x32, _mm512_set1_epi16(7));
|
|
285
|
+
subnorm_mantissa_i16x32 = _mm512_max_epi16(subnorm_mantissa_i16x32, _mm512_setzero_si512());
|
|
286
|
+
__m512i subnorm_e4m3_i16x32 = _mm512_or_si512(_mm512_slli_epi16(sign_i16x32, 7), subnorm_mantissa_i16x32);
|
|
287
|
+
__m512i first_normal_e4m3_i16x32 = _mm512_or_si512(_mm512_slli_epi16(sign_i16x32, 7), _mm512_set1_epi16(0x08));
|
|
288
|
+
subnorm_e4m3_i16x32 = _mm512_mask_blend_epi16(promotes_to_normal, subnorm_e4m3_i16x32, first_normal_e4m3_i16x32);
|
|
289
|
+
|
|
290
|
+
// Blend: use subnormal result when exp <= 0
|
|
291
|
+
__m512i e4m3_i16x32 = _mm512_mask_blend_epi16(is_subnormal, normal_e4m3_i16x32, subnorm_e4m3_i16x32);
|
|
292
|
+
|
|
293
|
+
// Pack 32 i16s to 32 unsigned i8s via AVX-512BW
|
|
294
|
+
return _mm512_cvtepi16_epi8(e4m3_i16x32);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
/** @brief Convert 32x bf16 → 32x e5m2 via bit manipulation (AVX-512BW).
|
|
298
|
+
* BF16: S EEEEEEEE MMMMMMM (bias=127). E5M2: S EEEEE MM (bias=15).
|
|
299
|
+
* Handles normal, subnormal, and overflow cases with RNE rounding. */
|
|
300
|
+
NK_INTERNAL __m256i nk_bf16x32_to_e5m2x32_icelake_(__m512i bf16x32) {
|
|
301
|
+
__m512i sign_i16x32 = _mm512_srli_epi16(bf16x32, 15);
|
|
302
|
+
__m512i bf16_exp_i16x32 = _mm512_and_si512(_mm512_srli_epi16(bf16x32, 7), _mm512_set1_epi16(0xFF));
|
|
303
|
+
|
|
304
|
+
// Round mantissa from 7 to 2 bits using RNE (round to nearest, ties to even)
|
|
305
|
+
__m512i significand_i16x32 = _mm512_or_si512(_mm512_and_si512(bf16x32, _mm512_set1_epi16(0x7F)),
|
|
306
|
+
_mm512_set1_epi16(0x80)); // Add implicit 1 bit
|
|
307
|
+
__m512i lsb_i16x32 = _mm512_and_si512(_mm512_srli_epi16(significand_i16x32, 5), _mm512_set1_epi16(1));
|
|
308
|
+
__m512i rounding_bias_i16x32 = _mm512_add_epi16(_mm512_set1_epi16(0x0F), lsb_i16x32);
|
|
309
|
+
__m512i rounded_sig_i16x32 = _mm512_add_epi16(significand_i16x32, rounding_bias_i16x32);
|
|
310
|
+
__m512i carry_i16x32 = _mm512_srli_epi16(rounded_sig_i16x32, 8); // Carry into exponent if bit 8 set
|
|
311
|
+
__m512i bf16_mantissa_i16x32 = _mm512_and_si512(_mm512_srli_epi16(rounded_sig_i16x32, 5), _mm512_set1_epi16(0x03));
|
|
312
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
313
|
+
bf16_mantissa_i16x32 = _mm512_andnot_si512(_mm512_slli_epi16(carry_i16x32, 15), bf16_mantissa_i16x32);
|
|
314
|
+
__m512i e5m2_exp_i16x32 = _mm512_sub_epi16(_mm512_add_epi16(bf16_exp_i16x32, carry_i16x32), _mm512_set1_epi16(112));
|
|
315
|
+
|
|
316
|
+
// Detect subnormal (exp <= 0) and overflow (exp > 31)
|
|
317
|
+
__mmask32 is_subnormal = _mm512_cmpgt_epi16_mask(_mm512_set1_epi16(1), e5m2_exp_i16x32);
|
|
318
|
+
__mmask32 overflow = _mm512_cmpgt_epi16_mask(e5m2_exp_i16x32, _mm512_set1_epi16(31));
|
|
319
|
+
|
|
320
|
+
// Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mantissa=0 = 0x7C)
|
|
321
|
+
__m512i clamped_exp_i16x32 = _mm512_max_epi16(e5m2_exp_i16x32, _mm512_set1_epi16(1));
|
|
322
|
+
clamped_exp_i16x32 = _mm512_min_epi16(clamped_exp_i16x32, _mm512_set1_epi16(31));
|
|
323
|
+
__m512i normal_mantissa_i16x32 = _mm512_mask_blend_epi16(overflow, bf16_mantissa_i16x32, _mm512_setzero_si512());
|
|
324
|
+
__m512i normal_e5m2_i16x32 = _mm512_or_si512(
|
|
325
|
+
_mm512_slli_epi16(sign_i16x32, 7),
|
|
326
|
+
_mm512_or_si512(_mm512_slli_epi16(clamped_exp_i16x32, 2), normal_mantissa_i16x32));
|
|
327
|
+
|
|
328
|
+
// Subnormal path: compute via f32 to get correct rounding
|
|
329
|
+
__m512i bf16_low_i32x16 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(bf16x32));
|
|
330
|
+
__m512i bf16_high_i32x16 = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(bf16x32, 1));
|
|
331
|
+
__m512 f32_low = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
|
|
332
|
+
__m512 f32_high = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
|
|
333
|
+
__m512 abs_f32_low = _mm512_and_ps(f32_low, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
334
|
+
__m512 abs_f32_high = _mm512_and_ps(f32_high, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
335
|
+
__m512 scaled_low = _mm512_mul_ps(abs_f32_low, _mm512_set1_ps(65536.0f));
|
|
336
|
+
__m512 scaled_high = _mm512_mul_ps(abs_f32_high, _mm512_set1_ps(65536.0f));
|
|
337
|
+
__m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low);
|
|
338
|
+
__m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high);
|
|
339
|
+
__m256i subnorm_mant_low_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_low_i32x16);
|
|
340
|
+
__m256i subnorm_mant_high_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_high_i32x16);
|
|
341
|
+
__m512i subnorm_mantissa_i16x32 = _mm512_inserti64x4(_mm512_castsi256_si512(subnorm_mant_low_i16x16),
|
|
342
|
+
subnorm_mant_high_i16x16, 1);
|
|
343
|
+
__mmask32 promotes_to_normal = _mm512_cmpgt_epi16_mask(subnorm_mantissa_i16x32, _mm512_set1_epi16(3));
|
|
344
|
+
subnorm_mantissa_i16x32 = _mm512_min_epi16(subnorm_mantissa_i16x32, _mm512_set1_epi16(3));
|
|
345
|
+
subnorm_mantissa_i16x32 = _mm512_max_epi16(subnorm_mantissa_i16x32, _mm512_setzero_si512());
|
|
346
|
+
__m512i subnorm_e5m2_i16x32 = _mm512_or_si512(_mm512_slli_epi16(sign_i16x32, 7), subnorm_mantissa_i16x32);
|
|
347
|
+
__m512i first_normal_e5m2_i16x32 = _mm512_or_si512(_mm512_slli_epi16(sign_i16x32, 7), _mm512_set1_epi16(0x04));
|
|
348
|
+
subnorm_e5m2_i16x32 = _mm512_mask_blend_epi16(promotes_to_normal, subnorm_e5m2_i16x32, first_normal_e5m2_i16x32);
|
|
349
|
+
|
|
350
|
+
// Blend: use subnormal result when exp <= 0
|
|
351
|
+
__m512i e5m2_i16x32 = _mm512_mask_blend_epi16(is_subnormal, normal_e5m2_i16x32, subnorm_e5m2_i16x32);
|
|
352
|
+
|
|
353
|
+
// Pack 32 i16s to 32 unsigned i8s via AVX-512BW
|
|
354
|
+
return _mm512_cvtepi16_epi8(e5m2_i16x32);
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
/** @brief Load 32x e4m3 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
|
|
358
|
+
NK_INTERNAL void nk_load_e4m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst) {
|
|
359
|
+
dst->zmm = nk_e4m3x32_to_bf16x32_icelake_(_mm256_loadu_si256((__m256i const *)src));
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
/** @brief Partial load n e4m3 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
|
|
363
|
+
NK_INTERNAL void nk_partial_load_e4m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
364
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
365
|
+
__m256i e4m3_partial = _mm256_maskz_loadu_epi8(mask, src);
|
|
366
|
+
dst->zmm = nk_e4m3x32_to_bf16x32_icelake_(e4m3_partial);
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
/** @brief Load 32x e5m2 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
|
|
370
|
+
NK_INTERNAL void nk_load_e5m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst) {
|
|
371
|
+
dst->zmm = nk_e5m2x32_to_bf16x32_icelake_(_mm256_loadu_si256((__m256i const *)src));
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
/** @brief Partial load n e5m2 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
|
|
375
|
+
NK_INTERNAL void nk_partial_load_e5m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
376
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
377
|
+
__m256i e5m2_partial = _mm256_maskz_loadu_epi8(mask, src);
|
|
378
|
+
dst->zmm = nk_e5m2x32_to_bf16x32_icelake_(e5m2_partial);
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
/** @brief Load 32x e2m3 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
|
|
382
|
+
NK_INTERNAL void nk_load_e2m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst) {
|
|
383
|
+
dst->zmm = nk_e2m3x32_to_bf16x32_icelake_(_mm256_loadu_si256((__m256i const *)src));
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
/** @brief Partial load n e2m3 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
|
|
387
|
+
NK_INTERNAL void nk_partial_load_e2m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
388
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
389
|
+
__m256i e2m3_partial = _mm256_maskz_loadu_epi8(mask, src);
|
|
390
|
+
dst->zmm = nk_e2m3x32_to_bf16x32_icelake_(e2m3_partial);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
/** @brief Load 32x e3m2 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
|
|
394
|
+
NK_INTERNAL void nk_load_e3m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst) {
|
|
395
|
+
dst->zmm = nk_e3m2x32_to_bf16x32_icelake_(_mm256_loadu_si256((__m256i const *)src));
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
/** @brief Partial load n e3m2 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
|
|
399
|
+
NK_INTERNAL void nk_partial_load_e3m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
400
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
401
|
+
__m256i e3m2_partial = _mm256_maskz_loadu_epi8(mask, src);
|
|
402
|
+
dst->zmm = nk_e3m2x32_to_bf16x32_icelake_(e3m2_partial);
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
#pragma endregion - Vectorized Conversions
|
|
406
|
+
|
|
407
|
+
#pragma region - Public API
|
|
408
|
+
|
|
409
|
+
NK_PUBLIC void nk_cast_icelake(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
410
|
+
// Group 1: Conversions to bf16 (e4m3 → bf16, e5m2 → bf16)
|
|
411
|
+
if (to_type == nk_bf16_k && (from_type == nk_e4m3_k || from_type == nk_e5m2_k)) {
|
|
412
|
+
nk_e4m3_t const *from_ptr = (nk_e4m3_t const *)from;
|
|
413
|
+
nk_bf16_t *to_ptr = (nk_bf16_t *)to;
|
|
414
|
+
for (nk_size_t idx = 0; idx < n; idx += 32) {
|
|
415
|
+
nk_size_t remaining = n - idx;
|
|
416
|
+
__mmask32 mask = (remaining >= 32) ? 0xFFFFFFFF : _bzhi_u32(0xFFFFFFFF, (unsigned)remaining);
|
|
417
|
+
__m256i in_f8x32 = _mm256_maskz_loadu_epi8(mask, from_ptr + idx);
|
|
418
|
+
__m512i out_bf16x32 = (from_type == nk_e4m3_k) ? nk_e4m3x32_to_bf16x32_icelake_(in_f8x32)
|
|
419
|
+
: nk_e5m2x32_to_bf16x32_icelake_(in_f8x32);
|
|
420
|
+
_mm512_mask_storeu_epi16(to_ptr + idx, mask, out_bf16x32);
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
// Group 2: Conversions FROM bf16 (bf16 → e4m3, bf16 → e5m2)
|
|
425
|
+
else if (from_type == nk_bf16_k && (to_type == nk_e4m3_k || to_type == nk_e5m2_k)) {
|
|
426
|
+
nk_bf16_t const *from_ptr = (nk_bf16_t const *)from;
|
|
427
|
+
nk_e4m3_t *to_ptr = (nk_e4m3_t *)to;
|
|
428
|
+
for (nk_size_t idx = 0; idx < n; idx += 32) {
|
|
429
|
+
nk_size_t remaining = n - idx;
|
|
430
|
+
__mmask32 mask = (remaining >= 32) ? 0xFFFFFFFF : _bzhi_u32(0xFFFFFFFF, (unsigned)remaining);
|
|
431
|
+
__m512i in_bf16x32 = _mm512_maskz_loadu_epi16(mask, from_ptr + idx);
|
|
432
|
+
__m256i out_f8x32 = (to_type == nk_e4m3_k) ? nk_bf16x32_to_e4m3x32_icelake_(in_bf16x32)
|
|
433
|
+
: nk_bf16x32_to_e5m2x32_icelake_(in_bf16x32);
|
|
434
|
+
_mm256_mask_storeu_epi8(to_ptr + idx, mask, out_f8x32);
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
// Group 3: Conversions to f16 (e4m3 → f16, e5m2 → f16)
|
|
439
|
+
else if (to_type == nk_f16_k && (from_type == nk_e4m3_k || from_type == nk_e5m2_k)) {
|
|
440
|
+
nk_e4m3_t const *from_ptr = (nk_e4m3_t const *)from;
|
|
441
|
+
nk_f16_t *to_ptr = (nk_f16_t *)to;
|
|
442
|
+
for (nk_size_t idx = 0; idx < n; idx += 32) {
|
|
443
|
+
nk_size_t remaining = n - idx;
|
|
444
|
+
__mmask32 mask = (remaining >= 32) ? 0xFFFFFFFF : _bzhi_u32(0xFFFFFFFF, (unsigned)remaining);
|
|
445
|
+
__m256i in_f8x32 = _mm256_maskz_loadu_epi8(mask, from_ptr + idx);
|
|
446
|
+
__m512i out_f16x32 = (from_type == nk_e4m3_k) ? nk_e4m3x32_to_f16x32_icelake_(in_f8x32)
|
|
447
|
+
: nk_e5m2x32_to_f16x32_icelake_(in_f8x32);
|
|
448
|
+
_mm512_mask_storeu_epi16(to_ptr + idx, mask, out_f16x32);
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
// Default: delegate to Skylake for all other conversions
|
|
453
|
+
else nk_cast_skylake(from, from_type, n, to, to_type);
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
#pragma endregion - Public API
|
|
457
|
+
|
|
458
|
+
#if defined(__clang__)
|
|
459
|
+
#pragma clang attribute pop
|
|
460
|
+
#elif defined(__GNUC__)
|
|
461
|
+
#pragma GCC pop_options
|
|
462
|
+
#endif
|
|
463
|
+
|
|
464
|
+
#if defined(__cplusplus)
|
|
465
|
+
} // extern "C"
|
|
466
|
+
#endif
|
|
467
|
+
|
|
468
|
+
#endif // NK_TARGET_ICELAKE
|
|
469
|
+
#endif // NK_TARGET_X86_
|
|
470
|
+
#endif // NK_CAST_ICELAKE_H
|