numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1562 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Elementwise Arithmetic for Skylake.
|
|
3
|
+
* @file include/numkong/each/skylake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/each.h
|
|
8
|
+
*
|
|
9
|
+
* @section skylake_elementwise_instructions Relevant Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction SKL ICL Genoa
|
|
12
|
+
* _mm512_add_ps VADDPS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 3cy @ p01
|
|
13
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 4cy @ p01
|
|
14
|
+
* _mm512_mul_ps VMULPS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 3cy @ p01
|
|
15
|
+
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 7cy @ p0 5cy @ p01
|
|
16
|
+
* _mm512_maskz_loadu_ps VMOVUPS (ZMM {K}, M512) 7cy @ p23 7cy @ p23 7cy @ p23
|
|
17
|
+
* _mm512_mask_storeu_ps VMOVUPS (M512 {K}, ZMM) 4cy @ p4 4cy @ p4 4cy @ p4
|
|
18
|
+
*
|
|
19
|
+
* Skylake-X server chips have dual 512-bit FMA units enabling 0.5cy throughput for arithmetic operations.
|
|
20
|
+
* AVX-512 masked loads and stores eliminate branch misprediction penalties for partial vector processing.
|
|
21
|
+
* Note that client Skylake chips may throttle frequency when executing 512-bit instructions continuously.
|
|
22
|
+
*/
|
|
23
|
+
#ifndef NK_EACH_SKYLAKE_H
|
|
24
|
+
#define NK_EACH_SKYLAKE_H
|
|
25
|
+
|
|
26
|
+
#if NK_TARGET_X86_
|
|
27
|
+
#if NK_TARGET_SKYLAKE
|
|
28
|
+
|
|
29
|
+
#include "numkong/types.h"
|
|
30
|
+
#include "numkong/cast/skylake.h" // `nk_e4m3x16_to_f32x16_skylake_`
|
|
31
|
+
|
|
32
|
+
#if defined(__cplusplus)
|
|
33
|
+
extern "C" {
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
#if defined(__clang__)
|
|
37
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
|
|
38
|
+
apply_to = function)
|
|
39
|
+
#elif defined(__GNUC__)
|
|
40
|
+
#pragma GCC push_options
|
|
41
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
NK_PUBLIC void nk_each_sum_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
45
|
+
__m512d a_vec, b_vec, sum_vec;
|
|
46
|
+
__mmask8 mask = 0xFF;
|
|
47
|
+
nk_each_sum_f64_skylake_cycle:
|
|
48
|
+
if (n < 8) {
|
|
49
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
50
|
+
a_vec = _mm512_maskz_loadu_pd(mask, a);
|
|
51
|
+
b_vec = _mm512_maskz_loadu_pd(mask, b);
|
|
52
|
+
n = 0;
|
|
53
|
+
}
|
|
54
|
+
else {
|
|
55
|
+
a_vec = _mm512_loadu_pd(a);
|
|
56
|
+
b_vec = _mm512_loadu_pd(b);
|
|
57
|
+
a += 8, b += 8, n -= 8;
|
|
58
|
+
}
|
|
59
|
+
sum_vec = _mm512_add_pd(a_vec, b_vec);
|
|
60
|
+
_mm512_mask_storeu_pd(result, mask, sum_vec);
|
|
61
|
+
result += 8;
|
|
62
|
+
if (n) goto nk_each_sum_f64_skylake_cycle;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
NK_PUBLIC void nk_each_scale_f64_skylake(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
66
|
+
nk_f64_t *result) {
|
|
67
|
+
nk_f64_t alpha_val = *alpha;
|
|
68
|
+
nk_f64_t beta_val = *beta;
|
|
69
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
70
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
71
|
+
__m512d a_f64x8, result_f64x8;
|
|
72
|
+
__mmask8 mask = 0xFF;
|
|
73
|
+
nk_each_scale_f64_skylake_cycle:
|
|
74
|
+
if (n < 8) {
|
|
75
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
76
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a);
|
|
77
|
+
n = 0;
|
|
78
|
+
}
|
|
79
|
+
else {
|
|
80
|
+
a_f64x8 = _mm512_loadu_pd(a);
|
|
81
|
+
a += 8, n -= 8;
|
|
82
|
+
}
|
|
83
|
+
result_f64x8 = _mm512_fmadd_pd(a_f64x8, alpha_f64x8, beta_f64x8);
|
|
84
|
+
_mm512_mask_storeu_pd(result, mask, result_f64x8);
|
|
85
|
+
result += 8;
|
|
86
|
+
if (n) goto nk_each_scale_f64_skylake_cycle;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
NK_PUBLIC void nk_each_blend_f64_skylake( //
|
|
90
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, //
|
|
91
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result) {
|
|
92
|
+
nk_f64_t alpha_val = *alpha;
|
|
93
|
+
nk_f64_t beta_val = *beta;
|
|
94
|
+
|
|
95
|
+
// There are several special cases we may want to implement:
|
|
96
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
97
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
98
|
+
// In this case we can avoid expensive multiplications.
|
|
99
|
+
nk_each_sum_f64_skylake(a, b, n, result);
|
|
100
|
+
return;
|
|
101
|
+
}
|
|
102
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
103
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
104
|
+
// In this case we can avoid half of the load instructions.
|
|
105
|
+
nk_f64_t zero = 0;
|
|
106
|
+
if (beta_val == 0) { nk_each_scale_f64_skylake(a, n, alpha, &zero, result); }
|
|
107
|
+
else { nk_each_scale_f64_skylake(b, n, beta, &zero, result); }
|
|
108
|
+
return;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
// The general case.
|
|
112
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
113
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
114
|
+
__m512d a_f64x8, b_f64x8, a_scaled_f64x8, result_f64x8;
|
|
115
|
+
__mmask8 mask = 0xFF;
|
|
116
|
+
nk_each_blend_f64_skylake_cycle:
|
|
117
|
+
if (n < 8) {
|
|
118
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
119
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a);
|
|
120
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, b);
|
|
121
|
+
n = 0;
|
|
122
|
+
}
|
|
123
|
+
else {
|
|
124
|
+
a_f64x8 = _mm512_loadu_pd(a);
|
|
125
|
+
b_f64x8 = _mm512_loadu_pd(b);
|
|
126
|
+
a += 8, b += 8, n -= 8;
|
|
127
|
+
}
|
|
128
|
+
a_scaled_f64x8 = _mm512_mul_pd(a_f64x8, alpha_f64x8);
|
|
129
|
+
result_f64x8 = _mm512_fmadd_pd(b_f64x8, beta_f64x8, a_scaled_f64x8);
|
|
130
|
+
_mm512_mask_storeu_pd(result, mask, result_f64x8);
|
|
131
|
+
result += 8;
|
|
132
|
+
if (n) goto nk_each_blend_f64_skylake_cycle;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
NK_PUBLIC void nk_each_sum_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
136
|
+
__m512 a_vec, b_vec, sum_vec;
|
|
137
|
+
__mmask16 mask = 0xFFFF;
|
|
138
|
+
|
|
139
|
+
nk_each_sum_f32_skylake_cycle:
|
|
140
|
+
if (n < 16) {
|
|
141
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
142
|
+
a_vec = _mm512_maskz_loadu_ps(mask, a);
|
|
143
|
+
b_vec = _mm512_maskz_loadu_ps(mask, b);
|
|
144
|
+
n = 0;
|
|
145
|
+
}
|
|
146
|
+
else {
|
|
147
|
+
a_vec = _mm512_loadu_ps(a);
|
|
148
|
+
b_vec = _mm512_loadu_ps(b);
|
|
149
|
+
a += 16, b += 16, n -= 16;
|
|
150
|
+
}
|
|
151
|
+
sum_vec = _mm512_add_ps(a_vec, b_vec);
|
|
152
|
+
_mm512_mask_storeu_ps(result, mask, sum_vec);
|
|
153
|
+
result += 16;
|
|
154
|
+
if (n) goto nk_each_sum_f32_skylake_cycle;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
NK_PUBLIC void nk_each_scale_f32_skylake(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
158
|
+
nk_f32_t *result) {
|
|
159
|
+
nk_f32_t alpha_val = *alpha;
|
|
160
|
+
nk_f32_t beta_val = *beta;
|
|
161
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
162
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
163
|
+
__m512 a_f32x16, result_f32x16;
|
|
164
|
+
__mmask16 mask = 0xFFFF;
|
|
165
|
+
|
|
166
|
+
nk_each_scale_f32_skylake_cycle:
|
|
167
|
+
if (n < 16) {
|
|
168
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
169
|
+
a_f32x16 = _mm512_maskz_loadu_ps(mask, a);
|
|
170
|
+
n = 0;
|
|
171
|
+
}
|
|
172
|
+
else {
|
|
173
|
+
a_f32x16 = _mm512_loadu_ps(a);
|
|
174
|
+
a += 16, n -= 16;
|
|
175
|
+
}
|
|
176
|
+
result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
177
|
+
_mm512_mask_storeu_ps(result, mask, result_f32x16);
|
|
178
|
+
result += 16;
|
|
179
|
+
if (n) goto nk_each_scale_f32_skylake_cycle;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
NK_PUBLIC void nk_each_blend_f32_skylake( //
|
|
183
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, //
|
|
184
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result) {
|
|
185
|
+
nk_f32_t alpha_val = *alpha;
|
|
186
|
+
nk_f32_t beta_val = *beta;
|
|
187
|
+
|
|
188
|
+
// There are several special cases we may want to implement:
|
|
189
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
190
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
191
|
+
// In this case we can avoid expensive multiplications.
|
|
192
|
+
nk_each_sum_f32_skylake(a, b, n, result);
|
|
193
|
+
return;
|
|
194
|
+
}
|
|
195
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
196
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
197
|
+
// In this case we can avoid half of the load instructions.
|
|
198
|
+
nk_f32_t zero = 0;
|
|
199
|
+
if (beta_val == 0) { nk_each_scale_f32_skylake(a, n, alpha, &zero, result); }
|
|
200
|
+
else { nk_each_scale_f32_skylake(b, n, beta, &zero, result); }
|
|
201
|
+
return;
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
// The general case.
|
|
205
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
206
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
207
|
+
__m512 a_f32x16, b_f32x16, a_scaled_f32x16, result_f32x16;
|
|
208
|
+
__mmask16 mask = 0xFFFF;
|
|
209
|
+
nk_each_blend_f32_skylake_cycle:
|
|
210
|
+
if (n < 16) {
|
|
211
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
212
|
+
a_f32x16 = _mm512_maskz_loadu_ps(mask, a);
|
|
213
|
+
b_f32x16 = _mm512_maskz_loadu_ps(mask, b);
|
|
214
|
+
n = 0;
|
|
215
|
+
}
|
|
216
|
+
else {
|
|
217
|
+
a_f32x16 = _mm512_loadu_ps(a);
|
|
218
|
+
b_f32x16 = _mm512_loadu_ps(b);
|
|
219
|
+
a += 16, b += 16, n -= 16;
|
|
220
|
+
}
|
|
221
|
+
a_scaled_f32x16 = _mm512_mul_ps(a_f32x16, alpha_f32x16);
|
|
222
|
+
result_f32x16 = _mm512_fmadd_ps(b_f32x16, beta_f32x16, a_scaled_f32x16);
|
|
223
|
+
_mm512_mask_storeu_ps(result, mask, result_f32x16);
|
|
224
|
+
result += 16;
|
|
225
|
+
if (n) goto nk_each_blend_f32_skylake_cycle;
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
NK_PUBLIC void nk_each_sum_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result) {
|
|
229
|
+
__m256i a_bf16_vec, b_bf16_vec, sum_bf16_vec;
|
|
230
|
+
__m512 a_vec, b_vec, sum_vec;
|
|
231
|
+
__mmask16 mask = 0xFFFF;
|
|
232
|
+
nk_each_sum_bf16_skylake_cycle:
|
|
233
|
+
if (n < 16) {
|
|
234
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
235
|
+
a_bf16_vec = _mm256_maskz_loadu_epi16(mask, a);
|
|
236
|
+
b_bf16_vec = _mm256_maskz_loadu_epi16(mask, b);
|
|
237
|
+
n = 0;
|
|
238
|
+
}
|
|
239
|
+
else {
|
|
240
|
+
a_bf16_vec = _mm256_loadu_epi16(a);
|
|
241
|
+
b_bf16_vec = _mm256_loadu_epi16(b);
|
|
242
|
+
a += 16, b += 16, n -= 16;
|
|
243
|
+
}
|
|
244
|
+
a_vec = nk_bf16x16_to_f32x16_skylake_(a_bf16_vec);
|
|
245
|
+
b_vec = nk_bf16x16_to_f32x16_skylake_(b_bf16_vec);
|
|
246
|
+
sum_vec = _mm512_add_ps(a_vec, b_vec);
|
|
247
|
+
sum_bf16_vec = nk_f32x16_to_bf16x16_skylake_(sum_vec);
|
|
248
|
+
_mm256_mask_storeu_epi16(result, mask, sum_bf16_vec);
|
|
249
|
+
result += 16;
|
|
250
|
+
if (n) goto nk_each_sum_bf16_skylake_cycle;
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
NK_PUBLIC void nk_each_scale_bf16_skylake(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
254
|
+
nk_bf16_t *result) {
|
|
255
|
+
nk_f32_t alpha_val = *alpha;
|
|
256
|
+
nk_f32_t beta_val = *beta;
|
|
257
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
258
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
259
|
+
__m256i a_bf16x16, result_bf16x16;
|
|
260
|
+
__m512 a_f32x16, result_f32x16;
|
|
261
|
+
__mmask16 mask = 0xFFFF;
|
|
262
|
+
nk_each_scale_bf16_skylake_cycle:
|
|
263
|
+
if (n < 16) {
|
|
264
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
265
|
+
a_bf16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
266
|
+
n = 0;
|
|
267
|
+
}
|
|
268
|
+
else {
|
|
269
|
+
a_bf16x16 = _mm256_loadu_epi16(a);
|
|
270
|
+
a += 16, n -= 16;
|
|
271
|
+
}
|
|
272
|
+
a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
273
|
+
result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
274
|
+
result_bf16x16 = nk_f32x16_to_bf16x16_skylake_(result_f32x16);
|
|
275
|
+
_mm256_mask_storeu_epi16(result, mask, result_bf16x16);
|
|
276
|
+
result += 16;
|
|
277
|
+
if (n) goto nk_each_scale_bf16_skylake_cycle;
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
NK_PUBLIC void nk_each_blend_bf16_skylake( //
|
|
281
|
+
nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, //
|
|
282
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result) {
|
|
283
|
+
nk_f32_t alpha_val = *alpha;
|
|
284
|
+
nk_f32_t beta_val = *beta;
|
|
285
|
+
|
|
286
|
+
// There are several special cases we may want to implement:
|
|
287
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
288
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
289
|
+
// In this case we can avoid expensive multiplications.
|
|
290
|
+
nk_each_sum_bf16_skylake(a, b, n, result);
|
|
291
|
+
return;
|
|
292
|
+
}
|
|
293
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
294
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
295
|
+
// In this case we can avoid half of the load instructions.
|
|
296
|
+
nk_f32_t zero = 0;
|
|
297
|
+
if (beta_val == 0) { nk_each_scale_bf16_skylake(a, n, alpha, &zero, result); }
|
|
298
|
+
else { nk_each_scale_bf16_skylake(b, n, beta, &zero, result); }
|
|
299
|
+
return;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
// The general case.
|
|
303
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
304
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
305
|
+
__m256i a_bf16x16, b_bf16x16, result_bf16x16;
|
|
306
|
+
__m512 a_f32x16, b_f32x16, a_scaled_f32x16, result_f32x16;
|
|
307
|
+
__mmask16 mask = 0xFFFF;
|
|
308
|
+
nk_each_blend_bf16_skylake_cycle:
|
|
309
|
+
if (n < 16) {
|
|
310
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
311
|
+
a_bf16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
312
|
+
b_bf16x16 = _mm256_maskz_loadu_epi16(mask, b);
|
|
313
|
+
n = 0;
|
|
314
|
+
}
|
|
315
|
+
else {
|
|
316
|
+
a_bf16x16 = _mm256_loadu_epi16(a);
|
|
317
|
+
b_bf16x16 = _mm256_loadu_epi16(b);
|
|
318
|
+
a += 16, b += 16, n -= 16;
|
|
319
|
+
}
|
|
320
|
+
a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
321
|
+
b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
322
|
+
a_scaled_f32x16 = _mm512_mul_ps(a_f32x16, alpha_f32x16);
|
|
323
|
+
result_f32x16 = _mm512_fmadd_ps(b_f32x16, beta_f32x16, a_scaled_f32x16);
|
|
324
|
+
result_bf16x16 = nk_f32x16_to_bf16x16_skylake_(result_f32x16);
|
|
325
|
+
_mm256_mask_storeu_epi16(result, mask, result_bf16x16);
|
|
326
|
+
result += 16;
|
|
327
|
+
if (n) goto nk_each_blend_bf16_skylake_cycle;
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
NK_PUBLIC void nk_each_fma_f64_skylake( //
|
|
331
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n, //
|
|
332
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result) {
|
|
333
|
+
nk_f64_t alpha_val = *alpha;
|
|
334
|
+
nk_f64_t beta_val = *beta;
|
|
335
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
336
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
337
|
+
__m512d a_f64x8, b_f64x8, c_f64x8, ab_f64x8, ab_scaled_f64x8, result_f64x8;
|
|
338
|
+
__mmask8 mask = 0xFF;
|
|
339
|
+
nk_each_fma_f64_skylake_cycle:
|
|
340
|
+
if (n < 8) {
|
|
341
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
342
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a);
|
|
343
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, b);
|
|
344
|
+
c_f64x8 = _mm512_maskz_loadu_pd(mask, c);
|
|
345
|
+
n = 0;
|
|
346
|
+
}
|
|
347
|
+
else {
|
|
348
|
+
a_f64x8 = _mm512_loadu_pd(a);
|
|
349
|
+
b_f64x8 = _mm512_loadu_pd(b);
|
|
350
|
+
c_f64x8 = _mm512_loadu_pd(c);
|
|
351
|
+
a += 8, b += 8, c += 8, n -= 8;
|
|
352
|
+
}
|
|
353
|
+
ab_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
354
|
+
ab_scaled_f64x8 = _mm512_mul_pd(ab_f64x8, alpha_f64x8);
|
|
355
|
+
result_f64x8 = _mm512_fmadd_pd(c_f64x8, beta_f64x8, ab_scaled_f64x8);
|
|
356
|
+
_mm512_mask_storeu_pd(result, mask, result_f64x8);
|
|
357
|
+
result += 8;
|
|
358
|
+
if (n) goto nk_each_fma_f64_skylake_cycle;
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
NK_PUBLIC void nk_each_fma_f32_skylake( //
|
|
362
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n, //
|
|
363
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result) {
|
|
364
|
+
nk_f32_t alpha_val = *alpha;
|
|
365
|
+
nk_f32_t beta_val = *beta;
|
|
366
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
367
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
368
|
+
__m512 a_f32x16, b_f32x16, c_f32x16, ab_f32x16, ab_scaled_f32x16, result_f32x16;
|
|
369
|
+
__mmask16 mask = 0xFFFF;
|
|
370
|
+
nk_each_fma_f32_skylake_cycle:
|
|
371
|
+
if (n < 16) {
|
|
372
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
373
|
+
a_f32x16 = _mm512_maskz_loadu_ps(mask, a);
|
|
374
|
+
b_f32x16 = _mm512_maskz_loadu_ps(mask, b);
|
|
375
|
+
c_f32x16 = _mm512_maskz_loadu_ps(mask, c);
|
|
376
|
+
n = 0;
|
|
377
|
+
}
|
|
378
|
+
else {
|
|
379
|
+
a_f32x16 = _mm512_loadu_ps(a);
|
|
380
|
+
b_f32x16 = _mm512_loadu_ps(b);
|
|
381
|
+
c_f32x16 = _mm512_loadu_ps(c);
|
|
382
|
+
a += 16, b += 16, c += 16, n -= 16;
|
|
383
|
+
}
|
|
384
|
+
ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
385
|
+
ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
386
|
+
result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
387
|
+
_mm512_mask_storeu_ps(result, mask, result_f32x16);
|
|
388
|
+
result += 16;
|
|
389
|
+
if (n) goto nk_each_fma_f32_skylake_cycle;
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
NK_PUBLIC void nk_each_fma_bf16_skylake( //
|
|
393
|
+
nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n, //
|
|
394
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result) {
|
|
395
|
+
nk_f32_t alpha_val = *alpha;
|
|
396
|
+
nk_f32_t beta_val = *beta;
|
|
397
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
398
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
399
|
+
__m256i a_bf16x16, b_bf16x16, c_bf16x16, result_bf16x16;
|
|
400
|
+
__m512 a_f32x16, b_f32x16, c_f32x16, ab_f32x16, ab_scaled_f32x16, result_f32x16;
|
|
401
|
+
__mmask16 mask = 0xFFFF;
|
|
402
|
+
nk_each_fma_bf16_skylake_cycle:
|
|
403
|
+
if (n < 16) {
|
|
404
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
405
|
+
a_bf16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
406
|
+
b_bf16x16 = _mm256_maskz_loadu_epi16(mask, b);
|
|
407
|
+
c_bf16x16 = _mm256_maskz_loadu_epi16(mask, c);
|
|
408
|
+
n = 0;
|
|
409
|
+
}
|
|
410
|
+
else {
|
|
411
|
+
a_bf16x16 = _mm256_loadu_epi16(a);
|
|
412
|
+
b_bf16x16 = _mm256_loadu_epi16(b);
|
|
413
|
+
c_bf16x16 = _mm256_loadu_epi16(c);
|
|
414
|
+
a += 16, b += 16, c += 16, n -= 16;
|
|
415
|
+
}
|
|
416
|
+
a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
417
|
+
b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
418
|
+
c_f32x16 = nk_bf16x16_to_f32x16_skylake_(c_bf16x16);
|
|
419
|
+
ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
420
|
+
ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
421
|
+
result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
422
|
+
result_bf16x16 = nk_f32x16_to_bf16x16_skylake_(result_f32x16);
|
|
423
|
+
_mm256_mask_storeu_epi16(result, mask, result_bf16x16);
|
|
424
|
+
result += 16;
|
|
425
|
+
if (n) goto nk_each_fma_bf16_skylake_cycle;
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
NK_PUBLIC void nk_each_scale_i8_skylake(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
429
|
+
nk_i8_t *result) {
|
|
430
|
+
nk_f32_t alpha_val = *alpha;
|
|
431
|
+
nk_f32_t beta_val = *beta;
|
|
432
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
433
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
434
|
+
__m128i a_i8x16, result_i8x16;
|
|
435
|
+
__m512 a_f32x16, result_f32x16;
|
|
436
|
+
__mmask16 mask = 0xFFFF;
|
|
437
|
+
__m512i result_i32x16;
|
|
438
|
+
__m512i min_i32x16 = _mm512_set1_epi32(-128);
|
|
439
|
+
__m512i max_i32x16 = _mm512_set1_epi32(127);
|
|
440
|
+
|
|
441
|
+
nk_each_scale_i8_skylake_cycle:
|
|
442
|
+
if (n < 16) {
|
|
443
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
444
|
+
a_i8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
445
|
+
n = 0;
|
|
446
|
+
}
|
|
447
|
+
else {
|
|
448
|
+
a_i8x16 = _mm_loadu_si128((__m128i *)a);
|
|
449
|
+
a += 16, n -= 16;
|
|
450
|
+
}
|
|
451
|
+
a_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(a_i8x16));
|
|
452
|
+
result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
453
|
+
result_i32x16 = _mm512_cvtps_epi32(result_f32x16);
|
|
454
|
+
result_i32x16 = _mm512_max_epi32(result_i32x16, min_i32x16);
|
|
455
|
+
result_i32x16 = _mm512_min_epi32(result_i32x16, max_i32x16);
|
|
456
|
+
result_i8x16 = _mm512_cvtepi32_epi8(result_i32x16);
|
|
457
|
+
_mm_mask_storeu_epi8(result, mask, result_i8x16);
|
|
458
|
+
result += 16;
|
|
459
|
+
if (n) goto nk_each_scale_i8_skylake_cycle;
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
NK_PUBLIC void nk_each_fma_i8_skylake( //
|
|
463
|
+
nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, //
|
|
464
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
|
|
465
|
+
nk_f32_t alpha_val = *alpha;
|
|
466
|
+
nk_f32_t beta_val = *beta;
|
|
467
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
468
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
469
|
+
__m128i a_i8x16, b_i8x16, c_i8x16, result_i8x16;
|
|
470
|
+
__m512 a_f32x16, b_f32x16, c_f32x16, ab_f32x16, ab_scaled_f32x16, result_f32x16;
|
|
471
|
+
__mmask16 mask = 0xFFFF;
|
|
472
|
+
__m512i result_i32x16;
|
|
473
|
+
__m512i min_i32x16 = _mm512_set1_epi32(-128);
|
|
474
|
+
__m512i max_i32x16 = _mm512_set1_epi32(127);
|
|
475
|
+
|
|
476
|
+
nk_each_fma_i8_skylake_cycle:
|
|
477
|
+
if (n < 16) {
|
|
478
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
479
|
+
a_i8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
480
|
+
b_i8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
481
|
+
c_i8x16 = _mm_maskz_loadu_epi8(mask, c);
|
|
482
|
+
n = 0;
|
|
483
|
+
}
|
|
484
|
+
else {
|
|
485
|
+
a_i8x16 = _mm_loadu_si128((__m128i *)a);
|
|
486
|
+
b_i8x16 = _mm_loadu_si128((__m128i *)b);
|
|
487
|
+
c_i8x16 = _mm_loadu_si128((__m128i *)c);
|
|
488
|
+
a += 16, b += 16, c += 16, n -= 16;
|
|
489
|
+
}
|
|
490
|
+
a_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(a_i8x16));
|
|
491
|
+
b_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(b_i8x16));
|
|
492
|
+
c_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(c_i8x16));
|
|
493
|
+
ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
494
|
+
ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
495
|
+
result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
496
|
+
result_i32x16 = _mm512_cvtps_epi32(result_f32x16);
|
|
497
|
+
result_i32x16 = _mm512_max_epi32(result_i32x16, min_i32x16);
|
|
498
|
+
result_i32x16 = _mm512_min_epi32(result_i32x16, max_i32x16);
|
|
499
|
+
result_i8x16 = _mm512_cvtepi32_epi8(result_i32x16);
|
|
500
|
+
_mm_mask_storeu_epi8(result, mask, result_i8x16);
|
|
501
|
+
result += 16;
|
|
502
|
+
if (n) goto nk_each_fma_i8_skylake_cycle;
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
NK_PUBLIC void nk_each_scale_u8_skylake(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
506
|
+
nk_u8_t *result) {
|
|
507
|
+
nk_f32_t alpha_val = *alpha;
|
|
508
|
+
nk_f32_t beta_val = *beta;
|
|
509
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
510
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
511
|
+
__m128i a_u8x16, result_u8x16;
|
|
512
|
+
__m512 a_f32x16, result_f32x16;
|
|
513
|
+
__mmask16 mask = 0xFFFF;
|
|
514
|
+
__m512i result_u32x16;
|
|
515
|
+
__m512i min_u32x16 = _mm512_set1_epi32(0);
|
|
516
|
+
__m512i max_u32x16 = _mm512_set1_epi32(255);
|
|
517
|
+
|
|
518
|
+
nk_each_scale_u8_skylake_cycle:
|
|
519
|
+
if (n < 16) {
|
|
520
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
521
|
+
a_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
522
|
+
n = 0;
|
|
523
|
+
}
|
|
524
|
+
else {
|
|
525
|
+
a_u8x16 = _mm_loadu_si128((__m128i *)a);
|
|
526
|
+
a += 16, n -= 16;
|
|
527
|
+
}
|
|
528
|
+
a_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(a_u8x16));
|
|
529
|
+
result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
530
|
+
result_u32x16 = _mm512_cvtps_epu32(result_f32x16);
|
|
531
|
+
result_u32x16 = _mm512_max_epu32(result_u32x16, min_u32x16);
|
|
532
|
+
result_u32x16 = _mm512_min_epu32(result_u32x16, max_u32x16);
|
|
533
|
+
result_u8x16 = _mm512_cvtepi32_epi8(result_u32x16);
|
|
534
|
+
_mm_mask_storeu_epi8(result, mask, result_u8x16);
|
|
535
|
+
result += 16;
|
|
536
|
+
if (n) goto nk_each_scale_u8_skylake_cycle;
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
NK_PUBLIC void nk_each_fma_u8_skylake( //
|
|
540
|
+
nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, //
|
|
541
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
|
|
542
|
+
nk_f32_t alpha_val = *alpha;
|
|
543
|
+
nk_f32_t beta_val = *beta;
|
|
544
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
545
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
546
|
+
__m128i a_u8x16, b_u8x16, c_u8x16, result_u8x16;
|
|
547
|
+
__m512 a_f32x16, b_f32x16, c_f32x16, ab_f32x16, ab_scaled_f32x16, result_f32x16;
|
|
548
|
+
__mmask16 mask = 0xFFFF;
|
|
549
|
+
__m512i result_u32x16;
|
|
550
|
+
__m512i min_u32x16 = _mm512_set1_epi32(0);
|
|
551
|
+
__m512i max_u32x16 = _mm512_set1_epi32(255);
|
|
552
|
+
|
|
553
|
+
nk_each_fma_u8_skylake_cycle:
|
|
554
|
+
if (n < 16) {
|
|
555
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
556
|
+
a_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
557
|
+
b_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
558
|
+
c_u8x16 = _mm_maskz_loadu_epi8(mask, c);
|
|
559
|
+
n = 0;
|
|
560
|
+
}
|
|
561
|
+
else {
|
|
562
|
+
a_u8x16 = _mm_loadu_si128((__m128i *)a);
|
|
563
|
+
b_u8x16 = _mm_loadu_si128((__m128i *)b);
|
|
564
|
+
c_u8x16 = _mm_loadu_si128((__m128i *)c);
|
|
565
|
+
a += 16, b += 16, c += 16, n -= 16;
|
|
566
|
+
}
|
|
567
|
+
a_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(a_u8x16));
|
|
568
|
+
b_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(b_u8x16));
|
|
569
|
+
c_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(c_u8x16));
|
|
570
|
+
ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
571
|
+
ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
572
|
+
result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
573
|
+
result_u32x16 = _mm512_cvtps_epu32(result_f32x16);
|
|
574
|
+
result_u32x16 = _mm512_max_epu32(result_u32x16, min_u32x16);
|
|
575
|
+
result_u32x16 = _mm512_min_epu32(result_u32x16, max_u32x16);
|
|
576
|
+
result_u8x16 = _mm512_cvtepi32_epi8(result_u32x16);
|
|
577
|
+
_mm_mask_storeu_epi8(result, mask, result_u8x16);
|
|
578
|
+
result += 16;
|
|
579
|
+
if (n) goto nk_each_fma_u8_skylake_cycle;
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
NK_PUBLIC void nk_each_scale_i16_skylake(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
583
|
+
nk_i16_t *result) {
|
|
584
|
+
nk_f32_t alpha_f32 = *alpha;
|
|
585
|
+
nk_f32_t beta_f32 = *beta;
|
|
586
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_f32);
|
|
587
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_f32);
|
|
588
|
+
__m256i a_i16x16, result_i16x16;
|
|
589
|
+
__m512 a_f32x16, result_f32x16;
|
|
590
|
+
__mmask16 mask = 0xFFFF;
|
|
591
|
+
__m512i result_i32x16;
|
|
592
|
+
__m512 min_f32x16 = _mm512_set1_ps(-32768.0f);
|
|
593
|
+
__m512 max_f32x16 = _mm512_set1_ps(32767.0f);
|
|
594
|
+
|
|
595
|
+
nk_each_scale_i16_skylake_cycle:
|
|
596
|
+
if (n < 16) {
|
|
597
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
598
|
+
a_i16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
599
|
+
n = 0;
|
|
600
|
+
}
|
|
601
|
+
else {
|
|
602
|
+
a_i16x16 = _mm256_loadu_si256((__m256i *)a);
|
|
603
|
+
a += 16, n -= 16;
|
|
604
|
+
}
|
|
605
|
+
a_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(a_i16x16));
|
|
606
|
+
result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
607
|
+
result_f32x16 = _mm512_max_ps(result_f32x16, min_f32x16);
|
|
608
|
+
result_f32x16 = _mm512_min_ps(result_f32x16, max_f32x16);
|
|
609
|
+
result_i32x16 = _mm512_cvtps_epi32(result_f32x16);
|
|
610
|
+
result_i16x16 = _mm512_cvtepi32_epi16(result_i32x16);
|
|
611
|
+
_mm256_mask_storeu_epi16(result, mask, result_i16x16);
|
|
612
|
+
result += 16;
|
|
613
|
+
if (n) goto nk_each_scale_i16_skylake_cycle;
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
NK_PUBLIC void nk_each_fma_i16_skylake( //
|
|
617
|
+
nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n, //
|
|
618
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result) {
|
|
619
|
+
nk_f32_t alpha_f32 = *alpha;
|
|
620
|
+
nk_f32_t beta_f32 = *beta;
|
|
621
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_f32);
|
|
622
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_f32);
|
|
623
|
+
__m256i a_i16x16, b_i16x16, c_i16x16, result_i16x16;
|
|
624
|
+
__m512 a_f32x16, b_f32x16, c_f32x16, ab_f32x16, ab_scaled_f32x16, result_f32x16;
|
|
625
|
+
__mmask16 mask = 0xFFFF;
|
|
626
|
+
__m512i result_i32x16;
|
|
627
|
+
__m512 min_f32x16 = _mm512_set1_ps(-32768.0f);
|
|
628
|
+
__m512 max_f32x16 = _mm512_set1_ps(32767.0f);
|
|
629
|
+
|
|
630
|
+
nk_each_fma_i16_skylake_cycle:
|
|
631
|
+
if (n < 16) {
|
|
632
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
633
|
+
a_i16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
634
|
+
b_i16x16 = _mm256_maskz_loadu_epi16(mask, b);
|
|
635
|
+
c_i16x16 = _mm256_maskz_loadu_epi16(mask, c);
|
|
636
|
+
n = 0;
|
|
637
|
+
}
|
|
638
|
+
else {
|
|
639
|
+
a_i16x16 = _mm256_loadu_si256((__m256i *)a);
|
|
640
|
+
b_i16x16 = _mm256_loadu_si256((__m256i *)b);
|
|
641
|
+
c_i16x16 = _mm256_loadu_si256((__m256i *)c);
|
|
642
|
+
a += 16, b += 16, c += 16, n -= 16;
|
|
643
|
+
}
|
|
644
|
+
a_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(a_i16x16));
|
|
645
|
+
b_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(b_i16x16));
|
|
646
|
+
c_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(c_i16x16));
|
|
647
|
+
ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
648
|
+
ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
649
|
+
result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
650
|
+
result_f32x16 = _mm512_max_ps(result_f32x16, min_f32x16);
|
|
651
|
+
result_f32x16 = _mm512_min_ps(result_f32x16, max_f32x16);
|
|
652
|
+
result_i32x16 = _mm512_cvtps_epi32(result_f32x16);
|
|
653
|
+
result_i16x16 = _mm512_cvtepi32_epi16(result_i32x16);
|
|
654
|
+
_mm256_mask_storeu_epi16(result, mask, result_i16x16);
|
|
655
|
+
result += 16;
|
|
656
|
+
if (n) goto nk_each_fma_i16_skylake_cycle;
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
NK_PUBLIC void nk_each_scale_u16_skylake(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
660
|
+
nk_u16_t *result) {
|
|
661
|
+
nk_f32_t alpha_f32 = *alpha;
|
|
662
|
+
nk_f32_t beta_f32 = *beta;
|
|
663
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_f32);
|
|
664
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_f32);
|
|
665
|
+
__m256i a_u16x16, result_u16x16;
|
|
666
|
+
__m512 a_f32x16, result_f32x16;
|
|
667
|
+
__mmask16 mask = 0xFFFF;
|
|
668
|
+
__m512i result_u32x16;
|
|
669
|
+
__m512 min_f32x16 = _mm512_setzero_ps();
|
|
670
|
+
__m512 max_f32x16 = _mm512_set1_ps(65535.0f);
|
|
671
|
+
|
|
672
|
+
nk_each_scale_u16_skylake_cycle:
|
|
673
|
+
if (n < 16) {
|
|
674
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
675
|
+
a_u16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
676
|
+
n = 0;
|
|
677
|
+
}
|
|
678
|
+
else {
|
|
679
|
+
a_u16x16 = _mm256_loadu_si256((__m256i *)a);
|
|
680
|
+
a += 16, n -= 16;
|
|
681
|
+
}
|
|
682
|
+
a_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(a_u16x16));
|
|
683
|
+
result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
684
|
+
result_f32x16 = _mm512_max_ps(result_f32x16, min_f32x16);
|
|
685
|
+
result_f32x16 = _mm512_min_ps(result_f32x16, max_f32x16);
|
|
686
|
+
result_u32x16 = _mm512_cvtps_epu32(result_f32x16);
|
|
687
|
+
result_u16x16 = _mm512_cvtepi32_epi16(result_u32x16);
|
|
688
|
+
_mm256_mask_storeu_epi16(result, mask, result_u16x16);
|
|
689
|
+
result += 16;
|
|
690
|
+
if (n) goto nk_each_scale_u16_skylake_cycle;
|
|
691
|
+
}
|
|
692
|
+
|
|
693
|
+
NK_PUBLIC void nk_each_fma_u16_skylake( //
|
|
694
|
+
nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n, //
|
|
695
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result) {
|
|
696
|
+
nk_f32_t alpha_f32 = *alpha;
|
|
697
|
+
nk_f32_t beta_f32 = *beta;
|
|
698
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_f32);
|
|
699
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_f32);
|
|
700
|
+
__m256i a_u16x16, b_u16x16, c_u16x16, result_u16x16;
|
|
701
|
+
__m512 a_f32x16, b_f32x16, c_f32x16, ab_f32x16, ab_scaled_f32x16, result_f32x16;
|
|
702
|
+
__mmask16 mask = 0xFFFF;
|
|
703
|
+
__m512i result_u32x16;
|
|
704
|
+
__m512 min_f32x16 = _mm512_setzero_ps();
|
|
705
|
+
__m512 max_f32x16 = _mm512_set1_ps(65535.0f);
|
|
706
|
+
|
|
707
|
+
nk_each_fma_u16_skylake_cycle:
|
|
708
|
+
if (n < 16) {
|
|
709
|
+
mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n);
|
|
710
|
+
a_u16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
711
|
+
b_u16x16 = _mm256_maskz_loadu_epi16(mask, b);
|
|
712
|
+
c_u16x16 = _mm256_maskz_loadu_epi16(mask, c);
|
|
713
|
+
n = 0;
|
|
714
|
+
}
|
|
715
|
+
else {
|
|
716
|
+
a_u16x16 = _mm256_loadu_si256((__m256i *)a);
|
|
717
|
+
b_u16x16 = _mm256_loadu_si256((__m256i *)b);
|
|
718
|
+
c_u16x16 = _mm256_loadu_si256((__m256i *)c);
|
|
719
|
+
a += 16, b += 16, c += 16, n -= 16;
|
|
720
|
+
}
|
|
721
|
+
a_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(a_u16x16));
|
|
722
|
+
b_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(b_u16x16));
|
|
723
|
+
c_f32x16 = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(c_u16x16));
|
|
724
|
+
ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
725
|
+
ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
726
|
+
result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
727
|
+
result_f32x16 = _mm512_max_ps(result_f32x16, min_f32x16);
|
|
728
|
+
result_f32x16 = _mm512_min_ps(result_f32x16, max_f32x16);
|
|
729
|
+
result_u32x16 = _mm512_cvtps_epu32(result_f32x16);
|
|
730
|
+
result_u16x16 = _mm512_cvtepi32_epi16(result_u32x16);
|
|
731
|
+
_mm256_mask_storeu_epi16(result, mask, result_u16x16);
|
|
732
|
+
result += 16;
|
|
733
|
+
if (n) goto nk_each_fma_u16_skylake_cycle;
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
NK_PUBLIC void nk_each_scale_i32_skylake(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
737
|
+
nk_i32_t *result) {
|
|
738
|
+
nk_f64_t alpha_val = *alpha;
|
|
739
|
+
nk_f64_t beta_val = *beta;
|
|
740
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
741
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
742
|
+
__m256i a_i32x8, result_i32x8;
|
|
743
|
+
__m512d a_f64x8, result_f64x8;
|
|
744
|
+
__mmask8 mask = 0xFF;
|
|
745
|
+
__m512d min_f64x8 = _mm512_set1_pd(-2147483648.0);
|
|
746
|
+
__m512d max_f64x8 = _mm512_set1_pd(2147483647.0);
|
|
747
|
+
|
|
748
|
+
nk_each_scale_i32_skylake_cycle:
|
|
749
|
+
if (n < 8) {
|
|
750
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
751
|
+
a_i32x8 = _mm256_maskz_loadu_epi32(mask, a);
|
|
752
|
+
n = 0;
|
|
753
|
+
}
|
|
754
|
+
else {
|
|
755
|
+
a_i32x8 = _mm256_loadu_si256((__m256i *)a);
|
|
756
|
+
a += 8, n -= 8;
|
|
757
|
+
}
|
|
758
|
+
a_f64x8 = _mm512_cvtepi32_pd(a_i32x8);
|
|
759
|
+
result_f64x8 = _mm512_fmadd_pd(a_f64x8, alpha_f64x8, beta_f64x8);
|
|
760
|
+
result_f64x8 = _mm512_max_pd(result_f64x8, min_f64x8);
|
|
761
|
+
result_f64x8 = _mm512_min_pd(result_f64x8, max_f64x8);
|
|
762
|
+
result_i32x8 = _mm512_cvtpd_epi32(result_f64x8);
|
|
763
|
+
_mm256_mask_storeu_epi32(result, mask, result_i32x8);
|
|
764
|
+
result += 8;
|
|
765
|
+
if (n) goto nk_each_scale_i32_skylake_cycle;
|
|
766
|
+
}
|
|
767
|
+
|
|
768
|
+
NK_PUBLIC void nk_each_fma_i32_skylake( //
|
|
769
|
+
nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n, //
|
|
770
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result) {
|
|
771
|
+
nk_f64_t alpha_val = *alpha;
|
|
772
|
+
nk_f64_t beta_val = *beta;
|
|
773
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
774
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
775
|
+
__m256i a_i32x8, b_i32x8, c_i32x8, result_i32x8;
|
|
776
|
+
__m512d a_f64x8, b_f64x8, c_f64x8, ab_f64x8, ab_scaled_f64x8, result_f64x8;
|
|
777
|
+
__mmask8 mask = 0xFF;
|
|
778
|
+
__m512d min_f64x8 = _mm512_set1_pd(-2147483648.0);
|
|
779
|
+
__m512d max_f64x8 = _mm512_set1_pd(2147483647.0);
|
|
780
|
+
|
|
781
|
+
nk_each_fma_i32_skylake_cycle:
|
|
782
|
+
if (n < 8) {
|
|
783
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
784
|
+
a_i32x8 = _mm256_maskz_loadu_epi32(mask, a);
|
|
785
|
+
b_i32x8 = _mm256_maskz_loadu_epi32(mask, b);
|
|
786
|
+
c_i32x8 = _mm256_maskz_loadu_epi32(mask, c);
|
|
787
|
+
n = 0;
|
|
788
|
+
}
|
|
789
|
+
else {
|
|
790
|
+
a_i32x8 = _mm256_loadu_si256((__m256i *)a);
|
|
791
|
+
b_i32x8 = _mm256_loadu_si256((__m256i *)b);
|
|
792
|
+
c_i32x8 = _mm256_loadu_si256((__m256i *)c);
|
|
793
|
+
a += 8, b += 8, c += 8, n -= 8;
|
|
794
|
+
}
|
|
795
|
+
a_f64x8 = _mm512_cvtepi32_pd(a_i32x8);
|
|
796
|
+
b_f64x8 = _mm512_cvtepi32_pd(b_i32x8);
|
|
797
|
+
c_f64x8 = _mm512_cvtepi32_pd(c_i32x8);
|
|
798
|
+
ab_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
799
|
+
ab_scaled_f64x8 = _mm512_mul_pd(ab_f64x8, alpha_f64x8);
|
|
800
|
+
result_f64x8 = _mm512_fmadd_pd(c_f64x8, beta_f64x8, ab_scaled_f64x8);
|
|
801
|
+
result_f64x8 = _mm512_max_pd(result_f64x8, min_f64x8);
|
|
802
|
+
result_f64x8 = _mm512_min_pd(result_f64x8, max_f64x8);
|
|
803
|
+
result_i32x8 = _mm512_cvtpd_epi32(result_f64x8);
|
|
804
|
+
_mm256_mask_storeu_epi32(result, mask, result_i32x8);
|
|
805
|
+
result += 8;
|
|
806
|
+
if (n) goto nk_each_fma_i32_skylake_cycle;
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
NK_PUBLIC void nk_each_scale_u32_skylake(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
810
|
+
nk_u32_t *result) {
|
|
811
|
+
nk_f64_t alpha_val = *alpha;
|
|
812
|
+
nk_f64_t beta_val = *beta;
|
|
813
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
814
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
815
|
+
__m256i a_u32x8, result_u32x8;
|
|
816
|
+
__m512d a_f64x8, result_f64x8;
|
|
817
|
+
__mmask8 mask = 0xFF;
|
|
818
|
+
__m512d min_f64x8 = _mm512_set1_pd(0.0);
|
|
819
|
+
__m512d max_f64x8 = _mm512_set1_pd(4294967295.0);
|
|
820
|
+
|
|
821
|
+
nk_each_scale_u32_skylake_cycle:
|
|
822
|
+
if (n < 8) {
|
|
823
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
824
|
+
a_u32x8 = _mm256_maskz_loadu_epi32(mask, a);
|
|
825
|
+
n = 0;
|
|
826
|
+
}
|
|
827
|
+
else {
|
|
828
|
+
a_u32x8 = _mm256_loadu_si256((__m256i *)a);
|
|
829
|
+
a += 8, n -= 8;
|
|
830
|
+
}
|
|
831
|
+
a_f64x8 = _mm512_cvtepu32_pd(a_u32x8);
|
|
832
|
+
result_f64x8 = _mm512_fmadd_pd(a_f64x8, alpha_f64x8, beta_f64x8);
|
|
833
|
+
result_f64x8 = _mm512_max_pd(result_f64x8, min_f64x8);
|
|
834
|
+
result_f64x8 = _mm512_min_pd(result_f64x8, max_f64x8);
|
|
835
|
+
result_u32x8 = _mm512_cvtpd_epu32(result_f64x8);
|
|
836
|
+
_mm256_mask_storeu_epi32(result, mask, result_u32x8);
|
|
837
|
+
result += 8;
|
|
838
|
+
if (n) goto nk_each_scale_u32_skylake_cycle;
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
NK_PUBLIC void nk_each_fma_u32_skylake( //
|
|
842
|
+
nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n, //
|
|
843
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result) {
|
|
844
|
+
nk_f64_t alpha_val = *alpha;
|
|
845
|
+
nk_f64_t beta_val = *beta;
|
|
846
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
847
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
848
|
+
__m256i a_u32x8, b_u32x8, c_u32x8, result_u32x8;
|
|
849
|
+
__m512d a_f64x8, b_f64x8, c_f64x8, ab_f64x8, ab_scaled_f64x8, result_f64x8;
|
|
850
|
+
__mmask8 mask = 0xFF;
|
|
851
|
+
__m512d min_f64x8 = _mm512_set1_pd(0.0);
|
|
852
|
+
__m512d max_f64x8 = _mm512_set1_pd(4294967295.0);
|
|
853
|
+
|
|
854
|
+
nk_each_fma_u32_skylake_cycle:
|
|
855
|
+
if (n < 8) {
|
|
856
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
857
|
+
a_u32x8 = _mm256_maskz_loadu_epi32(mask, a);
|
|
858
|
+
b_u32x8 = _mm256_maskz_loadu_epi32(mask, b);
|
|
859
|
+
c_u32x8 = _mm256_maskz_loadu_epi32(mask, c);
|
|
860
|
+
n = 0;
|
|
861
|
+
}
|
|
862
|
+
else {
|
|
863
|
+
a_u32x8 = _mm256_loadu_si256((__m256i *)a);
|
|
864
|
+
b_u32x8 = _mm256_loadu_si256((__m256i *)b);
|
|
865
|
+
c_u32x8 = _mm256_loadu_si256((__m256i *)c);
|
|
866
|
+
a += 8, b += 8, c += 8, n -= 8;
|
|
867
|
+
}
|
|
868
|
+
a_f64x8 = _mm512_cvtepu32_pd(a_u32x8);
|
|
869
|
+
b_f64x8 = _mm512_cvtepu32_pd(b_u32x8);
|
|
870
|
+
c_f64x8 = _mm512_cvtepu32_pd(c_u32x8);
|
|
871
|
+
ab_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
872
|
+
ab_scaled_f64x8 = _mm512_mul_pd(ab_f64x8, alpha_f64x8);
|
|
873
|
+
result_f64x8 = _mm512_fmadd_pd(c_f64x8, beta_f64x8, ab_scaled_f64x8);
|
|
874
|
+
result_f64x8 = _mm512_max_pd(result_f64x8, min_f64x8);
|
|
875
|
+
result_f64x8 = _mm512_min_pd(result_f64x8, max_f64x8);
|
|
876
|
+
result_u32x8 = _mm512_cvtpd_epu32(result_f64x8);
|
|
877
|
+
_mm256_mask_storeu_epi32(result, mask, result_u32x8);
|
|
878
|
+
result += 8;
|
|
879
|
+
if (n) goto nk_each_fma_u32_skylake_cycle;
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
NK_PUBLIC void nk_each_scale_i64_skylake(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
883
|
+
nk_i64_t *result) {
|
|
884
|
+
nk_f64_t alpha_val = *alpha;
|
|
885
|
+
nk_f64_t beta_val = *beta;
|
|
886
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
887
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
888
|
+
__m512i a_i64x8;
|
|
889
|
+
__m512d a_f64x8, result_f64x8;
|
|
890
|
+
__m512i result_i64x8;
|
|
891
|
+
__mmask8 mask = 0xFF;
|
|
892
|
+
|
|
893
|
+
nk_each_scale_i64_skylake_cycle:
|
|
894
|
+
if (n < 8) {
|
|
895
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
896
|
+
a_i64x8 = _mm512_maskz_loadu_epi64(mask, a);
|
|
897
|
+
n = 0;
|
|
898
|
+
}
|
|
899
|
+
else {
|
|
900
|
+
a_i64x8 = _mm512_loadu_si512((__m512i *)a);
|
|
901
|
+
a += 8, n -= 8;
|
|
902
|
+
}
|
|
903
|
+
a_f64x8 = _mm512_cvtepi64_pd(a_i64x8);
|
|
904
|
+
result_f64x8 = _mm512_fmadd_pd(a_f64x8, alpha_f64x8, beta_f64x8);
|
|
905
|
+
result_i64x8 = _mm512_cvtpd_epi64(result_f64x8);
|
|
906
|
+
_mm512_mask_storeu_epi64(result, mask, result_i64x8);
|
|
907
|
+
result += 8;
|
|
908
|
+
if (n) goto nk_each_scale_i64_skylake_cycle;
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
NK_PUBLIC void nk_each_fma_i64_skylake( //
|
|
912
|
+
nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n, //
|
|
913
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result) {
|
|
914
|
+
nk_f64_t alpha_val = *alpha;
|
|
915
|
+
nk_f64_t beta_val = *beta;
|
|
916
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
917
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
918
|
+
__m512i a_i64x8, b_i64x8, c_i64x8, result_i64x8;
|
|
919
|
+
__m512d a_f64x8, b_f64x8, c_f64x8, ab_f64x8, ab_scaled_f64x8, result_f64x8;
|
|
920
|
+
__mmask8 mask = 0xFF;
|
|
921
|
+
nk_each_fma_i64_skylake_cycle:
|
|
922
|
+
if (n < 8) {
|
|
923
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
924
|
+
a_i64x8 = _mm512_maskz_loadu_epi64(mask, a);
|
|
925
|
+
b_i64x8 = _mm512_maskz_loadu_epi64(mask, b);
|
|
926
|
+
c_i64x8 = _mm512_maskz_loadu_epi64(mask, c);
|
|
927
|
+
n = 0;
|
|
928
|
+
}
|
|
929
|
+
else {
|
|
930
|
+
a_i64x8 = _mm512_loadu_si512((__m512i *)a);
|
|
931
|
+
b_i64x8 = _mm512_loadu_si512((__m512i *)b);
|
|
932
|
+
c_i64x8 = _mm512_loadu_si512((__m512i *)c);
|
|
933
|
+
a += 8, b += 8, c += 8, n -= 8;
|
|
934
|
+
}
|
|
935
|
+
a_f64x8 = _mm512_cvtepi64_pd(a_i64x8);
|
|
936
|
+
b_f64x8 = _mm512_cvtepi64_pd(b_i64x8);
|
|
937
|
+
c_f64x8 = _mm512_cvtepi64_pd(c_i64x8);
|
|
938
|
+
ab_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
939
|
+
ab_scaled_f64x8 = _mm512_mul_pd(ab_f64x8, alpha_f64x8);
|
|
940
|
+
result_f64x8 = _mm512_fmadd_pd(c_f64x8, beta_f64x8, ab_scaled_f64x8);
|
|
941
|
+
result_i64x8 = _mm512_cvtpd_epi64(result_f64x8);
|
|
942
|
+
_mm512_mask_storeu_epi64(result, mask, result_i64x8);
|
|
943
|
+
result += 8;
|
|
944
|
+
if (n) goto nk_each_fma_i64_skylake_cycle;
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
NK_PUBLIC void nk_each_scale_u64_skylake(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
948
|
+
nk_u64_t *result) {
|
|
949
|
+
nk_f64_t alpha_val = *alpha;
|
|
950
|
+
nk_f64_t beta_val = *beta;
|
|
951
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
952
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
953
|
+
__m512i a_u64x8;
|
|
954
|
+
__m512d a_f64x8, result_f64x8;
|
|
955
|
+
__m512i result_u64x8;
|
|
956
|
+
__mmask8 mask = 0xFF;
|
|
957
|
+
|
|
958
|
+
nk_each_scale_u64_skylake_cycle:
|
|
959
|
+
if (n < 8) {
|
|
960
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
961
|
+
a_u64x8 = _mm512_maskz_loadu_epi64(mask, a);
|
|
962
|
+
n = 0;
|
|
963
|
+
}
|
|
964
|
+
else {
|
|
965
|
+
a_u64x8 = _mm512_loadu_si512((__m512i *)a);
|
|
966
|
+
a += 8, n -= 8;
|
|
967
|
+
}
|
|
968
|
+
a_f64x8 = _mm512_cvtepu64_pd(a_u64x8);
|
|
969
|
+
result_f64x8 = _mm512_fmadd_pd(a_f64x8, alpha_f64x8, beta_f64x8);
|
|
970
|
+
result_u64x8 = _mm512_cvtpd_epu64(result_f64x8);
|
|
971
|
+
_mm512_mask_storeu_epi64(result, mask, result_u64x8);
|
|
972
|
+
result += 8;
|
|
973
|
+
if (n) goto nk_each_scale_u64_skylake_cycle;
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
NK_PUBLIC void nk_each_fma_u64_skylake( //
|
|
977
|
+
nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n, //
|
|
978
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result) {
|
|
979
|
+
nk_f64_t alpha_val = *alpha;
|
|
980
|
+
nk_f64_t beta_val = *beta;
|
|
981
|
+
__m512d alpha_f64x8 = _mm512_set1_pd(alpha_val);
|
|
982
|
+
__m512d beta_f64x8 = _mm512_set1_pd(beta_val);
|
|
983
|
+
__m512i a_u64x8, b_u64x8, c_u64x8, result_u64x8;
|
|
984
|
+
__m512d a_f64x8, b_f64x8, c_f64x8, ab_f64x8, ab_scaled_f64x8, result_f64x8;
|
|
985
|
+
__mmask8 mask = 0xFF;
|
|
986
|
+
nk_each_fma_u64_skylake_cycle:
|
|
987
|
+
if (n < 8) {
|
|
988
|
+
mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
989
|
+
a_u64x8 = _mm512_maskz_loadu_epi64(mask, a);
|
|
990
|
+
b_u64x8 = _mm512_maskz_loadu_epi64(mask, b);
|
|
991
|
+
c_u64x8 = _mm512_maskz_loadu_epi64(mask, c);
|
|
992
|
+
n = 0;
|
|
993
|
+
}
|
|
994
|
+
else {
|
|
995
|
+
a_u64x8 = _mm512_loadu_si512((__m512i *)a);
|
|
996
|
+
b_u64x8 = _mm512_loadu_si512((__m512i *)b);
|
|
997
|
+
c_u64x8 = _mm512_loadu_si512((__m512i *)c);
|
|
998
|
+
a += 8, b += 8, c += 8, n -= 8;
|
|
999
|
+
}
|
|
1000
|
+
a_f64x8 = _mm512_cvtepu64_pd(a_u64x8);
|
|
1001
|
+
b_f64x8 = _mm512_cvtepu64_pd(b_u64x8);
|
|
1002
|
+
c_f64x8 = _mm512_cvtepu64_pd(c_u64x8);
|
|
1003
|
+
ab_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
1004
|
+
ab_scaled_f64x8 = _mm512_mul_pd(ab_f64x8, alpha_f64x8);
|
|
1005
|
+
result_f64x8 = _mm512_fmadd_pd(c_f64x8, beta_f64x8, ab_scaled_f64x8);
|
|
1006
|
+
result_u64x8 = _mm512_cvtpd_epu64(result_f64x8);
|
|
1007
|
+
_mm512_mask_storeu_epi64(result, mask, result_u64x8);
|
|
1008
|
+
result += 8;
|
|
1009
|
+
if (n) goto nk_each_fma_u64_skylake_cycle;
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
NK_PUBLIC void nk_each_sum_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
|
|
1013
|
+
__m128i a_e4m3x16, b_e4m3x16, result_e4m3x16;
|
|
1014
|
+
__m512 a_f32x16, b_f32x16, result_f32x16;
|
|
1015
|
+
__mmask16 mask = 0xFFFF;
|
|
1016
|
+
nk_each_sum_e4m3_skylake_cycle:
|
|
1017
|
+
if (n < 16) {
|
|
1018
|
+
mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
1019
|
+
a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
1020
|
+
b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
1021
|
+
n = 0;
|
|
1022
|
+
}
|
|
1023
|
+
else {
|
|
1024
|
+
a_e4m3x16 = _mm_loadu_si128((__m128i const *)a);
|
|
1025
|
+
b_e4m3x16 = _mm_loadu_si128((__m128i const *)b);
|
|
1026
|
+
a += 16, b += 16, n -= 16;
|
|
1027
|
+
}
|
|
1028
|
+
a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
|
|
1029
|
+
b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
|
|
1030
|
+
result_f32x16 = _mm512_add_ps(a_f32x16, b_f32x16);
|
|
1031
|
+
result_e4m3x16 = nk_f32x16_to_e4m3x16_skylake_(result_f32x16);
|
|
1032
|
+
_mm_mask_storeu_epi8(result, mask, result_e4m3x16);
|
|
1033
|
+
result += 16;
|
|
1034
|
+
if (n) goto nk_each_sum_e4m3_skylake_cycle;
|
|
1035
|
+
}
|
|
1036
|
+
|
|
1037
|
+
NK_PUBLIC void nk_each_sum_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result) {
|
|
1038
|
+
__m128i a_e5m2x16, b_e5m2x16, result_e5m2x16;
|
|
1039
|
+
__m512 a_f32x16, b_f32x16, result_f32x16;
|
|
1040
|
+
__mmask16 mask = 0xFFFF;
|
|
1041
|
+
nk_each_sum_e5m2_skylake_cycle:
|
|
1042
|
+
if (n < 16) {
|
|
1043
|
+
mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
1044
|
+
a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
1045
|
+
b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
1046
|
+
n = 0;
|
|
1047
|
+
}
|
|
1048
|
+
else {
|
|
1049
|
+
a_e5m2x16 = _mm_loadu_si128((__m128i const *)a);
|
|
1050
|
+
b_e5m2x16 = _mm_loadu_si128((__m128i const *)b);
|
|
1051
|
+
a += 16, b += 16, n -= 16;
|
|
1052
|
+
}
|
|
1053
|
+
a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
|
|
1054
|
+
b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
|
|
1055
|
+
result_f32x16 = _mm512_add_ps(a_f32x16, b_f32x16);
|
|
1056
|
+
result_e5m2x16 = nk_f32x16_to_e5m2x16_skylake_(result_f32x16);
|
|
1057
|
+
_mm_mask_storeu_epi8(result, mask, result_e5m2x16);
|
|
1058
|
+
result += 16;
|
|
1059
|
+
if (n) goto nk_each_sum_e5m2_skylake_cycle;
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
NK_PUBLIC void nk_each_scale_e4m3_skylake(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1063
|
+
nk_e4m3_t *result) {
|
|
1064
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(*alpha);
|
|
1065
|
+
__m512 beta_f32x16 = _mm512_set1_ps(*beta);
|
|
1066
|
+
__m128i a_e4m3x16, result_e4m3x16;
|
|
1067
|
+
__m512 a_f32x16, result_f32x16;
|
|
1068
|
+
__mmask16 mask = 0xFFFF;
|
|
1069
|
+
nk_each_scale_e4m3_skylake_cycle:
|
|
1070
|
+
if (n < 16) {
|
|
1071
|
+
mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
1072
|
+
a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
1073
|
+
n = 0;
|
|
1074
|
+
}
|
|
1075
|
+
else {
|
|
1076
|
+
a_e4m3x16 = _mm_loadu_si128((__m128i const *)a);
|
|
1077
|
+
a += 16, n -= 16;
|
|
1078
|
+
}
|
|
1079
|
+
a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
|
|
1080
|
+
// FP8 rounding note: FMA is acceptable here because scale computes (α × a + β),
|
|
1081
|
+
// a single multiply-add operation where single-rounding preserves accuracy.
|
|
1082
|
+
result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
1083
|
+
result_e4m3x16 = nk_f32x16_to_e4m3x16_skylake_(result_f32x16);
|
|
1084
|
+
_mm_mask_storeu_epi8(result, mask, result_e4m3x16);
|
|
1085
|
+
result += 16;
|
|
1086
|
+
if (n) goto nk_each_scale_e4m3_skylake_cycle;
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
NK_PUBLIC void nk_each_scale_e5m2_skylake(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1090
|
+
nk_e5m2_t *result) {
|
|
1091
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(*alpha);
|
|
1092
|
+
__m512 beta_f32x16 = _mm512_set1_ps(*beta);
|
|
1093
|
+
__m128i a_e5m2x16, result_e5m2x16;
|
|
1094
|
+
__m512 a_f32x16, result_f32x16;
|
|
1095
|
+
__mmask16 mask = 0xFFFF;
|
|
1096
|
+
nk_each_scale_e5m2_skylake_cycle:
|
|
1097
|
+
if (n < 16) {
|
|
1098
|
+
mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
1099
|
+
a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
1100
|
+
n = 0;
|
|
1101
|
+
}
|
|
1102
|
+
else {
|
|
1103
|
+
a_e5m2x16 = _mm_loadu_si128((__m128i const *)a);
|
|
1104
|
+
a += 16, n -= 16;
|
|
1105
|
+
}
|
|
1106
|
+
a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
|
|
1107
|
+
// FP8 rounding note: FMA is acceptable here because scale computes (α × a + β),
|
|
1108
|
+
// a single multiply-add operation where single-rounding preserves accuracy.
|
|
1109
|
+
result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
1110
|
+
result_e5m2x16 = nk_f32x16_to_e5m2x16_skylake_(result_f32x16);
|
|
1111
|
+
_mm_mask_storeu_epi8(result, mask, result_e5m2x16);
|
|
1112
|
+
result += 16;
|
|
1113
|
+
if (n) goto nk_each_scale_e5m2_skylake_cycle;
|
|
1114
|
+
}
|
|
1115
|
+
|
|
1116
|
+
NK_PUBLIC void nk_each_blend_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1117
|
+
nk_f32_t const *beta, nk_e4m3_t *result) {
|
|
1118
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(*alpha);
|
|
1119
|
+
__m512 beta_f32x16 = _mm512_set1_ps(*beta);
|
|
1120
|
+
__m128i a_e4m3x16, b_e4m3x16, result_e4m3x16;
|
|
1121
|
+
__m512 a_f32x16, b_f32x16, a_scaled_f32x16, result_f32x16;
|
|
1122
|
+
__mmask16 mask = 0xFFFF;
|
|
1123
|
+
nk_each_blend_e4m3_skylake_cycle:
|
|
1124
|
+
if (n < 16) {
|
|
1125
|
+
mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
1126
|
+
a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
1127
|
+
b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
1128
|
+
n = 0;
|
|
1129
|
+
}
|
|
1130
|
+
else {
|
|
1131
|
+
a_e4m3x16 = _mm_loadu_si128((__m128i const *)a);
|
|
1132
|
+
b_e4m3x16 = _mm_loadu_si128((__m128i const *)b);
|
|
1133
|
+
a += 16, b += 16, n -= 16;
|
|
1134
|
+
}
|
|
1135
|
+
a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
|
|
1136
|
+
b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
|
|
1137
|
+
a_scaled_f32x16 = _mm512_mul_ps(a_f32x16, alpha_f32x16);
|
|
1138
|
+
result_f32x16 = _mm512_fmadd_ps(b_f32x16, beta_f32x16, a_scaled_f32x16);
|
|
1139
|
+
result_e4m3x16 = nk_f32x16_to_e4m3x16_skylake_(result_f32x16);
|
|
1140
|
+
_mm_mask_storeu_epi8(result, mask, result_e4m3x16);
|
|
1141
|
+
result += 16;
|
|
1142
|
+
if (n) goto nk_each_blend_e4m3_skylake_cycle;
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
NK_PUBLIC void nk_each_blend_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1146
|
+
nk_f32_t const *beta, nk_e5m2_t *result) {
|
|
1147
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(*alpha);
|
|
1148
|
+
__m512 beta_f32x16 = _mm512_set1_ps(*beta);
|
|
1149
|
+
__m128i a_e5m2x16, b_e5m2x16, result_e5m2x16;
|
|
1150
|
+
__m512 a_f32x16, b_f32x16, a_scaled_f32x16, result_f32x16;
|
|
1151
|
+
__mmask16 mask = 0xFFFF;
|
|
1152
|
+
nk_each_blend_e5m2_skylake_cycle:
|
|
1153
|
+
if (n < 16) {
|
|
1154
|
+
mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
1155
|
+
a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
1156
|
+
b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
1157
|
+
n = 0;
|
|
1158
|
+
}
|
|
1159
|
+
else {
|
|
1160
|
+
a_e5m2x16 = _mm_loadu_si128((__m128i const *)a);
|
|
1161
|
+
b_e5m2x16 = _mm_loadu_si128((__m128i const *)b);
|
|
1162
|
+
a += 16, b += 16, n -= 16;
|
|
1163
|
+
}
|
|
1164
|
+
a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
|
|
1165
|
+
b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
|
|
1166
|
+
a_scaled_f32x16 = _mm512_mul_ps(a_f32x16, alpha_f32x16);
|
|
1167
|
+
result_f32x16 = _mm512_fmadd_ps(b_f32x16, beta_f32x16, a_scaled_f32x16);
|
|
1168
|
+
result_e5m2x16 = nk_f32x16_to_e5m2x16_skylake_(result_f32x16);
|
|
1169
|
+
_mm_mask_storeu_epi8(result, mask, result_e5m2x16);
|
|
1170
|
+
result += 16;
|
|
1171
|
+
if (n) goto nk_each_blend_e5m2_skylake_cycle;
|
|
1172
|
+
}
|
|
1173
|
+
|
|
1174
|
+
NK_PUBLIC void nk_each_fma_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
1175
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result) {
|
|
1176
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(*alpha);
|
|
1177
|
+
__m512 beta_f32x16 = _mm512_set1_ps(*beta);
|
|
1178
|
+
__m128i a_e4m3x16, b_e4m3x16, c_e4m3x16, result_e4m3x16;
|
|
1179
|
+
__m512 a_f32x16, b_f32x16, c_f32x16, ab_f32x16, ab_scaled_f32x16, result_f32x16;
|
|
1180
|
+
__mmask16 mask = 0xFFFF;
|
|
1181
|
+
nk_each_fma_e4m3_skylake_cycle:
|
|
1182
|
+
if (n < 16) {
|
|
1183
|
+
mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
1184
|
+
a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
1185
|
+
b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
1186
|
+
c_e4m3x16 = _mm_maskz_loadu_epi8(mask, c);
|
|
1187
|
+
n = 0;
|
|
1188
|
+
}
|
|
1189
|
+
else {
|
|
1190
|
+
a_e4m3x16 = _mm_loadu_si128((__m128i const *)a);
|
|
1191
|
+
b_e4m3x16 = _mm_loadu_si128((__m128i const *)b);
|
|
1192
|
+
c_e4m3x16 = _mm_loadu_si128((__m128i const *)c);
|
|
1193
|
+
a += 16, b += 16, c += 16, n -= 16;
|
|
1194
|
+
}
|
|
1195
|
+
a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
|
|
1196
|
+
b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
|
|
1197
|
+
c_f32x16 = nk_e4m3x16_to_f32x16_skylake_(c_e4m3x16);
|
|
1198
|
+
// FP8 rounding note: Hybrid approach - use separate MUL for (a × b) and (α × a × b) to
|
|
1199
|
+
// preserve intermediate rounding, then FMA for final addition since it matches scalar
|
|
1200
|
+
// semantics of (α × a × b + β × c) when the multiply term is already computed.
|
|
1201
|
+
ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
1202
|
+
ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
1203
|
+
result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
1204
|
+
result_e4m3x16 = nk_f32x16_to_e4m3x16_skylake_(result_f32x16);
|
|
1205
|
+
_mm_mask_storeu_epi8(result, mask, result_e4m3x16);
|
|
1206
|
+
result += 16;
|
|
1207
|
+
if (n) goto nk_each_fma_e4m3_skylake_cycle;
|
|
1208
|
+
}
|
|
1209
|
+
|
|
1210
|
+
NK_PUBLIC void nk_each_fma_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
1211
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result) {
|
|
1212
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(*alpha);
|
|
1213
|
+
__m512 beta_f32x16 = _mm512_set1_ps(*beta);
|
|
1214
|
+
__m128i a_e5m2x16, b_e5m2x16, c_e5m2x16, result_e5m2x16;
|
|
1215
|
+
__m512 a_f32x16, b_f32x16, c_f32x16, ab_f32x16, ab_scaled_f32x16, result_f32x16;
|
|
1216
|
+
__mmask16 mask = 0xFFFF;
|
|
1217
|
+
nk_each_fma_e5m2_skylake_cycle:
|
|
1218
|
+
if (n < 16) {
|
|
1219
|
+
mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
|
|
1220
|
+
a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
1221
|
+
b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
1222
|
+
c_e5m2x16 = _mm_maskz_loadu_epi8(mask, c);
|
|
1223
|
+
n = 0;
|
|
1224
|
+
}
|
|
1225
|
+
else {
|
|
1226
|
+
a_e5m2x16 = _mm_loadu_si128((__m128i const *)a);
|
|
1227
|
+
b_e5m2x16 = _mm_loadu_si128((__m128i const *)b);
|
|
1228
|
+
c_e5m2x16 = _mm_loadu_si128((__m128i const *)c);
|
|
1229
|
+
a += 16, b += 16, c += 16, n -= 16;
|
|
1230
|
+
}
|
|
1231
|
+
a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
|
|
1232
|
+
b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
|
|
1233
|
+
c_f32x16 = nk_e5m2x16_to_f32x16_skylake_(c_e5m2x16);
|
|
1234
|
+
// FP8 rounding note: Hybrid approach - use separate MUL for (a × b) and (α × a × b) to
|
|
1235
|
+
// preserve intermediate rounding, then FMA for final addition since it matches scalar
|
|
1236
|
+
// semantics of (α × a × b + β × c) when the multiply term is already computed.
|
|
1237
|
+
ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
1238
|
+
ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
1239
|
+
result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
1240
|
+
result_e5m2x16 = nk_f32x16_to_e5m2x16_skylake_(result_f32x16);
|
|
1241
|
+
_mm_mask_storeu_epi8(result, mask, result_e5m2x16);
|
|
1242
|
+
result += 16;
|
|
1243
|
+
if (n) goto nk_each_fma_e5m2_skylake_cycle;
|
|
1244
|
+
}
|
|
1245
|
+
|
|
1246
|
+
NK_PUBLIC void nk_each_scale_f32c_skylake(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha,
|
|
1247
|
+
nk_f32c_t const *beta, nk_f32c_t *result) {
|
|
1248
|
+
nk_f32_t const *a_f32 = (nk_f32_t const *)a;
|
|
1249
|
+
nk_f32_t *result_f32 = (nk_f32_t *)result;
|
|
1250
|
+
__m512 alpha_real_f32x16 = _mm512_set1_ps(alpha->real);
|
|
1251
|
+
__m512 alpha_imag_f32x16 = _mm512_set1_ps(alpha->imag);
|
|
1252
|
+
__m512 beta_f32x16 = _mm512_set_ps(beta->imag, beta->real, beta->imag, beta->real, beta->imag, beta->real,
|
|
1253
|
+
beta->imag, beta->real, beta->imag, beta->real, beta->imag, beta->real,
|
|
1254
|
+
beta->imag, beta->real, beta->imag, beta->real);
|
|
1255
|
+
nk_size_t i = 0;
|
|
1256
|
+
for (; i + 8 <= n; i += 8) {
|
|
1257
|
+
__m512 a_f32x16 = _mm512_loadu_ps(a_f32 + 2 * i);
|
|
1258
|
+
__m512 a_swapped_f32x16 = _mm512_permute_ps(a_f32x16, 0xB1);
|
|
1259
|
+
__m512 temp_f32x16 = _mm512_mul_ps(alpha_imag_f32x16, a_swapped_f32x16);
|
|
1260
|
+
__m512 y_f32x16 = _mm512_fmaddsub_ps(alpha_real_f32x16, a_f32x16, temp_f32x16);
|
|
1261
|
+
y_f32x16 = _mm512_add_ps(y_f32x16, beta_f32x16);
|
|
1262
|
+
_mm512_storeu_ps(result_f32 + 2 * i, y_f32x16);
|
|
1263
|
+
}
|
|
1264
|
+
for (; i < n; i++) {
|
|
1265
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1266
|
+
result[i].real = alpha->real * a_real - alpha->imag * a_imag + beta->real;
|
|
1267
|
+
result[i].imag = alpha->real * a_imag + alpha->imag * a_real + beta->imag;
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1270
|
+
|
|
1271
|
+
NK_PUBLIC void nk_each_scale_f64c_skylake(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha,
|
|
1272
|
+
nk_f64c_t const *beta, nk_f64c_t *result) {
|
|
1273
|
+
nk_f64_t const *a_f64 = (nk_f64_t const *)a;
|
|
1274
|
+
nk_f64_t *result_f64 = (nk_f64_t *)result;
|
|
1275
|
+
__m512d alpha_real_f64x8 = _mm512_set1_pd(alpha->real);
|
|
1276
|
+
__m512d alpha_imag_f64x8 = _mm512_set1_pd(alpha->imag);
|
|
1277
|
+
__m512d beta_f64x8 = _mm512_set_pd(beta->imag, beta->real, beta->imag, beta->real, beta->imag, beta->real,
|
|
1278
|
+
beta->imag, beta->real);
|
|
1279
|
+
nk_size_t i = 0;
|
|
1280
|
+
for (; i + 4 <= n; i += 4) {
|
|
1281
|
+
__m512d a_f64x8 = _mm512_loadu_pd(a_f64 + 2 * i);
|
|
1282
|
+
__m512d a_swapped_f64x8 = _mm512_permute_pd(a_f64x8, 0x55);
|
|
1283
|
+
__m512d temp_f64x8 = _mm512_mul_pd(alpha_imag_f64x8, a_swapped_f64x8);
|
|
1284
|
+
__m512d y_f64x8 = _mm512_fmaddsub_pd(alpha_real_f64x8, a_f64x8, temp_f64x8);
|
|
1285
|
+
y_f64x8 = _mm512_add_pd(y_f64x8, beta_f64x8);
|
|
1286
|
+
_mm512_storeu_pd(result_f64 + 2 * i, y_f64x8);
|
|
1287
|
+
}
|
|
1288
|
+
for (; i < n; i++) {
|
|
1289
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1290
|
+
result[i].real = alpha->real * a_real - alpha->imag * a_imag + beta->real;
|
|
1291
|
+
result[i].imag = alpha->real * a_imag + alpha->imag * a_real + beta->imag;
|
|
1292
|
+
}
|
|
1293
|
+
}
|
|
1294
|
+
|
|
1295
|
+
NK_PUBLIC void nk_each_blend_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
1296
|
+
nk_f32c_t const *beta, nk_f32c_t *result) {
|
|
1297
|
+
nk_f32_t const *a_f32 = (nk_f32_t const *)a;
|
|
1298
|
+
nk_f32_t const *b_f32 = (nk_f32_t const *)b;
|
|
1299
|
+
nk_f32_t *result_f32 = (nk_f32_t *)result;
|
|
1300
|
+
__m512 alpha_real_f32x16 = _mm512_set1_ps(alpha->real);
|
|
1301
|
+
__m512 alpha_imag_f32x16 = _mm512_set1_ps(alpha->imag);
|
|
1302
|
+
__m512 beta_real_f32x16 = _mm512_set1_ps(beta->real);
|
|
1303
|
+
__m512 beta_imag_f32x16 = _mm512_set1_ps(beta->imag);
|
|
1304
|
+
nk_size_t i = 0;
|
|
1305
|
+
for (; i + 8 <= n; i += 8) {
|
|
1306
|
+
__m512 a_f32x16 = _mm512_loadu_ps(a_f32 + 2 * i);
|
|
1307
|
+
__m512 b_f32x16 = _mm512_loadu_ps(b_f32 + 2 * i);
|
|
1308
|
+
__m512 a_swapped_f32x16 = _mm512_permute_ps(a_f32x16, 0xB1);
|
|
1309
|
+
__m512 ta_f32x16 = _mm512_mul_ps(alpha_imag_f32x16, a_swapped_f32x16);
|
|
1310
|
+
__m512 ya_f32x16 = _mm512_fmaddsub_ps(alpha_real_f32x16, a_f32x16, ta_f32x16);
|
|
1311
|
+
__m512 b_swapped_f32x16 = _mm512_permute_ps(b_f32x16, 0xB1);
|
|
1312
|
+
__m512 tb_f32x16 = _mm512_mul_ps(beta_imag_f32x16, b_swapped_f32x16);
|
|
1313
|
+
__m512 yb_f32x16 = _mm512_fmaddsub_ps(beta_real_f32x16, b_f32x16, tb_f32x16);
|
|
1314
|
+
_mm512_storeu_ps(result_f32 + 2 * i, _mm512_add_ps(ya_f32x16, yb_f32x16));
|
|
1315
|
+
}
|
|
1316
|
+
for (; i < n; i++) {
|
|
1317
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1318
|
+
nk_f32_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1319
|
+
nk_f32_t ar = alpha->real * a_real - alpha->imag * a_imag;
|
|
1320
|
+
nk_f32_t ai = alpha->real * a_imag + alpha->imag * a_real;
|
|
1321
|
+
nk_f32_t br = beta->real * b_real - beta->imag * b_imag;
|
|
1322
|
+
nk_f32_t bi = beta->real * b_imag + beta->imag * b_real;
|
|
1323
|
+
result[i].real = ar + br;
|
|
1324
|
+
result[i].imag = ai + bi;
|
|
1325
|
+
}
|
|
1326
|
+
}
|
|
1327
|
+
|
|
1328
|
+
NK_PUBLIC void nk_each_blend_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
1329
|
+
nk_f64c_t const *beta, nk_f64c_t *result) {
|
|
1330
|
+
nk_f64_t const *a_f64 = (nk_f64_t const *)a;
|
|
1331
|
+
nk_f64_t const *b_f64 = (nk_f64_t const *)b;
|
|
1332
|
+
nk_f64_t *result_f64 = (nk_f64_t *)result;
|
|
1333
|
+
__m512d alpha_real_f64x8 = _mm512_set1_pd(alpha->real);
|
|
1334
|
+
__m512d alpha_imag_f64x8 = _mm512_set1_pd(alpha->imag);
|
|
1335
|
+
__m512d beta_real_f64x8 = _mm512_set1_pd(beta->real);
|
|
1336
|
+
__m512d beta_imag_f64x8 = _mm512_set1_pd(beta->imag);
|
|
1337
|
+
nk_size_t i = 0;
|
|
1338
|
+
for (; i + 4 <= n; i += 4) {
|
|
1339
|
+
__m512d a_f64x8 = _mm512_loadu_pd(a_f64 + 2 * i);
|
|
1340
|
+
__m512d b_f64x8 = _mm512_loadu_pd(b_f64 + 2 * i);
|
|
1341
|
+
__m512d a_swapped_f64x8 = _mm512_permute_pd(a_f64x8, 0x55);
|
|
1342
|
+
__m512d ta_f64x8 = _mm512_mul_pd(alpha_imag_f64x8, a_swapped_f64x8);
|
|
1343
|
+
__m512d ya_f64x8 = _mm512_fmaddsub_pd(alpha_real_f64x8, a_f64x8, ta_f64x8);
|
|
1344
|
+
__m512d b_swapped_f64x8 = _mm512_permute_pd(b_f64x8, 0x55);
|
|
1345
|
+
__m512d tb_f64x8 = _mm512_mul_pd(beta_imag_f64x8, b_swapped_f64x8);
|
|
1346
|
+
__m512d yb_f64x8 = _mm512_fmaddsub_pd(beta_real_f64x8, b_f64x8, tb_f64x8);
|
|
1347
|
+
_mm512_storeu_pd(result_f64 + 2 * i, _mm512_add_pd(ya_f64x8, yb_f64x8));
|
|
1348
|
+
}
|
|
1349
|
+
for (; i < n; i++) {
|
|
1350
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1351
|
+
nk_f64_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1352
|
+
nk_f64_t ar = alpha->real * a_real - alpha->imag * a_imag;
|
|
1353
|
+
nk_f64_t ai = alpha->real * a_imag + alpha->imag * a_real;
|
|
1354
|
+
nk_f64_t br = beta->real * b_real - beta->imag * b_imag;
|
|
1355
|
+
nk_f64_t bi = beta->real * b_imag + beta->imag * b_real;
|
|
1356
|
+
result[i].real = ar + br;
|
|
1357
|
+
result[i].imag = ai + bi;
|
|
1358
|
+
}
|
|
1359
|
+
}
|
|
1360
|
+
|
|
1361
|
+
NK_PUBLIC void nk_each_fma_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
1362
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result) {
|
|
1363
|
+
nk_f32_t const *a_f32 = (nk_f32_t const *)a;
|
|
1364
|
+
nk_f32_t const *b_f32 = (nk_f32_t const *)b;
|
|
1365
|
+
nk_f32_t const *c_f32 = (nk_f32_t const *)c;
|
|
1366
|
+
nk_f32_t *result_f32 = (nk_f32_t *)result;
|
|
1367
|
+
__m512 alpha_real_f32x16 = _mm512_set1_ps(alpha->real);
|
|
1368
|
+
__m512 alpha_imag_f32x16 = _mm512_set1_ps(alpha->imag);
|
|
1369
|
+
__m512 beta_real_f32x16 = _mm512_set1_ps(beta->real);
|
|
1370
|
+
__m512 beta_imag_f32x16 = _mm512_set1_ps(beta->imag);
|
|
1371
|
+
nk_size_t i = 0;
|
|
1372
|
+
for (; i + 8 <= n; i += 8) {
|
|
1373
|
+
__m512 a_f32x16 = _mm512_loadu_ps(a_f32 + 2 * i);
|
|
1374
|
+
__m512 b_f32x16 = _mm512_loadu_ps(b_f32 + 2 * i);
|
|
1375
|
+
__m512 c_f32x16 = _mm512_loadu_ps(c_f32 + 2 * i);
|
|
1376
|
+
__m512 b_swapped_f32x16 = _mm512_permute_ps(b_f32x16, 0xB1);
|
|
1377
|
+
__m512 a_real_f32x16 = _mm512_moveldup_ps(a_f32x16);
|
|
1378
|
+
__m512 a_imag_f32x16 = _mm512_movehdup_ps(a_f32x16);
|
|
1379
|
+
__m512 tab_f32x16 = _mm512_mul_ps(a_imag_f32x16, b_swapped_f32x16);
|
|
1380
|
+
__m512 ab_f32x16 = _mm512_fmaddsub_ps(a_real_f32x16, b_f32x16, tab_f32x16);
|
|
1381
|
+
__m512 ab_swapped_f32x16 = _mm512_permute_ps(ab_f32x16, 0xB1);
|
|
1382
|
+
__m512 taa_f32x16 = _mm512_mul_ps(alpha_imag_f32x16, ab_swapped_f32x16);
|
|
1383
|
+
__m512 ya_f32x16 = _mm512_fmaddsub_ps(alpha_real_f32x16, ab_f32x16, taa_f32x16);
|
|
1384
|
+
__m512 c_swapped_f32x16 = _mm512_permute_ps(c_f32x16, 0xB1);
|
|
1385
|
+
__m512 tbc_f32x16 = _mm512_mul_ps(beta_imag_f32x16, c_swapped_f32x16);
|
|
1386
|
+
__m512 yb_f32x16 = _mm512_fmaddsub_ps(beta_real_f32x16, c_f32x16, tbc_f32x16);
|
|
1387
|
+
_mm512_storeu_ps(result_f32 + 2 * i, _mm512_add_ps(ya_f32x16, yb_f32x16));
|
|
1388
|
+
}
|
|
1389
|
+
for (; i < n; i++) {
|
|
1390
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1391
|
+
nk_f32_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1392
|
+
nk_f32_t c_real = c[i].real, c_imag = c[i].imag;
|
|
1393
|
+
nk_f32_t ab_real = a_real * b_real - a_imag * b_imag;
|
|
1394
|
+
nk_f32_t ab_imag = a_real * b_imag + a_imag * b_real;
|
|
1395
|
+
nk_f32_t aab_real = alpha->real * ab_real - alpha->imag * ab_imag;
|
|
1396
|
+
nk_f32_t aab_imag = alpha->real * ab_imag + alpha->imag * ab_real;
|
|
1397
|
+
nk_f32_t bc_real = beta->real * c_real - beta->imag * c_imag;
|
|
1398
|
+
nk_f32_t bc_imag = beta->real * c_imag + beta->imag * c_real;
|
|
1399
|
+
result[i].real = aab_real + bc_real;
|
|
1400
|
+
result[i].imag = aab_imag + bc_imag;
|
|
1401
|
+
}
|
|
1402
|
+
}
|
|
1403
|
+
|
|
1404
|
+
NK_PUBLIC void nk_each_fma_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
1405
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result) {
|
|
1406
|
+
nk_f64_t const *a_f64 = (nk_f64_t const *)a;
|
|
1407
|
+
nk_f64_t const *b_f64 = (nk_f64_t const *)b;
|
|
1408
|
+
nk_f64_t const *c_f64 = (nk_f64_t const *)c;
|
|
1409
|
+
nk_f64_t *result_f64 = (nk_f64_t *)result;
|
|
1410
|
+
__m512d alpha_real_f64x8 = _mm512_set1_pd(alpha->real);
|
|
1411
|
+
__m512d alpha_imag_f64x8 = _mm512_set1_pd(alpha->imag);
|
|
1412
|
+
__m512d beta_real_f64x8 = _mm512_set1_pd(beta->real);
|
|
1413
|
+
__m512d beta_imag_f64x8 = _mm512_set1_pd(beta->imag);
|
|
1414
|
+
nk_size_t i = 0;
|
|
1415
|
+
for (; i + 4 <= n; i += 4) {
|
|
1416
|
+
__m512d a_f64x8 = _mm512_loadu_pd(a_f64 + 2 * i);
|
|
1417
|
+
__m512d b_f64x8 = _mm512_loadu_pd(b_f64 + 2 * i);
|
|
1418
|
+
__m512d c_f64x8 = _mm512_loadu_pd(c_f64 + 2 * i);
|
|
1419
|
+
__m512d b_swapped_f64x8 = _mm512_permute_pd(b_f64x8, 0x55);
|
|
1420
|
+
__m512d a_real_f64x8 = _mm512_unpacklo_pd(a_f64x8, a_f64x8);
|
|
1421
|
+
__m512d a_imag_f64x8 = _mm512_unpackhi_pd(a_f64x8, a_f64x8);
|
|
1422
|
+
__m512d tab_f64x8 = _mm512_mul_pd(a_imag_f64x8, b_swapped_f64x8);
|
|
1423
|
+
__m512d ab_f64x8 = _mm512_fmaddsub_pd(a_real_f64x8, b_f64x8, tab_f64x8);
|
|
1424
|
+
__m512d ab_swapped_f64x8 = _mm512_permute_pd(ab_f64x8, 0x55);
|
|
1425
|
+
__m512d taa_f64x8 = _mm512_mul_pd(alpha_imag_f64x8, ab_swapped_f64x8);
|
|
1426
|
+
__m512d ya_f64x8 = _mm512_fmaddsub_pd(alpha_real_f64x8, ab_f64x8, taa_f64x8);
|
|
1427
|
+
__m512d c_swapped_f64x8 = _mm512_permute_pd(c_f64x8, 0x55);
|
|
1428
|
+
__m512d tbc_f64x8 = _mm512_mul_pd(beta_imag_f64x8, c_swapped_f64x8);
|
|
1429
|
+
__m512d yb_f64x8 = _mm512_fmaddsub_pd(beta_real_f64x8, c_f64x8, tbc_f64x8);
|
|
1430
|
+
_mm512_storeu_pd(result_f64 + 2 * i, _mm512_add_pd(ya_f64x8, yb_f64x8));
|
|
1431
|
+
}
|
|
1432
|
+
for (; i < n; i++) {
|
|
1433
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1434
|
+
nk_f64_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1435
|
+
nk_f64_t c_real = c[i].real, c_imag = c[i].imag;
|
|
1436
|
+
nk_f64_t ab_real = a_real * b_real - a_imag * b_imag;
|
|
1437
|
+
nk_f64_t ab_imag = a_real * b_imag + a_imag * b_real;
|
|
1438
|
+
nk_f64_t aab_real = alpha->real * ab_real - alpha->imag * ab_imag;
|
|
1439
|
+
nk_f64_t aab_imag = alpha->real * ab_imag + alpha->imag * ab_real;
|
|
1440
|
+
nk_f64_t bc_real = beta->real * c_real - beta->imag * c_imag;
|
|
1441
|
+
nk_f64_t bc_imag = beta->real * c_imag + beta->imag * c_real;
|
|
1442
|
+
result[i].real = aab_real + bc_real;
|
|
1443
|
+
result[i].imag = aab_imag + bc_imag;
|
|
1444
|
+
}
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
NK_PUBLIC void nk_each_scale_f16_skylake(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1448
|
+
nk_f16_t *result) {
|
|
1449
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(*alpha);
|
|
1450
|
+
__m512 beta_f32x16 = _mm512_set1_ps(*beta);
|
|
1451
|
+
__m512 a_f32x16;
|
|
1452
|
+
nk_each_scale_f16_skylake_cycle:
|
|
1453
|
+
if (n < 16) {
|
|
1454
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
1455
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
|
|
1456
|
+
__m512 result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
1457
|
+
_mm256_mask_storeu_epi16(result, mask, _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT));
|
|
1458
|
+
n = 0;
|
|
1459
|
+
}
|
|
1460
|
+
else {
|
|
1461
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)a));
|
|
1462
|
+
__m512 result_f32x16 = _mm512_fmadd_ps(a_f32x16, alpha_f32x16, beta_f32x16);
|
|
1463
|
+
_mm256_storeu_si256((__m256i *)result, _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT));
|
|
1464
|
+
a += 16, result += 16, n -= 16;
|
|
1465
|
+
}
|
|
1466
|
+
if (n) goto nk_each_scale_f16_skylake_cycle;
|
|
1467
|
+
}
|
|
1468
|
+
|
|
1469
|
+
NK_PUBLIC void nk_each_blend_f16_skylake( //
|
|
1470
|
+
nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, //
|
|
1471
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result) {
|
|
1472
|
+
|
|
1473
|
+
nk_f32_t alpha_val = *alpha;
|
|
1474
|
+
nk_f32_t beta_val = *beta;
|
|
1475
|
+
|
|
1476
|
+
// There are several special cases we may want to implement:
|
|
1477
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
1478
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
1479
|
+
// In this case we can avoid expensive multiplications.
|
|
1480
|
+
nk_each_sum_f16_haswell(a, b, n, result);
|
|
1481
|
+
return;
|
|
1482
|
+
}
|
|
1483
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
1484
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
1485
|
+
// In this case we can avoid half of the load instructions.
|
|
1486
|
+
nk_f32_t zero = 0;
|
|
1487
|
+
if (beta_val == 0) { nk_each_scale_f16_skylake(a, n, alpha, &zero, result); }
|
|
1488
|
+
else { nk_each_scale_f16_skylake(b, n, beta, &zero, result); }
|
|
1489
|
+
return;
|
|
1490
|
+
}
|
|
1491
|
+
|
|
1492
|
+
// The general case: compute in f32 for precision (f16 products overflow at 127x127=16129)
|
|
1493
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(alpha_val);
|
|
1494
|
+
__m512 beta_f32x16 = _mm512_set1_ps(beta_val);
|
|
1495
|
+
__m512 a_f32x16, b_f32x16;
|
|
1496
|
+
nk_each_blend_f16_skylake_cycle:
|
|
1497
|
+
if (n < 16) {
|
|
1498
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
1499
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
|
|
1500
|
+
b_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
|
|
1501
|
+
__m512 a_scaled_f32x16 = _mm512_mul_ps(a_f32x16, alpha_f32x16);
|
|
1502
|
+
__m512 result_f32x16 = _mm512_fmadd_ps(b_f32x16, beta_f32x16, a_scaled_f32x16);
|
|
1503
|
+
_mm256_mask_storeu_epi16(result, mask, _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT));
|
|
1504
|
+
n = 0;
|
|
1505
|
+
}
|
|
1506
|
+
else {
|
|
1507
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)a));
|
|
1508
|
+
b_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)b));
|
|
1509
|
+
__m512 a_scaled_f32x16 = _mm512_mul_ps(a_f32x16, alpha_f32x16);
|
|
1510
|
+
__m512 result_f32x16 = _mm512_fmadd_ps(b_f32x16, beta_f32x16, a_scaled_f32x16);
|
|
1511
|
+
_mm256_storeu_si256((__m256i *)result, _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT));
|
|
1512
|
+
a += 16, b += 16, result += 16, n -= 16;
|
|
1513
|
+
}
|
|
1514
|
+
if (n) goto nk_each_blend_f16_skylake_cycle;
|
|
1515
|
+
}
|
|
1516
|
+
|
|
1517
|
+
NK_PUBLIC void nk_each_fma_f16_skylake( //
|
|
1518
|
+
nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n, //
|
|
1519
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result) {
|
|
1520
|
+
|
|
1521
|
+
// Compute in f32 for precision (f16 products overflow at 127x127=16129)
|
|
1522
|
+
__m512 alpha_f32x16 = _mm512_set1_ps(*alpha);
|
|
1523
|
+
__m512 beta_f32x16 = _mm512_set1_ps(*beta);
|
|
1524
|
+
__m512 a_f32x16, b_f32x16, c_f32x16;
|
|
1525
|
+
nk_each_fma_f16_skylake_cycle:
|
|
1526
|
+
if (n < 16) {
|
|
1527
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
1528
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
|
|
1529
|
+
b_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
|
|
1530
|
+
c_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, c));
|
|
1531
|
+
__m512 ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
1532
|
+
__m512 ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
1533
|
+
__m512 result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
1534
|
+
_mm256_mask_storeu_epi16(result, mask, _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT));
|
|
1535
|
+
n = 0;
|
|
1536
|
+
}
|
|
1537
|
+
else {
|
|
1538
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)a));
|
|
1539
|
+
b_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)b));
|
|
1540
|
+
c_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)c));
|
|
1541
|
+
__m512 ab_f32x16 = _mm512_mul_ps(a_f32x16, b_f32x16);
|
|
1542
|
+
__m512 ab_scaled_f32x16 = _mm512_mul_ps(ab_f32x16, alpha_f32x16);
|
|
1543
|
+
__m512 result_f32x16 = _mm512_fmadd_ps(c_f32x16, beta_f32x16, ab_scaled_f32x16);
|
|
1544
|
+
_mm256_storeu_si256((__m256i *)result, _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT));
|
|
1545
|
+
a += 16, b += 16, c += 16, result += 16, n -= 16;
|
|
1546
|
+
}
|
|
1547
|
+
if (n) goto nk_each_fma_f16_skylake_cycle;
|
|
1548
|
+
}
|
|
1549
|
+
|
|
1550
|
+
#if defined(__clang__)
|
|
1551
|
+
#pragma clang attribute pop
|
|
1552
|
+
#elif defined(__GNUC__)
|
|
1553
|
+
#pragma GCC pop_options
|
|
1554
|
+
#endif
|
|
1555
|
+
|
|
1556
|
+
#if defined(__cplusplus)
|
|
1557
|
+
} // extern "C"
|
|
1558
|
+
#endif
|
|
1559
|
+
|
|
1560
|
+
#endif // NK_TARGET_SKYLAKE
|
|
1561
|
+
#endif // NK_TARGET_X86_
|
|
1562
|
+
#endif // NK_EACH_SKYLAKE_H
|