numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,856 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Type Conversions for Skylake.
|
|
3
|
+
* @file include/numkong/cast/skylake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/cast.h
|
|
8
|
+
*
|
|
9
|
+
* @section skylake_cast_instructions AVX-512 Conversion Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction SKL ICL Genoa
|
|
12
|
+
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p05 4cy @ p01
|
|
13
|
+
* _mm512_cvtps_ph VCVTPS2PH (YMM, ZMM, imm) 5cy @ p05 5cy @ p05 4cy @ p01
|
|
14
|
+
* _mm512_cvtps_epi32 VCVTPS2DQ (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
|
|
15
|
+
* _mm512_cvtepi32_ps VCVTDQ2PS (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
|
|
16
|
+
* _mm512_cvtepi32_epi16 VPMOVDW (YMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
|
|
17
|
+
* _mm512_cvtsepi32_epi8 VPMOVSDB (XMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
|
|
18
|
+
*
|
|
19
|
+
* F16 conversions use hardware F16C via VCVTPH2PS/VCVTPS2PH. BF16 lacks hardware support on Skylake,
|
|
20
|
+
* requiring emulation via VPMOVZXWD + VPSLLD for bf16-to-f32, achieving ~4cy total. FP8 (E4M3/E5M2)
|
|
21
|
+
* conversions use bit manipulation with VPTERNLOGD for sign/exp/mantissa composition.
|
|
22
|
+
*/
|
|
23
|
+
#ifndef NK_CAST_SKYLAKE_H
|
|
24
|
+
#define NK_CAST_SKYLAKE_H
|
|
25
|
+
|
|
26
|
+
#if NK_TARGET_X86_
|
|
27
|
+
#if NK_TARGET_SKYLAKE
|
|
28
|
+
|
|
29
|
+
#include "numkong/types.h"
|
|
30
|
+
#include "numkong/cast/serial.h" // `nk_dtype_bits`
|
|
31
|
+
|
|
32
|
+
#if defined(__cplusplus)
|
|
33
|
+
extern "C" {
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
#if defined(__clang__)
|
|
37
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
|
|
38
|
+
apply_to = function)
|
|
39
|
+
#elif defined(__GNUC__)
|
|
40
|
+
#pragma GCC push_options
|
|
41
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
#pragma region - Type Punned Loads and Stores
|
|
45
|
+
|
|
46
|
+
/** @brief Type-agnostic 512-bit full load (Skylake AVX-512). */
|
|
47
|
+
NK_INTERNAL void nk_load_b512_skylake_(void const *src, nk_b512_vec_t *dst) { dst->zmm = _mm512_loadu_si512(src); }
|
|
48
|
+
|
|
49
|
+
/** @brief Type-agnostic partial load for 64-bit elements (8 elements max) into 512-bit vector (Skylake AVX-512). */
|
|
50
|
+
NK_INTERNAL void nk_partial_load_b64x8_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
51
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)n);
|
|
52
|
+
dst->zmm = _mm512_maskz_loadu_epi64(mask, src);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
/** @brief Type-agnostic partial load for 32-bit elements (16 elements max) into 512-bit vector (Skylake AVX-512). */
|
|
56
|
+
NK_INTERNAL void nk_partial_load_b32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
57
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
58
|
+
dst->zmm = _mm512_maskz_loadu_epi32(mask, src);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
/** @brief Type-agnostic partial load for 16-bit elements (32 elements max) into 512-bit vector (Skylake AVX-512). */
|
|
62
|
+
NK_INTERNAL void nk_partial_load_b16x32_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
63
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
64
|
+
dst->zmm = _mm512_maskz_loadu_epi16(mask, src);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
/** @brief Partial load for 8-bit elements (64 max) into 512-bit vector (zeros in remaining slots). */
|
|
68
|
+
NK_INTERNAL void nk_partial_load_b8x64_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
69
|
+
__mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n);
|
|
70
|
+
dst->zmm = _mm512_maskz_loadu_epi8(mask, src);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/** @brief Partial load for 4-bit nibbles (128 max = 64 bytes) into 512-bit vector (Skylake AVX-512). */
|
|
74
|
+
NK_INTERNAL void nk_partial_load_b4x128_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
75
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
76
|
+
__mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n_bytes);
|
|
77
|
+
dst->zmm = _mm512_maskz_loadu_epi8(mask, src);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
/** @brief Type-agnostic partial load for 32-bit elements (8 elements max) into 256-bit vector (Skylake AVX-512). */
|
|
81
|
+
NK_INTERNAL void nk_partial_load_b32x8_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
82
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)n);
|
|
83
|
+
dst->ymm = _mm256_maskz_loadu_epi32(mask, src);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/** @brief Type-agnostic partial load for 16-bit elements (16 elements max) into 256-bit vector (Skylake AVX-512). */
|
|
87
|
+
NK_INTERNAL void nk_partial_load_b16x16_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
88
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
89
|
+
dst->ymm = _mm256_maskz_loadu_epi16(mask, src);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
/** @brief Type-agnostic partial load for 8-bit elements (16 elements max) into 128-bit vector (Skylake AVX-512). */
|
|
93
|
+
NK_INTERNAL void nk_partial_load_b8x16_skylake_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
94
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
95
|
+
dst->xmm = _mm_maskz_loadu_epi8(mask, src);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
/** @brief Partial load for 1-bit elements (512 max bits = 64 bytes) into 512-bit vector (Skylake AVX-512).
|
|
99
|
+
* Wrapper that converts bit count to byte count and delegates to byte-level masked load. */
|
|
100
|
+
NK_INTERNAL void nk_partial_load_b1x512_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n_bits) {
|
|
101
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
|
|
102
|
+
nk_partial_load_b8x64_skylake_(src, dst, n_bytes);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
/** @brief Type-agnostic partial load for 32-bit elements (4 elements max) into 128-bit vector (Skylake AVX-512). */
|
|
106
|
+
NK_INTERNAL void nk_partial_load_b32x4_skylake_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
107
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xF, (unsigned int)n);
|
|
108
|
+
dst->xmm = _mm_maskz_loadu_epi32(mask, src);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
/** @brief Type-agnostic partial load for 64-bit elements (4 elements max) into 256-bit vector (Skylake AVX-512). */
|
|
112
|
+
NK_INTERNAL void nk_partial_load_b64x4_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
113
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xF, (unsigned int)n);
|
|
114
|
+
dst->ymm = _mm256_maskz_loadu_epi64(mask, src);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
/** @brief Type-agnostic partial store for 32-bit elements (16 elements max) from 512-bit vector (Skylake AVX-512). */
|
|
118
|
+
NK_INTERNAL void nk_partial_store_b32x16_skylake_(nk_b512_vec_t const *src, void *dst, nk_size_t n) {
|
|
119
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
120
|
+
_mm512_mask_storeu_epi32(dst, mask, src->zmm);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
/** @brief Type-agnostic partial store for 32-bit elements (4 elements max) from 128-bit vector (Skylake AVX-512). */
|
|
124
|
+
NK_INTERNAL void nk_partial_store_b32x4_skylake_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
125
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xF, (unsigned int)n);
|
|
126
|
+
_mm_mask_storeu_epi32(dst, mask, src->xmm);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
/** @brief Type-agnostic partial store for 64-bit elements (4 elements max) from 256-bit vector (Skylake AVX-512). */
|
|
130
|
+
NK_INTERNAL void nk_partial_store_b64x4_skylake_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
|
|
131
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xF, (unsigned int)n);
|
|
132
|
+
_mm256_mask_storeu_epi64(dst, mask, src->ymm);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
#pragma endregion - Type Punned Loads and Stores
|
|
136
|
+
|
|
137
|
+
#pragma region - Vectorized Conversions
|
|
138
|
+
|
|
139
|
+
/** @brief Convert 16x bf16 → 16x f32 (Skylake AVX-512). */
|
|
140
|
+
NK_INTERNAL __m512 nk_bf16x16_to_f32x16_skylake_(__m256i a) {
|
|
141
|
+
// Upcasting from `bf16` to `f32` is done by shifting the `bf16` values by 16 bits to the left, like:
|
|
142
|
+
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
/** @brief Convert 16x f32 → 16x bf16 (Skylake AVX-512). */
|
|
146
|
+
NK_INTERNAL __m256i nk_f32x16_to_bf16x16_skylake_(__m512 a) {
|
|
147
|
+
// Round-to-nearest-even: add (0x7FFF + lsb) to match hardware BF16 behavior
|
|
148
|
+
__m512i bits = _mm512_castps_si512(a);
|
|
149
|
+
__m512i lsb = _mm512_and_si512(_mm512_srli_epi32(bits, 16), _mm512_set1_epi32(1));
|
|
150
|
+
__m512i rounded = _mm512_add_epi32(bits, _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb));
|
|
151
|
+
__m512i x = _mm512_srli_epi32(rounded, 16);
|
|
152
|
+
return _mm512_cvtepi32_epi16(x);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
/** @brief Convert 16x e4m3 → 16x f32 via bit manipulation (AVX-512).
|
|
156
|
+
* E4M3 format: S EEEE MMM (bias=7). F32: sign<<31, (exp+120)<<23, mantissa<<20.
|
|
157
|
+
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻⁷⁾ × 2⁻³ = mantissa ÷ 512. */
|
|
158
|
+
NK_INTERNAL __m512 nk_e4m3x16_to_f32x16_skylake_(__m128i e4m3_i8x16) {
|
|
159
|
+
__m512i e4m3_i32x16 = _mm512_cvtepu8_epi32(e4m3_i8x16);
|
|
160
|
+
|
|
161
|
+
// Extract fields
|
|
162
|
+
__m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e4m3_i32x16, 3), _mm512_set1_epi32(0x0F));
|
|
163
|
+
__m512i mantissa_i32x16 = _mm512_and_si512(e4m3_i32x16, _mm512_set1_epi32(0x07));
|
|
164
|
+
__m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e4m3_i32x16, 7), 31);
|
|
165
|
+
|
|
166
|
+
// Normal path: sign | ((exp+120)<<23) | (mantissa<<20)
|
|
167
|
+
__m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(120)), 23);
|
|
168
|
+
__m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 20);
|
|
169
|
+
__m512 result_f32x16 = _mm512_castsi512_ps(
|
|
170
|
+
_mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
|
|
171
|
+
|
|
172
|
+
// Subnormal fix: for exp==0 lanes, replace with (mantissa / 512) | sign using masked OR
|
|
173
|
+
__mmask16 is_subnormal = _mm512_testn_epi32_mask(e4m3_i32x16, _mm512_set1_epi32(0x78));
|
|
174
|
+
__m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 512.0f));
|
|
175
|
+
result_f32x16 = _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16,
|
|
176
|
+
_mm512_castsi512_ps(sign_i32x16));
|
|
177
|
+
|
|
178
|
+
// NaN path: E4M3FN has NaN only when exp=15 AND mant=7 (0x7F or 0xFF)
|
|
179
|
+
__mmask16 is_nan = _mm512_mask_cmpeq_epi32_mask( //
|
|
180
|
+
_mm512_cmpeq_epi32_mask(exp_i32x16, _mm512_set1_epi32(15)), //
|
|
181
|
+
mantissa_i32x16, _mm512_set1_epi32(7)); //
|
|
182
|
+
__m512i nan_bits = _mm512_or_si512(sign_i32x16, _mm512_set1_epi32(0x7FC00000)); // F32 quiet NaN
|
|
183
|
+
return _mm512_mask_blend_ps(is_nan, result_f32x16, _mm512_castsi512_ps(nan_bits));
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
/** @brief Convert 16x e5m2 → 16x f32 via bit manipulation (AVX-512).
|
|
187
|
+
* E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mantissa<<21.
|
|
188
|
+
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁵⁾ × 2⁻² = mantissa ÷ 65536. */
|
|
189
|
+
NK_INTERNAL __m512 nk_e5m2x16_to_f32x16_skylake_(__m128i e5m2_i8x16) {
|
|
190
|
+
__m512i e5m2_i32x16 = _mm512_cvtepu8_epi32(e5m2_i8x16);
|
|
191
|
+
|
|
192
|
+
// Extract fields
|
|
193
|
+
__m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e5m2_i32x16, 2), _mm512_set1_epi32(0x1F));
|
|
194
|
+
__m512i mantissa_i32x16 = _mm512_and_si512(e5m2_i32x16, _mm512_set1_epi32(0x03));
|
|
195
|
+
__m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e5m2_i32x16, 7), 31);
|
|
196
|
+
|
|
197
|
+
// Normal path: sign | ((exp+112)<<23) | (mantissa<<21)
|
|
198
|
+
__m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(112)), 23);
|
|
199
|
+
__m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 21);
|
|
200
|
+
__m512 result_f32x16 = _mm512_castsi512_ps(
|
|
201
|
+
_mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
|
|
202
|
+
|
|
203
|
+
// Subnormal fix: for exp==0 lanes, replace with (mantissa / 65536) | sign using masked OR
|
|
204
|
+
__mmask16 is_subnormal = _mm512_testn_epi32_mask(e5m2_i32x16, _mm512_set1_epi32(0x7C));
|
|
205
|
+
__m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 65536.0f));
|
|
206
|
+
return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
/** @brief Convert 16x e2m3 → 16x f32 via bit manipulation (AVX-512).
|
|
210
|
+
* E2M3 format: S EE MMM (bias=1, only 6 bits used). F32: sign<<31, (exp+126)<<23, mantissa<<20.
|
|
211
|
+
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁾ × 2⁻³ = mantissa ÷ 8. */
|
|
212
|
+
NK_INTERNAL __m512 nk_e2m3x16_to_f32x16_skylake_(__m128i e2m3_i8x16) {
|
|
213
|
+
__m512i e2m3_i32x16 = _mm512_cvtepu8_epi32(e2m3_i8x16);
|
|
214
|
+
|
|
215
|
+
// Extract fields (only 6 bits used: S EE MMM)
|
|
216
|
+
__m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e2m3_i32x16, 3), _mm512_set1_epi32(0x03));
|
|
217
|
+
__m512i mantissa_i32x16 = _mm512_and_si512(e2m3_i32x16, _mm512_set1_epi32(0x07));
|
|
218
|
+
__m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e2m3_i32x16, 5), 31);
|
|
219
|
+
|
|
220
|
+
// Normal path: sign | ((exp+126)<<23) | (mantissa<<20)
|
|
221
|
+
__m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(126)), 23);
|
|
222
|
+
__m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 20);
|
|
223
|
+
__m512 result_f32x16 = _mm512_castsi512_ps(
|
|
224
|
+
_mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
|
|
225
|
+
|
|
226
|
+
// Subnormal fix: for exp==0 lanes, replace with (mantissa / 8) | sign using masked OR
|
|
227
|
+
__mmask16 is_subnormal = _mm512_testn_epi32_mask(e2m3_i32x16, _mm512_set1_epi32(0x18));
|
|
228
|
+
__m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 8.0f));
|
|
229
|
+
return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
/** @brief Convert 16x e3m2 → 16x f32 via bit manipulation (AVX-512).
|
|
233
|
+
* E3M2 format: S EEE MM (bias=3, only 6 bits used). F32: sign<<31, (exp+124)<<23, mantissa<<21.
|
|
234
|
+
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻³⁾ × 2⁻² = mantissa ÷ 16. */
|
|
235
|
+
NK_INTERNAL __m512 nk_e3m2x16_to_f32x16_skylake_(__m128i e3m2_i8x16) {
|
|
236
|
+
__m512i e3m2_i32x16 = _mm512_cvtepu8_epi32(e3m2_i8x16);
|
|
237
|
+
|
|
238
|
+
// Extract fields (only 6 bits used: S EEE MM)
|
|
239
|
+
__m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e3m2_i32x16, 2), _mm512_set1_epi32(0x07));
|
|
240
|
+
__m512i mantissa_i32x16 = _mm512_and_si512(e3m2_i32x16, _mm512_set1_epi32(0x03));
|
|
241
|
+
__m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e3m2_i32x16, 5), 31);
|
|
242
|
+
|
|
243
|
+
// Normal path: sign | ((exp+124)<<23) | (mantissa<<21)
|
|
244
|
+
__m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(124)), 23);
|
|
245
|
+
__m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 21);
|
|
246
|
+
__m512 result_f32x16 = _mm512_castsi512_ps(
|
|
247
|
+
_mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
|
|
248
|
+
|
|
249
|
+
// Subnormal fix: for exp==0 lanes, replace with (mantissa / 16) | sign using masked OR
|
|
250
|
+
__mmask16 is_subnormal = _mm512_testn_epi32_mask(e3m2_i32x16, _mm512_set1_epi32(0x1C));
|
|
251
|
+
__m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 16.0f));
|
|
252
|
+
return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
/** @brief Convert 16x f32 → 16x e2m3 via bit manipulation (AVX-512).
|
|
256
|
+
* E2M3 format: S EE MMM (bias=1). Handles normal, subnormal, and overflow cases.
|
|
257
|
+
* Subnormals (f32_exp ≤ 126): mantissa = round(abs_f32 * 8), clamped to [0,7]. */
|
|
258
|
+
NK_INTERNAL __m128i nk_f32x16_to_e2m3x16_skylake_(__m512 f32x16) {
|
|
259
|
+
__m512i bits_i32x16 = _mm512_castps_si512(f32x16);
|
|
260
|
+
__m512i sign_i32x16 = _mm512_srli_epi32(bits_i32x16, 31);
|
|
261
|
+
__m512i f32_exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(bits_i32x16, 23), _mm512_set1_epi32(0xFF));
|
|
262
|
+
|
|
263
|
+
// Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
|
|
264
|
+
__m512i significand_i32x16 = _mm512_or_si512(_mm512_and_si512(bits_i32x16, _mm512_set1_epi32(0x007FFFFF)),
|
|
265
|
+
_mm512_set1_epi32(0x00800000)); // (a & mask) | implicit_one
|
|
266
|
+
__m512i lsb_i32x16 = _mm512_and_si512(_mm512_srli_epi32(significand_i32x16, 20), _mm512_set1_epi32(1));
|
|
267
|
+
__m512i rounding_bias_i32x16 = _mm512_add_epi32(_mm512_set1_epi32(0x0007FFFF), lsb_i32x16);
|
|
268
|
+
__m512i rounded_sig_i32x16 = _mm512_add_epi32(significand_i32x16, rounding_bias_i32x16);
|
|
269
|
+
__m512i carry_i32x16 = _mm512_srli_epi32(rounded_sig_i32x16, 24); // Carry into exponent if bit 24 set
|
|
270
|
+
__m512i f32_mantissa_i32x16 = _mm512_and_si512(_mm512_srli_epi32(rounded_sig_i32x16, 20), _mm512_set1_epi32(0x07));
|
|
271
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
272
|
+
f32_mantissa_i32x16 = _mm512_andnot_si512(_mm512_slli_epi32(carry_i32x16, 31), f32_mantissa_i32x16);
|
|
273
|
+
__m512i e2m3_exp_i32x16 = _mm512_sub_epi32(_mm512_add_epi32(f32_exp_i32x16, carry_i32x16), _mm512_set1_epi32(126));
|
|
274
|
+
|
|
275
|
+
// Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 3)
|
|
276
|
+
__mmask16 is_subnormal = _mm512_cmpgt_epi32_mask(_mm512_set1_epi32(1), e2m3_exp_i32x16);
|
|
277
|
+
__mmask16 overflow = _mm512_cmpgt_epi32_mask(e2m3_exp_i32x16, _mm512_set1_epi32(3));
|
|
278
|
+
|
|
279
|
+
// Normal path: clamp exp to [1,3], extract mantissa bits
|
|
280
|
+
__m512i clamped_exp_i32x16 = _mm512_max_epi32(e2m3_exp_i32x16, _mm512_set1_epi32(1));
|
|
281
|
+
clamped_exp_i32x16 = _mm512_min_epi32(clamped_exp_i32x16, _mm512_set1_epi32(3));
|
|
282
|
+
__m512i normal_mantissa_i32x16 = _mm512_mask_blend_epi32(overflow, f32_mantissa_i32x16, _mm512_set1_epi32(0x07));
|
|
283
|
+
__m512i normal_e2m3_i32x16 = _mm512_ternarylogic_epi32(_mm512_slli_epi32(sign_i32x16, 5),
|
|
284
|
+
_mm512_slli_epi32(clamped_exp_i32x16, 3),
|
|
285
|
+
normal_mantissa_i32x16, 0xFE); // a | b | c
|
|
286
|
+
|
|
287
|
+
// Subnormal path: mantissa = round(abs_f32 * 8)
|
|
288
|
+
// If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
|
|
289
|
+
__m512 abs_f32x16 = _mm512_and_ps(f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
290
|
+
__m512 scaled_f32x16 = _mm512_mul_ps(abs_f32x16, _mm512_set1_ps(8.0f));
|
|
291
|
+
__m512i subnorm_mantissa_i32x16 = _mm512_cvtps_epi32(scaled_f32x16);
|
|
292
|
+
__mmask16 promotes_to_normal = _mm512_cmpgt_epi32_mask(subnorm_mantissa_i32x16, _mm512_set1_epi32(7));
|
|
293
|
+
subnorm_mantissa_i32x16 = _mm512_min_epi32(subnorm_mantissa_i32x16, _mm512_set1_epi32(7));
|
|
294
|
+
subnorm_mantissa_i32x16 = _mm512_max_epi32(subnorm_mantissa_i32x16, _mm512_setzero_si512());
|
|
295
|
+
__m512i subnorm_e2m3_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 5), subnorm_mantissa_i32x16);
|
|
296
|
+
// When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
|
|
297
|
+
__m512i first_normal_e2m3_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 5), _mm512_set1_epi32(0x08));
|
|
298
|
+
subnorm_e2m3_i32x16 = _mm512_mask_blend_epi32(promotes_to_normal, subnorm_e2m3_i32x16, first_normal_e2m3_i32x16);
|
|
299
|
+
|
|
300
|
+
// Blend: use subnormal result when exp <= 0, else normal
|
|
301
|
+
__m512i e2m3_i32x16 = _mm512_mask_blend_epi32(is_subnormal, normal_e2m3_i32x16, subnorm_e2m3_i32x16);
|
|
302
|
+
|
|
303
|
+
// Pack 16 i32s to 16 unsigned i8s via AVX-512 cvtepi32_epi8
|
|
304
|
+
return _mm512_cvtepi32_epi8(e2m3_i32x16);
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
/** @brief Convert 16x f32 → 16x e3m2 via bit manipulation (AVX-512).
|
|
308
|
+
* E3M2 format: S EEE MM (bias=3). Handles normal, subnormal, and overflow cases.
|
|
309
|
+
* Subnormals (f32_exp ≤ 124): mantissa = round(abs_f32 * 16), clamped to [0,3]. */
|
|
310
|
+
NK_INTERNAL __m128i nk_f32x16_to_e3m2x16_skylake_(__m512 f32x16) {
|
|
311
|
+
__m512i bits_i32x16 = _mm512_castps_si512(f32x16);
|
|
312
|
+
__m512i sign_i32x16 = _mm512_srli_epi32(bits_i32x16, 31);
|
|
313
|
+
__m512i f32_exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(bits_i32x16, 23), _mm512_set1_epi32(0xFF));
|
|
314
|
+
|
|
315
|
+
// Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
|
|
316
|
+
__m512i significand_i32x16 = _mm512_or_si512(_mm512_and_si512(bits_i32x16, _mm512_set1_epi32(0x007FFFFF)),
|
|
317
|
+
_mm512_set1_epi32(0x00800000)); // (a & mask) | implicit_one
|
|
318
|
+
__m512i lsb_i32x16 = _mm512_and_si512(_mm512_srli_epi32(significand_i32x16, 21), _mm512_set1_epi32(1));
|
|
319
|
+
__m512i rounding_bias_i32x16 = _mm512_add_epi32(_mm512_set1_epi32(0x000FFFFF), lsb_i32x16);
|
|
320
|
+
__m512i rounded_sig_i32x16 = _mm512_add_epi32(significand_i32x16, rounding_bias_i32x16);
|
|
321
|
+
__m512i carry_i32x16 = _mm512_srli_epi32(rounded_sig_i32x16, 24); // Carry into exponent if bit 24 set
|
|
322
|
+
__m512i f32_mantissa_i32x16 = _mm512_and_si512(_mm512_srli_epi32(rounded_sig_i32x16, 21), _mm512_set1_epi32(0x03));
|
|
323
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
324
|
+
f32_mantissa_i32x16 = _mm512_andnot_si512(_mm512_slli_epi32(carry_i32x16, 31), f32_mantissa_i32x16);
|
|
325
|
+
__m512i e3m2_exp_i32x16 = _mm512_sub_epi32(_mm512_add_epi32(f32_exp_i32x16, carry_i32x16), _mm512_set1_epi32(124));
|
|
326
|
+
|
|
327
|
+
// Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 7)
|
|
328
|
+
__mmask16 is_subnormal = _mm512_cmpgt_epi32_mask(_mm512_set1_epi32(1), e3m2_exp_i32x16);
|
|
329
|
+
__mmask16 overflow = _mm512_cmpgt_epi32_mask(e3m2_exp_i32x16, _mm512_set1_epi32(7));
|
|
330
|
+
|
|
331
|
+
// Normal path: clamp exp to [1,7], extract mantissa bits
|
|
332
|
+
__m512i clamped_exp_i32x16 = _mm512_max_epi32(e3m2_exp_i32x16, _mm512_set1_epi32(1));
|
|
333
|
+
clamped_exp_i32x16 = _mm512_min_epi32(clamped_exp_i32x16, _mm512_set1_epi32(7));
|
|
334
|
+
__m512i normal_mantissa_i32x16 = _mm512_mask_blend_epi32(overflow, f32_mantissa_i32x16, _mm512_set1_epi32(0x03));
|
|
335
|
+
__m512i normal_e3m2_i32x16 = _mm512_ternarylogic_epi32(_mm512_slli_epi32(sign_i32x16, 5),
|
|
336
|
+
_mm512_slli_epi32(clamped_exp_i32x16, 2),
|
|
337
|
+
normal_mantissa_i32x16, 0xFE); // a | b | c
|
|
338
|
+
|
|
339
|
+
// Subnormal path: mantissa = round(abs_f32 * 16)
|
|
340
|
+
// If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
|
|
341
|
+
__m512 abs_f32x16 = _mm512_and_ps(f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
342
|
+
__m512 scaled_f32x16 = _mm512_mul_ps(abs_f32x16, _mm512_set1_ps(16.0f));
|
|
343
|
+
__m512i subnorm_mantissa_i32x16 = _mm512_cvtps_epi32(scaled_f32x16);
|
|
344
|
+
__mmask16 promotes_to_normal = _mm512_cmpgt_epi32_mask(subnorm_mantissa_i32x16, _mm512_set1_epi32(3));
|
|
345
|
+
subnorm_mantissa_i32x16 = _mm512_min_epi32(subnorm_mantissa_i32x16, _mm512_set1_epi32(3));
|
|
346
|
+
subnorm_mantissa_i32x16 = _mm512_max_epi32(subnorm_mantissa_i32x16, _mm512_setzero_si512());
|
|
347
|
+
__m512i subnorm_e3m2_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 5), subnorm_mantissa_i32x16);
|
|
348
|
+
// When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
|
|
349
|
+
__m512i first_normal_e3m2_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 5), _mm512_set1_epi32(0x04));
|
|
350
|
+
subnorm_e3m2_i32x16 = _mm512_mask_blend_epi32(promotes_to_normal, subnorm_e3m2_i32x16, first_normal_e3m2_i32x16);
|
|
351
|
+
|
|
352
|
+
// Blend: use subnormal result when exp <= 0, else normal
|
|
353
|
+
__m512i e3m2_i32x16 = _mm512_mask_blend_epi32(is_subnormal, normal_e3m2_i32x16, subnorm_e3m2_i32x16);
|
|
354
|
+
|
|
355
|
+
// Pack 16 i32s to 16 unsigned i8s via AVX-512 cvtepi32_epi8
|
|
356
|
+
return _mm512_cvtepi32_epi8(e3m2_i32x16);
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
/** @brief Convert 16x f32 → 16x e4m3 via bit manipulation (AVX-512).
|
|
360
|
+
* E4M3 format: S EEEE MMM (bias=7). Handles normal, subnormal, and overflow cases.
|
|
361
|
+
* Subnormals (f32_exp ≤ 120): mantissa = round(abs_f32 * 512), clamped to [0,7]. */
|
|
362
|
+
NK_INTERNAL __m128i nk_f32x16_to_e4m3x16_skylake_(__m512 f32x16) {
|
|
363
|
+
__m512i bits_i32x16 = _mm512_castps_si512(f32x16);
|
|
364
|
+
__m512i sign_i32x16 = _mm512_srli_epi32(bits_i32x16, 31);
|
|
365
|
+
__m512i f32_exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(bits_i32x16, 23), _mm512_set1_epi32(0xFF));
|
|
366
|
+
|
|
367
|
+
// Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
|
|
368
|
+
// RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
|
|
369
|
+
__m512i significand_i32x16 = _mm512_or_si512(_mm512_and_si512(bits_i32x16, _mm512_set1_epi32(0x007FFFFF)),
|
|
370
|
+
_mm512_set1_epi32(0x00800000)); // (a & mask) | implicit_one
|
|
371
|
+
__m512i lsb_i32x16 = _mm512_and_si512(_mm512_srli_epi32(significand_i32x16, 20), _mm512_set1_epi32(1));
|
|
372
|
+
__m512i rounding_bias_i32x16 = _mm512_add_epi32(_mm512_set1_epi32(0x0007FFFF), lsb_i32x16);
|
|
373
|
+
__m512i rounded_sig_i32x16 = _mm512_add_epi32(significand_i32x16, rounding_bias_i32x16);
|
|
374
|
+
__m512i carry_i32x16 = _mm512_srli_epi32(rounded_sig_i32x16, 24); // Carry into exponent if bit 24 set
|
|
375
|
+
__m512i f32_mantissa_i32x16 = _mm512_and_si512(_mm512_srli_epi32(rounded_sig_i32x16, 20), _mm512_set1_epi32(0x07));
|
|
376
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
377
|
+
f32_mantissa_i32x16 = _mm512_andnot_si512(_mm512_slli_epi32(carry_i32x16, 31), f32_mantissa_i32x16);
|
|
378
|
+
__m512i e4m3_exp_i32x16 = _mm512_sub_epi32(_mm512_add_epi32(f32_exp_i32x16, carry_i32x16), _mm512_set1_epi32(120));
|
|
379
|
+
|
|
380
|
+
// Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 15)
|
|
381
|
+
__mmask16 is_subnormal = _mm512_cmpgt_epi32_mask(_mm512_set1_epi32(1), e4m3_exp_i32x16);
|
|
382
|
+
__mmask16 overflow = _mm512_cmpgt_epi32_mask(e4m3_exp_i32x16, _mm512_set1_epi32(15));
|
|
383
|
+
|
|
384
|
+
// Normal path: clamp exp to [1,15], extract mantissa bits
|
|
385
|
+
// e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
|
|
386
|
+
__m512i clamped_exp_i32x16 = _mm512_max_epi32(e4m3_exp_i32x16, _mm512_set1_epi32(1));
|
|
387
|
+
clamped_exp_i32x16 = _mm512_min_epi32(clamped_exp_i32x16, _mm512_set1_epi32(15));
|
|
388
|
+
__mmask16 is_max_exp = _mm512_cmpeq_epi32_mask(clamped_exp_i32x16, _mm512_set1_epi32(15));
|
|
389
|
+
__m512i max_mantissa_i32x16 = _mm512_mask_blend_epi32(is_max_exp, _mm512_set1_epi32(7), _mm512_set1_epi32(6));
|
|
390
|
+
__m512i normal_mantissa_i32x16 = _mm512_min_epi32(f32_mantissa_i32x16, max_mantissa_i32x16);
|
|
391
|
+
normal_mantissa_i32x16 = _mm512_mask_blend_epi32(overflow, normal_mantissa_i32x16, _mm512_set1_epi32(0x06));
|
|
392
|
+
__m512i normal_e4m3_i32x16 = _mm512_ternarylogic_epi32(_mm512_slli_epi32(sign_i32x16, 7),
|
|
393
|
+
_mm512_slli_epi32(clamped_exp_i32x16, 3),
|
|
394
|
+
normal_mantissa_i32x16, 0xFE); // a | b | c
|
|
395
|
+
|
|
396
|
+
// Subnormal path: mantissa = round(abs_f32 * 512)
|
|
397
|
+
// If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
|
|
398
|
+
__m512 abs_f32x16 = _mm512_and_ps(f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
399
|
+
__m512 scaled_f32x16 = _mm512_mul_ps(abs_f32x16, _mm512_set1_ps(512.0f));
|
|
400
|
+
__m512i subnorm_mantissa_i32x16 = _mm512_cvtps_epi32(scaled_f32x16);
|
|
401
|
+
__mmask16 promotes_to_normal = _mm512_cmpgt_epi32_mask(subnorm_mantissa_i32x16, _mm512_set1_epi32(7));
|
|
402
|
+
subnorm_mantissa_i32x16 = _mm512_min_epi32(subnorm_mantissa_i32x16, _mm512_set1_epi32(7));
|
|
403
|
+
subnorm_mantissa_i32x16 = _mm512_max_epi32(subnorm_mantissa_i32x16, _mm512_setzero_si512());
|
|
404
|
+
__m512i subnorm_e4m3_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 7), subnorm_mantissa_i32x16);
|
|
405
|
+
// When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
|
|
406
|
+
__m512i first_normal_e4m3_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 7), _mm512_set1_epi32(0x08));
|
|
407
|
+
subnorm_e4m3_i32x16 = _mm512_mask_blend_epi32(promotes_to_normal, subnorm_e4m3_i32x16, first_normal_e4m3_i32x16);
|
|
408
|
+
|
|
409
|
+
// Blend: use subnormal result when exp <= 0, else normal
|
|
410
|
+
__m512i e4m3_i32x16 = _mm512_mask_blend_epi32(is_subnormal, normal_e4m3_i32x16, subnorm_e4m3_i32x16);
|
|
411
|
+
|
|
412
|
+
// Pack 16 i32s to 16 unsigned i8s via AVX-512 cvtepi32_epi8
|
|
413
|
+
return _mm512_cvtepi32_epi8(e4m3_i32x16);
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
/** @brief Convert 16x f32 → 16x e5m2 via bit manipulation (AVX-512).
|
|
417
|
+
* E5M2 format: S EEEEE MM (bias=15). Handles normal, subnormal, and overflow cases.
|
|
418
|
+
* Uses RNE (round to nearest even) for mantissa rounding. */
|
|
419
|
+
NK_INTERNAL __m128i nk_f32x16_to_e5m2x16_skylake_(__m512 f32x16) {
|
|
420
|
+
__m512i bits_i32x16 = _mm512_castps_si512(f32x16);
|
|
421
|
+
__m512i sign_i32x16 = _mm512_srli_epi32(bits_i32x16, 31);
|
|
422
|
+
__m512i f32_exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(bits_i32x16, 23), _mm512_set1_epi32(0xFF));
|
|
423
|
+
|
|
424
|
+
// Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
|
|
425
|
+
// RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
|
|
426
|
+
__m512i significand_i32x16 = _mm512_or_si512(_mm512_and_si512(bits_i32x16, _mm512_set1_epi32(0x007FFFFF)),
|
|
427
|
+
_mm512_set1_epi32(0x00800000)); // (a & mask) | implicit_one
|
|
428
|
+
__m512i lsb_i32x16 = _mm512_and_si512(_mm512_srli_epi32(significand_i32x16, 21), _mm512_set1_epi32(1));
|
|
429
|
+
__m512i rounding_bias_i32x16 = _mm512_add_epi32(_mm512_set1_epi32(0x000FFFFF), lsb_i32x16); // half = 0x100000
|
|
430
|
+
__m512i rounded_sig_i32x16 = _mm512_add_epi32(significand_i32x16, rounding_bias_i32x16);
|
|
431
|
+
__m512i carry_i32x16 = _mm512_srli_epi32(rounded_sig_i32x16, 24); // Carry into exponent if bit 24 set
|
|
432
|
+
__m512i f32_mantissa_i32x16 = _mm512_and_si512(_mm512_srli_epi32(rounded_sig_i32x16, 21), _mm512_set1_epi32(0x03));
|
|
433
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
434
|
+
f32_mantissa_i32x16 = _mm512_andnot_si512(_mm512_slli_epi32(carry_i32x16, 31), f32_mantissa_i32x16);
|
|
435
|
+
__m512i e5m2_exp_i32x16 = _mm512_sub_epi32(_mm512_add_epi32(f32_exp_i32x16, carry_i32x16), _mm512_set1_epi32(112));
|
|
436
|
+
|
|
437
|
+
// Detect subnormal (exp <= 0) and overflow (exp > 31)
|
|
438
|
+
__mmask16 is_subnormal = _mm512_cmpgt_epi32_mask(_mm512_set1_epi32(1), e5m2_exp_i32x16);
|
|
439
|
+
__mmask16 overflow = _mm512_cmpgt_epi32_mask(e5m2_exp_i32x16, _mm512_set1_epi32(31));
|
|
440
|
+
|
|
441
|
+
// Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mantissa=0 = 0x7C)
|
|
442
|
+
__m512i clamped_exp_i32x16 = _mm512_max_epi32(e5m2_exp_i32x16, _mm512_set1_epi32(1));
|
|
443
|
+
clamped_exp_i32x16 = _mm512_min_epi32(clamped_exp_i32x16, _mm512_set1_epi32(31));
|
|
444
|
+
__m512i normal_mantissa_i32x16 = _mm512_mask_blend_epi32(overflow, f32_mantissa_i32x16, _mm512_setzero_si512());
|
|
445
|
+
__m512i normal_e5m2_i32x16 = _mm512_ternarylogic_epi32(_mm512_slli_epi32(sign_i32x16, 7),
|
|
446
|
+
_mm512_slli_epi32(clamped_exp_i32x16, 2),
|
|
447
|
+
normal_mantissa_i32x16, 0xFE); // a | b | c
|
|
448
|
+
|
|
449
|
+
// Subnormal path: mantissa = round(abs_f32 * 65536)
|
|
450
|
+
// If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
|
|
451
|
+
__m512 abs_f32x16 = _mm512_and_ps(f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
452
|
+
__m512 scaled_f32x16 = _mm512_mul_ps(abs_f32x16, _mm512_set1_ps(65536.0f));
|
|
453
|
+
__m512i subnorm_mantissa_i32x16 = _mm512_cvtps_epi32(scaled_f32x16);
|
|
454
|
+
__mmask16 promotes_to_normal = _mm512_cmpgt_epi32_mask(subnorm_mantissa_i32x16, _mm512_set1_epi32(3));
|
|
455
|
+
subnorm_mantissa_i32x16 = _mm512_min_epi32(subnorm_mantissa_i32x16, _mm512_set1_epi32(3));
|
|
456
|
+
subnorm_mantissa_i32x16 = _mm512_max_epi32(subnorm_mantissa_i32x16, _mm512_setzero_si512());
|
|
457
|
+
__m512i subnorm_e5m2_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 7), subnorm_mantissa_i32x16);
|
|
458
|
+
// When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
|
|
459
|
+
__m512i first_normal_e5m2_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 7), _mm512_set1_epi32(0x04));
|
|
460
|
+
subnorm_e5m2_i32x16 = _mm512_mask_blend_epi32(promotes_to_normal, subnorm_e5m2_i32x16, first_normal_e5m2_i32x16);
|
|
461
|
+
|
|
462
|
+
// Blend: use subnormal result when exp <= 0
|
|
463
|
+
__m512i e5m2_i32x16 = _mm512_mask_blend_epi32(is_subnormal, normal_e5m2_i32x16, subnorm_e5m2_i32x16);
|
|
464
|
+
|
|
465
|
+
// Pack 16 i32s to 16 unsigned i8s via AVX-512 cvtepi32_epi8
|
|
466
|
+
return _mm512_cvtepi32_epi8(e5m2_i32x16);
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
NK_INTERNAL __m512 nk_i8x16_to_f32x16_skylake_(__m128i i8x16) {
|
|
470
|
+
return _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(i8x16));
|
|
471
|
+
}
|
|
472
|
+
NK_INTERNAL __m512 nk_u8x16_to_f32x16_skylake_(__m128i u8x16) {
|
|
473
|
+
return _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(u8x16));
|
|
474
|
+
}
|
|
475
|
+
NK_INTERNAL __m512 nk_i16x16_to_f32x16_skylake_(__m256i i16x16) {
|
|
476
|
+
return _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(i16x16));
|
|
477
|
+
}
|
|
478
|
+
NK_INTERNAL __m512 nk_u16x16_to_f32x16_skylake_(__m256i u16x16) {
|
|
479
|
+
return _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16x16));
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
NK_INTERNAL __m128i nk_f32x16_to_i8x16_skylake_(__m512 f32x16) {
|
|
483
|
+
__m512 clamped = _mm512_min_ps(_mm512_max_ps(f32x16, _mm512_set1_ps(-128.0f)), _mm512_set1_ps(127.0f));
|
|
484
|
+
return _mm512_cvtsepi32_epi8(_mm512_cvtps_epi32(clamped));
|
|
485
|
+
}
|
|
486
|
+
NK_INTERNAL __m128i nk_f32x16_to_u8x16_skylake_(__m512 f32x16) {
|
|
487
|
+
__m512 clamped = _mm512_min_ps(_mm512_max_ps(f32x16, _mm512_setzero_ps()), _mm512_set1_ps(255.0f));
|
|
488
|
+
return _mm512_cvtusepi32_epi8(_mm512_cvtps_epu32(clamped));
|
|
489
|
+
}
|
|
490
|
+
NK_INTERNAL __m256i nk_f32x16_to_i16x16_skylake_(__m512 f32x16) {
|
|
491
|
+
__m512 clamped = _mm512_min_ps(_mm512_max_ps(f32x16, _mm512_set1_ps(-32768.0f)), _mm512_set1_ps(32767.0f));
|
|
492
|
+
return _mm512_cvtsepi32_epi16(_mm512_cvtps_epi32(clamped));
|
|
493
|
+
}
|
|
494
|
+
NK_INTERNAL __m256i nk_f32x16_to_u16x16_skylake_(__m512 f32x16) {
|
|
495
|
+
__m512 clamped = _mm512_min_ps(_mm512_max_ps(f32x16, _mm512_setzero_ps()), _mm512_set1_ps(65535.0f));
|
|
496
|
+
return _mm512_cvtusepi32_epi16(_mm512_cvtps_epu32(clamped));
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
NK_INTERNAL __m512i nk_u8x8_to_u64x8_skylake_(__m128i u8x8) { return _mm512_cvtepu8_epi64(u8x8); }
|
|
500
|
+
NK_INTERNAL __m512i nk_u16x8_to_u64x8_skylake_(__m128i u16x8) { return _mm512_cvtepu16_epi64(u16x8); }
|
|
501
|
+
NK_INTERNAL __m512i nk_u32x8_to_u64x8_skylake_(__m256i u32x8) { return _mm512_cvtepu32_epi64(u32x8); }
|
|
502
|
+
|
|
503
|
+
NK_INTERNAL __m128i nk_u64x8_to_u8x8_skylake_(__m512i u64x8) {
|
|
504
|
+
__m512i clamped = _mm512_min_epu64(u64x8, _mm512_set1_epi64(255));
|
|
505
|
+
return _mm512_cvtepi64_epi8(clamped);
|
|
506
|
+
}
|
|
507
|
+
NK_INTERNAL __m128i nk_u64x8_to_u16x8_skylake_(__m512i u64x8) {
|
|
508
|
+
__m512i clamped = _mm512_min_epu64(u64x8, _mm512_set1_epi64(65535));
|
|
509
|
+
return _mm512_cvtepi64_epi16(clamped);
|
|
510
|
+
}
|
|
511
|
+
NK_INTERNAL __m256i nk_u64x8_to_u32x8_skylake_(__m512i u64x8) {
|
|
512
|
+
__m512i clamped = _mm512_min_epu64(u64x8, _mm512_set1_epi64(0xFFFFFFFFULL));
|
|
513
|
+
return _mm512_cvtepi64_epi32(clamped);
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
NK_INTERNAL __m512i nk_i8x8_to_i64x8_skylake_(__m128i i8x8) { return _mm512_cvtepi8_epi64(i8x8); }
|
|
517
|
+
NK_INTERNAL __m512i nk_i16x8_to_i64x8_skylake_(__m128i i16x8) { return _mm512_cvtepi16_epi64(i16x8); }
|
|
518
|
+
NK_INTERNAL __m512i nk_i32x8_to_i64x8_skylake_(__m256i i32x8) { return _mm512_cvtepi32_epi64(i32x8); }
|
|
519
|
+
NK_INTERNAL __m512i nk_u8x8_to_i64x8_skylake_(__m128i u8x8) { return _mm512_cvtepu8_epi64(u8x8); }
|
|
520
|
+
NK_INTERNAL __m512i nk_u16x8_to_i64x8_skylake_(__m128i u16x8) { return _mm512_cvtepu16_epi64(u16x8); }
|
|
521
|
+
NK_INTERNAL __m512i nk_u32x8_to_i64x8_skylake_(__m256i u32x8) { return _mm512_cvtepu32_epi64(u32x8); }
|
|
522
|
+
|
|
523
|
+
NK_INTERNAL __m128i nk_i64x8_to_i8x8_skylake_(__m512i i64x8) {
|
|
524
|
+
__m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(127)), _mm512_set1_epi64(-128));
|
|
525
|
+
return _mm512_cvtepi64_epi8(clamped);
|
|
526
|
+
}
|
|
527
|
+
NK_INTERNAL __m128i nk_i64x8_to_u8x8_skylake_(__m512i i64x8) {
|
|
528
|
+
__m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(255)), _mm512_setzero_si512());
|
|
529
|
+
return _mm512_cvtepi64_epi8(clamped);
|
|
530
|
+
}
|
|
531
|
+
NK_INTERNAL __m128i nk_i64x8_to_i16x8_skylake_(__m512i i64x8) {
|
|
532
|
+
__m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(32767)), _mm512_set1_epi64(-32768));
|
|
533
|
+
return _mm512_cvtepi64_epi16(clamped);
|
|
534
|
+
}
|
|
535
|
+
NK_INTERNAL __m128i nk_i64x8_to_u16x8_skylake_(__m512i i64x8) {
|
|
536
|
+
__m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(65535)), _mm512_setzero_si512());
|
|
537
|
+
return _mm512_cvtepi64_epi16(clamped);
|
|
538
|
+
}
|
|
539
|
+
NK_INTERNAL __m256i nk_i64x8_to_i32x8_skylake_(__m512i i64x8) {
|
|
540
|
+
__m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(NK_I32_MAX)),
|
|
541
|
+
_mm512_set1_epi64(NK_I32_MIN));
|
|
542
|
+
return _mm512_cvtepi64_epi32(clamped);
|
|
543
|
+
}
|
|
544
|
+
NK_INTERNAL __m256i nk_i64x8_to_u32x8_skylake_(__m512i i64x8) {
|
|
545
|
+
__m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(NK_U32_MAX)), _mm512_setzero_si512());
|
|
546
|
+
return _mm512_cvtepi64_epi32(clamped);
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
NK_INTERNAL __m512d nk_f32x8_to_f64x8_skylake_(__m256 f32x8) { return _mm512_cvtps_pd(f32x8); }
|
|
550
|
+
NK_INTERNAL __m512d nk_i32x8_to_f64x8_skylake_(__m256i i32x8) { return _mm512_cvtepi32_pd(i32x8); }
|
|
551
|
+
NK_INTERNAL __m512d nk_u32x8_to_f64x8_skylake_(__m256i u32x8) { return _mm512_cvtepu32_pd(u32x8); }
|
|
552
|
+
|
|
553
|
+
NK_INTERNAL __m256 nk_f64x8_to_f32x8_skylake_(__m512d f64x8) { return _mm512_cvtpd_ps(f64x8); }
|
|
554
|
+
NK_INTERNAL __m256i nk_f64x8_to_i32x8_skylake_(__m512d f64x8) {
|
|
555
|
+
__m512d clamped = _mm512_min_pd(_mm512_max_pd(f64x8, _mm512_set1_pd((double)NK_I32_MIN)),
|
|
556
|
+
_mm512_set1_pd((double)NK_I32_MAX));
|
|
557
|
+
return _mm512_cvtpd_epi32(clamped);
|
|
558
|
+
}
|
|
559
|
+
NK_INTERNAL __m256i nk_f64x8_to_u32x8_skylake_(__m512d f64x8) {
|
|
560
|
+
__m512d clamped = _mm512_min_pd(_mm512_max_pd(f64x8, _mm512_setzero_pd()), _mm512_set1_pd((double)NK_U32_MAX));
|
|
561
|
+
return _mm512_cvtpd_epu32(clamped);
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
#pragma endregion - Vectorized Conversions
|
|
565
|
+
|
|
566
|
+
#pragma region - Converting Loads and Stores
|
|
567
|
+
|
|
568
|
+
/** @brief Load 16 f16 values and convert to 16 f32 (Skylake AVX-512). */
|
|
569
|
+
NK_INTERNAL void nk_load_f16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
|
|
570
|
+
dst->zmm_ps = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)src));
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
/** @brief Partial load of up to 16 f16 values with conversion to f32 (Skylake AVX-512). */
|
|
574
|
+
NK_INTERNAL void nk_partial_load_f16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
575
|
+
nk_b256_vec_t f16_partial;
|
|
576
|
+
nk_partial_load_b16x16_skylake_(src, &f16_partial, n);
|
|
577
|
+
dst->zmm_ps = _mm512_cvtph_ps(f16_partial.ymm);
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
/** @brief Load 16 bf16 values and convert to 16 f32 (Skylake AVX-512). */
|
|
581
|
+
NK_INTERNAL void nk_load_bf16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
|
|
582
|
+
dst->zmm_ps = nk_bf16x16_to_f32x16_skylake_(_mm256_loadu_si256((__m256i const *)src));
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
/** @brief Partial load of up to 16 bf16 values with conversion to f32 (Skylake AVX-512). */
|
|
586
|
+
NK_INTERNAL void nk_partial_load_bf16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
587
|
+
nk_b256_vec_t bf16_partial;
|
|
588
|
+
nk_partial_load_b16x16_skylake_(src, &bf16_partial, n);
|
|
589
|
+
dst->zmm_ps = nk_bf16x16_to_f32x16_skylake_(bf16_partial.ymm);
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
/** @brief Load 16 e4m3 values and convert to 16 f32 (Skylake AVX-512). */
|
|
593
|
+
NK_INTERNAL void nk_load_e4m3x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
|
|
594
|
+
dst->zmm_ps = nk_e4m3x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
/** @brief Partial load of up to 16 e4m3 values with conversion to f32 (Skylake AVX-512). */
|
|
598
|
+
NK_INTERNAL void nk_partial_load_e4m3x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
599
|
+
nk_b128_vec_t e4m3_partial;
|
|
600
|
+
nk_partial_load_b8x16_skylake_(src, &e4m3_partial, n);
|
|
601
|
+
dst->zmm_ps = nk_e4m3x16_to_f32x16_skylake_(e4m3_partial.xmm);
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
/** @brief Load 16 e5m2 values and convert to 16 f32 (Skylake AVX-512). */
|
|
605
|
+
NK_INTERNAL void nk_load_e5m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
|
|
606
|
+
dst->zmm_ps = nk_e5m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
/** @brief Partial load of up to 16 e5m2 values with conversion to f32 (Skylake AVX-512). */
|
|
610
|
+
NK_INTERNAL void nk_partial_load_e5m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
611
|
+
nk_b128_vec_t e5m2_partial;
|
|
612
|
+
nk_partial_load_b8x16_skylake_(src, &e5m2_partial, n);
|
|
613
|
+
dst->zmm_ps = nk_e5m2x16_to_f32x16_skylake_(e5m2_partial.xmm);
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
/** @brief Load 16 e2m3 values and convert to 16 f32 (Skylake AVX-512). */
|
|
617
|
+
NK_INTERNAL void nk_load_e2m3x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
|
|
618
|
+
dst->zmm_ps = nk_e2m3x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
/** @brief Partial load of up to 16 e2m3 values with conversion to f32 (Skylake AVX-512). */
|
|
622
|
+
NK_INTERNAL void nk_partial_load_e2m3x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
623
|
+
nk_b128_vec_t e2m3_partial;
|
|
624
|
+
nk_partial_load_b8x16_skylake_(src, &e2m3_partial, n);
|
|
625
|
+
dst->zmm_ps = nk_e2m3x16_to_f32x16_skylake_(e2m3_partial.xmm);
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
/** @brief Load 16 e3m2 values and convert to 16 f32 (Skylake AVX-512). */
|
|
629
|
+
NK_INTERNAL void nk_load_e3m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
|
|
630
|
+
dst->zmm_ps = nk_e3m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
/** @brief Partial load of up to 16 e3m2 values with conversion to f32 (Skylake AVX-512). */
|
|
634
|
+
NK_INTERNAL void nk_partial_load_e3m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
635
|
+
nk_b128_vec_t e3m2_partial;
|
|
636
|
+
nk_partial_load_b8x16_skylake_(src, &e3m2_partial, n);
|
|
637
|
+
dst->zmm_ps = nk_e3m2x16_to_f32x16_skylake_(e3m2_partial.xmm);
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
#pragma endregion - Converting Loads and Stores
|
|
641
|
+
|
|
642
|
+
#pragma region - Public API
|
|
643
|
+
|
|
644
|
+
NK_PUBLIC void nk_cast_skylake(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
645
|
+
// Same-type fast path
|
|
646
|
+
if (from_type == to_type) {
|
|
647
|
+
nk_size_t size_bits = nk_dtype_bits(from_type);
|
|
648
|
+
if (size_bits > 0) nk_copy_bytes_(to, from, nk_size_divide_round_up_(n * size_bits, 8));
|
|
649
|
+
return;
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
// Type classification for hub selection
|
|
653
|
+
int from_f32_hub = (from_type == nk_f32_k || from_type == nk_f16_k || from_type == nk_bf16_k ||
|
|
654
|
+
from_type == nk_e4m3_k || from_type == nk_e5m2_k || from_type == nk_e2m3_k ||
|
|
655
|
+
from_type == nk_e3m2_k || from_type == nk_i8_k || from_type == nk_u8_k ||
|
|
656
|
+
from_type == nk_i16_k || from_type == nk_u16_k);
|
|
657
|
+
int to_f32_hub = (to_type == nk_f32_k || to_type == nk_f16_k || to_type == nk_bf16_k || to_type == nk_e4m3_k ||
|
|
658
|
+
to_type == nk_e5m2_k || to_type == nk_e2m3_k || to_type == nk_e3m2_k || to_type == nk_i8_k ||
|
|
659
|
+
to_type == nk_u8_k || to_type == nk_i16_k || to_type == nk_u16_k);
|
|
660
|
+
int from_unsigned = (from_type == nk_u8_k || from_type == nk_u16_k || from_type == nk_u32_k ||
|
|
661
|
+
from_type == nk_u64_k);
|
|
662
|
+
int to_unsigned = (to_type == nk_u8_k || to_type == nk_u16_k || to_type == nk_u32_k || to_type == nk_u64_k);
|
|
663
|
+
int from_signed = (from_type == nk_i8_k || from_type == nk_i16_k || from_type == nk_i32_k || from_type == nk_i64_k);
|
|
664
|
+
int to_signed = (to_type == nk_i8_k || to_type == nk_i16_k || to_type == nk_i32_k || to_type == nk_i64_k);
|
|
665
|
+
int from_f64 = (from_type == nk_f64_k);
|
|
666
|
+
int to_f64 = (to_type == nk_f64_k);
|
|
667
|
+
|
|
668
|
+
nk_u8_t const *from_ptr = (nk_u8_t const *)from;
|
|
669
|
+
nk_u8_t *to_ptr = (nk_u8_t *)to;
|
|
670
|
+
|
|
671
|
+
// Hub 1: f32x16 - float types + small integers (16 elements/batch)
|
|
672
|
+
if (from_f32_hub && to_f32_hub) {
|
|
673
|
+
nk_size_t from_bytes = nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
|
|
674
|
+
nk_size_t to_bytes = nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
|
|
675
|
+
while (n > 0) {
|
|
676
|
+
nk_size_t batch = n < 16 ? n : 16;
|
|
677
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)batch);
|
|
678
|
+
__m512 hub_f32x16;
|
|
679
|
+
|
|
680
|
+
// Upcast to f32x16
|
|
681
|
+
if (from_type == nk_f32_k) hub_f32x16 = _mm512_maskz_loadu_ps(mask, from_ptr);
|
|
682
|
+
else if (from_type == nk_f16_k) hub_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, from_ptr));
|
|
683
|
+
else if (from_type == nk_bf16_k)
|
|
684
|
+
hub_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask, from_ptr));
|
|
685
|
+
else if (from_type == nk_e4m3_k)
|
|
686
|
+
hub_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
687
|
+
else if (from_type == nk_e5m2_k)
|
|
688
|
+
hub_f32x16 = nk_e5m2x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
689
|
+
else if (from_type == nk_e2m3_k)
|
|
690
|
+
hub_f32x16 = nk_e2m3x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
691
|
+
else if (from_type == nk_e3m2_k)
|
|
692
|
+
hub_f32x16 = nk_e3m2x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
693
|
+
else if (from_type == nk_i8_k)
|
|
694
|
+
hub_f32x16 = nk_i8x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
695
|
+
else if (from_type == nk_u8_k)
|
|
696
|
+
hub_f32x16 = nk_u8x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
697
|
+
else if (from_type == nk_i16_k)
|
|
698
|
+
hub_f32x16 = nk_i16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask, from_ptr));
|
|
699
|
+
else if (from_type == nk_u16_k)
|
|
700
|
+
hub_f32x16 = nk_u16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask, from_ptr));
|
|
701
|
+
else hub_f32x16 = _mm512_setzero_ps();
|
|
702
|
+
|
|
703
|
+
// Downcast from f32x16
|
|
704
|
+
if (to_type == nk_f32_k) _mm512_mask_storeu_ps(to_ptr, mask, hub_f32x16);
|
|
705
|
+
else if (to_type == nk_f16_k)
|
|
706
|
+
_mm256_mask_storeu_epi16(to_ptr, mask, _mm512_cvtps_ph(hub_f32x16, _MM_FROUND_TO_NEAREST_INT));
|
|
707
|
+
else if (to_type == nk_bf16_k)
|
|
708
|
+
_mm256_mask_storeu_epi16(to_ptr, mask, nk_f32x16_to_bf16x16_skylake_(hub_f32x16));
|
|
709
|
+
else if (to_type == nk_e4m3_k)
|
|
710
|
+
_mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_e4m3x16_skylake_(hub_f32x16));
|
|
711
|
+
else if (to_type == nk_e5m2_k)
|
|
712
|
+
_mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_e5m2x16_skylake_(hub_f32x16));
|
|
713
|
+
else if (to_type == nk_e2m3_k)
|
|
714
|
+
_mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_e2m3x16_skylake_(hub_f32x16));
|
|
715
|
+
else if (to_type == nk_e3m2_k)
|
|
716
|
+
_mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_e3m2x16_skylake_(hub_f32x16));
|
|
717
|
+
else if (to_type == nk_i8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_i8x16_skylake_(hub_f32x16));
|
|
718
|
+
else if (to_type == nk_u8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_u8x16_skylake_(hub_f32x16));
|
|
719
|
+
else if (to_type == nk_i16_k)
|
|
720
|
+
_mm256_mask_storeu_epi16(to_ptr, mask, nk_f32x16_to_i16x16_skylake_(hub_f32x16));
|
|
721
|
+
else if (to_type == nk_u16_k)
|
|
722
|
+
_mm256_mask_storeu_epi16(to_ptr, mask, nk_f32x16_to_u16x16_skylake_(hub_f32x16));
|
|
723
|
+
|
|
724
|
+
from_ptr += batch * from_bytes;
|
|
725
|
+
to_ptr += batch * to_bytes;
|
|
726
|
+
n -= batch;
|
|
727
|
+
}
|
|
728
|
+
return;
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
// Hub 2: u64x8 - unsigned ↔ unsigned integers (8 elements/batch)
|
|
732
|
+
if (from_unsigned && to_unsigned) {
|
|
733
|
+
nk_size_t from_bytes = nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
|
|
734
|
+
nk_size_t to_bytes = nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
|
|
735
|
+
while (n > 0) {
|
|
736
|
+
nk_size_t batch = n < 8 ? n : 8;
|
|
737
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)batch);
|
|
738
|
+
__m512i hub_u64x8;
|
|
739
|
+
|
|
740
|
+
// Upcast to u64x8
|
|
741
|
+
if (from_type == nk_u8_k) hub_u64x8 = nk_u8x8_to_u64x8_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
742
|
+
else if (from_type == nk_u16_k)
|
|
743
|
+
hub_u64x8 = nk_u16x8_to_u64x8_skylake_(_mm_maskz_loadu_epi16(mask, from_ptr));
|
|
744
|
+
else if (from_type == nk_u32_k)
|
|
745
|
+
hub_u64x8 = nk_u32x8_to_u64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
|
|
746
|
+
else if (from_type == nk_u64_k) hub_u64x8 = _mm512_maskz_loadu_epi64(mask, from_ptr);
|
|
747
|
+
else hub_u64x8 = _mm512_setzero_si512();
|
|
748
|
+
|
|
749
|
+
// Downcast from u64x8
|
|
750
|
+
if (to_type == nk_u8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_u64x8_to_u8x8_skylake_(hub_u64x8));
|
|
751
|
+
else if (to_type == nk_u16_k) _mm_mask_storeu_epi16(to_ptr, mask, nk_u64x8_to_u16x8_skylake_(hub_u64x8));
|
|
752
|
+
else if (to_type == nk_u32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_u64x8_to_u32x8_skylake_(hub_u64x8));
|
|
753
|
+
else if (to_type == nk_u64_k) _mm512_mask_storeu_epi64(to_ptr, mask, hub_u64x8);
|
|
754
|
+
|
|
755
|
+
from_ptr += batch * from_bytes;
|
|
756
|
+
to_ptr += batch * to_bytes;
|
|
757
|
+
n -= batch;
|
|
758
|
+
}
|
|
759
|
+
return;
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
// Hub 3: i64x8 - signed/mixed integer conversions (8 elements/batch)
|
|
763
|
+
if ((from_signed || from_unsigned) && (to_signed || to_unsigned)) {
|
|
764
|
+
nk_size_t from_bytes = nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
|
|
765
|
+
nk_size_t to_bytes = nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
|
|
766
|
+
while (n > 0) {
|
|
767
|
+
nk_size_t batch = n < 8 ? n : 8;
|
|
768
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)batch);
|
|
769
|
+
__m512i hub_i64x8;
|
|
770
|
+
|
|
771
|
+
// Upcast to i64x8
|
|
772
|
+
if (from_type == nk_i8_k) hub_i64x8 = nk_i8x8_to_i64x8_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
773
|
+
else if (from_type == nk_u8_k) hub_i64x8 = nk_u8x8_to_i64x8_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
|
|
774
|
+
else if (from_type == nk_i16_k)
|
|
775
|
+
hub_i64x8 = nk_i16x8_to_i64x8_skylake_(_mm_maskz_loadu_epi16(mask, from_ptr));
|
|
776
|
+
else if (from_type == nk_u16_k)
|
|
777
|
+
hub_i64x8 = nk_u16x8_to_i64x8_skylake_(_mm_maskz_loadu_epi16(mask, from_ptr));
|
|
778
|
+
else if (from_type == nk_i32_k)
|
|
779
|
+
hub_i64x8 = nk_i32x8_to_i64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
|
|
780
|
+
else if (from_type == nk_u32_k)
|
|
781
|
+
hub_i64x8 = nk_u32x8_to_i64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
|
|
782
|
+
else if (from_type == nk_i64_k || from_type == nk_u64_k)
|
|
783
|
+
hub_i64x8 = _mm512_maskz_loadu_epi64(mask, from_ptr);
|
|
784
|
+
else hub_i64x8 = _mm512_setzero_si512();
|
|
785
|
+
|
|
786
|
+
// Downcast from i64x8
|
|
787
|
+
if (to_type == nk_i8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_i64x8_to_i8x8_skylake_(hub_i64x8));
|
|
788
|
+
else if (to_type == nk_u8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_i64x8_to_u8x8_skylake_(hub_i64x8));
|
|
789
|
+
else if (to_type == nk_i16_k) _mm_mask_storeu_epi16(to_ptr, mask, nk_i64x8_to_i16x8_skylake_(hub_i64x8));
|
|
790
|
+
else if (to_type == nk_u16_k) _mm_mask_storeu_epi16(to_ptr, mask, nk_i64x8_to_u16x8_skylake_(hub_i64x8));
|
|
791
|
+
else if (to_type == nk_i32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_i64x8_to_i32x8_skylake_(hub_i64x8));
|
|
792
|
+
else if (to_type == nk_u32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_i64x8_to_u32x8_skylake_(hub_i64x8));
|
|
793
|
+
else if (to_type == nk_i64_k || to_type == nk_u64_k) _mm512_mask_storeu_epi64(to_ptr, mask, hub_i64x8);
|
|
794
|
+
|
|
795
|
+
from_ptr += batch * from_bytes;
|
|
796
|
+
to_ptr += batch * to_bytes;
|
|
797
|
+
n -= batch;
|
|
798
|
+
}
|
|
799
|
+
return;
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
// Hub 4: f64x8 - f64 conversions (8 elements/batch)
|
|
803
|
+
// Only enter when both sides are types we can actually handle: f64, f32, i32, u32.
|
|
804
|
+
// Unsupported pairs (e.g. i8→f64, f16→f64) fall through to serial fallback.
|
|
805
|
+
if ((from_f64 || to_f64) && //
|
|
806
|
+
(from_type == nk_f64_k || from_type == nk_f32_k || from_type == nk_i32_k || from_type == nk_u32_k) && //
|
|
807
|
+
(to_type == nk_f64_k || to_type == nk_f32_k || to_type == nk_i32_k || to_type == nk_u32_k)) {
|
|
808
|
+
nk_size_t from_bytes = nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
|
|
809
|
+
nk_size_t to_bytes = nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
|
|
810
|
+
while (n > 0) {
|
|
811
|
+
nk_size_t batch = n < 8 ? n : 8;
|
|
812
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)batch);
|
|
813
|
+
__m512d hub_f64x8;
|
|
814
|
+
|
|
815
|
+
// Upcast to f64x8
|
|
816
|
+
if (from_type == nk_f64_k) hub_f64x8 = _mm512_maskz_loadu_pd(mask, from_ptr);
|
|
817
|
+
else if (from_type == nk_f32_k)
|
|
818
|
+
hub_f64x8 = nk_f32x8_to_f64x8_skylake_(_mm256_maskz_loadu_ps(mask, from_ptr));
|
|
819
|
+
else if (from_type == nk_i32_k)
|
|
820
|
+
hub_f64x8 = nk_i32x8_to_f64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
|
|
821
|
+
else if (from_type == nk_u32_k)
|
|
822
|
+
hub_f64x8 = nk_u32x8_to_f64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
|
|
823
|
+
else hub_f64x8 = _mm512_setzero_pd();
|
|
824
|
+
|
|
825
|
+
// Downcast from f64x8
|
|
826
|
+
if (to_type == nk_f64_k) _mm512_mask_storeu_pd(to_ptr, mask, hub_f64x8);
|
|
827
|
+
else if (to_type == nk_f32_k) _mm256_mask_storeu_ps(to_ptr, mask, nk_f64x8_to_f32x8_skylake_(hub_f64x8));
|
|
828
|
+
else if (to_type == nk_i32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_f64x8_to_i32x8_skylake_(hub_f64x8));
|
|
829
|
+
else if (to_type == nk_u32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_f64x8_to_u32x8_skylake_(hub_f64x8));
|
|
830
|
+
|
|
831
|
+
from_ptr += batch * from_bytes;
|
|
832
|
+
to_ptr += batch * to_bytes;
|
|
833
|
+
n -= batch;
|
|
834
|
+
}
|
|
835
|
+
return;
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
// Fallback: complex types, i4/u4/u1, unsupported combinations
|
|
839
|
+
nk_cast_serial(from, from_type, n, to, to_type);
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
#pragma endregion - Public API
|
|
843
|
+
|
|
844
|
+
#if defined(__clang__)
|
|
845
|
+
#pragma clang attribute pop
|
|
846
|
+
#elif defined(__GNUC__)
|
|
847
|
+
#pragma GCC pop_options
|
|
848
|
+
#endif
|
|
849
|
+
|
|
850
|
+
#if defined(__cplusplus)
|
|
851
|
+
} // extern "C"
|
|
852
|
+
#endif
|
|
853
|
+
|
|
854
|
+
#endif // NK_TARGET_SKYLAKE
|
|
855
|
+
#endif // NK_TARGET_X86_
|
|
856
|
+
#endif // NK_CAST_SKYLAKE_H
|