numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,975 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Type Conversions for Haswell.
|
|
3
|
+
* @file include/numkong/cast/haswell.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 2, 2026
|
|
6
|
+
*
|
|
7
|
+
* @section haswell_cast_instructions Key F16C/AVX2 Conversion Instructions
|
|
8
|
+
*
|
|
9
|
+
* Intrinsic Instruction Latency Throughput Ports
|
|
10
|
+
* _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy 1/cy p01
|
|
11
|
+
* _mm256_cvtps_ph VCVTPS2PH (XMM, YMM, I8) 4cy 1/cy p01+p5
|
|
12
|
+
* _mm256_cvtepi16_epi32 VPMOVSXWD (YMM, XMM) 3cy 1/cy p5
|
|
13
|
+
* _mm256_slli_epi32 VPSLLD (YMM, YMM, I8) 1cy 0.5/cy p01
|
|
14
|
+
* _mm256_blendv_ps VBLENDVPS (YMM, YMM, YMM, YMM) 2cy 1/cy p015
|
|
15
|
+
*
|
|
16
|
+
* F16C provides hardware F16<->F32 conversion. BF16 lacks hardware support and is emulated via
|
|
17
|
+
* bit manipulation (shift upper 16 bits). FP8 formats (E4M3/E5M2) use lookup tables for subnormal
|
|
18
|
+
* handling combined with arithmetic for normal values. All conversions hub through F32.
|
|
19
|
+
*/
|
|
20
|
+
#ifndef NK_CAST_HASWELL_H
|
|
21
|
+
#define NK_CAST_HASWELL_H
|
|
22
|
+
|
|
23
|
+
#if NK_TARGET_X86_
|
|
24
|
+
#if NK_TARGET_HASWELL
|
|
25
|
+
|
|
26
|
+
#include "numkong/types.h"
|
|
27
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b16x16_serial_`
|
|
28
|
+
|
|
29
|
+
#if defined(__cplusplus)
|
|
30
|
+
extern "C" {
|
|
31
|
+
#endif
|
|
32
|
+
|
|
33
|
+
#if defined(__clang__)
|
|
34
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
|
|
35
|
+
#elif defined(__GNUC__)
|
|
36
|
+
#pragma GCC push_options
|
|
37
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
|
|
38
|
+
#endif
|
|
39
|
+
|
|
40
|
+
NK_PUBLIC void nk_f32_to_f16_haswell(nk_f32_t const *from, nk_f16_t *to) {
|
|
41
|
+
*to = _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(*from), _MM_FROUND_TO_NEAREST_INT));
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
NK_PUBLIC void nk_f16_to_f32_haswell(nk_f16_t const *from, nk_f32_t *to) {
|
|
45
|
+
*to = _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(*from)));
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
#pragma region - Type Punned Loads and Stores
|
|
49
|
+
|
|
50
|
+
/** @brief Type-agnostic 256-bit full load (Haswell AVX2). */
|
|
51
|
+
NK_INTERNAL void nk_load_b256_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
52
|
+
dst->ymm = _mm256_loadu_si256((const __m256i *)src);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
/** @brief Type-agnostic 256-bit full store (Haswell AVX2). */
|
|
56
|
+
NK_INTERNAL void nk_store_b256_haswell_(nk_b256_vec_t const *src, void *dst) {
|
|
57
|
+
_mm256_storeu_si256((__m256i *)dst, src->ymm);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/** @brief Type-agnostic 128-bit full load (Haswell AVX2). */
|
|
61
|
+
NK_INTERNAL void nk_load_b128_haswell_(void const *src, nk_b128_vec_t *dst) {
|
|
62
|
+
dst->xmm = _mm_loadu_si128((const __m128i *)src);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/** @brief Type-agnostic 128-bit full store (SSE2). */
|
|
66
|
+
NK_INTERNAL void nk_store_b128_haswell_(nk_b128_vec_t const *src, void *dst) {
|
|
67
|
+
_mm_storeu_si128((__m128i *)dst, src->xmm);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
/** @brief Type-agnostic 128-bit partial load with AVX maskload. */
|
|
71
|
+
NK_INTERNAL void nk_partial_load_b32x4_haswell_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
72
|
+
__m128i idx_i32x4 = _mm_setr_epi32(0, 1, 2, 3);
|
|
73
|
+
__m128i limit_i32x4 = _mm_set1_epi32((int)n);
|
|
74
|
+
__m128i mask_i32x4 = _mm_cmpgt_epi32(limit_i32x4, idx_i32x4);
|
|
75
|
+
dst->xmm = _mm_castps_si128(_mm_maskload_ps((float const *)src, mask_i32x4));
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
/** @brief Type-agnostic 128-bit partial store with AVX maskstore. */
|
|
79
|
+
NK_INTERNAL void nk_partial_store_b32x4_haswell_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
80
|
+
__m128i idx_i32x4 = _mm_setr_epi32(0, 1, 2, 3);
|
|
81
|
+
__m128i limit_i32x4 = _mm_set1_epi32((int)n);
|
|
82
|
+
__m128i mask_i32x4 = _mm_cmpgt_epi32(limit_i32x4, idx_i32x4);
|
|
83
|
+
_mm_maskstore_ps((float *)dst, mask_i32x4, _mm_castsi128_ps(src->xmm));
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/** @brief Type-agnostic 256-bit partial load with AVX2 maskload. */
|
|
87
|
+
NK_INTERNAL void nk_partial_load_b64x4_haswell_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
88
|
+
__m256i idx_i64x4 = _mm256_setr_epi64x(0, 1, 2, 3);
|
|
89
|
+
__m256i limit_i64x4 = _mm256_set1_epi64x((long long)n);
|
|
90
|
+
__m256i mask_i64x4 = _mm256_cmpgt_epi64(limit_i64x4, idx_i64x4);
|
|
91
|
+
dst->ymm = _mm256_castpd_si256(_mm256_maskload_pd((double const *)src, mask_i64x4));
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
/** @brief Type-agnostic 256-bit partial store with AVX2 maskstore. */
|
|
95
|
+
NK_INTERNAL void nk_partial_store_b64x4_haswell_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
|
|
96
|
+
__m256i idx_i64x4 = _mm256_setr_epi64x(0, 1, 2, 3);
|
|
97
|
+
__m256i limit_i64x4 = _mm256_set1_epi64x((long long)n);
|
|
98
|
+
__m256i mask_i64x4 = _mm256_cmpgt_epi64(limit_i64x4, idx_i64x4);
|
|
99
|
+
_mm256_maskstore_pd((double *)dst, mask_i64x4, _mm256_castsi256_pd(src->ymm));
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
#pragma endregion - Type Punned Loads and Stores
|
|
103
|
+
|
|
104
|
+
#pragma region - Vectorized Conversions
|
|
105
|
+
|
|
106
|
+
/** @brief Convert 8x bf16 → 8x f32 by shifting left 16 bits (AVX2). */
|
|
107
|
+
NK_INTERNAL __m256 nk_bf16x8_to_f32x8_haswell_(__m128i bf16_i16x8) {
|
|
108
|
+
return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(bf16_i16x8), 16));
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
/** @brief Convert 8x f32 → 8x bf16 by truncating with RNE rounding (AVX2). */
|
|
112
|
+
NK_INTERNAL __m128i nk_f32x8_to_bf16x8_haswell_(__m256 f32x8) {
|
|
113
|
+
__m256i bits_i32x8 = _mm256_castps_si256(f32x8);
|
|
114
|
+
// RNE rounding: add (0x7FFF + lsb) where lsb is bit 16
|
|
115
|
+
__m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 16), _mm256_set1_epi32(1));
|
|
116
|
+
__m256i rounded_i32x8 = _mm256_add_epi32(bits_i32x8, _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), lsb_i32x8));
|
|
117
|
+
__m256i bf16_i32x8 = _mm256_srli_epi32(rounded_i32x8, 16);
|
|
118
|
+
// Pack 8x i32 to 8x i16
|
|
119
|
+
__m128i lo_i32x4 = _mm256_castsi256_si128(bf16_i32x8);
|
|
120
|
+
__m128i hi_i32x4 = _mm256_extracti128_si256(bf16_i32x8, 1);
|
|
121
|
+
return _mm_packus_epi32(lo_i32x4, hi_i32x4);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
/** @brief Integer upcasts to f32x8 (AVX2). */
|
|
125
|
+
NK_INTERNAL __m256 nk_i8x8_to_f32x8_haswell_(__m128i i8x8) { return _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(i8x8)); }
|
|
126
|
+
NK_INTERNAL __m256 nk_u8x8_to_f32x8_haswell_(__m128i u8x8) { return _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(u8x8)); }
|
|
127
|
+
NK_INTERNAL __m256 nk_i16x8_to_f32x8_haswell_(__m128i i16x8) {
|
|
128
|
+
return _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(i16x8));
|
|
129
|
+
}
|
|
130
|
+
NK_INTERNAL __m256 nk_u16x8_to_f32x8_haswell_(__m128i u16x8) {
|
|
131
|
+
return _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(u16x8));
|
|
132
|
+
}
|
|
133
|
+
NK_INTERNAL __m256 nk_i32x8_to_f32x8_haswell_(__m256i i32x8) { return _mm256_cvtepi32_ps(i32x8); }
|
|
134
|
+
NK_INTERNAL __m256 nk_u32x8_to_f32x8_haswell_(__m256i u32x8) {
|
|
135
|
+
__m256i lo_i32x8 = _mm256_and_si256(u32x8, _mm256_set1_epi32(0xFFFF));
|
|
136
|
+
__m256i hi_i32x8 = _mm256_srli_epi32(u32x8, 16);
|
|
137
|
+
return _mm256_add_ps(_mm256_cvtepi32_ps(lo_i32x8),
|
|
138
|
+
_mm256_mul_ps(_mm256_cvtepi32_ps(hi_i32x8), _mm256_set1_ps(65536.0f)));
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/** @brief Saturating f32x8 downcasts to integers (AVX2). */
|
|
142
|
+
NK_INTERNAL __m256i nk_f32x8_to_i32x8_haswell_(__m256 f32x8) { return _mm256_cvtps_epi32(f32x8); }
|
|
143
|
+
NK_INTERNAL __m256i nk_f32x8_to_u32x8_haswell_(__m256 f32x8) {
|
|
144
|
+
__m256 clamped_f32x8 = _mm256_max_ps(_mm256_min_ps(f32x8, _mm256_set1_ps((float)NK_U32_MAX)), _mm256_setzero_ps());
|
|
145
|
+
__m256 threshold_f32x8 = _mm256_set1_ps(2147483648.0f);
|
|
146
|
+
__m256i mask_i32x8 = _mm256_castps_si256(_mm256_cmp_ps(clamped_f32x8, threshold_f32x8, _CMP_GE_OQ));
|
|
147
|
+
__m256 adjusted_f32x8 = _mm256_sub_ps(clamped_f32x8,
|
|
148
|
+
_mm256_and_ps(_mm256_castsi256_ps(mask_i32x8), threshold_f32x8));
|
|
149
|
+
return _mm256_add_epi32(_mm256_cvtps_epi32(adjusted_f32x8),
|
|
150
|
+
_mm256_and_si256(mask_i32x8, _mm256_set1_epi32((int)0x80000000)));
|
|
151
|
+
}
|
|
152
|
+
NK_INTERNAL __m128i nk_f32x8_to_i16x8_haswell_(__m256 f32x8) {
|
|
153
|
+
__m256 clamped_f32x8 = _mm256_min_ps(_mm256_max_ps(f32x8, _mm256_set1_ps(-32768.0f)), _mm256_set1_ps(32767.0f));
|
|
154
|
+
__m256i i32x8 = _mm256_cvtps_epi32(clamped_f32x8);
|
|
155
|
+
return _mm_packs_epi32(_mm256_castsi256_si128(i32x8), _mm256_extracti128_si256(i32x8, 1));
|
|
156
|
+
}
|
|
157
|
+
NK_INTERNAL __m128i nk_f32x8_to_u16x8_haswell_(__m256 f32x8) {
|
|
158
|
+
__m256 clamped_f32x8 = _mm256_min_ps(_mm256_max_ps(f32x8, _mm256_setzero_ps()), _mm256_set1_ps(65535.0f));
|
|
159
|
+
__m256i i32x8 = _mm256_cvtps_epi32(clamped_f32x8);
|
|
160
|
+
return _mm_packus_epi32(_mm256_castsi256_si128(i32x8), _mm256_extracti128_si256(i32x8, 1));
|
|
161
|
+
}
|
|
162
|
+
NK_INTERNAL __m128i nk_f32x8_to_i8x8_haswell_(__m256 f32x8) {
|
|
163
|
+
__m256 clamped_f32x8 = _mm256_min_ps(_mm256_max_ps(f32x8, _mm256_set1_ps(-128.0f)), _mm256_set1_ps(127.0f));
|
|
164
|
+
__m256i i32x8 = _mm256_cvtps_epi32(clamped_f32x8);
|
|
165
|
+
__m128i i16x8 = _mm_packs_epi32(_mm256_castsi256_si128(i32x8), _mm256_extracti128_si256(i32x8, 1));
|
|
166
|
+
return _mm_packs_epi16(i16x8, _mm_setzero_si128());
|
|
167
|
+
}
|
|
168
|
+
NK_INTERNAL __m128i nk_f32x8_to_u8x8_haswell_(__m256 f32x8) {
|
|
169
|
+
__m256 clamped_f32x8 = _mm256_min_ps(_mm256_max_ps(f32x8, _mm256_setzero_ps()), _mm256_set1_ps(255.0f));
|
|
170
|
+
__m256i i32x8 = _mm256_cvtps_epi32(clamped_f32x8);
|
|
171
|
+
__m128i u16x8 = _mm_packus_epi32(_mm256_castsi256_si128(i32x8), _mm256_extracti128_si256(i32x8, 1));
|
|
172
|
+
return _mm_packus_epi16(u16x8, _mm_setzero_si128());
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
/** @brief Convert 16x e4m3 → 16x bf16 via arithmetic + small LUT for subnormals (AVX2).
|
|
176
|
+
* E4M3 format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
|
|
177
|
+
* Normal values: BF16 = sign | ((lower7 << 4) + 0x3C00).
|
|
178
|
+
* Subnormals (8 values): looked up via vpshufb from an 8-entry LUT.
|
|
179
|
+
* Handles all corner cases: zero, subnormals, normals, and NaN. */
|
|
180
|
+
NK_INTERNAL __m256i nk_e4m3x16_to_bf16x16_haswell_(__m128i e4m3x16) {
|
|
181
|
+
__m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3x16);
|
|
182
|
+
__m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
|
|
183
|
+
__m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
|
|
184
|
+
|
|
185
|
+
// Normal path: BF16 = ((lower7 << 4) + 0x3C00) | (sign << 8)
|
|
186
|
+
__m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 4), _mm256_set1_epi16(0x3C00));
|
|
187
|
+
sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
188
|
+
__m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
|
|
189
|
+
|
|
190
|
+
// Subnormal LUT via shuffle_epi8 (8 entries: mantissa 0-7 → BF16)
|
|
191
|
+
// E4M3 subnormal BF16 values: 0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60
|
|
192
|
+
// Split into low bytes and high bytes for reconstruction
|
|
193
|
+
__m256i const lo_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
194
|
+
0x60, 0x40, 0x20, 0x00, (char)0xC0, (char)0x80, 0x00, 0x00, //
|
|
195
|
+
0x60, 0x40, 0x20, 0x00, (char)0xC0, (char)0x80, 0x00, 0x00)); //
|
|
196
|
+
__m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
197
|
+
0x3C, 0x3C, 0x3C, 0x3C, 0x3B, 0x3B, 0x3B, 0x00, //
|
|
198
|
+
0x3C, 0x3C, 0x3C, 0x3C, 0x3B, 0x3B, 0x3B, 0x00)); //
|
|
199
|
+
|
|
200
|
+
// Extract mantissa (bits 0-2) as byte indices for shuffle
|
|
201
|
+
__m256i byte_idx_i8x32 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi8(0x07));
|
|
202
|
+
__m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
|
|
203
|
+
__m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
|
|
204
|
+
|
|
205
|
+
// Combine low and high bytes into 16-bit values
|
|
206
|
+
__m256i subnorm_abs_i16x16 = _mm256_or_si256( //
|
|
207
|
+
_mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
|
|
208
|
+
_mm256_slli_epi16(hi_bytes_i8x32, 8)); //
|
|
209
|
+
__m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
|
|
210
|
+
|
|
211
|
+
// Blend: if exponent == 0, use subnormal result; else use normal result
|
|
212
|
+
__m256i exp_bits_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x78));
|
|
213
|
+
__m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
|
|
214
|
+
__m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
|
|
215
|
+
|
|
216
|
+
// Handle NaN: E4M3 index 127 (0x7F) → BF16 NaN (0x7FC0)
|
|
217
|
+
__m256i is_nan_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7F));
|
|
218
|
+
__m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7FC0));
|
|
219
|
+
return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
/** @brief Convert 16x e5m2 → 16x bf16 via arithmetic + small LUT for subnormals (AVX2).
|
|
223
|
+
* E5M2 format: S EEEEE MM (bias=15). BF16: S EEEEEEEE MMMMMMM (bias=127).
|
|
224
|
+
* Normal values: BF16 = sign | ((lower7 << 5) + 0x3800).
|
|
225
|
+
* Subnormals (4 values): looked up via vpshufb from a 4-entry LUT.
|
|
226
|
+
* Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
|
|
227
|
+
NK_INTERNAL __m256i nk_e5m2x16_to_bf16x16_haswell_(__m128i e5m2x16) {
|
|
228
|
+
__m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2x16);
|
|
229
|
+
__m256i sign_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16((short)0x80));
|
|
230
|
+
__m256i lower7_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F));
|
|
231
|
+
|
|
232
|
+
// Normal path: BF16 = ((lower7 << 5) + 0x3800) | (sign << 8)
|
|
233
|
+
__m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 5), _mm256_set1_epi16(0x3800));
|
|
234
|
+
sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
235
|
+
__m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
|
|
236
|
+
|
|
237
|
+
// Subnormal LUT via shuffle_epi8 (4 entries: mantissa 0-3 → BF16)
|
|
238
|
+
// E5M2 subnormal BF16 values: 0x0000, 0x3780, 0x3800, 0x3840
|
|
239
|
+
__m256i const lo_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
240
|
+
0x00, 0x00, 0x00, 0x00, 0x40, 0x00, (char)0x80, 0x00, //
|
|
241
|
+
0x00, 0x00, 0x00, 0x00, 0x40, 0x00, (char)0x80, 0x00)); //
|
|
242
|
+
__m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
243
|
+
0x00, 0x00, 0x00, 0x00, 0x38, 0x38, 0x37, 0x00, //
|
|
244
|
+
0x00, 0x00, 0x00, 0x00, 0x38, 0x38, 0x37, 0x00)); //
|
|
245
|
+
|
|
246
|
+
// Extract mantissa (bits 0-1) as byte indices for shuffle
|
|
247
|
+
__m256i byte_idx_i8x32 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi8(0x03));
|
|
248
|
+
__m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
|
|
249
|
+
__m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
|
|
250
|
+
|
|
251
|
+
// Combine low and high bytes into 16-bit values
|
|
252
|
+
__m256i subnorm_abs_i16x16 = _mm256_or_si256( //
|
|
253
|
+
_mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
|
|
254
|
+
_mm256_slli_epi16(hi_bytes_i8x32, 8)); //
|
|
255
|
+
__m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
|
|
256
|
+
|
|
257
|
+
// Blend: if exponent == 0, use subnormal result; else use normal result
|
|
258
|
+
__m256i exp_bits_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7C));
|
|
259
|
+
__m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
|
|
260
|
+
__m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
|
|
261
|
+
|
|
262
|
+
// Handle Inf (0x7C) and NaN (0x7D-0x7F)
|
|
263
|
+
__m256i is_inf_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7C));
|
|
264
|
+
__m256i is_nan_i16x16 = _mm256_cmpgt_epi16(lower7_i16x16, _mm256_set1_epi16(0x7C));
|
|
265
|
+
__m256i inf_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7F80));
|
|
266
|
+
__m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7FC0));
|
|
267
|
+
result_i16x16 = _mm256_blendv_epi8(result_i16x16, inf_i16x16, is_inf_i16x16);
|
|
268
|
+
return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
/** @brief Convert 16x e4m3 → 16x f16 via arithmetic + small LUT for subnormals (AVX2).
|
|
272
|
+
* E4M3 format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
273
|
+
* Normal values: F16 = sign | ((lower7 << 7) + 0x2000).
|
|
274
|
+
* Subnormals (8 values): looked up via vpshufb from an 8-entry LUT.
|
|
275
|
+
* Handles all corner cases: zero, subnormals, normals, and NaN. */
|
|
276
|
+
NK_INTERNAL __m256i nk_e4m3x16_to_f16x16_haswell_(__m128i e4m3x16) {
|
|
277
|
+
__m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3x16);
|
|
278
|
+
__m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
|
|
279
|
+
__m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
|
|
280
|
+
|
|
281
|
+
// Normal path: F16 = ((lower7 << 7) + 0x2000) | (sign << 8)
|
|
282
|
+
__m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 7), _mm256_set1_epi16(0x2000));
|
|
283
|
+
sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
284
|
+
__m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
|
|
285
|
+
|
|
286
|
+
// Subnormal LUT via shuffle_epi8 (8 entries: mantissa 0-7 → F16)
|
|
287
|
+
// E4M3 subnormal F16 values: 0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300
|
|
288
|
+
// All low bytes are 0x00, high bytes: 0x00, 0x18, 0x1C, 0x1E, 0x20, 0x21, 0x22, 0x23
|
|
289
|
+
// _mm_set_epi8 order: b15..u1 (unused), b7=idx7, b6=idx6, ..., b0=idx0
|
|
290
|
+
__m256i const lo_lut_i8x32 = _mm256_setzero_si256();
|
|
291
|
+
__m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
292
|
+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //
|
|
293
|
+
0x23, 0x22, 0x21, 0x20, 0x1E, 0x1C, 0x18, 0x00)); //
|
|
294
|
+
|
|
295
|
+
// Extract mantissa (bits 0-2) as byte indices for shuffle
|
|
296
|
+
__m256i byte_idx_i8x32 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi8(0x07));
|
|
297
|
+
__m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
|
|
298
|
+
__m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
|
|
299
|
+
|
|
300
|
+
// Combine low and high bytes into 16-bit values
|
|
301
|
+
__m256i subnorm_abs_i16x16 = _mm256_or_si256( //
|
|
302
|
+
_mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
|
|
303
|
+
_mm256_slli_epi16(hi_bytes_i8x32, 8)); //
|
|
304
|
+
__m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
|
|
305
|
+
|
|
306
|
+
// Blend: if exponent == 0, use subnormal result; else use normal result
|
|
307
|
+
__m256i exp_bits_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x78));
|
|
308
|
+
__m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
|
|
309
|
+
__m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
|
|
310
|
+
|
|
311
|
+
// Handle NaN: E4M3 index 127 (0x7F) → F16 NaN (0x7E00)
|
|
312
|
+
__m256i is_nan_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7F));
|
|
313
|
+
__m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7E00));
|
|
314
|
+
return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
/** @brief Convert 16x e5m2 → 16x f16 via simple bit shift (AVX2).
|
|
318
|
+
* E5M2 format: S EEEEE MM (bias=15). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
319
|
+
* Same exponent bias means F16 = (lower7 << 8) | (sign << 15).
|
|
320
|
+
* Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
|
|
321
|
+
NK_INTERNAL __m256i nk_e5m2x16_to_f16x16_haswell_(__m128i e5m2x16) {
|
|
322
|
+
__m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2x16);
|
|
323
|
+
__m256i sign_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16((short)0x80));
|
|
324
|
+
__m256i lower7_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F));
|
|
325
|
+
|
|
326
|
+
// F16 = (lower7 << 8) | (sign << 15)
|
|
327
|
+
// Works for all cases: subnormals, normals, infinity, and NaN
|
|
328
|
+
__m256i result_i16x16 = _mm256_slli_epi16(lower7_i16x16, 8);
|
|
329
|
+
sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
330
|
+
return _mm256_or_si256(result_i16x16, sign_i16x16);
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
/** @brief Convert 8x e4m3 → 8x f32 via bit manipulation (AVX2).
|
|
334
|
+
* E4M3 format: S EEEE MMM (bias=7). F32: sign<<31, (exp+120)<<23, mant<<20.
|
|
335
|
+
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻⁷⁾ × 2⁻³ = mantissa ÷ 512. */
|
|
336
|
+
NK_INTERNAL __m256 nk_e4m3x8_to_f32x8_haswell_(__m128i e4m3_i8x8) {
|
|
337
|
+
__m256i e4m3_i32x8 = _mm256_cvtepu8_epi32(e4m3_i8x8);
|
|
338
|
+
|
|
339
|
+
// Extract fields
|
|
340
|
+
__m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e4m3_i32x8, 3), _mm256_set1_epi32(0x0F));
|
|
341
|
+
__m256i mant_i32x8 = _mm256_and_si256(e4m3_i32x8, _mm256_set1_epi32(0x07));
|
|
342
|
+
|
|
343
|
+
// Build F32 sign bit
|
|
344
|
+
__m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e4m3_i32x8, 7), 31);
|
|
345
|
+
|
|
346
|
+
// Normal path: sign | ((exp+120)<<23) | (mant<<20)
|
|
347
|
+
__m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(120)), 23);
|
|
348
|
+
__m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 20);
|
|
349
|
+
__m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
|
|
350
|
+
|
|
351
|
+
// Subnormal path: value = mantissa / 512.0f, then apply sign
|
|
352
|
+
__m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 512.0f));
|
|
353
|
+
__m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
|
|
354
|
+
|
|
355
|
+
// Blend: if exp==0, use subnormal result; otherwise use normal bits
|
|
356
|
+
__m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
|
|
357
|
+
__m256 result = _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8,
|
|
358
|
+
_mm256_castsi256_ps(exp_zero_mask));
|
|
359
|
+
|
|
360
|
+
// NaN path: E4M3FN has NaN only when exp=15 AND mant=7 (0x7F or 0xFF)
|
|
361
|
+
__m256i is_nan_mask = _mm256_and_si256( //
|
|
362
|
+
_mm256_cmpeq_epi32(exp_i32x8, _mm256_set1_epi32(15)), //
|
|
363
|
+
_mm256_cmpeq_epi32(mant_i32x8, _mm256_set1_epi32(7))); //
|
|
364
|
+
__m256i nan_bits = _mm256_or_si256(f32_sign_i32x8, _mm256_set1_epi32(0x7FC00000)); // F32 quiet NaN
|
|
365
|
+
return _mm256_blendv_ps(result, _mm256_castsi256_ps(nan_bits), _mm256_castsi256_ps(is_nan_mask));
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
/** @brief Convert 8x e5m2 → 8x f32 via bit manipulation (AVX2).
|
|
369
|
+
* E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mant<<21.
|
|
370
|
+
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁵⁾ × 2⁻² = mantissa ÷ 65536. */
|
|
371
|
+
NK_INTERNAL __m256 nk_e5m2x8_to_f32x8_haswell_(__m128i e5m2_i8x8) {
|
|
372
|
+
__m256i e5m2_i32x8 = _mm256_cvtepu8_epi32(e5m2_i8x8);
|
|
373
|
+
|
|
374
|
+
// Extract fields
|
|
375
|
+
__m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e5m2_i32x8, 2), _mm256_set1_epi32(0x1F));
|
|
376
|
+
__m256i mant_i32x8 = _mm256_and_si256(e5m2_i32x8, _mm256_set1_epi32(0x03));
|
|
377
|
+
|
|
378
|
+
// Build F32 sign bit
|
|
379
|
+
__m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e5m2_i32x8, 7), 31);
|
|
380
|
+
|
|
381
|
+
// Normal path: sign | ((exp+112)<<23) | (mant<<21)
|
|
382
|
+
__m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(112)), 23);
|
|
383
|
+
__m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 21);
|
|
384
|
+
__m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
|
|
385
|
+
|
|
386
|
+
// Subnormal path: value = mantissa / 65536.0f, then apply sign
|
|
387
|
+
__m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 65536.0f));
|
|
388
|
+
__m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
|
|
389
|
+
|
|
390
|
+
// Blend: if exp==0, use subnormal result; otherwise use normal bits
|
|
391
|
+
__m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
|
|
392
|
+
return _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8, _mm256_castsi256_ps(exp_zero_mask));
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
/** @brief Convert 8x f32 → 8x e4m3 via bit manipulation (AVX2).
|
|
396
|
+
* E4M3 format: S EEEE MMM (bias=7). Handles normal, subnormal, and overflow cases.
|
|
397
|
+
* Subnormals (f32_exp ≤ 120): mantissa = round(abs_f32 * 512), clamped to [0,7]. */
|
|
398
|
+
NK_INTERNAL __m128i nk_f32x8_to_e4m3x8_haswell_(__m256 f32x8) {
|
|
399
|
+
__m256i bits_i32x8 = _mm256_castps_si256(f32x8);
|
|
400
|
+
__m256i sign_i32x8 = _mm256_srli_epi32(bits_i32x8, 31);
|
|
401
|
+
__m256i f32_exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 23), _mm256_set1_epi32(0xFF));
|
|
402
|
+
|
|
403
|
+
// Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
|
|
404
|
+
// RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
|
|
405
|
+
__m256i significand_i32x8 = _mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)),
|
|
406
|
+
_mm256_set1_epi32(0x00800000)); // Add implicit 1 bit
|
|
407
|
+
__m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(significand_i32x8, 20), _mm256_set1_epi32(1));
|
|
408
|
+
__m256i rounding_bias_i32x8 = _mm256_add_epi32(_mm256_set1_epi32(0x0007FFFF), lsb_i32x8);
|
|
409
|
+
__m256i rounded_sig_i32x8 = _mm256_add_epi32(significand_i32x8, rounding_bias_i32x8);
|
|
410
|
+
__m256i carry_i32x8 = _mm256_srli_epi32(rounded_sig_i32x8, 24); // Carry into exponent if bit 24 set
|
|
411
|
+
__m256i f32_mantissa_i32x8 = _mm256_and_si256(_mm256_srli_epi32(rounded_sig_i32x8, 20), _mm256_set1_epi32(0x07));
|
|
412
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
413
|
+
f32_mantissa_i32x8 = _mm256_andnot_si256(_mm256_slli_epi32(carry_i32x8, 31), f32_mantissa_i32x8);
|
|
414
|
+
__m256i e4m3_exp_i32x8 = _mm256_sub_epi32(_mm256_add_epi32(f32_exp_i32x8, carry_i32x8), _mm256_set1_epi32(120));
|
|
415
|
+
|
|
416
|
+
// Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 15)
|
|
417
|
+
__m256i is_subnormal_i32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32(1), e4m3_exp_i32x8);
|
|
418
|
+
__m256i overflow_i32x8 = _mm256_cmpgt_epi32(e4m3_exp_i32x8, _mm256_set1_epi32(15));
|
|
419
|
+
|
|
420
|
+
// Normal path: clamp exp to [1,15], extract mantissa bits
|
|
421
|
+
// e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
|
|
422
|
+
__m256i clamped_exp_i32x8 = _mm256_max_epi32(e4m3_exp_i32x8, _mm256_set1_epi32(1));
|
|
423
|
+
clamped_exp_i32x8 = _mm256_min_epi32(clamped_exp_i32x8, _mm256_set1_epi32(15));
|
|
424
|
+
__m256i is_max_exp_i32x8 = _mm256_cmpeq_epi32(clamped_exp_i32x8, _mm256_set1_epi32(15));
|
|
425
|
+
__m256i max_mantissa_i32x8 = _mm256_blendv_epi8(_mm256_set1_epi32(7), _mm256_set1_epi32(6), is_max_exp_i32x8);
|
|
426
|
+
__m256i normal_mantissa_i32x8 = _mm256_min_epi32(f32_mantissa_i32x8, max_mantissa_i32x8);
|
|
427
|
+
normal_mantissa_i32x8 = _mm256_blendv_epi8(normal_mantissa_i32x8, _mm256_set1_epi32(0x06), overflow_i32x8);
|
|
428
|
+
__m256i normal_e4m3_i32x8 = _mm256_or_si256(
|
|
429
|
+
_mm256_slli_epi32(sign_i32x8, 7),
|
|
430
|
+
_mm256_or_si256(_mm256_slli_epi32(clamped_exp_i32x8, 3), normal_mantissa_i32x8));
|
|
431
|
+
|
|
432
|
+
// Subnormal path: mantissa = round(abs_f32 * 512)
|
|
433
|
+
// If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
|
|
434
|
+
__m256 abs_f32x8 = _mm256_and_ps(f32x8, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
|
|
435
|
+
__m256 scaled_f32x8 = _mm256_mul_ps(abs_f32x8, _mm256_set1_ps(512.0f));
|
|
436
|
+
__m256i subnorm_mantissa_i32x8 = _mm256_cvtps_epi32(scaled_f32x8);
|
|
437
|
+
__m256i promotes_to_normal_i32x8 = _mm256_cmpgt_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(7));
|
|
438
|
+
subnorm_mantissa_i32x8 = _mm256_min_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(7));
|
|
439
|
+
subnorm_mantissa_i32x8 = _mm256_max_epi32(subnorm_mantissa_i32x8, _mm256_setzero_si256());
|
|
440
|
+
__m256i subnorm_e4m3_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 7), subnorm_mantissa_i32x8);
|
|
441
|
+
// When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
|
|
442
|
+
__m256i first_normal_e4m3_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 7), _mm256_set1_epi32(0x08));
|
|
443
|
+
subnorm_e4m3_i32x8 = _mm256_blendv_epi8(subnorm_e4m3_i32x8, first_normal_e4m3_i32x8, promotes_to_normal_i32x8);
|
|
444
|
+
|
|
445
|
+
// Blend: use subnormal result when exp <= 0, else normal
|
|
446
|
+
__m256i e4m3_i32x8 = _mm256_blendv_epi8(normal_e4m3_i32x8, subnorm_e4m3_i32x8, is_subnormal_i32x8);
|
|
447
|
+
|
|
448
|
+
// Pack 8 i32s to 8 unsigned i8s (use unsigned saturation to preserve values 128-255)
|
|
449
|
+
__m128i low_i32x4 = _mm256_castsi256_si128(e4m3_i32x8);
|
|
450
|
+
__m128i high_i32x4 = _mm256_extracti128_si256(e4m3_i32x8, 1);
|
|
451
|
+
__m128i packed_i16x8 = _mm_packus_epi32(low_i32x4, high_i32x4);
|
|
452
|
+
__m128i packed_i8x8 = _mm_packus_epi16(packed_i16x8, packed_i16x8);
|
|
453
|
+
return packed_i8x8;
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
/** @brief Convert 8x f32 → 8x e5m2 via bit manipulation (AVX2).
|
|
457
|
+
* E5M2 format: S EEEEE MM (bias=15). Handles normal, subnormal, and overflow cases.
|
|
458
|
+
* Uses RNE (round to nearest even) for mantissa rounding. */
|
|
459
|
+
NK_INTERNAL __m128i nk_f32x8_to_e5m2x8_haswell_(__m256 f32x8) {
|
|
460
|
+
__m256i bits_i32x8 = _mm256_castps_si256(f32x8);
|
|
461
|
+
__m256i sign_i32x8 = _mm256_srli_epi32(bits_i32x8, 31);
|
|
462
|
+
__m256i f32_exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 23), _mm256_set1_epi32(0xFF));
|
|
463
|
+
|
|
464
|
+
// Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
|
|
465
|
+
// RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
|
|
466
|
+
__m256i significand_i32x8 = _mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)),
|
|
467
|
+
_mm256_set1_epi32(0x00800000)); // Add implicit 1 bit
|
|
468
|
+
__m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(significand_i32x8, 21), _mm256_set1_epi32(1));
|
|
469
|
+
__m256i rounding_bias_i32x8 = _mm256_add_epi32(_mm256_set1_epi32(0x000FFFFF), lsb_i32x8); // half = 0x100000
|
|
470
|
+
__m256i rounded_sig_i32x8 = _mm256_add_epi32(significand_i32x8, rounding_bias_i32x8);
|
|
471
|
+
__m256i carry_i32x8 = _mm256_srli_epi32(rounded_sig_i32x8, 24); // Carry into exponent if bit 24 set
|
|
472
|
+
__m256i f32_mantissa_i32x8 = _mm256_and_si256(_mm256_srli_epi32(rounded_sig_i32x8, 21), _mm256_set1_epi32(0x03));
|
|
473
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
474
|
+
f32_mantissa_i32x8 = _mm256_andnot_si256(_mm256_slli_epi32(carry_i32x8, 31), f32_mantissa_i32x8);
|
|
475
|
+
__m256i e5m2_exp_i32x8 = _mm256_sub_epi32(_mm256_add_epi32(f32_exp_i32x8, carry_i32x8), _mm256_set1_epi32(112));
|
|
476
|
+
|
|
477
|
+
// Detect subnormal (exp <= 0) and overflow (exp > 31)
|
|
478
|
+
__m256i is_subnormal_i32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32(1), e5m2_exp_i32x8);
|
|
479
|
+
__m256i overflow_i32x8 = _mm256_cmpgt_epi32(e5m2_exp_i32x8, _mm256_set1_epi32(31));
|
|
480
|
+
|
|
481
|
+
// Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mantissa=0 = 0x7C)
|
|
482
|
+
__m256i clamped_exp_i32x8 = _mm256_max_epi32(e5m2_exp_i32x8, _mm256_set1_epi32(1));
|
|
483
|
+
clamped_exp_i32x8 = _mm256_min_epi32(clamped_exp_i32x8, _mm256_set1_epi32(31));
|
|
484
|
+
__m256i normal_mantissa_i32x8 = _mm256_blendv_epi8(f32_mantissa_i32x8, _mm256_setzero_si256(), overflow_i32x8);
|
|
485
|
+
__m256i normal_e5m2_i32x8 = _mm256_or_si256(
|
|
486
|
+
_mm256_slli_epi32(sign_i32x8, 7),
|
|
487
|
+
_mm256_or_si256(_mm256_slli_epi32(clamped_exp_i32x8, 2), normal_mantissa_i32x8));
|
|
488
|
+
|
|
489
|
+
// Subnormal path: mantissa = round(abs_f32 * 65536)
|
|
490
|
+
// If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
|
|
491
|
+
__m256 abs_f32x8 = _mm256_and_ps(f32x8, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
|
|
492
|
+
__m256 scaled_f32x8 = _mm256_mul_ps(abs_f32x8, _mm256_set1_ps(65536.0f));
|
|
493
|
+
__m256i subnorm_mantissa_i32x8 = _mm256_cvtps_epi32(scaled_f32x8);
|
|
494
|
+
__m256i promotes_to_normal_i32x8 = _mm256_cmpgt_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(3));
|
|
495
|
+
subnorm_mantissa_i32x8 = _mm256_min_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(3));
|
|
496
|
+
subnorm_mantissa_i32x8 = _mm256_max_epi32(subnorm_mantissa_i32x8, _mm256_setzero_si256());
|
|
497
|
+
__m256i subnorm_e5m2_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 7), subnorm_mantissa_i32x8);
|
|
498
|
+
// When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
|
|
499
|
+
__m256i first_normal_e5m2_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 7), _mm256_set1_epi32(0x04));
|
|
500
|
+
subnorm_e5m2_i32x8 = _mm256_blendv_epi8(subnorm_e5m2_i32x8, first_normal_e5m2_i32x8, promotes_to_normal_i32x8);
|
|
501
|
+
|
|
502
|
+
// Blend: use subnormal result when exp <= 0
|
|
503
|
+
__m256i e5m2_i32x8 = _mm256_blendv_epi8(normal_e5m2_i32x8, subnorm_e5m2_i32x8, is_subnormal_i32x8);
|
|
504
|
+
|
|
505
|
+
// Pack 8 i32s to 8 unsigned i8s (use unsigned saturation to preserve values 128-255)
|
|
506
|
+
__m128i low_i32x4 = _mm256_castsi256_si128(e5m2_i32x8);
|
|
507
|
+
__m128i high_i32x4 = _mm256_extracti128_si256(e5m2_i32x8, 1);
|
|
508
|
+
__m128i packed_i16x8 = _mm_packus_epi32(low_i32x4, high_i32x4);
|
|
509
|
+
__m128i packed_i8x8 = _mm_packus_epi16(packed_i16x8, packed_i16x8);
|
|
510
|
+
return packed_i8x8;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
/** @brief Convert 8x e2m3 → 8x f32 via bit manipulation (AVX2).
|
|
514
|
+
* E2M3 format: S EE MMM (bias=1). F32: sign<<31, (exp+126)<<23, mantissa<<20.
|
|
515
|
+
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁾ × 2⁻³ = mantissa ÷ 8. */
|
|
516
|
+
NK_INTERNAL __m256 nk_e2m3x8_to_f32x8_haswell_(__m128i e2m3_i8x8) {
|
|
517
|
+
__m256i e2m3_i32x8 = _mm256_cvtepu8_epi32(e2m3_i8x8);
|
|
518
|
+
|
|
519
|
+
// Extract fields (only 6 bits used: S EE MMM)
|
|
520
|
+
__m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e2m3_i32x8, 3), _mm256_set1_epi32(0x03));
|
|
521
|
+
__m256i mant_i32x8 = _mm256_and_si256(e2m3_i32x8, _mm256_set1_epi32(0x07));
|
|
522
|
+
|
|
523
|
+
// Build F32 sign bit
|
|
524
|
+
__m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e2m3_i32x8, 5), 31);
|
|
525
|
+
|
|
526
|
+
// Normal path: sign | ((exp+126)<<23) | (mant<<20)
|
|
527
|
+
__m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(126)), 23);
|
|
528
|
+
__m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 20);
|
|
529
|
+
__m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
|
|
530
|
+
|
|
531
|
+
// Subnormal path: value = mantissa / 8.0f, then apply sign
|
|
532
|
+
__m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 8.0f));
|
|
533
|
+
__m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
|
|
534
|
+
|
|
535
|
+
// Blend: if exp==0, use subnormal result; otherwise use normal bits
|
|
536
|
+
__m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
|
|
537
|
+
return _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8, _mm256_castsi256_ps(exp_zero_mask));
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
/** @brief Convert 8x e3m2 → 8x f32 via bit manipulation (AVX2).
|
|
541
|
+
* E3M2 format: S EEE MM (bias=3). F32: sign<<31, (exp+124)<<23, mantissa<<21.
|
|
542
|
+
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻³⁾ × 2⁻² = mantissa ÷ 16. */
|
|
543
|
+
NK_INTERNAL __m256 nk_e3m2x8_to_f32x8_haswell_(__m128i e3m2_i8x8) {
|
|
544
|
+
__m256i e3m2_i32x8 = _mm256_cvtepu8_epi32(e3m2_i8x8);
|
|
545
|
+
|
|
546
|
+
// Extract fields (only 6 bits used: S EEE MM)
|
|
547
|
+
__m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e3m2_i32x8, 2), _mm256_set1_epi32(0x07));
|
|
548
|
+
__m256i mant_i32x8 = _mm256_and_si256(e3m2_i32x8, _mm256_set1_epi32(0x03));
|
|
549
|
+
|
|
550
|
+
// Build F32 sign bit
|
|
551
|
+
__m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e3m2_i32x8, 5), 31);
|
|
552
|
+
|
|
553
|
+
// Normal path: sign | ((exp+124)<<23) | (mant<<21)
|
|
554
|
+
__m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(124)), 23);
|
|
555
|
+
__m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 21);
|
|
556
|
+
__m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
|
|
557
|
+
|
|
558
|
+
// Subnormal path: value = mantissa / 16.0f, then apply sign
|
|
559
|
+
__m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 16.0f));
|
|
560
|
+
__m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
|
|
561
|
+
|
|
562
|
+
// Blend: if exp==0, use subnormal result; otherwise use normal bits
|
|
563
|
+
__m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
|
|
564
|
+
return _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8, _mm256_castsi256_ps(exp_zero_mask));
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
/** @brief Convert 8x f32 → 8x e2m3 via bit manipulation (AVX2).
|
|
568
|
+
* E2M3 format: S EE MMM (bias=1). Handles normal, subnormal, and overflow cases.
|
|
569
|
+
* Subnormals (f32_exp ≤ 126): mantissa = round(abs_f32 * 8), clamped to [0,7]. */
|
|
570
|
+
NK_INTERNAL __m128i nk_f32x8_to_e2m3x8_haswell_(__m256 f32x8) {
|
|
571
|
+
__m256i bits_i32x8 = _mm256_castps_si256(f32x8);
|
|
572
|
+
__m256i sign_i32x8 = _mm256_srli_epi32(bits_i32x8, 31);
|
|
573
|
+
__m256i f32_exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 23), _mm256_set1_epi32(0xFF));
|
|
574
|
+
|
|
575
|
+
// Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
|
|
576
|
+
__m256i significand_i32x8 = _mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)),
|
|
577
|
+
_mm256_set1_epi32(0x00800000)); // Add implicit 1 bit
|
|
578
|
+
__m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(significand_i32x8, 20), _mm256_set1_epi32(1));
|
|
579
|
+
__m256i rounding_bias_i32x8 = _mm256_add_epi32(_mm256_set1_epi32(0x0007FFFF), lsb_i32x8);
|
|
580
|
+
__m256i rounded_sig_i32x8 = _mm256_add_epi32(significand_i32x8, rounding_bias_i32x8);
|
|
581
|
+
__m256i carry_i32x8 = _mm256_srli_epi32(rounded_sig_i32x8, 24); // Carry into exponent if bit 24 set
|
|
582
|
+
__m256i f32_mantissa_i32x8 = _mm256_and_si256(_mm256_srli_epi32(rounded_sig_i32x8, 20), _mm256_set1_epi32(0x07));
|
|
583
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
584
|
+
f32_mantissa_i32x8 = _mm256_andnot_si256(_mm256_slli_epi32(carry_i32x8, 31), f32_mantissa_i32x8);
|
|
585
|
+
__m256i e2m3_exp_i32x8 = _mm256_sub_epi32(_mm256_add_epi32(f32_exp_i32x8, carry_i32x8), _mm256_set1_epi32(126));
|
|
586
|
+
|
|
587
|
+
// Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 3)
|
|
588
|
+
__m256i is_subnormal_i32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32(1), e2m3_exp_i32x8);
|
|
589
|
+
__m256i overflow_i32x8 = _mm256_cmpgt_epi32(e2m3_exp_i32x8, _mm256_set1_epi32(3));
|
|
590
|
+
|
|
591
|
+
// Normal path: clamp exp to [1,3], extract mantissa bits
|
|
592
|
+
__m256i clamped_exp_i32x8 = _mm256_max_epi32(e2m3_exp_i32x8, _mm256_set1_epi32(1));
|
|
593
|
+
clamped_exp_i32x8 = _mm256_min_epi32(clamped_exp_i32x8, _mm256_set1_epi32(3));
|
|
594
|
+
__m256i normal_mantissa_i32x8 = _mm256_blendv_epi8(f32_mantissa_i32x8, _mm256_set1_epi32(0x07), overflow_i32x8);
|
|
595
|
+
__m256i normal_e2m3_i32x8 = _mm256_or_si256(
|
|
596
|
+
_mm256_slli_epi32(sign_i32x8, 5),
|
|
597
|
+
_mm256_or_si256(_mm256_slli_epi32(clamped_exp_i32x8, 3), normal_mantissa_i32x8));
|
|
598
|
+
|
|
599
|
+
// Subnormal path: mantissa = round(abs_f32 * 8)
|
|
600
|
+
// If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
|
|
601
|
+
__m256 abs_f32x8 = _mm256_and_ps(f32x8, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
|
|
602
|
+
__m256 scaled_f32x8 = _mm256_mul_ps(abs_f32x8, _mm256_set1_ps(8.0f));
|
|
603
|
+
__m256i subnorm_mantissa_i32x8 = _mm256_cvtps_epi32(scaled_f32x8);
|
|
604
|
+
__m256i promotes_to_normal_i32x8 = _mm256_cmpgt_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(7));
|
|
605
|
+
subnorm_mantissa_i32x8 = _mm256_min_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(7));
|
|
606
|
+
subnorm_mantissa_i32x8 = _mm256_max_epi32(subnorm_mantissa_i32x8, _mm256_setzero_si256());
|
|
607
|
+
__m256i subnorm_e2m3_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 5), subnorm_mantissa_i32x8);
|
|
608
|
+
// When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
|
|
609
|
+
__m256i first_normal_e2m3_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 5), _mm256_set1_epi32(0x08));
|
|
610
|
+
subnorm_e2m3_i32x8 = _mm256_blendv_epi8(subnorm_e2m3_i32x8, first_normal_e2m3_i32x8, promotes_to_normal_i32x8);
|
|
611
|
+
|
|
612
|
+
// Blend: use subnormal result when exp <= 0, else normal
|
|
613
|
+
__m256i e2m3_i32x8 = _mm256_blendv_epi8(normal_e2m3_i32x8, subnorm_e2m3_i32x8, is_subnormal_i32x8);
|
|
614
|
+
|
|
615
|
+
// Pack 8 i32s to 8 unsigned i8s (use unsigned saturation to preserve values 128-255)
|
|
616
|
+
__m128i low_i32x4 = _mm256_castsi256_si128(e2m3_i32x8);
|
|
617
|
+
__m128i high_i32x4 = _mm256_extracti128_si256(e2m3_i32x8, 1);
|
|
618
|
+
__m128i packed_i16x8 = _mm_packus_epi32(low_i32x4, high_i32x4);
|
|
619
|
+
__m128i packed_i8x8 = _mm_packus_epi16(packed_i16x8, packed_i16x8);
|
|
620
|
+
return packed_i8x8;
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
/** @brief Convert 8x f32 → 8x e3m2 via bit manipulation (AVX2).
|
|
624
|
+
* E3M2 format: S EEE MM (bias=3). Handles normal, subnormal, and overflow cases.
|
|
625
|
+
* Subnormals (f32_exp ≤ 124): mantissa = round(abs_f32 * 16), clamped to [0,3]. */
|
|
626
|
+
NK_INTERNAL __m128i nk_f32x8_to_e3m2x8_haswell_(__m256 f32x8) {
|
|
627
|
+
__m256i bits_i32x8 = _mm256_castps_si256(f32x8);
|
|
628
|
+
__m256i sign_i32x8 = _mm256_srli_epi32(bits_i32x8, 31);
|
|
629
|
+
__m256i f32_exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 23), _mm256_set1_epi32(0xFF));
|
|
630
|
+
|
|
631
|
+
// Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
|
|
632
|
+
__m256i significand_i32x8 = _mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)),
|
|
633
|
+
_mm256_set1_epi32(0x00800000)); // Add implicit 1 bit
|
|
634
|
+
__m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(significand_i32x8, 21), _mm256_set1_epi32(1));
|
|
635
|
+
__m256i rounding_bias_i32x8 = _mm256_add_epi32(_mm256_set1_epi32(0x000FFFFF), lsb_i32x8);
|
|
636
|
+
__m256i rounded_sig_i32x8 = _mm256_add_epi32(significand_i32x8, rounding_bias_i32x8);
|
|
637
|
+
__m256i carry_i32x8 = _mm256_srli_epi32(rounded_sig_i32x8, 24); // Carry into exponent if bit 24 set
|
|
638
|
+
__m256i f32_mantissa_i32x8 = _mm256_and_si256(_mm256_srli_epi32(rounded_sig_i32x8, 21), _mm256_set1_epi32(0x03));
|
|
639
|
+
// If carry, mantissa becomes 0 (we rounded up to next power of 2)
|
|
640
|
+
f32_mantissa_i32x8 = _mm256_andnot_si256(_mm256_slli_epi32(carry_i32x8, 31), f32_mantissa_i32x8);
|
|
641
|
+
__m256i e3m2_exp_i32x8 = _mm256_sub_epi32(_mm256_add_epi32(f32_exp_i32x8, carry_i32x8), _mm256_set1_epi32(124));
|
|
642
|
+
|
|
643
|
+
// Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 7)
|
|
644
|
+
__m256i is_subnormal_i32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32(1), e3m2_exp_i32x8);
|
|
645
|
+
__m256i overflow_i32x8 = _mm256_cmpgt_epi32(e3m2_exp_i32x8, _mm256_set1_epi32(7));
|
|
646
|
+
|
|
647
|
+
// Normal path: clamp exp to [1,7], extract mantissa bits
|
|
648
|
+
__m256i clamped_exp_i32x8 = _mm256_max_epi32(e3m2_exp_i32x8, _mm256_set1_epi32(1));
|
|
649
|
+
clamped_exp_i32x8 = _mm256_min_epi32(clamped_exp_i32x8, _mm256_set1_epi32(7));
|
|
650
|
+
__m256i normal_mantissa_i32x8 = _mm256_blendv_epi8(f32_mantissa_i32x8, _mm256_set1_epi32(0x03), overflow_i32x8);
|
|
651
|
+
__m256i normal_e3m2_i32x8 = _mm256_or_si256(
|
|
652
|
+
_mm256_slli_epi32(sign_i32x8, 5),
|
|
653
|
+
_mm256_or_si256(_mm256_slli_epi32(clamped_exp_i32x8, 2), normal_mantissa_i32x8));
|
|
654
|
+
|
|
655
|
+
// Subnormal path: mantissa = round(abs_f32 * 16)
|
|
656
|
+
// If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
|
|
657
|
+
__m256 abs_f32x8 = _mm256_and_ps(f32x8, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
|
|
658
|
+
__m256 scaled_f32x8 = _mm256_mul_ps(abs_f32x8, _mm256_set1_ps(16.0f));
|
|
659
|
+
__m256i subnorm_mantissa_i32x8 = _mm256_cvtps_epi32(scaled_f32x8);
|
|
660
|
+
__m256i promotes_to_normal_i32x8 = _mm256_cmpgt_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(3));
|
|
661
|
+
subnorm_mantissa_i32x8 = _mm256_min_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(3));
|
|
662
|
+
subnorm_mantissa_i32x8 = _mm256_max_epi32(subnorm_mantissa_i32x8, _mm256_setzero_si256());
|
|
663
|
+
__m256i subnorm_e3m2_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 5), subnorm_mantissa_i32x8);
|
|
664
|
+
// When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
|
|
665
|
+
__m256i first_normal_e3m2_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 5), _mm256_set1_epi32(0x04));
|
|
666
|
+
subnorm_e3m2_i32x8 = _mm256_blendv_epi8(subnorm_e3m2_i32x8, first_normal_e3m2_i32x8, promotes_to_normal_i32x8);
|
|
667
|
+
|
|
668
|
+
// Blend: use subnormal result when exp <= 0
|
|
669
|
+
__m256i e3m2_i32x8 = _mm256_blendv_epi8(normal_e3m2_i32x8, subnorm_e3m2_i32x8, is_subnormal_i32x8);
|
|
670
|
+
|
|
671
|
+
// Pack 8 i32s to 8 unsigned i8s (use unsigned saturation to preserve values 128-255)
|
|
672
|
+
__m128i low_i32x4 = _mm256_castsi256_si128(e3m2_i32x8);
|
|
673
|
+
__m128i high_i32x4 = _mm256_extracti128_si256(e3m2_i32x8, 1);
|
|
674
|
+
__m128i packed_i16x8 = _mm_packus_epi32(low_i32x4, high_i32x4);
|
|
675
|
+
__m128i packed_i8x8 = _mm_packus_epi16(packed_i16x8, packed_i16x8);
|
|
676
|
+
return packed_i8x8;
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
#pragma endregion - Vectorized Conversions
|
|
680
|
+
|
|
681
|
+
#pragma region - Converting Loads and Stores
|
|
682
|
+
|
|
683
|
+
/** @brief Full load for f16 elements (8) with conversion to f32 via F16C. */
|
|
684
|
+
NK_INTERNAL void nk_load_f16x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
685
|
+
dst->ymm_ps = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)src));
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
/** @brief Partial load for f16 elements (up to 8) with conversion to f32 via F16C. */
|
|
689
|
+
NK_INTERNAL void nk_partial_load_f16x8_to_f32x8_haswell_(nk_f16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
690
|
+
nk_b128_vec_t vec;
|
|
691
|
+
nk_partial_load_b16x8_serial_(src, &vec, n);
|
|
692
|
+
dst->ymm_ps = _mm256_cvtph_ps(vec.xmm);
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
/** @brief Full load for bf16 elements (8) with conversion to f32. */
|
|
696
|
+
NK_INTERNAL void nk_load_bf16x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
697
|
+
dst->ymm_ps = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)src));
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
/** @brief Partial load for bf16 elements (up to 8) with conversion to f32. */
|
|
701
|
+
NK_INTERNAL void nk_partial_load_bf16x8_to_f32x8_haswell_(nk_bf16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
702
|
+
nk_b128_vec_t vec;
|
|
703
|
+
nk_partial_load_b16x8_serial_(src, &vec, n);
|
|
704
|
+
dst->ymm_ps = nk_bf16x8_to_f32x8_haswell_(vec.xmm);
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
/** @brief Full load for e4m3 elements (8) with conversion to f32. */
|
|
708
|
+
NK_INTERNAL void nk_load_e4m3x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
709
|
+
dst->ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)src));
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
/** @brief Partial load for e4m3 elements (up to 8) with conversion to f32. */
|
|
713
|
+
NK_INTERNAL void nk_partial_load_e4m3x8_to_f32x8_haswell_(nk_e4m3_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
714
|
+
nk_b64_vec_t vec;
|
|
715
|
+
nk_partial_load_b8x8_serial_(src, &vec, n);
|
|
716
|
+
dst->ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
/** @brief Full load for e5m2 elements (8) with conversion to f32. */
|
|
720
|
+
NK_INTERNAL void nk_load_e5m2x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
721
|
+
dst->ymm_ps = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)src));
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
/** @brief Partial load for e5m2 elements (up to 8) with conversion to f32. */
|
|
725
|
+
NK_INTERNAL void nk_partial_load_e5m2x8_to_f32x8_haswell_(nk_e5m2_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
726
|
+
nk_b64_vec_t vec;
|
|
727
|
+
nk_partial_load_b8x8_serial_(src, &vec, n);
|
|
728
|
+
dst->ymm_ps = nk_e5m2x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
/** @brief Full load for e2m3 elements (8) with conversion to f32. */
|
|
732
|
+
NK_INTERNAL void nk_load_e2m3x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
733
|
+
dst->ymm_ps = nk_e2m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)src));
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
/** @brief Partial load for e2m3 elements (up to 8) with conversion to f32. */
|
|
737
|
+
NK_INTERNAL void nk_partial_load_e2m3x8_to_f32x8_haswell_(nk_e2m3_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
738
|
+
nk_b64_vec_t vec;
|
|
739
|
+
nk_partial_load_b8x8_serial_(src, &vec, n);
|
|
740
|
+
dst->ymm_ps = nk_e2m3x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
/** @brief Full load for e3m2 elements (8) with conversion to f32. */
|
|
744
|
+
NK_INTERNAL void nk_load_e3m2x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
745
|
+
dst->ymm_ps = nk_e3m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)src));
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
/** @brief Partial load for e3m2 elements (up to 8) with conversion to f32. */
|
|
749
|
+
NK_INTERNAL void nk_partial_load_e3m2x8_to_f32x8_haswell_(nk_e3m2_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
750
|
+
nk_b64_vec_t vec;
|
|
751
|
+
nk_partial_load_b8x8_serial_(src, &vec, n);
|
|
752
|
+
dst->ymm_ps = nk_e3m2x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
/** @brief Partial load for i8 elements (up to 8) with conversion to f32. */
|
|
756
|
+
NK_INTERNAL void nk_partial_load_i8x8_to_f32x8_haswell_(nk_i8_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
757
|
+
nk_b64_vec_t vec;
|
|
758
|
+
nk_partial_load_b8x8_serial_(src, &vec, n);
|
|
759
|
+
dst->ymm_ps = nk_i8x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
/** @brief Partial load for u8 elements (up to 8) with conversion to f32. */
|
|
763
|
+
NK_INTERNAL void nk_partial_load_u8x8_to_f32x8_haswell_(nk_u8_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
764
|
+
nk_b64_vec_t vec;
|
|
765
|
+
nk_partial_load_b8x8_serial_(src, &vec, n);
|
|
766
|
+
dst->ymm_ps = nk_u8x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
/** @brief Partial load for i16 elements (up to 8) with conversion to f32. */
|
|
770
|
+
NK_INTERNAL void nk_partial_load_i16x8_to_f32x8_haswell_(nk_i16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
771
|
+
nk_b128_vec_t vec;
|
|
772
|
+
nk_partial_load_b16x8_serial_(src, &vec, n);
|
|
773
|
+
dst->ymm_ps = nk_i16x8_to_f32x8_haswell_(vec.xmm);
|
|
774
|
+
}
|
|
775
|
+
|
|
776
|
+
/** @brief Partial load for u16 elements (up to 8) with conversion to f32. */
|
|
777
|
+
NK_INTERNAL void nk_partial_load_u16x8_to_f32x8_haswell_(nk_u16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
778
|
+
nk_b128_vec_t vec;
|
|
779
|
+
nk_partial_load_b16x8_serial_(src, &vec, n);
|
|
780
|
+
dst->ymm_ps = nk_u16x8_to_f32x8_haswell_(vec.xmm);
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
/** @brief Partial load for i32 elements (up to 8) with conversion to f32. */
|
|
784
|
+
NK_INTERNAL void nk_partial_load_i32x8_to_f32x8_haswell_(nk_i32_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
785
|
+
nk_b256_vec_t vec;
|
|
786
|
+
nk_partial_load_b32x8_serial_(src, &vec, n);
|
|
787
|
+
dst->ymm_ps = nk_i32x8_to_f32x8_haswell_(vec.ymm);
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
/** @brief Partial load for u32 elements (up to 8) with conversion to f32. */
|
|
791
|
+
NK_INTERNAL void nk_partial_load_u32x8_to_f32x8_haswell_(nk_u32_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
792
|
+
nk_b256_vec_t vec;
|
|
793
|
+
nk_partial_load_b32x8_serial_(src, &vec, n);
|
|
794
|
+
dst->ymm_ps = nk_u32x8_to_f32x8_haswell_(vec.ymm);
|
|
795
|
+
}
|
|
796
|
+
|
|
797
|
+
#pragma endregion - Converting Loads and Stores
|
|
798
|
+
|
|
799
|
+
#pragma region - Public API
|
|
800
|
+
|
|
801
|
+
NK_PUBLIC void nk_cast_haswell(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
802
|
+
// Same-type fast path
|
|
803
|
+
if (from_type == to_type) {
|
|
804
|
+
nk_size_t size_bits = nk_dtype_bits(from_type);
|
|
805
|
+
if (size_bits > 0) nk_copy_bytes_(to, from, nk_size_divide_round_up_(n * size_bits, NK_BITS_PER_BYTE));
|
|
806
|
+
return;
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
// Supported types: floats (f32, f16, bf16, e4m3, e5m2, e2m3, e3m2) and integers (i8, u8, i16, u16, i32, u32)
|
|
810
|
+
int from_supported = (from_type == nk_f32_k || from_type == nk_f16_k || from_type == nk_bf16_k ||
|
|
811
|
+
from_type == nk_e4m3_k || from_type == nk_e5m2_k || from_type == nk_e2m3_k ||
|
|
812
|
+
from_type == nk_e3m2_k || from_type == nk_i8_k || from_type == nk_u8_k ||
|
|
813
|
+
from_type == nk_i16_k || from_type == nk_u16_k || from_type == nk_i32_k ||
|
|
814
|
+
from_type == nk_u32_k);
|
|
815
|
+
int to_supported = (to_type == nk_f32_k || to_type == nk_f16_k || to_type == nk_bf16_k || to_type == nk_e4m3_k ||
|
|
816
|
+
to_type == nk_e5m2_k || to_type == nk_e2m3_k || to_type == nk_e3m2_k || to_type == nk_i8_k ||
|
|
817
|
+
to_type == nk_u8_k || to_type == nk_i16_k || to_type == nk_u16_k || to_type == nk_i32_k ||
|
|
818
|
+
to_type == nk_u32_k);
|
|
819
|
+
if (!from_supported || !to_supported) {
|
|
820
|
+
nk_cast_serial(from, from_type, n, to, to_type);
|
|
821
|
+
return;
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
// Fall back to serial for i32/u32↔i32/u32 (f32 intermediate loses precision for large values)
|
|
825
|
+
int from_32bit_int = (from_type == nk_i32_k || from_type == nk_u32_k);
|
|
826
|
+
int to_32bit_int = (to_type == nk_i32_k || to_type == nk_u32_k);
|
|
827
|
+
if (from_32bit_int && to_32bit_int) {
|
|
828
|
+
nk_cast_serial(from, from_type, n, to, to_type);
|
|
829
|
+
return;
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
// Byte steps per 8 elements
|
|
833
|
+
nk_size_t from_step = 8 * nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
|
|
834
|
+
nk_size_t to_step = 8 * nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
|
|
835
|
+
|
|
836
|
+
nk_u8_t const *from_ptr = (nk_u8_t const *)from;
|
|
837
|
+
nk_u8_t *to_ptr = (nk_u8_t *)to;
|
|
838
|
+
nk_size_t batches = n / 8;
|
|
839
|
+
nk_size_t tail = n % 8;
|
|
840
|
+
nk_b256_vec_t hub;
|
|
841
|
+
|
|
842
|
+
for (nk_size_t idx = 0; idx < batches; ++idx, from_ptr += from_step, to_ptr += to_step) {
|
|
843
|
+
// Upcast to f32x8
|
|
844
|
+
if (from_type == nk_f32_k) hub.ymm_ps = _mm256_loadu_ps((float const *)from_ptr);
|
|
845
|
+
else if (from_type == nk_f16_k) hub.ymm_ps = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)from_ptr));
|
|
846
|
+
else if (from_type == nk_bf16_k)
|
|
847
|
+
hub.ymm_ps = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)from_ptr));
|
|
848
|
+
else if (from_type == nk_e4m3_k)
|
|
849
|
+
hub.ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
|
|
850
|
+
else if (from_type == nk_e5m2_k)
|
|
851
|
+
hub.ymm_ps = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
|
|
852
|
+
else if (from_type == nk_e2m3_k)
|
|
853
|
+
hub.ymm_ps = nk_e2m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
|
|
854
|
+
else if (from_type == nk_e3m2_k)
|
|
855
|
+
hub.ymm_ps = nk_e3m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
|
|
856
|
+
else if (from_type == nk_i8_k)
|
|
857
|
+
hub.ymm_ps = nk_i8x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
|
|
858
|
+
else if (from_type == nk_u8_k)
|
|
859
|
+
hub.ymm_ps = nk_u8x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
|
|
860
|
+
else if (from_type == nk_i16_k)
|
|
861
|
+
hub.ymm_ps = nk_i16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)from_ptr));
|
|
862
|
+
else if (from_type == nk_u16_k)
|
|
863
|
+
hub.ymm_ps = nk_u16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)from_ptr));
|
|
864
|
+
else if (from_type == nk_i32_k)
|
|
865
|
+
hub.ymm_ps = nk_i32x8_to_f32x8_haswell_(_mm256_loadu_si256((__m256i const *)from_ptr));
|
|
866
|
+
else if (from_type == nk_u32_k)
|
|
867
|
+
hub.ymm_ps = nk_u32x8_to_f32x8_haswell_(_mm256_loadu_si256((__m256i const *)from_ptr));
|
|
868
|
+
|
|
869
|
+
// Downcast from f32x8
|
|
870
|
+
if (to_type == nk_f32_k) _mm256_storeu_ps((float *)to_ptr, hub.ymm_ps);
|
|
871
|
+
else if (to_type == nk_f16_k)
|
|
872
|
+
_mm_storeu_si128((__m128i *)to_ptr, _mm256_cvtps_ph(hub.ymm_ps, _MM_FROUND_TO_NEAREST_INT));
|
|
873
|
+
else if (to_type == nk_bf16_k) _mm_storeu_si128((__m128i *)to_ptr, nk_f32x8_to_bf16x8_haswell_(hub.ymm_ps));
|
|
874
|
+
else if (to_type == nk_e4m3_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_e4m3x8_haswell_(hub.ymm_ps));
|
|
875
|
+
else if (to_type == nk_e5m2_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_e5m2x8_haswell_(hub.ymm_ps));
|
|
876
|
+
else if (to_type == nk_e2m3_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_e2m3x8_haswell_(hub.ymm_ps));
|
|
877
|
+
else if (to_type == nk_e3m2_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_e3m2x8_haswell_(hub.ymm_ps));
|
|
878
|
+
else if (to_type == nk_i8_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_i8x8_haswell_(hub.ymm_ps));
|
|
879
|
+
else if (to_type == nk_u8_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_u8x8_haswell_(hub.ymm_ps));
|
|
880
|
+
else if (to_type == nk_i16_k) _mm_storeu_si128((__m128i *)to_ptr, nk_f32x8_to_i16x8_haswell_(hub.ymm_ps));
|
|
881
|
+
else if (to_type == nk_u16_k) _mm_storeu_si128((__m128i *)to_ptr, nk_f32x8_to_u16x8_haswell_(hub.ymm_ps));
|
|
882
|
+
else if (to_type == nk_i32_k) _mm256_storeu_si256((__m256i *)to_ptr, nk_f32x8_to_i32x8_haswell_(hub.ymm_ps));
|
|
883
|
+
else if (to_type == nk_u32_k) _mm256_storeu_si256((__m256i *)to_ptr, nk_f32x8_to_u32x8_haswell_(hub.ymm_ps));
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
// Handle tail with partial loads/stores
|
|
887
|
+
if (tail) {
|
|
888
|
+
// Upcast tail to f32x8
|
|
889
|
+
if (from_type == nk_f32_k) nk_partial_load_b32x8_serial_(from_ptr, &hub, tail);
|
|
890
|
+
else if (from_type == nk_f16_k) nk_partial_load_f16x8_to_f32x8_haswell_((nk_f16_t const *)from_ptr, &hub, tail);
|
|
891
|
+
else if (from_type == nk_bf16_k)
|
|
892
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_((nk_bf16_t const *)from_ptr, &hub, tail);
|
|
893
|
+
else if (from_type == nk_e4m3_k)
|
|
894
|
+
nk_partial_load_e4m3x8_to_f32x8_haswell_((nk_e4m3_t const *)from_ptr, &hub, tail);
|
|
895
|
+
else if (from_type == nk_e5m2_k)
|
|
896
|
+
nk_partial_load_e5m2x8_to_f32x8_haswell_((nk_e5m2_t const *)from_ptr, &hub, tail);
|
|
897
|
+
else if (from_type == nk_e2m3_k)
|
|
898
|
+
nk_partial_load_e2m3x8_to_f32x8_haswell_((nk_e2m3_t const *)from_ptr, &hub, tail);
|
|
899
|
+
else if (from_type == nk_e3m2_k)
|
|
900
|
+
nk_partial_load_e3m2x8_to_f32x8_haswell_((nk_e3m2_t const *)from_ptr, &hub, tail);
|
|
901
|
+
else if (from_type == nk_i8_k) nk_partial_load_i8x8_to_f32x8_haswell_((nk_i8_t const *)from_ptr, &hub, tail);
|
|
902
|
+
else if (from_type == nk_u8_k) nk_partial_load_u8x8_to_f32x8_haswell_((nk_u8_t const *)from_ptr, &hub, tail);
|
|
903
|
+
else if (from_type == nk_i16_k) nk_partial_load_i16x8_to_f32x8_haswell_((nk_i16_t const *)from_ptr, &hub, tail);
|
|
904
|
+
else if (from_type == nk_u16_k) nk_partial_load_u16x8_to_f32x8_haswell_((nk_u16_t const *)from_ptr, &hub, tail);
|
|
905
|
+
else if (from_type == nk_i32_k) nk_partial_load_i32x8_to_f32x8_haswell_((nk_i32_t const *)from_ptr, &hub, tail);
|
|
906
|
+
else if (from_type == nk_u32_k) nk_partial_load_u32x8_to_f32x8_haswell_((nk_u32_t const *)from_ptr, &hub, tail);
|
|
907
|
+
|
|
908
|
+
// Downcast and store tail
|
|
909
|
+
if (to_type == nk_f32_k) nk_partial_store_b32x8_serial_(&hub, to_ptr, tail);
|
|
910
|
+
else if (to_type == nk_f16_k) {
|
|
911
|
+
hub.xmms[0] = _mm256_cvtps_ph(hub.ymm_ps, _MM_FROUND_TO_NEAREST_INT);
|
|
912
|
+
nk_partial_store_b16x8_serial_((nk_b128_vec_t *)&hub, to_ptr, tail);
|
|
913
|
+
}
|
|
914
|
+
else if (to_type == nk_bf16_k) {
|
|
915
|
+
hub.xmms[0] = nk_f32x8_to_bf16x8_haswell_(hub.ymm_ps);
|
|
916
|
+
nk_partial_store_b16x8_serial_((nk_b128_vec_t *)&hub, to_ptr, tail);
|
|
917
|
+
}
|
|
918
|
+
else if (to_type == nk_e4m3_k) {
|
|
919
|
+
hub.xmms[0] = nk_f32x8_to_e4m3x8_haswell_(hub.ymm_ps);
|
|
920
|
+
nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
|
|
921
|
+
}
|
|
922
|
+
else if (to_type == nk_e5m2_k) {
|
|
923
|
+
hub.xmms[0] = nk_f32x8_to_e5m2x8_haswell_(hub.ymm_ps);
|
|
924
|
+
nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
|
|
925
|
+
}
|
|
926
|
+
else if (to_type == nk_e2m3_k) {
|
|
927
|
+
hub.xmms[0] = nk_f32x8_to_e2m3x8_haswell_(hub.ymm_ps);
|
|
928
|
+
nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
|
|
929
|
+
}
|
|
930
|
+
else if (to_type == nk_e3m2_k) {
|
|
931
|
+
hub.xmms[0] = nk_f32x8_to_e3m2x8_haswell_(hub.ymm_ps);
|
|
932
|
+
nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
|
|
933
|
+
}
|
|
934
|
+
else if (to_type == nk_i8_k) {
|
|
935
|
+
hub.xmms[0] = nk_f32x8_to_i8x8_haswell_(hub.ymm_ps);
|
|
936
|
+
nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
|
|
937
|
+
}
|
|
938
|
+
else if (to_type == nk_u8_k) {
|
|
939
|
+
hub.xmms[0] = nk_f32x8_to_u8x8_haswell_(hub.ymm_ps);
|
|
940
|
+
nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
|
|
941
|
+
}
|
|
942
|
+
else if (to_type == nk_i16_k) {
|
|
943
|
+
hub.xmms[0] = nk_f32x8_to_i16x8_haswell_(hub.ymm_ps);
|
|
944
|
+
nk_partial_store_b16x8_serial_((nk_b128_vec_t *)&hub, to_ptr, tail);
|
|
945
|
+
}
|
|
946
|
+
else if (to_type == nk_u16_k) {
|
|
947
|
+
hub.xmms[0] = nk_f32x8_to_u16x8_haswell_(hub.ymm_ps);
|
|
948
|
+
nk_partial_store_b16x8_serial_((nk_b128_vec_t *)&hub, to_ptr, tail);
|
|
949
|
+
}
|
|
950
|
+
else if (to_type == nk_i32_k) {
|
|
951
|
+
hub.ymm = nk_f32x8_to_i32x8_haswell_(hub.ymm_ps);
|
|
952
|
+
nk_partial_store_b32x8_serial_(&hub, to_ptr, tail);
|
|
953
|
+
}
|
|
954
|
+
else if (to_type == nk_u32_k) {
|
|
955
|
+
hub.ymm = nk_f32x8_to_u32x8_haswell_(hub.ymm_ps);
|
|
956
|
+
nk_partial_store_b32x8_serial_(&hub, to_ptr, tail);
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
#pragma endregion - Public API
|
|
962
|
+
|
|
963
|
+
#if defined(__clang__)
|
|
964
|
+
#pragma clang attribute pop
|
|
965
|
+
#elif defined(__GNUC__)
|
|
966
|
+
#pragma GCC pop_options
|
|
967
|
+
#endif
|
|
968
|
+
|
|
969
|
+
#if defined(__cplusplus)
|
|
970
|
+
} // extern "C"
|
|
971
|
+
#endif
|
|
972
|
+
|
|
973
|
+
#endif // NK_TARGET_HASWELL
|
|
974
|
+
#endif // NK_TARGET_X86_
|
|
975
|
+
#endif // NK_CAST_HASWELL_H
|