numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Type Conversions for Sapphire Rapids.
|
|
3
|
+
* @file include/numkong/cast/sapphire.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 2, 2026
|
|
6
|
+
*
|
|
7
|
+
* @section sapphire_cast_instructions Relevant Instructions
|
|
8
|
+
*
|
|
9
|
+
* Intrinsic Instruction Sapphire Genoa
|
|
10
|
+
* _mm_cvtss_sh VCVTSS2SH (XMM, XMM, XMM) 5cy @ p05 5cy @ p01
|
|
11
|
+
* _mm_cvtsh_ss VCVTSH2SS (XMM, XMM, XMM) 5cy @ p05 5cy @ p01
|
|
12
|
+
* _mm256_cvtepu8_epi16 VPMOVZXBW (YMM, XMM) 3cy @ p5 3cy @ p12
|
|
13
|
+
* _mm256_mul_ph VMULPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
|
|
14
|
+
* _mm256_cvtepi16_ph VCVTW2PH (YMM, YMM) 4cy @ p05 4cy @ p01
|
|
15
|
+
* _mm256_cvtph_epi16 VCVTPH2W (YMM, YMM) 4cy @ p05 4cy @ p01
|
|
16
|
+
* _mm256_mask_blend_epi16 VPBLENDMW (YMM, K, YMM, YMM) 1cy @ p05 1cy @ p0123
|
|
17
|
+
* _mm256_testn_epi16_mask VPTESTNMW (K, YMM, YMM) 3cy @ p5 3cy @ p0
|
|
18
|
+
* _mm256_cvtepi16_epi8 VPMOVWB (XMM, YMM) 4cy @ p5 4cy @ p12
|
|
19
|
+
* _mm_maskz_loadu_epi8 VMOVDQU8 (XMM {K}, M128) 7cy @ p23 7cy @ p23
|
|
20
|
+
* _mm256_mask_storeu_epi16 VMOVDQU16 (M256 {K}, YMM) 4cy @ p4 4cy @ p4
|
|
21
|
+
*/
|
|
22
|
+
#ifndef NK_CAST_SAPPHIRE_H
|
|
23
|
+
#define NK_CAST_SAPPHIRE_H
|
|
24
|
+
|
|
25
|
+
#if NK_TARGET_X86_
|
|
26
|
+
#if NK_TARGET_SAPPHIRE
|
|
27
|
+
|
|
28
|
+
#include "numkong/types.h"
|
|
29
|
+
#include "numkong/cast/icelake.h" // `nk_cast_icelake`
|
|
30
|
+
|
|
31
|
+
#if defined(__cplusplus)
|
|
32
|
+
extern "C" {
|
|
33
|
+
#endif
|
|
34
|
+
|
|
35
|
+
#if defined(__clang__)
|
|
36
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512fp16,f16c,fma,bmi,bmi2"))), \
|
|
37
|
+
apply_to = function)
|
|
38
|
+
#elif defined(__GNUC__)
|
|
39
|
+
#pragma GCC push_options
|
|
40
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
|
|
41
|
+
#endif
|
|
42
|
+
|
|
43
|
+
NK_PUBLIC void nk_f32_to_f16_sapphire(nk_f32_t const *from, nk_f16_t *to) {
|
|
44
|
+
*to = _mm_cvtsi128_si32(_mm_castph_si128(_mm_cvtss_sh(_mm_setzero_ph(), _mm_set_ss(*from))));
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
NK_PUBLIC void nk_f16_to_f32_sapphire(nk_f16_t const *from, nk_f32_t *to) {
|
|
48
|
+
*to = _mm_cvtss_f32(_mm_cvtsh_ss(_mm_setzero_ps(), _mm_castsi128_ph(_mm_cvtsi32_si128(*from))));
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
#pragma region - Vectorized Conversions
|
|
52
|
+
|
|
53
|
+
/** @brief Convert 16x e4m3 → 16x f16 via bit manipulation (AVX-512 FP16).
|
|
54
|
+
* E4M3 format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
55
|
+
* Normal: sign | ((exp+8)<<10) | (mant<<7).
|
|
56
|
+
* Subnormals (exp=0): value = mantissa ÷ 512, computed via f16 arithmetic. */
|
|
57
|
+
NK_INTERNAL __m256h nk_e4m3x16_to_f16x16_sapphire_(__m128i e4m3_i8x16) {
|
|
58
|
+
__m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3_i8x16);
|
|
59
|
+
|
|
60
|
+
// Extract fields
|
|
61
|
+
__m256i mantissa_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x07));
|
|
62
|
+
__m256i sign_i16x16 = _mm256_and_si256(_mm256_slli_epi16(e4m3_i16x16, 8), _mm256_set1_epi16((short)0x8000));
|
|
63
|
+
|
|
64
|
+
// Normal path: sign | ((exp+8)<<10) | (mantissa<<7) via single shift + bias add
|
|
65
|
+
__m256i exp_mantissa_i16x16 = _mm256_slli_epi16(_mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F)), 7);
|
|
66
|
+
__m256i exp_mantissa_biased_i16x16 = _mm256_add_epi16(exp_mantissa_i16x16, _mm256_set1_epi16(0x2000));
|
|
67
|
+
__m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, exp_mantissa_biased_i16x16);
|
|
68
|
+
|
|
69
|
+
// Subnormal fix: for exp==0 lanes, use (subnorm_abs | sign); else keep normal
|
|
70
|
+
__mmask16 is_subnormal = _mm256_testn_epi16_mask(e4m3_i16x16, _mm256_set1_epi16(0x78));
|
|
71
|
+
__m256h subnorm_abs_f16x16 = _mm256_mul_ph(_mm256_cvtepi16_ph(mantissa_i16x16),
|
|
72
|
+
_mm256_castsi256_ph(_mm256_set1_epi16(0x1800))); // 1/512
|
|
73
|
+
__m256i subnorm_signed_i16x16 = _mm256_or_si256(_mm256_castph_si256(subnorm_abs_f16x16), sign_i16x16);
|
|
74
|
+
__m256i result_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_i16x16, subnorm_signed_i16x16);
|
|
75
|
+
|
|
76
|
+
// NaN path: E4M3FN has NaN only when exp=15 AND mant=7 (lower 7 bits == 0x7F)
|
|
77
|
+
__mmask16 is_nan = _mm256_cmpeq_epi16_mask( //
|
|
78
|
+
_mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F)), //
|
|
79
|
+
_mm256_set1_epi16(0x7F)); //
|
|
80
|
+
__m256i nan_bits = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7E00)); // F16 quiet NaN
|
|
81
|
+
return _mm256_castsi256_ph(_mm256_mask_blend_epi16(is_nan, result_i16x16, nan_bits));
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
/** @brief Convert 16x e5m2 → 16x f16 via bit manipulation (AVX-512 FP16).
|
|
85
|
+
* E5M2 format: S EEEEE MM (bias=15). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
86
|
+
* Normal: sign | (exp<<10) | (mant<<8) (same exponent bias).
|
|
87
|
+
* Subnormals (exp=0): value = mantissa ÷ 65536, computed via f16 arithmetic. */
|
|
88
|
+
NK_INTERNAL __m256h nk_e5m2x16_to_f16x16_sapphire_(__m128i e5m2_i8x16) {
|
|
89
|
+
__m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2_i8x16);
|
|
90
|
+
|
|
91
|
+
// Extract fields
|
|
92
|
+
__m256i mantissa_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x03));
|
|
93
|
+
__m256i sign_i16x16 = _mm256_and_si256(_mm256_slli_epi16(e5m2_i16x16, 8), _mm256_set1_epi16((short)0x8000));
|
|
94
|
+
|
|
95
|
+
// Normal path: sign | (exp<<10) | (mant<<8) - same exponent bias so just shift lower7 by 8
|
|
96
|
+
__m256i exp_mantissa_i16x16 = _mm256_slli_epi16(_mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F)), 8);
|
|
97
|
+
__m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, exp_mantissa_i16x16);
|
|
98
|
+
|
|
99
|
+
// Subnormal fix: for exp==0 lanes, use (subnorm_abs | sign); else keep normal
|
|
100
|
+
__mmask16 is_subnormal = _mm256_testn_epi16_mask(e5m2_i16x16, _mm256_set1_epi16(0x7C));
|
|
101
|
+
__m256h subnorm_abs_f16x16 = _mm256_mul_ph(_mm256_cvtepi16_ph(mantissa_i16x16),
|
|
102
|
+
_mm256_castsi256_ph(_mm256_set1_epi16(0x0100))); // 1/65536
|
|
103
|
+
__m256i subnorm_signed_i16x16 = _mm256_or_si256(_mm256_castph_si256(subnorm_abs_f16x16), sign_i16x16);
|
|
104
|
+
return _mm256_castsi256_ph(_mm256_mask_blend_epi16(is_subnormal, normal_i16x16, subnorm_signed_i16x16));
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
/** @brief Convert 16x f16 → 16x e4m3 via bit manipulation (AVX-512 FP16).
|
|
108
|
+
* F16: S EEEEE MMMMMMMMMM (bias=15). E4M3: S EEEE MMM (bias=7).
|
|
109
|
+
* Handles normal, subnormal, and overflow cases with RNE rounding. */
|
|
110
|
+
NK_INTERNAL __m128i nk_f16x16_to_e4m3x16_sapphire_(__m256h f16x16) {
|
|
111
|
+
__m256i bits_i16x16 = _mm256_castph_si256(f16x16);
|
|
112
|
+
__m256i sign_i16x16 = _mm256_srli_epi16(bits_i16x16, 15);
|
|
113
|
+
__m256i f16_exp_i16x16 = _mm256_and_si256(_mm256_srli_epi16(bits_i16x16, 10), _mm256_set1_epi16(0x1F));
|
|
114
|
+
|
|
115
|
+
// Round mantissa from 10 to 3 bits using RNE (round to nearest, ties to even)
|
|
116
|
+
__m256i significand_i16x16 = _mm256_or_si256(_mm256_and_si256(bits_i16x16, _mm256_set1_epi16(0x03FF)),
|
|
117
|
+
_mm256_set1_epi16(0x0400)); // Add implicit 1 bit
|
|
118
|
+
__m256i lsb_i16x16 = _mm256_and_si256(_mm256_srli_epi16(significand_i16x16, 7), _mm256_set1_epi16(1));
|
|
119
|
+
__m256i rounding_bias_i16x16 = _mm256_add_epi16(_mm256_set1_epi16(0x003F), lsb_i16x16);
|
|
120
|
+
__m256i rounded_sig_i16x16 = _mm256_add_epi16(significand_i16x16, rounding_bias_i16x16);
|
|
121
|
+
__m256i carry_i16x16 = _mm256_srli_epi16(rounded_sig_i16x16, 11); // Carry into exponent if bit 11 set
|
|
122
|
+
__m256i f16_mantissa_i16x16 = _mm256_and_si256(_mm256_srli_epi16(rounded_sig_i16x16, 7), _mm256_set1_epi16(0x07));
|
|
123
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
124
|
+
f16_mantissa_i16x16 = _mm256_andnot_si256(_mm256_slli_epi16(carry_i16x16, 15), f16_mantissa_i16x16);
|
|
125
|
+
__m256i e4m3_exp_i16x16 = _mm256_sub_epi16(_mm256_add_epi16(f16_exp_i16x16, carry_i16x16), _mm256_set1_epi16(8));
|
|
126
|
+
|
|
127
|
+
// Detect underflow (exp <= 0) and overflow (exp > 15)
|
|
128
|
+
__mmask16 is_subnormal = _mm256_cmpgt_epi16_mask(_mm256_set1_epi16(1), e4m3_exp_i16x16);
|
|
129
|
+
__mmask16 overflow = _mm256_cmpgt_epi16_mask(e4m3_exp_i16x16, _mm256_set1_epi16(15));
|
|
130
|
+
|
|
131
|
+
// Normal path: clamp exp to [1,15]
|
|
132
|
+
// e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
|
|
133
|
+
__m256i clamped_exp_i16x16 = _mm256_max_epi16(e4m3_exp_i16x16, _mm256_set1_epi16(1));
|
|
134
|
+
clamped_exp_i16x16 = _mm256_min_epi16(clamped_exp_i16x16, _mm256_set1_epi16(15));
|
|
135
|
+
__mmask16 is_max_exp = _mm256_cmpeq_epi16_mask(clamped_exp_i16x16, _mm256_set1_epi16(15));
|
|
136
|
+
__m256i max_mantissa_i16x16 = _mm256_mask_blend_epi16(is_max_exp, _mm256_set1_epi16(7), _mm256_set1_epi16(6));
|
|
137
|
+
__m256i normal_mantissa_i16x16 = _mm256_min_epi16(f16_mantissa_i16x16, max_mantissa_i16x16);
|
|
138
|
+
normal_mantissa_i16x16 = _mm256_mask_blend_epi16(overflow, normal_mantissa_i16x16, _mm256_set1_epi16(0x06));
|
|
139
|
+
__m256i normal_e4m3_i16x16 = _mm256_or_si256(
|
|
140
|
+
_mm256_slli_epi16(sign_i16x16, 7),
|
|
141
|
+
_mm256_or_si256(_mm256_slli_epi16(clamped_exp_i16x16, 3), normal_mantissa_i16x16));
|
|
142
|
+
|
|
143
|
+
// Subnormal path: mantissa = round(abs_f16 * 512)
|
|
144
|
+
__m256h abs_f16x16 = _mm256_castsi256_ph(_mm256_and_si256(_mm256_castph_si256(f16x16), _mm256_set1_epi16(0x7FFF)));
|
|
145
|
+
__m256h scaled_f16x16 = _mm256_mul_ph(abs_f16x16, _mm256_castsi256_ph(_mm256_set1_epi16(0x6000))); // 512
|
|
146
|
+
__m256i subnorm_mantissa_i16x16 = _mm256_cvtph_epi16(scaled_f16x16);
|
|
147
|
+
__mmask16 promotes_to_normal = _mm256_cmpgt_epi16_mask(subnorm_mantissa_i16x16, _mm256_set1_epi16(7));
|
|
148
|
+
subnorm_mantissa_i16x16 = _mm256_min_epi16(subnorm_mantissa_i16x16, _mm256_set1_epi16(7));
|
|
149
|
+
subnorm_mantissa_i16x16 = _mm256_max_epi16(subnorm_mantissa_i16x16, _mm256_setzero_si256());
|
|
150
|
+
__m256i subnorm_e4m3_i16x16 = _mm256_or_si256(_mm256_slli_epi16(sign_i16x16, 7), subnorm_mantissa_i16x16);
|
|
151
|
+
__m256i first_normal_e4m3_i16x16 = _mm256_or_si256(_mm256_slli_epi16(sign_i16x16, 7), _mm256_set1_epi16(0x08));
|
|
152
|
+
subnorm_e4m3_i16x16 = _mm256_mask_blend_epi16(promotes_to_normal, subnorm_e4m3_i16x16, first_normal_e4m3_i16x16);
|
|
153
|
+
|
|
154
|
+
// Blend: use subnormal result when exp <= 0
|
|
155
|
+
__m256i e4m3_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_e4m3_i16x16, subnorm_e4m3_i16x16);
|
|
156
|
+
|
|
157
|
+
// Pack 16 i16s to 16 unsigned i8s via AVX-512BW
|
|
158
|
+
return _mm256_cvtepi16_epi8(e4m3_i16x16);
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
/** @brief Convert 16x f16 → 16x e5m2 via bit manipulation (AVX-512 FP16).
|
|
162
|
+
* F16: S EEEEE MMMMMMMMMM (bias=15). E5M2: S EEEEE MM (bias=15).
|
|
163
|
+
* Same exponent bias, so just round mantissa from 10 to 2 bits. */
|
|
164
|
+
NK_INTERNAL __m128i nk_f16x16_to_e5m2x16_sapphire_(__m256h f16x16) {
|
|
165
|
+
__m256i bits_i16x16 = _mm256_castph_si256(f16x16);
|
|
166
|
+
__m256i sign_i16x16 = _mm256_srli_epi16(bits_i16x16, 15);
|
|
167
|
+
__m256i f16_exp_i16x16 = _mm256_and_si256(_mm256_srli_epi16(bits_i16x16, 10), _mm256_set1_epi16(0x1F));
|
|
168
|
+
|
|
169
|
+
// Round mantissa from 10 to 2 bits using RNE (round to nearest, ties to even)
|
|
170
|
+
__m256i significand_i16x16 = _mm256_or_si256(_mm256_and_si256(bits_i16x16, _mm256_set1_epi16(0x03FF)),
|
|
171
|
+
_mm256_set1_epi16(0x0400)); // Add implicit 1 bit
|
|
172
|
+
__m256i lsb_i16x16 = _mm256_and_si256(_mm256_srli_epi16(significand_i16x16, 8), _mm256_set1_epi16(1));
|
|
173
|
+
__m256i rounding_bias_i16x16 = _mm256_add_epi16(_mm256_set1_epi16(0x007F), lsb_i16x16);
|
|
174
|
+
__m256i rounded_sig_i16x16 = _mm256_add_epi16(significand_i16x16, rounding_bias_i16x16);
|
|
175
|
+
__m256i carry_i16x16 = _mm256_srli_epi16(rounded_sig_i16x16, 11); // Carry into exponent if bit 11 set
|
|
176
|
+
__m256i f16_mantissa_i16x16 = _mm256_and_si256(_mm256_srli_epi16(rounded_sig_i16x16, 8), _mm256_set1_epi16(0x03));
|
|
177
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
178
|
+
f16_mantissa_i16x16 = _mm256_andnot_si256(_mm256_slli_epi16(carry_i16x16, 15), f16_mantissa_i16x16);
|
|
179
|
+
__m256i e5m2_exp_i16x16 = _mm256_add_epi16(f16_exp_i16x16, carry_i16x16);
|
|
180
|
+
|
|
181
|
+
// Detect subnormal (exp <= 0) and overflow (exp > 31)
|
|
182
|
+
__mmask16 is_subnormal = _mm256_cmpeq_epi16_mask(f16_exp_i16x16, _mm256_setzero_si256());
|
|
183
|
+
__mmask16 overflow = _mm256_cmpgt_epi16_mask(e5m2_exp_i16x16, _mm256_set1_epi16(31));
|
|
184
|
+
|
|
185
|
+
// Normal path: clamp exp to [1,31], on overflow return infinity
|
|
186
|
+
__m256i clamped_exp_i16x16 = _mm256_max_epi16(e5m2_exp_i16x16, _mm256_set1_epi16(1));
|
|
187
|
+
clamped_exp_i16x16 = _mm256_min_epi16(clamped_exp_i16x16, _mm256_set1_epi16(31));
|
|
188
|
+
__m256i normal_mantissa_i16x16 = _mm256_mask_blend_epi16(overflow, f16_mantissa_i16x16, _mm256_setzero_si256());
|
|
189
|
+
__m256i normal_e5m2_i16x16 = _mm256_or_si256(
|
|
190
|
+
_mm256_slli_epi16(sign_i16x16, 7),
|
|
191
|
+
_mm256_or_si256(_mm256_slli_epi16(clamped_exp_i16x16, 2), normal_mantissa_i16x16));
|
|
192
|
+
|
|
193
|
+
// Subnormal path: mantissa = round(abs_f16 * 65536)
|
|
194
|
+
__m256h abs_f16x16 = _mm256_castsi256_ph(_mm256_and_si256(_mm256_castph_si256(f16x16), _mm256_set1_epi16(0x7FFF)));
|
|
195
|
+
__m256h scaled_f16x16 = _mm256_mul_ph(abs_f16x16, _mm256_castsi256_ph(_mm256_set1_epi16(0x7C00))); // 65536 (inf)
|
|
196
|
+
__m256i subnorm_mantissa_i16x16 = _mm256_cvtph_epi16(scaled_f16x16);
|
|
197
|
+
__mmask16 promotes_to_normal = _mm256_cmpgt_epi16_mask(subnorm_mantissa_i16x16, _mm256_set1_epi16(3));
|
|
198
|
+
subnorm_mantissa_i16x16 = _mm256_min_epi16(subnorm_mantissa_i16x16, _mm256_set1_epi16(3));
|
|
199
|
+
subnorm_mantissa_i16x16 = _mm256_max_epi16(subnorm_mantissa_i16x16, _mm256_setzero_si256());
|
|
200
|
+
__m256i subnorm_e5m2_i16x16 = _mm256_or_si256(_mm256_slli_epi16(sign_i16x16, 7), subnorm_mantissa_i16x16);
|
|
201
|
+
__m256i first_normal_e5m2_i16x16 = _mm256_or_si256(_mm256_slli_epi16(sign_i16x16, 7), _mm256_set1_epi16(0x04));
|
|
202
|
+
subnorm_e5m2_i16x16 = _mm256_mask_blend_epi16(promotes_to_normal, subnorm_e5m2_i16x16, first_normal_e5m2_i16x16);
|
|
203
|
+
|
|
204
|
+
// Blend: use subnormal result when exp == 0
|
|
205
|
+
__m256i e5m2_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_e5m2_i16x16, subnorm_e5m2_i16x16);
|
|
206
|
+
|
|
207
|
+
// Pack 16 i16s to 16 unsigned i8s via AVX-512BW
|
|
208
|
+
return _mm256_cvtepi16_epi8(e5m2_i16x16);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
#pragma endregion - Vectorized Conversions
|
|
212
|
+
|
|
213
|
+
#pragma region - Public API
|
|
214
|
+
|
|
215
|
+
NK_PUBLIC void nk_cast_sapphire(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
216
|
+
// Group 1: Conversions to f16 (e4m3 → f16, e5m2 → f16)
|
|
217
|
+
if (to_type == nk_f16_k && (from_type == nk_e4m3_k || from_type == nk_e5m2_k)) {
|
|
218
|
+
nk_e4m3_t const *from_ptr = (nk_e4m3_t const *)from;
|
|
219
|
+
nk_f16_t *to_ptr = (nk_f16_t *)to;
|
|
220
|
+
for (nk_size_t idx = 0; idx < n; idx += 16) {
|
|
221
|
+
nk_size_t remaining = n - idx;
|
|
222
|
+
__mmask16 mask = (remaining >= 16) ? 0xFFFF : (unsigned short)_bzhi_u32(0xFFFF, (unsigned)remaining);
|
|
223
|
+
__m128i in_f8x16 = _mm_maskz_loadu_epi8(mask, from_ptr + idx);
|
|
224
|
+
__m256h out_f16x16 = (from_type == nk_e4m3_k) ? nk_e4m3x16_to_f16x16_sapphire_(in_f8x16)
|
|
225
|
+
: nk_e5m2x16_to_f16x16_sapphire_(in_f8x16);
|
|
226
|
+
_mm256_mask_storeu_epi16(to_ptr + idx, mask, _mm256_castph_si256(out_f16x16));
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// Group 2: Conversions from f16 (f16 → e4m3, f16 → e5m2)
|
|
231
|
+
else if (from_type == nk_f16_k && (to_type == nk_e4m3_k || to_type == nk_e5m2_k)) {
|
|
232
|
+
nk_f16_t const *from_ptr = (nk_f16_t const *)from;
|
|
233
|
+
nk_e4m3_t *to_ptr = (nk_e4m3_t *)to;
|
|
234
|
+
for (nk_size_t idx = 0; idx < n; idx += 16) {
|
|
235
|
+
nk_size_t remaining = n - idx;
|
|
236
|
+
__mmask16 mask = (remaining >= 16) ? 0xFFFF : (unsigned short)_bzhi_u32(0xFFFF, (unsigned)remaining);
|
|
237
|
+
__m256h in_f16x16 = _mm256_castsi256_ph(_mm256_maskz_loadu_epi16(mask, from_ptr + idx));
|
|
238
|
+
__m128i out_f8x16 = (to_type == nk_e4m3_k) ? nk_f16x16_to_e4m3x16_sapphire_(in_f16x16)
|
|
239
|
+
: nk_f16x16_to_e5m2x16_sapphire_(in_f16x16);
|
|
240
|
+
_mm_mask_storeu_epi8(to_ptr + idx, mask, out_f8x16);
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
// Default: delegate to Ice for all other conversions
|
|
245
|
+
else nk_cast_icelake(from, from_type, n, to, to_type);
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
#pragma endregion - Public API
|
|
249
|
+
|
|
250
|
+
#if defined(__clang__)
|
|
251
|
+
#pragma clang attribute pop
|
|
252
|
+
#elif defined(__GNUC__)
|
|
253
|
+
#pragma GCC pop_options
|
|
254
|
+
#endif
|
|
255
|
+
|
|
256
|
+
#if defined(__cplusplus)
|
|
257
|
+
} // extern "C"
|
|
258
|
+
#endif
|
|
259
|
+
|
|
260
|
+
#endif // NK_TARGET_SAPPHIRE
|
|
261
|
+
#endif // NK_TARGET_X86_
|
|
262
|
+
#endif // NK_CAST_SAPPHIRE_H
|