numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Elementwise Arithmetic for Sapphire Rapids.
|
|
3
|
+
* @file include/numkong/each/sapphire.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/each.h
|
|
8
|
+
*
|
|
9
|
+
* @section sapphire_elementwise_instructions Relevant Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Sapphire Genoa
|
|
12
|
+
* _mm512_add_ph VADDPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
|
|
13
|
+
* _mm512_mul_ph VMULPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
|
|
14
|
+
* _mm512_fmadd_ph VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
15
|
+
* _mm512_cvtepi16_ph VCVTW2PH (ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
16
|
+
* _mm512_cvtph_epi16 VCVTPH2W (ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
17
|
+
* _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
|
|
18
|
+
* _mm512_cvtsepi16_epi8 VPMOVSWB (YMM, ZMM) 4cy @ p5 4cy @ p12
|
|
19
|
+
* _mm512_packus_epi16 VPACKUSWB (ZMM, ZMM, ZMM) 1cy @ p5 1cy @ p12
|
|
20
|
+
* _mm256_add_ph VADDPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
|
|
21
|
+
* _mm512_maskz_loadu_epi16 VMOVDQU16 (ZMM {K}, M512) 7cy @ p23 7cy @ p23
|
|
22
|
+
* _mm512_mask_storeu_epi16 VMOVDQU16 (M512 {K}, ZMM) 4cy @ p4 4cy @ p4
|
|
23
|
+
*/
|
|
24
|
+
#ifndef NK_EACH_SAPPHIRE_H
|
|
25
|
+
#define NK_EACH_SAPPHIRE_H
|
|
26
|
+
|
|
27
|
+
#if NK_TARGET_X86_
|
|
28
|
+
#if NK_TARGET_SAPPHIRE
|
|
29
|
+
|
|
30
|
+
#include "numkong/types.h"
|
|
31
|
+
#include "numkong/cast/sapphire.h" // `nk_f32_to_f16_sapphire`
|
|
32
|
+
|
|
33
|
+
#if defined(__cplusplus)
|
|
34
|
+
extern "C" {
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
#if defined(__clang__)
|
|
38
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512fp16,f16c,fma,bmi,bmi2"))), \
|
|
39
|
+
apply_to = function)
|
|
40
|
+
#elif defined(__GNUC__)
|
|
41
|
+
#pragma GCC push_options
|
|
42
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
NK_PUBLIC void nk_each_sum_f16_sapphire(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result) {
|
|
46
|
+
__mmask32 mask = 0xFFFFFFFF;
|
|
47
|
+
__m512h a_f16_vec, b_f16_vec;
|
|
48
|
+
__m512h sum_f16_vec;
|
|
49
|
+
nk_each_sum_f16_sapphire_cycle:
|
|
50
|
+
if (n < 32) {
|
|
51
|
+
mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
52
|
+
a_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
|
|
53
|
+
b_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
|
|
54
|
+
n = 0;
|
|
55
|
+
}
|
|
56
|
+
else {
|
|
57
|
+
a_f16_vec = _mm512_loadu_ph(a);
|
|
58
|
+
b_f16_vec = _mm512_loadu_ph(b);
|
|
59
|
+
a += 32, b += 32, n -= 32;
|
|
60
|
+
}
|
|
61
|
+
sum_f16_vec = _mm512_add_ph(a_f16_vec, b_f16_vec);
|
|
62
|
+
_mm512_mask_storeu_epi16(result, mask, _mm512_castph_si512(sum_f16_vec));
|
|
63
|
+
result += 32;
|
|
64
|
+
if (n) goto nk_each_sum_f16_sapphire_cycle;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
NK_PUBLIC void nk_each_scale_u8_sapphire(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
68
|
+
nk_u8_t *result) {
|
|
69
|
+
short alpha_short, beta_short;
|
|
70
|
+
nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
|
|
71
|
+
nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
|
|
72
|
+
__mmask64 mask = 0xFFFFFFFFFFFFFFFFull;
|
|
73
|
+
__m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
|
|
74
|
+
__m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
|
|
75
|
+
__m512i a_u8x64, result_u8x64;
|
|
76
|
+
__m512h a_low_f16x32, a_high_f16x32;
|
|
77
|
+
__m512h result_low_f16x32, result_high_f16x32;
|
|
78
|
+
__m512i result_low_i16x32, result_high_i16x32;
|
|
79
|
+
nk_each_scale_u8_sapphire_cycle:
|
|
80
|
+
if (n < 64) {
|
|
81
|
+
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
|
|
82
|
+
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
83
|
+
n = 0;
|
|
84
|
+
}
|
|
85
|
+
else {
|
|
86
|
+
a_u8x64 = _mm512_loadu_epi8(a);
|
|
87
|
+
a += 64, n -= 64;
|
|
88
|
+
}
|
|
89
|
+
// Upcast:
|
|
90
|
+
a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8x64, _mm512_setzero_si512()));
|
|
91
|
+
a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8x64, _mm512_setzero_si512()));
|
|
92
|
+
// Scale:
|
|
93
|
+
result_low_f16x32 = _mm512_fmadd_ph(a_low_f16x32, alpha_f16x32, beta_f16x32);
|
|
94
|
+
result_high_f16x32 = _mm512_fmadd_ph(a_high_f16x32, alpha_f16x32, beta_f16x32);
|
|
95
|
+
// Downcast:
|
|
96
|
+
result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
|
|
97
|
+
result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
|
|
98
|
+
result_u8x64 = _mm512_packus_epi16(result_low_i16x32, result_high_i16x32);
|
|
99
|
+
_mm512_mask_storeu_epi8(result, mask, result_u8x64);
|
|
100
|
+
result += 64;
|
|
101
|
+
if (n) goto nk_each_scale_u8_sapphire_cycle;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
NK_PUBLIC void nk_each_blend_u8_sapphire( //
|
|
105
|
+
nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, //
|
|
106
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
|
|
107
|
+
|
|
108
|
+
nk_f32_t alpha_val = *alpha;
|
|
109
|
+
nk_f32_t beta_val = *beta;
|
|
110
|
+
|
|
111
|
+
// There are several special cases we may want to implement:
|
|
112
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
113
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
114
|
+
// In this case we can avoid expensive multiplications.
|
|
115
|
+
nk_each_sum_u8_icelake(a, b, n, result);
|
|
116
|
+
return;
|
|
117
|
+
}
|
|
118
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
119
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
120
|
+
// In this case we can avoid half of the load instructions.
|
|
121
|
+
nk_f32_t zero = 0;
|
|
122
|
+
if (beta_val == 0) { nk_each_scale_u8_sapphire(a, n, alpha, &zero, result); }
|
|
123
|
+
else { nk_each_scale_u8_sapphire(b, n, beta, &zero, result); }
|
|
124
|
+
return;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// The general case.
|
|
128
|
+
short alpha_short, beta_short;
|
|
129
|
+
nk_f32_to_f16_sapphire(&alpha_val, (nk_f16_t *)&alpha_short);
|
|
130
|
+
nk_f32_to_f16_sapphire(&beta_val, (nk_f16_t *)&beta_short);
|
|
131
|
+
__mmask64 mask = 0xFFFFFFFFFFFFFFFFull;
|
|
132
|
+
__m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
|
|
133
|
+
__m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
|
|
134
|
+
__m512i a_u8x64, b_u8x64, result_u8x64;
|
|
135
|
+
__m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
|
|
136
|
+
__m512h a_scaled_low_f16x32, a_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
|
|
137
|
+
__m512i result_low_i16x32, result_high_i16x32;
|
|
138
|
+
nk_each_blend_u8_sapphire_cycle:
|
|
139
|
+
if (n < 64) {
|
|
140
|
+
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
|
|
141
|
+
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
142
|
+
b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
143
|
+
n = 0;
|
|
144
|
+
}
|
|
145
|
+
else {
|
|
146
|
+
a_u8x64 = _mm512_loadu_epi8(a);
|
|
147
|
+
b_u8x64 = _mm512_loadu_epi8(b);
|
|
148
|
+
a += 64, b += 64, n -= 64;
|
|
149
|
+
}
|
|
150
|
+
// Upcast:
|
|
151
|
+
a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8x64, _mm512_setzero_si512()));
|
|
152
|
+
a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8x64, _mm512_setzero_si512()));
|
|
153
|
+
b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(b_u8x64, _mm512_setzero_si512()));
|
|
154
|
+
b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(b_u8x64, _mm512_setzero_si512()));
|
|
155
|
+
// Scale:
|
|
156
|
+
a_scaled_low_f16x32 = _mm512_mul_ph(a_low_f16x32, alpha_f16x32);
|
|
157
|
+
a_scaled_high_f16x32 = _mm512_mul_ph(a_high_f16x32, alpha_f16x32);
|
|
158
|
+
// Add:
|
|
159
|
+
result_low_f16x32 = _mm512_fmadd_ph(b_low_f16x32, beta_f16x32, a_scaled_low_f16x32);
|
|
160
|
+
result_high_f16x32 = _mm512_fmadd_ph(b_high_f16x32, beta_f16x32, a_scaled_high_f16x32);
|
|
161
|
+
// Downcast:
|
|
162
|
+
result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
|
|
163
|
+
result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
|
|
164
|
+
result_u8x64 = _mm512_packus_epi16(result_low_i16x32, result_high_i16x32);
|
|
165
|
+
_mm512_mask_storeu_epi8(result, mask, result_u8x64);
|
|
166
|
+
result += 64;
|
|
167
|
+
if (n) goto nk_each_blend_u8_sapphire_cycle;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
NK_PUBLIC void nk_each_scale_i8_sapphire(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
171
|
+
nk_i8_t *result) {
|
|
172
|
+
short alpha_short, beta_short;
|
|
173
|
+
nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
|
|
174
|
+
nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
|
|
175
|
+
__mmask64 mask = 0xFFFFFFFFFFFFFFFFull;
|
|
176
|
+
__m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
|
|
177
|
+
__m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
|
|
178
|
+
__m256i a_low_i8x32, a_high_i8x32;
|
|
179
|
+
__m512i result_i8x64;
|
|
180
|
+
__m512h a_low_f16x32, a_high_f16x32;
|
|
181
|
+
__m512h result_low_f16x32, result_high_f16x32;
|
|
182
|
+
__m512i result_low_i16x32, result_high_i16x32;
|
|
183
|
+
nk_each_scale_i8_sapphire_cycle:
|
|
184
|
+
if (n < 64) {
|
|
185
|
+
// Tail: use masked 512-bit load and extract (runs once)
|
|
186
|
+
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
|
|
187
|
+
__m512i a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
188
|
+
a_low_i8x32 = _mm512_castsi512_si256(a_i8x64);
|
|
189
|
+
a_high_i8x32 = _mm512_extracti64x4_epi64(a_i8x64, 1);
|
|
190
|
+
n = 0;
|
|
191
|
+
}
|
|
192
|
+
else {
|
|
193
|
+
// Hot path: 2×256-bit loads to avoid VEXTRACTI64X4 (Port 5)
|
|
194
|
+
a_low_i8x32 = _mm256_loadu_epi8(a);
|
|
195
|
+
a_high_i8x32 = _mm256_loadu_epi8(a + 32);
|
|
196
|
+
a += 64, n -= 64;
|
|
197
|
+
}
|
|
198
|
+
// Upcast from 256-bit halves:
|
|
199
|
+
a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_low_i8x32));
|
|
200
|
+
a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_high_i8x32));
|
|
201
|
+
// Scale:
|
|
202
|
+
result_low_f16x32 = _mm512_fmadd_ph(a_low_f16x32, alpha_f16x32, beta_f16x32);
|
|
203
|
+
result_high_f16x32 = _mm512_fmadd_ph(a_high_f16x32, alpha_f16x32, beta_f16x32);
|
|
204
|
+
// Downcast:
|
|
205
|
+
result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
|
|
206
|
+
result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
|
|
207
|
+
result_i8x64 = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(result_low_i16x32)),
|
|
208
|
+
_mm512_cvtsepi16_epi8(result_high_i16x32), 1);
|
|
209
|
+
_mm512_mask_storeu_epi8(result, mask, result_i8x64);
|
|
210
|
+
result += 64;
|
|
211
|
+
if (n) goto nk_each_scale_i8_sapphire_cycle;
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
NK_PUBLIC void nk_each_blend_i8_sapphire( //
|
|
215
|
+
nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, //
|
|
216
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
|
|
217
|
+
|
|
218
|
+
nk_f32_t alpha_val = *alpha;
|
|
219
|
+
nk_f32_t beta_val = *beta;
|
|
220
|
+
|
|
221
|
+
// There are several special cases we may want to implement:
|
|
222
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
223
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
224
|
+
// In this case we can avoid expensive multiplications.
|
|
225
|
+
nk_each_sum_i8_icelake(a, b, n, result);
|
|
226
|
+
return;
|
|
227
|
+
}
|
|
228
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
229
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
230
|
+
// In this case we can avoid half of the load instructions.
|
|
231
|
+
nk_f32_t zero = 0;
|
|
232
|
+
if (beta_val == 0) { nk_each_scale_i8_sapphire(a, n, alpha, &zero, result); }
|
|
233
|
+
else { nk_each_scale_i8_sapphire(b, n, beta, &zero, result); }
|
|
234
|
+
return;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// The general case.
|
|
238
|
+
short alpha_short, beta_short;
|
|
239
|
+
nk_f32_to_f16_sapphire(&alpha_val, (nk_f16_t *)&alpha_short);
|
|
240
|
+
nk_f32_to_f16_sapphire(&beta_val, (nk_f16_t *)&beta_short);
|
|
241
|
+
__mmask64 mask = 0xFFFFFFFFFFFFFFFFull;
|
|
242
|
+
__m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
|
|
243
|
+
__m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
|
|
244
|
+
__m256i a_low_i8x32, a_high_i8x32, b_low_i8x32, b_high_i8x32;
|
|
245
|
+
__m512i result_i8x64;
|
|
246
|
+
__m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
|
|
247
|
+
__m512h a_scaled_low_f16x32, a_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
|
|
248
|
+
__m512i result_low_i16x32, result_high_i16x32;
|
|
249
|
+
nk_each_blend_i8_sapphire_cycle:
|
|
250
|
+
if (n < 64) {
|
|
251
|
+
// Tail: use masked 512-bit loads and extract (runs once)
|
|
252
|
+
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
|
|
253
|
+
__m512i a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
254
|
+
__m512i b_i8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
255
|
+
a_low_i8x32 = _mm512_castsi512_si256(a_i8x64);
|
|
256
|
+
a_high_i8x32 = _mm512_extracti64x4_epi64(a_i8x64, 1);
|
|
257
|
+
b_low_i8x32 = _mm512_castsi512_si256(b_i8x64);
|
|
258
|
+
b_high_i8x32 = _mm512_extracti64x4_epi64(b_i8x64, 1);
|
|
259
|
+
n = 0;
|
|
260
|
+
}
|
|
261
|
+
else {
|
|
262
|
+
// Hot path: 2×256-bit loads per vector to avoid VEXTRACTI64X4 (Port 5)
|
|
263
|
+
a_low_i8x32 = _mm256_loadu_epi8(a);
|
|
264
|
+
a_high_i8x32 = _mm256_loadu_epi8(a + 32);
|
|
265
|
+
b_low_i8x32 = _mm256_loadu_epi8(b);
|
|
266
|
+
b_high_i8x32 = _mm256_loadu_epi8(b + 32);
|
|
267
|
+
a += 64, b += 64, n -= 64;
|
|
268
|
+
}
|
|
269
|
+
// Upcast from 256-bit halves:
|
|
270
|
+
a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_low_i8x32));
|
|
271
|
+
a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_high_i8x32));
|
|
272
|
+
b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_low_i8x32));
|
|
273
|
+
b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_high_i8x32));
|
|
274
|
+
// Scale:
|
|
275
|
+
a_scaled_low_f16x32 = _mm512_mul_ph(a_low_f16x32, alpha_f16x32);
|
|
276
|
+
a_scaled_high_f16x32 = _mm512_mul_ph(a_high_f16x32, alpha_f16x32);
|
|
277
|
+
// Add:
|
|
278
|
+
result_low_f16x32 = _mm512_fmadd_ph(b_low_f16x32, beta_f16x32, a_scaled_low_f16x32);
|
|
279
|
+
result_high_f16x32 = _mm512_fmadd_ph(b_high_f16x32, beta_f16x32, a_scaled_high_f16x32);
|
|
280
|
+
// Downcast:
|
|
281
|
+
result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
|
|
282
|
+
result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
|
|
283
|
+
result_i8x64 = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(result_low_i16x32)),
|
|
284
|
+
_mm512_cvtsepi16_epi8(result_high_i16x32), 1);
|
|
285
|
+
_mm512_mask_storeu_epi8(result, mask, result_i8x64);
|
|
286
|
+
result += 64;
|
|
287
|
+
if (n) goto nk_each_blend_i8_sapphire_cycle;
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
NK_PUBLIC void nk_each_fma_i8_sapphire( //
|
|
291
|
+
nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, //
|
|
292
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
|
|
293
|
+
|
|
294
|
+
short alpha_short, beta_short;
|
|
295
|
+
nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
|
|
296
|
+
nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
|
|
297
|
+
__mmask64 mask = 0xFFFFFFFFFFFFFFFF;
|
|
298
|
+
__m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
|
|
299
|
+
__m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
|
|
300
|
+
__m256i a_low_i8x32, a_high_i8x32, b_low_i8x32, b_high_i8x32, c_low_i8x32, c_high_i8x32;
|
|
301
|
+
__m512i result_i8x64;
|
|
302
|
+
__m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
|
|
303
|
+
__m512h c_low_f16x32, c_high_f16x32, ab_low_f16x32, ab_high_f16x32;
|
|
304
|
+
__m512h ab_scaled_low_f16x32, ab_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
|
|
305
|
+
__m512i result_low_i16x32, result_high_i16x32;
|
|
306
|
+
__m512h min_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(-128));
|
|
307
|
+
__m512h max_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(127));
|
|
308
|
+
|
|
309
|
+
nk_each_fma_i8_sapphire_cycle:
|
|
310
|
+
if (n < 64) {
|
|
311
|
+
// Tail: use masked 512-bit loads and extract (runs once)
|
|
312
|
+
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
|
|
313
|
+
__m512i a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
314
|
+
__m512i b_i8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
315
|
+
__m512i c_i8x64 = _mm512_maskz_loadu_epi8(mask, c);
|
|
316
|
+
a_low_i8x32 = _mm512_castsi512_si256(a_i8x64);
|
|
317
|
+
a_high_i8x32 = _mm512_extracti64x4_epi64(a_i8x64, 1);
|
|
318
|
+
b_low_i8x32 = _mm512_castsi512_si256(b_i8x64);
|
|
319
|
+
b_high_i8x32 = _mm512_extracti64x4_epi64(b_i8x64, 1);
|
|
320
|
+
c_low_i8x32 = _mm512_castsi512_si256(c_i8x64);
|
|
321
|
+
c_high_i8x32 = _mm512_extracti64x4_epi64(c_i8x64, 1);
|
|
322
|
+
n = 0;
|
|
323
|
+
}
|
|
324
|
+
else {
|
|
325
|
+
// Hot path: 2×256-bit loads per vector to avoid VEXTRACTI64X4 (Port 5)
|
|
326
|
+
a_low_i8x32 = _mm256_loadu_epi8(a);
|
|
327
|
+
a_high_i8x32 = _mm256_loadu_epi8(a + 32);
|
|
328
|
+
b_low_i8x32 = _mm256_loadu_epi8(b);
|
|
329
|
+
b_high_i8x32 = _mm256_loadu_epi8(b + 32);
|
|
330
|
+
c_low_i8x32 = _mm256_loadu_epi8(c);
|
|
331
|
+
c_high_i8x32 = _mm256_loadu_epi8(c + 32);
|
|
332
|
+
a += 64, b += 64, c += 64, n -= 64;
|
|
333
|
+
}
|
|
334
|
+
// Upcast from 256-bit halves:
|
|
335
|
+
a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_low_i8x32));
|
|
336
|
+
a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_high_i8x32));
|
|
337
|
+
b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_low_i8x32));
|
|
338
|
+
b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_high_i8x32));
|
|
339
|
+
c_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(c_low_i8x32));
|
|
340
|
+
c_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(c_high_i8x32));
|
|
341
|
+
// Multiply:
|
|
342
|
+
ab_low_f16x32 = _mm512_mul_ph(a_low_f16x32, b_low_f16x32);
|
|
343
|
+
ab_high_f16x32 = _mm512_mul_ph(a_high_f16x32, b_high_f16x32);
|
|
344
|
+
// Scale:
|
|
345
|
+
ab_scaled_low_f16x32 = _mm512_mul_ph(ab_low_f16x32, alpha_f16x32);
|
|
346
|
+
ab_scaled_high_f16x32 = _mm512_mul_ph(ab_high_f16x32, alpha_f16x32);
|
|
347
|
+
// Add:
|
|
348
|
+
result_low_f16x32 = _mm512_fmadd_ph(c_low_f16x32, beta_f16x32, ab_scaled_low_f16x32);
|
|
349
|
+
result_high_f16x32 = _mm512_fmadd_ph(c_high_f16x32, beta_f16x32, ab_scaled_high_f16x32);
|
|
350
|
+
// Clip the 16-bit result to 8-bit:
|
|
351
|
+
result_low_f16x32 = _mm512_max_ph(_mm512_min_ph(result_low_f16x32, max_f16x32), min_f16x32);
|
|
352
|
+
result_high_f16x32 = _mm512_max_ph(_mm512_min_ph(result_high_f16x32, max_f16x32), min_f16x32);
|
|
353
|
+
// Downcast:
|
|
354
|
+
result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
|
|
355
|
+
result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
|
|
356
|
+
// Merge back:
|
|
357
|
+
result_i8x64 = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(result_low_i16x32)),
|
|
358
|
+
_mm512_cvtsepi16_epi8(result_high_i16x32), 1);
|
|
359
|
+
_mm512_mask_storeu_epi8(result, mask, result_i8x64);
|
|
360
|
+
result += 64;
|
|
361
|
+
if (n) goto nk_each_fma_i8_sapphire_cycle;
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
NK_PUBLIC void nk_each_fma_u8_sapphire( //
|
|
365
|
+
nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, //
|
|
366
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
|
|
367
|
+
|
|
368
|
+
short alpha_short, beta_short;
|
|
369
|
+
nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
|
|
370
|
+
nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
|
|
371
|
+
__mmask64 mask = 0xFFFFFFFFFFFFFFFF;
|
|
372
|
+
__m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
|
|
373
|
+
__m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
|
|
374
|
+
__m512i a_u8x64, b_u8x64, c_u8x64, result_u8x64;
|
|
375
|
+
__m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
|
|
376
|
+
__m512h c_low_f16x32, c_high_f16x32, ab_low_f16x32, ab_high_f16x32;
|
|
377
|
+
__m512h ab_scaled_low_f16x32, ab_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
|
|
378
|
+
__m512i result_low_i16x32, result_high_i16x32;
|
|
379
|
+
__m512h min_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(0));
|
|
380
|
+
__m512h max_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(255));
|
|
381
|
+
|
|
382
|
+
nk_each_fma_u8_sapphire_cycle:
|
|
383
|
+
if (n < 64) {
|
|
384
|
+
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
|
|
385
|
+
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
386
|
+
b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
387
|
+
c_u8x64 = _mm512_maskz_loadu_epi8(mask, c);
|
|
388
|
+
n = 0;
|
|
389
|
+
}
|
|
390
|
+
else {
|
|
391
|
+
a_u8x64 = _mm512_loadu_epi8(a);
|
|
392
|
+
b_u8x64 = _mm512_loadu_epi8(b);
|
|
393
|
+
c_u8x64 = _mm512_loadu_epi8(c);
|
|
394
|
+
a += 64, b += 64, c += 64, n -= 64;
|
|
395
|
+
}
|
|
396
|
+
// Upcast:
|
|
397
|
+
a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8x64, _mm512_setzero_si512()));
|
|
398
|
+
a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8x64, _mm512_setzero_si512()));
|
|
399
|
+
b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(b_u8x64, _mm512_setzero_si512()));
|
|
400
|
+
b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(b_u8x64, _mm512_setzero_si512()));
|
|
401
|
+
c_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(c_u8x64, _mm512_setzero_si512()));
|
|
402
|
+
c_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(c_u8x64, _mm512_setzero_si512()));
|
|
403
|
+
// Multiply:
|
|
404
|
+
ab_low_f16x32 = _mm512_mul_ph(a_low_f16x32, b_low_f16x32);
|
|
405
|
+
ab_high_f16x32 = _mm512_mul_ph(a_high_f16x32, b_high_f16x32);
|
|
406
|
+
// Scale:
|
|
407
|
+
ab_scaled_low_f16x32 = _mm512_mul_ph(ab_low_f16x32, alpha_f16x32);
|
|
408
|
+
ab_scaled_high_f16x32 = _mm512_mul_ph(ab_high_f16x32, alpha_f16x32);
|
|
409
|
+
// Add:
|
|
410
|
+
result_low_f16x32 = _mm512_fmadd_ph(c_low_f16x32, beta_f16x32, ab_scaled_low_f16x32);
|
|
411
|
+
result_high_f16x32 = _mm512_fmadd_ph(c_high_f16x32, beta_f16x32, ab_scaled_high_f16x32);
|
|
412
|
+
// Clip the 16-bit result to 8-bit:
|
|
413
|
+
result_low_f16x32 = _mm512_max_ph(_mm512_min_ph(result_low_f16x32, max_f16x32), min_f16x32);
|
|
414
|
+
result_high_f16x32 = _mm512_max_ph(_mm512_min_ph(result_high_f16x32, max_f16x32), min_f16x32);
|
|
415
|
+
// Downcast:
|
|
416
|
+
result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
|
|
417
|
+
result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
|
|
418
|
+
// Merge back:
|
|
419
|
+
result_u8x64 = _mm512_packus_epi16(result_low_i16x32, result_high_i16x32);
|
|
420
|
+
_mm512_mask_storeu_epi8(result, mask, result_u8x64);
|
|
421
|
+
result += 64;
|
|
422
|
+
if (n) goto nk_each_fma_u8_sapphire_cycle;
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
NK_PUBLIC void nk_each_sum_e4m3_sapphire(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
|
|
426
|
+
__m256i a_e4m3x32, b_e4m3x32;
|
|
427
|
+
__m256h a_lo_f16x16, a_hi_f16x16, b_lo_f16x16, b_hi_f16x16;
|
|
428
|
+
__m256h sum_lo_f16x16, sum_hi_f16x16;
|
|
429
|
+
__m128i result_lo_e4m3x16, result_hi_e4m3x16;
|
|
430
|
+
__mmask32 mask = 0xFFFFFFFF;
|
|
431
|
+
nk_each_sum_e4m3_sapphire_cycle:
|
|
432
|
+
if (n < 32) {
|
|
433
|
+
mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
434
|
+
a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
435
|
+
b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
436
|
+
n = 0;
|
|
437
|
+
}
|
|
438
|
+
else {
|
|
439
|
+
a_e4m3x32 = _mm256_loadu_si256((__m256i const *)a);
|
|
440
|
+
b_e4m3x32 = _mm256_loadu_si256((__m256i const *)b);
|
|
441
|
+
a += 32, b += 32, n -= 32;
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
// Convert e4m3x16 → f16x16 (two halves)
|
|
445
|
+
a_lo_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(a_e4m3x32));
|
|
446
|
+
a_hi_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(a_e4m3x32, 1));
|
|
447
|
+
b_lo_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(b_e4m3x32));
|
|
448
|
+
b_hi_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(b_e4m3x32, 1));
|
|
449
|
+
|
|
450
|
+
// Add in F16 - e4m3 sum is safe (max 896 < 65504)
|
|
451
|
+
sum_lo_f16x16 = _mm256_add_ph(a_lo_f16x16, b_lo_f16x16);
|
|
452
|
+
sum_hi_f16x16 = _mm256_add_ph(a_hi_f16x16, b_hi_f16x16);
|
|
453
|
+
|
|
454
|
+
// Convert f16x16 → e4m3x16
|
|
455
|
+
result_lo_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_lo_f16x16);
|
|
456
|
+
result_hi_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_hi_f16x16);
|
|
457
|
+
|
|
458
|
+
// Pack and store
|
|
459
|
+
__m256i result_e4m3x32 = _mm256_inserti128_si256(_mm256_castsi128_si256(result_lo_e4m3x16), result_hi_e4m3x16, 1);
|
|
460
|
+
_mm256_mask_storeu_epi8(result, mask, result_e4m3x32);
|
|
461
|
+
result += 32;
|
|
462
|
+
if (n) goto nk_each_sum_e4m3_sapphire_cycle;
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
#if defined(__clang__)
|
|
466
|
+
#pragma clang attribute pop
|
|
467
|
+
#elif defined(__GNUC__)
|
|
468
|
+
#pragma GCC pop_options
|
|
469
|
+
#endif
|
|
470
|
+
|
|
471
|
+
#if defined(__cplusplus)
|
|
472
|
+
} // extern "C"
|
|
473
|
+
#endif
|
|
474
|
+
|
|
475
|
+
#endif // NK_TARGET_SAPPHIRE
|
|
476
|
+
#endif // NK_TARGET_X86_
|
|
477
|
+
#endif // NK_EACH_SAPPHIRE_H
|