numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,13 +8,13 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section skylake_cast_instructions AVX-512 Conversion Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_cvtph_ps
|
|
13
|
-
* _mm512_cvtps_ph
|
|
14
|
-
* _mm512_cvtps_epi32
|
|
15
|
-
* _mm512_cvtepi32_ps
|
|
16
|
-
* _mm512_cvtepi32_epi16
|
|
17
|
-
* _mm512_cvtsepi32_epi8
|
|
11
|
+
* Intrinsic Instruction SKL ICL Genoa
|
|
12
|
+
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p05 4cy @ p01
|
|
13
|
+
* _mm512_cvtps_ph VCVTPS2PH (YMM, ZMM, imm) 5cy @ p05 5cy @ p05 4cy @ p01
|
|
14
|
+
* _mm512_cvtps_epi32 VCVTPS2DQ (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
|
|
15
|
+
* _mm512_cvtepi32_ps VCVTDQ2PS (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
|
|
16
|
+
* _mm512_cvtepi32_epi16 VPMOVDW (YMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
|
|
17
|
+
* _mm512_cvtsepi32_epi8 VPMOVSDB (XMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
|
|
18
18
|
*
|
|
19
19
|
* F16 conversions use hardware F16C via VCVTPH2PS/VCVTPS2PH. BF16 lacks hardware support on Skylake,
|
|
20
20
|
* requiring emulation via VPMOVZXWD + VPSLLD for bf16-to-f32, achieving ~4cy total. FP8 (E4M3/E5M2)
|
|
@@ -41,7 +41,7 @@ extern "C" {
|
|
|
41
41
|
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
42
42
|
#endif
|
|
43
43
|
|
|
44
|
-
#pragma region
|
|
44
|
+
#pragma region Type Punned Loads and Stores
|
|
45
45
|
|
|
46
46
|
/** @brief Type-agnostic 512-bit full load (Skylake AVX-512). */
|
|
47
47
|
NK_INTERNAL void nk_load_b512_skylake_(void const *src, nk_b512_vec_t *dst) { dst->zmm = _mm512_loadu_si512(src); }
|
|
@@ -132,9 +132,32 @@ NK_INTERNAL void nk_partial_store_b64x4_skylake_(nk_b256_vec_t const *src, void
|
|
|
132
132
|
_mm256_mask_storeu_epi64(dst, mask, src->ymm);
|
|
133
133
|
}
|
|
134
134
|
|
|
135
|
-
|
|
135
|
+
/** @brief Type-agnostic full store for 512-bit vector (Skylake AVX-512). */
|
|
136
|
+
NK_INTERNAL void nk_store_b512_skylake_(nk_b512_vec_t const *src, void *dst) {
|
|
137
|
+
_mm512_storeu_si512((__m512i *)dst, src->zmm);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
/** @brief Type-agnostic partial store for 16-bit elements (32 elements max) from 512-bit vector (Skylake AVX-512). */
|
|
141
|
+
NK_INTERNAL void nk_partial_store_b16x32_skylake_(nk_b512_vec_t const *src, void *dst, nk_size_t n) {
|
|
142
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
143
|
+
_mm512_mask_storeu_epi16(dst, mask, src->zmm);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
/** @brief Type-agnostic partial store for 8-bit elements (64 elements max) from 512-bit vector (Skylake AVX-512). */
|
|
147
|
+
NK_INTERNAL void nk_partial_store_b8x64_skylake_(nk_b512_vec_t const *src, void *dst, nk_size_t n) {
|
|
148
|
+
__mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n);
|
|
149
|
+
_mm512_mask_storeu_epi8(dst, mask, src->zmm);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
/** @brief Type-agnostic partial store for 64-bit elements (8 elements max) from 512-bit vector (Skylake AVX-512). */
|
|
153
|
+
NK_INTERNAL void nk_partial_store_b64x8_skylake_(nk_b512_vec_t const *src, void *dst, nk_size_t n) {
|
|
154
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)n);
|
|
155
|
+
_mm512_mask_storeu_epi64(dst, mask, src->zmm);
|
|
156
|
+
}
|
|
136
157
|
|
|
137
|
-
#pragma
|
|
158
|
+
#pragma endregion Type Punned Loads and Stores
|
|
159
|
+
|
|
160
|
+
#pragma region Vectorized Conversions
|
|
138
161
|
|
|
139
162
|
/** @brief Convert 16x bf16 → 16x f32 (Skylake AVX-512). */
|
|
140
163
|
NK_INTERNAL __m512 nk_bf16x16_to_f32x16_skylake_(__m256i a) {
|
|
@@ -169,18 +192,20 @@ NK_INTERNAL __m512 nk_e4m3x16_to_f32x16_skylake_(__m128i e4m3_i8x16) {
|
|
|
169
192
|
__m512 result_f32x16 = _mm512_castsi512_ps(
|
|
170
193
|
_mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
|
|
171
194
|
|
|
172
|
-
// Subnormal fix:
|
|
195
|
+
// Subnormal fix: vpermps from 8-entry LUT (repeated to fill 16 lanes)
|
|
173
196
|
__mmask16 is_subnormal = _mm512_testn_epi32_mask(e4m3_i32x16, _mm512_set1_epi32(0x78));
|
|
174
|
-
__m512
|
|
197
|
+
__m512 subnorm_lut_f32x16 = _mm512_setr_ps( //
|
|
198
|
+
0, 1.0f / 512, 2.0f / 512, 3.0f / 512, 4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512, //
|
|
199
|
+
0, 1.0f / 512, 2.0f / 512, 3.0f / 512, 4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512);
|
|
200
|
+
__m512 subnorm_abs_f32x16 = _mm512_permutexvar_ps(mantissa_i32x16, subnorm_lut_f32x16);
|
|
175
201
|
result_f32x16 = _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16,
|
|
176
202
|
_mm512_castsi512_ps(sign_i32x16));
|
|
177
203
|
|
|
178
|
-
// NaN
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
return _mm512_mask_blend_ps(is_nan, result_f32x16, _mm512_castsi512_ps(nan_bits));
|
|
204
|
+
// NaN: E4M3FN has NaN only at magnitude 0x7F (single mask comparison)
|
|
205
|
+
__m512i lower7_i32x16 = _mm512_and_si512(e4m3_i32x16, _mm512_set1_epi32(0x7F));
|
|
206
|
+
__mmask16 is_nan = _mm512_cmpeq_epi32_mask(lower7_i32x16, _mm512_set1_epi32(0x7F));
|
|
207
|
+
__m512i nan_i32x16 = _mm512_or_si512(sign_i32x16, _mm512_set1_epi32(0x7FC00000));
|
|
208
|
+
return _mm512_mask_blend_ps(is_nan, result_f32x16, _mm512_castsi512_ps(nan_i32x16));
|
|
184
209
|
}
|
|
185
210
|
|
|
186
211
|
/** @brief Convert 16x e5m2 → 16x f32 via bit manipulation (AVX-512).
|
|
@@ -561,9 +586,43 @@ NK_INTERNAL __m256i nk_f64x8_to_u32x8_skylake_(__m512d f64x8) {
|
|
|
561
586
|
return _mm512_cvtpd_epu32(clamped);
|
|
562
587
|
}
|
|
563
588
|
|
|
564
|
-
|
|
589
|
+
/**
|
|
590
|
+
* @brief Convert 64x E2M3 → 64x I8 using VPSHUFB LUT (Skylake AVX-512).
|
|
591
|
+
*
|
|
592
|
+
* E2M3 format: [sign:1][magnitude:5] where magnitude indexes a 32-entry LUT
|
|
593
|
+
* that produces the scaled integer value. Sign bit negates the result.
|
|
594
|
+
* The 32-entry LUT is split into two 16-entry halves for VPSHUFB (which
|
|
595
|
+
* indexes within 16-byte lanes). Bit 4 of the magnitude selects the half.
|
|
596
|
+
*/
|
|
597
|
+
NK_INTERNAL __m512i nk_e2m3x64_to_i8x64_skylake_(__m512i raw_i8x64) {
|
|
598
|
+
// lut_magnitude[0..15] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}
|
|
599
|
+
// lut_magnitude[16..31] = {32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120}
|
|
600
|
+
// _mm512_set4_epi32(d3, d2, d1, d0) fills bytes [0..3]=d0, [4..7]=d1, [8..11]=d2, [12..15]=d3
|
|
601
|
+
// per 128-bit lane, matching VPSHUFB's per-lane indexing.
|
|
602
|
+
__m512i lut_low_i8x64 = _mm512_set4_epi32( //
|
|
603
|
+
0x1E1C1A18, 0x16141210, 0x0E0C0A08, 0x06040200);
|
|
604
|
+
__m512i lut_high_i8x64 = _mm512_set4_epi32( //
|
|
605
|
+
0x78706860, 0x58504840, 0x3C383430, 0x2C282420);
|
|
606
|
+
|
|
607
|
+
__m512i magnitude_i8x64 = _mm512_and_si512(raw_i8x64, _mm512_set1_epi8(0x1F));
|
|
608
|
+
__m512i index_i8x64 = _mm512_and_si512(magnitude_i8x64, _mm512_set1_epi8(0x0F));
|
|
609
|
+
|
|
610
|
+
__m512i val_low_i8x64 = _mm512_shuffle_epi8(lut_low_i8x64, index_i8x64);
|
|
611
|
+
__m512i val_high_i8x64 = _mm512_shuffle_epi8(lut_high_i8x64, index_i8x64);
|
|
612
|
+
|
|
613
|
+
// Select high half when bit 4 of magnitude is set (magnitude >= 16)
|
|
614
|
+
__mmask64 use_high_mask = _mm512_test_epi8_mask(magnitude_i8x64, _mm512_set1_epi8(0x10));
|
|
615
|
+
__m512i val_i8x64 = _mm512_mask_blend_epi8(use_high_mask, val_low_i8x64, val_high_i8x64);
|
|
616
|
+
|
|
617
|
+
// Negate if sign bit (bit 5) is set
|
|
618
|
+
__mmask64 sign_mask = _mm512_test_epi8_mask(raw_i8x64, _mm512_set1_epi8(0x20));
|
|
619
|
+
__m512i negated_i8x64 = _mm512_sub_epi8(_mm512_setzero_si512(), val_i8x64);
|
|
620
|
+
return _mm512_mask_blend_epi8(sign_mask, val_i8x64, negated_i8x64);
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
#pragma endregion Vectorized Conversions
|
|
565
624
|
|
|
566
|
-
#pragma region
|
|
625
|
+
#pragma region Converting Loads and Stores
|
|
567
626
|
|
|
568
627
|
/** @brief Load 16 f16 values and convert to 16 f32 (Skylake AVX-512). */
|
|
569
628
|
NK_INTERNAL void nk_load_f16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
|
|
@@ -637,9 +696,9 @@ NK_INTERNAL void nk_partial_load_e3m2x16_to_f32x16_skylake_(void const *src, nk_
|
|
|
637
696
|
dst->zmm_ps = nk_e3m2x16_to_f32x16_skylake_(e3m2_partial.xmm);
|
|
638
697
|
}
|
|
639
698
|
|
|
640
|
-
#pragma endregion
|
|
699
|
+
#pragma endregion Converting Loads and Stores
|
|
641
700
|
|
|
642
|
-
#pragma region
|
|
701
|
+
#pragma region Public API
|
|
643
702
|
|
|
644
703
|
NK_PUBLIC void nk_cast_skylake(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
645
704
|
// Same-type fast path
|
|
@@ -839,7 +898,7 @@ NK_PUBLIC void nk_cast_skylake(void const *from, nk_dtype_t from_type, nk_size_t
|
|
|
839
898
|
nk_cast_serial(from, from_type, n, to, to_type);
|
|
840
899
|
}
|
|
841
900
|
|
|
842
|
-
#pragma endregion
|
|
901
|
+
#pragma endregion Public API
|
|
843
902
|
|
|
844
903
|
#if defined(__clang__)
|
|
845
904
|
#pragma clang attribute pop
|