numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1658 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Elementwise Arithmetic for Haswell.
|
|
3
|
+
* @file include/numkong/each/haswell.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/each.h
|
|
8
|
+
*
|
|
9
|
+
* @section haswell_elementwise_instructions Key AVX2 Elementwise Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput Ports
|
|
12
|
+
* _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
|
|
13
|
+
* _mm256_add_ps VADDPS (YMM, YMM, YMM) 3cy 1/cy p01
|
|
14
|
+
* _mm256_mul_ps VMULPS (YMM, YMM, YMM) 5cy 0.5/cy p01
|
|
15
|
+
* _mm256_cvtepi32_ps VCVTDQ2PS (YMM, YMM) 4cy 1/cy p01
|
|
16
|
+
* _mm256_cvtepi8_epi32 VPMOVSXBD (YMM, XMM) 3cy 1/cy p5
|
|
17
|
+
*
|
|
18
|
+
* Elementwise operations (sum, scale, blend, fma) are compute-bound on FMA throughput. For mixed-
|
|
19
|
+
* precision operations, type conversion chains (e.g., i8->i32->f32) add ~7-10 cycles overhead.
|
|
20
|
+
* The FMA unit handles both multiply-add fusion and standalone multiply/add operations.
|
|
21
|
+
*/
|
|
22
|
+
#ifndef NK_EACH_HASWELL_H
|
|
23
|
+
#define NK_EACH_HASWELL_H
|
|
24
|
+
|
|
25
|
+
#if NK_TARGET_X86_
|
|
26
|
+
#if NK_TARGET_HASWELL
|
|
27
|
+
|
|
28
|
+
#include "numkong/types.h"
|
|
29
|
+
#include "numkong/cast/serial.h" // `nk_f32_to_i8_serial`
|
|
30
|
+
#include "numkong/reduce/haswell.h" // `nk_e4m3x8_to_f32x8_haswell_`
|
|
31
|
+
|
|
32
|
+
#if defined(__cplusplus)
|
|
33
|
+
extern "C" {
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
#if defined(__clang__)
|
|
37
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
|
|
38
|
+
#elif defined(__GNUC__)
|
|
39
|
+
#pragma GCC push_options
|
|
40
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
|
|
41
|
+
#endif
|
|
42
|
+
|
|
43
|
+
NK_PUBLIC void nk_each_sum_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
44
|
+
// The main loop:
|
|
45
|
+
nk_size_t i = 0;
|
|
46
|
+
for (; i + 8 <= n; i += 8) {
|
|
47
|
+
__m256 a_f32x8 = _mm256_loadu_ps(a + i);
|
|
48
|
+
__m256 b_f32x8 = _mm256_loadu_ps(b + i);
|
|
49
|
+
__m256 result_f32x8 = _mm256_add_ps(a_f32x8, b_f32x8);
|
|
50
|
+
_mm256_storeu_ps(result + i, result_f32x8);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// The tail:
|
|
54
|
+
for (; i < n; ++i) result[i] = a[i] + b[i];
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
NK_PUBLIC void nk_each_scale_f32_haswell(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
58
|
+
nk_f32_t *result) {
|
|
59
|
+
nk_f32_t alpha_val = *alpha;
|
|
60
|
+
nk_f32_t beta_val = *beta;
|
|
61
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
62
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
63
|
+
|
|
64
|
+
// The main loop:
|
|
65
|
+
nk_size_t i = 0;
|
|
66
|
+
for (; i + 8 <= n; i += 8) {
|
|
67
|
+
__m256 a_f32x8 = _mm256_loadu_ps(a + i);
|
|
68
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
69
|
+
_mm256_storeu_ps(result + i, result_f32x8);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
// The tail:
|
|
73
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] + beta_val;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
NK_PUBLIC void nk_each_blend_f32_haswell( //
|
|
77
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, //
|
|
78
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result) {
|
|
79
|
+
nk_f32_t alpha_val = *alpha;
|
|
80
|
+
nk_f32_t beta_val = *beta;
|
|
81
|
+
|
|
82
|
+
// There are several special cases we may want to implement:
|
|
83
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
84
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
85
|
+
// In this case we can avoid expensive multiplications.
|
|
86
|
+
nk_each_sum_f32_haswell(a, b, n, result);
|
|
87
|
+
return;
|
|
88
|
+
}
|
|
89
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
90
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
91
|
+
// In this case we can avoid half of the load instructions.
|
|
92
|
+
nk_f32_t zero = 0;
|
|
93
|
+
if (beta_val == 0) { nk_each_scale_f32_haswell(a, n, alpha, &zero, result); }
|
|
94
|
+
else { nk_each_scale_f32_haswell(b, n, beta, &zero, result); }
|
|
95
|
+
return;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// The general case.
|
|
99
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
100
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
101
|
+
|
|
102
|
+
// The main loop:
|
|
103
|
+
nk_size_t i = 0;
|
|
104
|
+
for (; i + 8 <= n; i += 8) {
|
|
105
|
+
__m256 a_f32x8 = _mm256_loadu_ps(a + i);
|
|
106
|
+
__m256 b_f32x8 = _mm256_loadu_ps(b + i);
|
|
107
|
+
__m256 a_scaled_f32x8 = _mm256_mul_ps(a_f32x8, alpha_f32x8);
|
|
108
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(b_f32x8, beta_f32x8, a_scaled_f32x8);
|
|
109
|
+
_mm256_storeu_ps(result + i, result_f32x8);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
// The tail:
|
|
113
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] + beta_val * b[i];
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
NK_PUBLIC void nk_each_sum_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
117
|
+
// The main loop:
|
|
118
|
+
nk_size_t i = 0;
|
|
119
|
+
for (; i + 4 <= n; i += 4) {
|
|
120
|
+
__m256d a_f64x4 = _mm256_loadu_pd(a + i);
|
|
121
|
+
__m256d b_f64x4 = _mm256_loadu_pd(b + i);
|
|
122
|
+
__m256d result_f64x4 = _mm256_add_pd(a_f64x4, b_f64x4);
|
|
123
|
+
_mm256_storeu_pd(result + i, result_f64x4);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
// The tail:
|
|
127
|
+
for (; i < n; ++i) result[i] = a[i] + b[i];
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
NK_PUBLIC void nk_each_scale_f64_haswell(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
131
|
+
nk_f64_t *result) {
|
|
132
|
+
nk_f64_t alpha_val = *alpha;
|
|
133
|
+
nk_f64_t beta_val = *beta;
|
|
134
|
+
__m256d alpha_f64x4 = _mm256_set1_pd(alpha_val);
|
|
135
|
+
__m256d beta_f64x4 = _mm256_set1_pd(beta_val);
|
|
136
|
+
|
|
137
|
+
// The main loop:
|
|
138
|
+
nk_size_t i = 0;
|
|
139
|
+
for (; i + 4 <= n; i += 4) {
|
|
140
|
+
__m256d a_f64x4 = _mm256_loadu_pd(a + i);
|
|
141
|
+
__m256d result_f64x4 = _mm256_fmadd_pd(a_f64x4, alpha_f64x4, beta_f64x4);
|
|
142
|
+
_mm256_storeu_pd(result + i, result_f64x4);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
// The tail:
|
|
146
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] + beta_val;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
NK_PUBLIC void nk_each_blend_f64_haswell( //
|
|
150
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, //
|
|
151
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result) {
|
|
152
|
+
nk_f64_t alpha_val = *alpha;
|
|
153
|
+
nk_f64_t beta_val = *beta;
|
|
154
|
+
|
|
155
|
+
// There are several special cases we may want to implement:
|
|
156
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
157
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
158
|
+
// In this case we can avoid expensive multiplications.
|
|
159
|
+
nk_each_sum_f64_haswell(a, b, n, result);
|
|
160
|
+
return;
|
|
161
|
+
}
|
|
162
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
163
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
164
|
+
// In this case we can avoid half of the load instructions.
|
|
165
|
+
nk_f64_t zero = 0;
|
|
166
|
+
if (beta_val == 0) { nk_each_scale_f64_haswell(a, n, alpha, &zero, result); }
|
|
167
|
+
else { nk_each_scale_f64_haswell(b, n, beta, &zero, result); }
|
|
168
|
+
return;
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// The general case.
|
|
172
|
+
__m256d alpha_f64x4 = _mm256_set1_pd(alpha_val);
|
|
173
|
+
__m256d beta_f64x4 = _mm256_set1_pd(beta_val);
|
|
174
|
+
|
|
175
|
+
// The main loop:
|
|
176
|
+
nk_size_t i = 0;
|
|
177
|
+
for (; i + 4 <= n; i += 4) {
|
|
178
|
+
__m256d a_f64x4 = _mm256_loadu_pd(a + i);
|
|
179
|
+
__m256d b_f64x4 = _mm256_loadu_pd(b + i);
|
|
180
|
+
__m256d a_scaled_f64x4 = _mm256_mul_pd(a_f64x4, alpha_f64x4);
|
|
181
|
+
__m256d result_f64x4 = _mm256_fmadd_pd(b_f64x4, beta_f64x4, a_scaled_f64x4);
|
|
182
|
+
_mm256_storeu_pd(result + i, result_f64x4);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
// The tail:
|
|
186
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] + beta_val * b[i];
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
NK_PUBLIC void nk_each_sum_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result) {
|
|
190
|
+
|
|
191
|
+
// The main loop:
|
|
192
|
+
nk_size_t i = 0;
|
|
193
|
+
for (; i + 8 <= n; i += 8) {
|
|
194
|
+
__m128i a_f16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
195
|
+
__m128i b_f16x8 = _mm_loadu_si128((__m128i const *)(b + i));
|
|
196
|
+
__m256 a_f32x8 = _mm256_cvtph_ps(a_f16x8);
|
|
197
|
+
__m256 b_f32x8 = _mm256_cvtph_ps(b_f16x8);
|
|
198
|
+
__m256 result_f32x8 = _mm256_add_ps(a_f32x8, b_f32x8);
|
|
199
|
+
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
|
200
|
+
_mm_storeu_si128((__m128i *)(result + i), result_f16x8);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
// The tail:
|
|
204
|
+
for (; i < n; ++i) {
|
|
205
|
+
nk_f32_t ai, bi;
|
|
206
|
+
nk_f16_to_f32_haswell(a + i, &ai);
|
|
207
|
+
nk_f16_to_f32_haswell(b + i, &bi);
|
|
208
|
+
nk_f32_t sum = ai + bi;
|
|
209
|
+
nk_f32_to_f16_haswell(&sum, result + i);
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
NK_PUBLIC void nk_each_scale_f16_haswell(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
214
|
+
nk_f16_t *result) {
|
|
215
|
+
nk_f32_t alpha_val = *alpha;
|
|
216
|
+
nk_f32_t beta_val = *beta;
|
|
217
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
218
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
219
|
+
|
|
220
|
+
// The main loop:
|
|
221
|
+
nk_size_t i = 0;
|
|
222
|
+
for (; i + 8 <= n; i += 8) {
|
|
223
|
+
__m128i a_f16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
224
|
+
__m256 a_f32x8 = _mm256_cvtph_ps(a_f16x8);
|
|
225
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
226
|
+
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
|
227
|
+
_mm_storeu_si128((__m128i *)(result + i), result_f16x8);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// The tail:
|
|
231
|
+
for (; i < n; ++i) {
|
|
232
|
+
nk_f32_t ai;
|
|
233
|
+
nk_f16_to_f32_haswell(a + i, &ai);
|
|
234
|
+
nk_f32_t sum = alpha_val * ai + beta_val;
|
|
235
|
+
nk_f32_to_f16_haswell(&sum, result + i);
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
NK_PUBLIC void nk_each_blend_f16_haswell( //
|
|
240
|
+
nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, //
|
|
241
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result) {
|
|
242
|
+
nk_f32_t alpha_val = *alpha;
|
|
243
|
+
nk_f32_t beta_val = *beta;
|
|
244
|
+
|
|
245
|
+
// There are several special cases we may want to implement:
|
|
246
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
247
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
248
|
+
// In this case we can avoid expensive multiplications.
|
|
249
|
+
nk_each_sum_f16_haswell(a, b, n, result);
|
|
250
|
+
return;
|
|
251
|
+
}
|
|
252
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
253
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
254
|
+
// In this case we can avoid half of the load instructions.
|
|
255
|
+
nk_f32_t zero = 0;
|
|
256
|
+
if (beta_val == 0) { nk_each_scale_f16_haswell(a, n, alpha, &zero, result); }
|
|
257
|
+
else { nk_each_scale_f16_haswell(b, n, beta, &zero, result); }
|
|
258
|
+
return;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
// The general case.
|
|
262
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
263
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
264
|
+
|
|
265
|
+
// The main loop:
|
|
266
|
+
nk_size_t i = 0;
|
|
267
|
+
for (; i + 8 <= n; i += 8) {
|
|
268
|
+
__m128i a_f16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
269
|
+
__m128i b_f16x8 = _mm_loadu_si128((__m128i const *)(b + i));
|
|
270
|
+
__m256 a_f32x8 = _mm256_cvtph_ps(a_f16x8);
|
|
271
|
+
__m256 b_f32x8 = _mm256_cvtph_ps(b_f16x8);
|
|
272
|
+
__m256 a_scaled_f32x8 = _mm256_mul_ps(a_f32x8, alpha_f32x8);
|
|
273
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(b_f32x8, beta_f32x8, a_scaled_f32x8);
|
|
274
|
+
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
|
275
|
+
_mm_storeu_si128((__m128i *)(result + i), result_f16x8);
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
// The tail:
|
|
279
|
+
for (; i < n; ++i) {
|
|
280
|
+
nk_f32_t ai, bi;
|
|
281
|
+
nk_f16_to_f32_haswell(a + i, &ai);
|
|
282
|
+
nk_f16_to_f32_haswell(b + i, &bi);
|
|
283
|
+
nk_f32_t sum = alpha_val * ai + beta_val * bi;
|
|
284
|
+
nk_f32_to_f16_haswell(&sum, result + i);
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
NK_PUBLIC void nk_each_sum_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result) {
|
|
289
|
+
// The main loop:
|
|
290
|
+
nk_size_t i = 0;
|
|
291
|
+
for (; i + 8 <= n; i += 8) {
|
|
292
|
+
__m128i a_bf16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
293
|
+
__m128i b_bf16x8 = _mm_loadu_si128((__m128i const *)(b + i));
|
|
294
|
+
__m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16x8);
|
|
295
|
+
__m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16x8);
|
|
296
|
+
__m256 result_f32x8 = _mm256_add_ps(a_f32x8, b_f32x8);
|
|
297
|
+
__m128i result_bf16x8 = nk_f32x8_to_bf16x8_haswell_(result_f32x8);
|
|
298
|
+
_mm_storeu_si128((__m128i *)(result + i), result_bf16x8);
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
// The tail:
|
|
302
|
+
for (; i < n; ++i) {
|
|
303
|
+
nk_f32_t ai, bi;
|
|
304
|
+
nk_bf16_to_f32_serial(a + i, &ai);
|
|
305
|
+
nk_bf16_to_f32_serial(b + i, &bi);
|
|
306
|
+
nk_f32_t sum = ai + bi;
|
|
307
|
+
nk_f32_to_bf16_serial(&sum, result + i);
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
NK_PUBLIC void nk_each_scale_bf16_haswell(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
312
|
+
nk_bf16_t *result) {
|
|
313
|
+
nk_f32_t alpha_val = *alpha;
|
|
314
|
+
nk_f32_t beta_val = *beta;
|
|
315
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
316
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
317
|
+
|
|
318
|
+
// The main loop:
|
|
319
|
+
nk_size_t i = 0;
|
|
320
|
+
for (; i + 8 <= n; i += 8) {
|
|
321
|
+
__m128i a_bf16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
322
|
+
__m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16x8);
|
|
323
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
324
|
+
__m128i result_bf16x8 = nk_f32x8_to_bf16x8_haswell_(result_f32x8);
|
|
325
|
+
_mm_storeu_si128((__m128i *)(result + i), result_bf16x8);
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// The tail:
|
|
329
|
+
for (; i < n; ++i) {
|
|
330
|
+
nk_f32_t ai;
|
|
331
|
+
nk_bf16_to_f32_serial(a + i, &ai);
|
|
332
|
+
nk_f32_t sum = alpha_val * ai + beta_val;
|
|
333
|
+
nk_f32_to_bf16_serial(&sum, result + i);
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
NK_PUBLIC void nk_each_blend_bf16_haswell( //
|
|
338
|
+
nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, //
|
|
339
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result) {
|
|
340
|
+
nk_f32_t alpha_val = *alpha;
|
|
341
|
+
nk_f32_t beta_val = *beta;
|
|
342
|
+
|
|
343
|
+
// There are several special cases we may want to implement:
|
|
344
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
345
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
346
|
+
// In this case we can avoid expensive multiplications.
|
|
347
|
+
nk_each_sum_bf16_haswell(a, b, n, result);
|
|
348
|
+
return;
|
|
349
|
+
}
|
|
350
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
351
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
352
|
+
// In this case we can avoid half of the load instructions.
|
|
353
|
+
nk_f32_t zero = 0;
|
|
354
|
+
if (beta_val == 0) { nk_each_scale_bf16_haswell(a, n, alpha, &zero, result); }
|
|
355
|
+
else { nk_each_scale_bf16_haswell(b, n, beta, &zero, result); }
|
|
356
|
+
return;
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
// The general case.
|
|
360
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
361
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
362
|
+
|
|
363
|
+
// The main loop:
|
|
364
|
+
nk_size_t i = 0;
|
|
365
|
+
for (; i + 8 <= n; i += 8) {
|
|
366
|
+
__m128i a_bf16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
367
|
+
__m128i b_bf16x8 = _mm_loadu_si128((__m128i const *)(b + i));
|
|
368
|
+
__m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16x8);
|
|
369
|
+
__m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16x8);
|
|
370
|
+
__m256 a_scaled_f32x8 = _mm256_mul_ps(a_f32x8, alpha_f32x8);
|
|
371
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(b_f32x8, beta_f32x8, a_scaled_f32x8);
|
|
372
|
+
__m128i result_bf16x8 = nk_f32x8_to_bf16x8_haswell_(result_f32x8);
|
|
373
|
+
_mm_storeu_si128((__m128i *)(result + i), result_bf16x8);
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
// The tail:
|
|
377
|
+
for (; i < n; ++i) {
|
|
378
|
+
nk_f32_t ai, bi;
|
|
379
|
+
nk_bf16_to_f32_serial(a + i, &ai);
|
|
380
|
+
nk_bf16_to_f32_serial(b + i, &bi);
|
|
381
|
+
nk_f32_t sum = alpha_val * ai + beta_val * bi;
|
|
382
|
+
nk_f32_to_bf16_serial(&sum, result + i);
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
NK_PUBLIC void nk_each_fma_f32_haswell( //
|
|
387
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, //
|
|
388
|
+
nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result) {
|
|
389
|
+
nk_f32_t alpha_val = *alpha;
|
|
390
|
+
nk_f32_t beta_val = *beta;
|
|
391
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
392
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
393
|
+
|
|
394
|
+
// The main loop:
|
|
395
|
+
nk_size_t i = 0;
|
|
396
|
+
for (; i + 8 <= n; i += 8) {
|
|
397
|
+
__m256 a_f32x8 = _mm256_loadu_ps(a + i);
|
|
398
|
+
__m256 b_f32x8 = _mm256_loadu_ps(b + i);
|
|
399
|
+
__m256 c_f32x8 = _mm256_loadu_ps(c + i);
|
|
400
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
401
|
+
__m256 ab_scaled_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
402
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, ab_scaled_f32x8);
|
|
403
|
+
_mm256_storeu_ps(result + i, result_f32x8);
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
// The tail:
|
|
407
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] * b[i] + beta_val * c[i];
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
NK_PUBLIC void nk_each_fma_f64_haswell( //
|
|
411
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, //
|
|
412
|
+
nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result) {
|
|
413
|
+
nk_f64_t alpha_val = *alpha;
|
|
414
|
+
nk_f64_t beta_val = *beta;
|
|
415
|
+
__m256d alpha_f64x4 = _mm256_set1_pd(alpha_val);
|
|
416
|
+
__m256d beta_f64x4 = _mm256_set1_pd(beta_val);
|
|
417
|
+
|
|
418
|
+
// The main loop:
|
|
419
|
+
nk_size_t i = 0;
|
|
420
|
+
for (; i + 4 <= n; i += 4) {
|
|
421
|
+
__m256d a_f64x4 = _mm256_loadu_pd(a + i);
|
|
422
|
+
__m256d b_f64x4 = _mm256_loadu_pd(b + i);
|
|
423
|
+
__m256d c_f64x4 = _mm256_loadu_pd(c + i);
|
|
424
|
+
__m256d ab_f64x4 = _mm256_mul_pd(a_f64x4, b_f64x4);
|
|
425
|
+
__m256d abc_f64x4 = _mm256_mul_pd(ab_f64x4, alpha_f64x4);
|
|
426
|
+
__m256d result_f64x4 = _mm256_fmadd_pd(c_f64x4, beta_f64x4, abc_f64x4);
|
|
427
|
+
_mm256_storeu_pd(result + i, result_f64x4);
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
// The tail:
|
|
431
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] * b[i] + beta_val * c[i];
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
NK_PUBLIC void nk_each_fma_f16_haswell( //
|
|
435
|
+
nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, //
|
|
436
|
+
nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result) {
|
|
437
|
+
nk_f32_t alpha_val = *alpha;
|
|
438
|
+
nk_f32_t beta_val = *beta;
|
|
439
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
440
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
441
|
+
|
|
442
|
+
// The main loop:
|
|
443
|
+
nk_size_t i = 0;
|
|
444
|
+
for (; i + 8 <= n; i += 8) {
|
|
445
|
+
__m128i a_f16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
446
|
+
__m128i b_f16x8 = _mm_loadu_si128((__m128i const *)(b + i));
|
|
447
|
+
__m128i c_f16x8 = _mm_loadu_si128((__m128i const *)(c + i));
|
|
448
|
+
__m256 a_f32x8 = _mm256_cvtph_ps(a_f16x8);
|
|
449
|
+
__m256 b_f32x8 = _mm256_cvtph_ps(b_f16x8);
|
|
450
|
+
__m256 c_f32x8 = _mm256_cvtph_ps(c_f16x8);
|
|
451
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
452
|
+
__m256 abc_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
453
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, abc_f32x8);
|
|
454
|
+
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
|
455
|
+
_mm_storeu_si128((__m128i *)(result + i), result_f16x8);
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
// The tail:
|
|
459
|
+
for (; i < n; ++i) {
|
|
460
|
+
nk_f32_t ai, bi, ci;
|
|
461
|
+
nk_f16_to_f32_haswell(a + i, &ai);
|
|
462
|
+
nk_f16_to_f32_haswell(b + i, &bi);
|
|
463
|
+
nk_f16_to_f32_haswell(c + i, &ci);
|
|
464
|
+
nk_f32_t sum = alpha_val * ai * bi + beta_val * ci;
|
|
465
|
+
nk_f32_to_f16_haswell(&sum, result + i);
|
|
466
|
+
}
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
NK_PUBLIC void nk_each_fma_bf16_haswell( //
|
|
470
|
+
nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, //
|
|
471
|
+
nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result) {
|
|
472
|
+
nk_f32_t alpha_val = *alpha;
|
|
473
|
+
nk_f32_t beta_val = *beta;
|
|
474
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
475
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
476
|
+
|
|
477
|
+
// The main loop:
|
|
478
|
+
nk_size_t i = 0;
|
|
479
|
+
for (; i + 8 <= n; i += 8) {
|
|
480
|
+
__m128i a_bf16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
481
|
+
__m128i b_bf16x8 = _mm_loadu_si128((__m128i const *)(b + i));
|
|
482
|
+
__m128i c_bf16x8 = _mm_loadu_si128((__m128i const *)(c + i));
|
|
483
|
+
__m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16x8);
|
|
484
|
+
__m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16x8);
|
|
485
|
+
__m256 c_f32x8 = nk_bf16x8_to_f32x8_haswell_(c_bf16x8);
|
|
486
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
487
|
+
__m256 abc_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
488
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, abc_f32x8);
|
|
489
|
+
__m128i result_bf16x8 = nk_f32x8_to_bf16x8_haswell_(result_f32x8);
|
|
490
|
+
_mm_storeu_si128((__m128i *)(result + i), result_bf16x8);
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
// The tail:
|
|
494
|
+
for (; i < n; ++i) {
|
|
495
|
+
nk_f32_t ai, bi, ci;
|
|
496
|
+
nk_bf16_to_f32_serial(a + i, &ai);
|
|
497
|
+
nk_bf16_to_f32_serial(b + i, &bi);
|
|
498
|
+
nk_bf16_to_f32_serial(c + i, &ci);
|
|
499
|
+
nk_f32_t sum = alpha_val * ai * bi + beta_val * ci;
|
|
500
|
+
nk_f32_to_bf16_serial(&sum, result + i);
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
NK_PUBLIC void nk_each_sum_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result) {
|
|
505
|
+
// The main loop:
|
|
506
|
+
nk_size_t i = 0;
|
|
507
|
+
for (; i + 32 <= n; i += 32) {
|
|
508
|
+
__m256i a_i8x32 = _mm256_loadu_si256((__m256i *)(a + i));
|
|
509
|
+
__m256i b_i8x32 = _mm256_loadu_si256((__m256i *)(b + i));
|
|
510
|
+
__m256i result_i8x32 = _mm256_adds_epi8(a_i8x32, b_i8x32);
|
|
511
|
+
_mm256_storeu_si256((__m256i *)(result + i), result_i8x32);
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
// The tail:
|
|
515
|
+
for (; i < n; ++i) {
|
|
516
|
+
nk_f32_t ai = a[i], bi = b[i];
|
|
517
|
+
nk_f32_t sum = ai + bi;
|
|
518
|
+
nk_f32_to_i8_serial(&sum, result + i);
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
NK_PUBLIC void nk_each_scale_i8_haswell(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
523
|
+
nk_i8_t *result) {
|
|
524
|
+
nk_f32_t alpha_val = *alpha;
|
|
525
|
+
nk_f32_t beta_val = *beta;
|
|
526
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
527
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
528
|
+
int sum_i32s[8], a_i32s[8];
|
|
529
|
+
|
|
530
|
+
// The main loop:
|
|
531
|
+
nk_size_t i = 0;
|
|
532
|
+
for (; i + 8 <= n; i += 8) {
|
|
533
|
+
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
|
|
534
|
+
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
|
|
535
|
+
a_i32s[0] = a[i + 0], a_i32s[1] = a[i + 1], a_i32s[2] = a[i + 2], a_i32s[3] = a[i + 3], //
|
|
536
|
+
a_i32s[4] = a[i + 4], a_i32s[5] = a[i + 5], a_i32s[6] = a[i + 6], a_i32s[7] = a[i + 7];
|
|
537
|
+
//! This can be done at least 50% faster if we convert 8-bit integers to floats instead
|
|
538
|
+
//! of relying on `_mm256_cvtepi32_ps`: 4cy (1/cy) @ p01.
|
|
539
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)a_i32s));
|
|
540
|
+
// The normal part.
|
|
541
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
542
|
+
// Instead of serial calls to expensive `nk_f32_to_u8_serial`, convert and clip with SIMD.
|
|
543
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
544
|
+
result_i32x8 = _mm256_max_epi32(result_i32x8, _mm256_set1_epi32(-128));
|
|
545
|
+
result_i32x8 = _mm256_min_epi32(result_i32x8, _mm256_set1_epi32(127));
|
|
546
|
+
// Export into a serial buffer.
|
|
547
|
+
_mm256_storeu_si256((__m256i *)sum_i32s, result_i32x8);
|
|
548
|
+
result[i + 0] = (nk_i8_t)sum_i32s[0];
|
|
549
|
+
result[i + 1] = (nk_i8_t)sum_i32s[1];
|
|
550
|
+
result[i + 2] = (nk_i8_t)sum_i32s[2];
|
|
551
|
+
result[i + 3] = (nk_i8_t)sum_i32s[3];
|
|
552
|
+
result[i + 4] = (nk_i8_t)sum_i32s[4];
|
|
553
|
+
result[i + 5] = (nk_i8_t)sum_i32s[5];
|
|
554
|
+
result[i + 6] = (nk_i8_t)sum_i32s[6];
|
|
555
|
+
result[i + 7] = (nk_i8_t)sum_i32s[7];
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
// The tail:
|
|
559
|
+
for (; i < n; ++i) {
|
|
560
|
+
nk_f32_t ai = a[i];
|
|
561
|
+
nk_f32_t sum = alpha_val * ai + beta_val;
|
|
562
|
+
nk_f32_to_i8_serial(&sum, result + i);
|
|
563
|
+
}
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
NK_PUBLIC void nk_each_blend_i8_haswell( //
|
|
567
|
+
nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, //
|
|
568
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
|
|
569
|
+
nk_f32_t alpha_val = *alpha;
|
|
570
|
+
nk_f32_t beta_val = *beta;
|
|
571
|
+
|
|
572
|
+
// There are several special cases we may want to implement:
|
|
573
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
574
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
575
|
+
// In this case we can avoid expensive multiplications.
|
|
576
|
+
nk_each_sum_i8_haswell(a, b, n, result);
|
|
577
|
+
return;
|
|
578
|
+
}
|
|
579
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
580
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
581
|
+
// In this case we can avoid half of the load instructions.
|
|
582
|
+
nk_f32_t zero = 0;
|
|
583
|
+
if (beta_val == 0) { nk_each_scale_i8_haswell(a, n, alpha, &zero, result); }
|
|
584
|
+
else { nk_each_scale_i8_haswell(b, n, beta, &zero, result); }
|
|
585
|
+
return;
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
// The general case.
|
|
589
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
590
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
591
|
+
int sum_i32s[8], a_i32s[8], b_i32s[8];
|
|
592
|
+
|
|
593
|
+
// The main loop:
|
|
594
|
+
nk_size_t i = 0;
|
|
595
|
+
for (; i + 8 <= n; i += 8) {
|
|
596
|
+
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
|
|
597
|
+
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
|
|
598
|
+
a_i32s[0] = a[i + 0], a_i32s[1] = a[i + 1], a_i32s[2] = a[i + 2], a_i32s[3] = a[i + 3], //
|
|
599
|
+
a_i32s[4] = a[i + 4], a_i32s[5] = a[i + 5], a_i32s[6] = a[i + 6], a_i32s[7] = a[i + 7];
|
|
600
|
+
b_i32s[0] = b[i + 0], b_i32s[1] = b[i + 1], b_i32s[2] = b[i + 2], b_i32s[3] = b[i + 3], //
|
|
601
|
+
b_i32s[4] = b[i + 4], b_i32s[5] = b[i + 5], b_i32s[6] = b[i + 6], b_i32s[7] = b[i + 7];
|
|
602
|
+
//! This can be done at least 50% faster if we convert 8-bit integers to floats instead
|
|
603
|
+
//! of relying on `_mm256_cvtepi32_ps`: 4cy (1/cy) @ p01.
|
|
604
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)a_i32s));
|
|
605
|
+
__m256 b_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)b_i32s));
|
|
606
|
+
// The normal part.
|
|
607
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, alpha_f32x8);
|
|
608
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(b_f32x8, beta_f32x8, ab_f32x8);
|
|
609
|
+
// Instead of serial calls to expensive `nk_f32_to_u8_serial`, convert and clip with SIMD.
|
|
610
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
611
|
+
result_i32x8 = _mm256_max_epi32(result_i32x8, _mm256_set1_epi32(-128));
|
|
612
|
+
result_i32x8 = _mm256_min_epi32(result_i32x8, _mm256_set1_epi32(127));
|
|
613
|
+
// Export into a serial buffer.
|
|
614
|
+
_mm256_storeu_si256((__m256i *)sum_i32s, result_i32x8);
|
|
615
|
+
result[i + 0] = (nk_i8_t)sum_i32s[0];
|
|
616
|
+
result[i + 1] = (nk_i8_t)sum_i32s[1];
|
|
617
|
+
result[i + 2] = (nk_i8_t)sum_i32s[2];
|
|
618
|
+
result[i + 3] = (nk_i8_t)sum_i32s[3];
|
|
619
|
+
result[i + 4] = (nk_i8_t)sum_i32s[4];
|
|
620
|
+
result[i + 5] = (nk_i8_t)sum_i32s[5];
|
|
621
|
+
result[i + 6] = (nk_i8_t)sum_i32s[6];
|
|
622
|
+
result[i + 7] = (nk_i8_t)sum_i32s[7];
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
// The tail:
|
|
626
|
+
for (; i < n; ++i) {
|
|
627
|
+
nk_f32_t ai = a[i], bi = b[i];
|
|
628
|
+
nk_f32_t sum = alpha_val * ai + beta_val * bi;
|
|
629
|
+
nk_f32_to_i8_serial(&sum, result + i);
|
|
630
|
+
}
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
NK_PUBLIC void nk_each_sum_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result) {
|
|
634
|
+
// The main loop:
|
|
635
|
+
nk_size_t i = 0;
|
|
636
|
+
for (; i + 32 <= n; i += 32) {
|
|
637
|
+
__m256i a_u8x32 = _mm256_loadu_si256((__m256i *)(a + i));
|
|
638
|
+
__m256i b_u8x32 = _mm256_loadu_si256((__m256i *)(b + i));
|
|
639
|
+
__m256i result_u8x32 = _mm256_adds_epu8(a_u8x32, b_u8x32);
|
|
640
|
+
_mm256_storeu_si256((__m256i *)(result + i), result_u8x32);
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
// The tail:
|
|
644
|
+
for (; i < n; ++i) {
|
|
645
|
+
nk_f32_t ai = a[i], bi = b[i];
|
|
646
|
+
nk_f32_t sum = ai + bi;
|
|
647
|
+
nk_f32_to_u8_serial(&sum, result + i);
|
|
648
|
+
}
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
NK_PUBLIC void nk_each_scale_u8_haswell(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
652
|
+
nk_u8_t *result) {
|
|
653
|
+
nk_f32_t alpha_val = *alpha;
|
|
654
|
+
nk_f32_t beta_val = *beta;
|
|
655
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
656
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
657
|
+
int sum_i32s[8], a_i32s[8];
|
|
658
|
+
|
|
659
|
+
// The main loop:
|
|
660
|
+
nk_size_t i = 0;
|
|
661
|
+
for (; i + 8 <= n; i += 8) {
|
|
662
|
+
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
|
|
663
|
+
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
|
|
664
|
+
a_i32s[0] = a[i + 0], a_i32s[1] = a[i + 1], a_i32s[2] = a[i + 2], a_i32s[3] = a[i + 3], //
|
|
665
|
+
a_i32s[4] = a[i + 4], a_i32s[5] = a[i + 5], a_i32s[6] = a[i + 6], a_i32s[7] = a[i + 7];
|
|
666
|
+
//! This can be done at least 50% faster if we convert 8-bit integers to floats instead
|
|
667
|
+
//! of relying on `_mm256_cvtepi32_ps`: 4cy (1/cy) @ p01.
|
|
668
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)a_i32s));
|
|
669
|
+
// The normal part.
|
|
670
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
671
|
+
// Instead of serial calls to expensive `nk_f32_to_u8_serial`, convert and clip with SIMD.
|
|
672
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
673
|
+
result_i32x8 = _mm256_max_epi32(result_i32x8, _mm256_set1_epi32(0));
|
|
674
|
+
result_i32x8 = _mm256_min_epi32(result_i32x8, _mm256_set1_epi32(255));
|
|
675
|
+
// Export into a serial buffer.
|
|
676
|
+
_mm256_storeu_si256((__m256i *)sum_i32s, result_i32x8);
|
|
677
|
+
result[i + 0] = (nk_u8_t)sum_i32s[0];
|
|
678
|
+
result[i + 1] = (nk_u8_t)sum_i32s[1];
|
|
679
|
+
result[i + 2] = (nk_u8_t)sum_i32s[2];
|
|
680
|
+
result[i + 3] = (nk_u8_t)sum_i32s[3];
|
|
681
|
+
result[i + 4] = (nk_u8_t)sum_i32s[4];
|
|
682
|
+
result[i + 5] = (nk_u8_t)sum_i32s[5];
|
|
683
|
+
result[i + 6] = (nk_u8_t)sum_i32s[6];
|
|
684
|
+
result[i + 7] = (nk_u8_t)sum_i32s[7];
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
// The tail:
|
|
688
|
+
for (; i < n; ++i) {
|
|
689
|
+
nk_f32_t ai = a[i];
|
|
690
|
+
nk_f32_t sum = alpha_val * ai + beta_val;
|
|
691
|
+
nk_f32_to_u8_serial(&sum, result + i);
|
|
692
|
+
}
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
NK_PUBLIC void nk_each_blend_u8_haswell( //
|
|
696
|
+
nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, //
|
|
697
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
|
|
698
|
+
nk_f32_t alpha_val = *alpha;
|
|
699
|
+
nk_f32_t beta_val = *beta;
|
|
700
|
+
|
|
701
|
+
// There are several special cases we may want to implement:
|
|
702
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
703
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
704
|
+
// In this case we can avoid expensive multiplications.
|
|
705
|
+
nk_each_sum_u8_haswell(a, b, n, result);
|
|
706
|
+
return;
|
|
707
|
+
}
|
|
708
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
709
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
710
|
+
// In this case we can avoid half of the load instructions.
|
|
711
|
+
nk_f32_t zero = 0;
|
|
712
|
+
if (beta_val == 0) { nk_each_scale_u8_haswell(a, n, alpha, &zero, result); }
|
|
713
|
+
else { nk_each_scale_u8_haswell(b, n, beta, &zero, result); }
|
|
714
|
+
return;
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
// The general case.
|
|
718
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
719
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
720
|
+
int sum_i32s[8], a_i32s[8], b_i32s[8];
|
|
721
|
+
|
|
722
|
+
// The main loop:
|
|
723
|
+
nk_size_t i = 0;
|
|
724
|
+
for (; i + 8 <= n; i += 8) {
|
|
725
|
+
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
|
|
726
|
+
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
|
|
727
|
+
a_i32s[0] = a[i + 0], a_i32s[1] = a[i + 1], a_i32s[2] = a[i + 2], a_i32s[3] = a[i + 3], //
|
|
728
|
+
a_i32s[4] = a[i + 4], a_i32s[5] = a[i + 5], a_i32s[6] = a[i + 6], a_i32s[7] = a[i + 7];
|
|
729
|
+
b_i32s[0] = b[i + 0], b_i32s[1] = b[i + 1], b_i32s[2] = b[i + 2], b_i32s[3] = b[i + 3], //
|
|
730
|
+
b_i32s[4] = b[i + 4], b_i32s[5] = b[i + 5], b_i32s[6] = b[i + 6], b_i32s[7] = b[i + 7];
|
|
731
|
+
//! This can be done at least 50% faster if we convert 8-bit integers to floats instead
|
|
732
|
+
//! of relying on `_mm256_cvtepi32_ps`: 4cy (1/cy) @ p01.
|
|
733
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)a_i32s));
|
|
734
|
+
__m256 b_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)b_i32s));
|
|
735
|
+
// The normal part.
|
|
736
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, alpha_f32x8);
|
|
737
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(b_f32x8, beta_f32x8, ab_f32x8);
|
|
738
|
+
// Instead of serial calls to expensive `nk_f32_to_u8_serial`, convert and clip with SIMD.
|
|
739
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
740
|
+
result_i32x8 = _mm256_max_epi32(result_i32x8, _mm256_set1_epi32(0));
|
|
741
|
+
result_i32x8 = _mm256_min_epi32(result_i32x8, _mm256_set1_epi32(255));
|
|
742
|
+
// Export into a serial buffer.
|
|
743
|
+
_mm256_storeu_si256((__m256i *)sum_i32s, result_i32x8);
|
|
744
|
+
result[i + 0] = (nk_u8_t)sum_i32s[0];
|
|
745
|
+
result[i + 1] = (nk_u8_t)sum_i32s[1];
|
|
746
|
+
result[i + 2] = (nk_u8_t)sum_i32s[2];
|
|
747
|
+
result[i + 3] = (nk_u8_t)sum_i32s[3];
|
|
748
|
+
result[i + 4] = (nk_u8_t)sum_i32s[4];
|
|
749
|
+
result[i + 5] = (nk_u8_t)sum_i32s[5];
|
|
750
|
+
result[i + 6] = (nk_u8_t)sum_i32s[6];
|
|
751
|
+
result[i + 7] = (nk_u8_t)sum_i32s[7];
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
// The tail:
|
|
755
|
+
for (; i < n; ++i) {
|
|
756
|
+
nk_f32_t ai = a[i], bi = b[i];
|
|
757
|
+
nk_f32_t sum = alpha_val * ai + beta_val * bi;
|
|
758
|
+
nk_f32_to_u8_serial(&sum, result + i);
|
|
759
|
+
}
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
NK_PUBLIC void nk_each_fma_i8_haswell( //
|
|
763
|
+
nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, //
|
|
764
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
|
|
765
|
+
nk_f32_t alpha_val = *alpha;
|
|
766
|
+
nk_f32_t beta_val = *beta;
|
|
767
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
768
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
769
|
+
int sum_i32s[8], a_i32s[8], b_i32s[8], c_i32s[8];
|
|
770
|
+
|
|
771
|
+
// The main loop:
|
|
772
|
+
nk_size_t i = 0;
|
|
773
|
+
for (; i + 8 <= n; i += 8) {
|
|
774
|
+
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
|
|
775
|
+
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
|
|
776
|
+
a_i32s[0] = a[i + 0], a_i32s[1] = a[i + 1], a_i32s[2] = a[i + 2], a_i32s[3] = a[i + 3], //
|
|
777
|
+
a_i32s[4] = a[i + 4], a_i32s[5] = a[i + 5], a_i32s[6] = a[i + 6], a_i32s[7] = a[i + 7];
|
|
778
|
+
b_i32s[0] = b[i + 0], b_i32s[1] = b[i + 1], b_i32s[2] = b[i + 2], b_i32s[3] = b[i + 3], //
|
|
779
|
+
b_i32s[4] = b[i + 4], b_i32s[5] = b[i + 5], b_i32s[6] = b[i + 6], b_i32s[7] = b[i + 7];
|
|
780
|
+
c_i32s[0] = c[i + 0], c_i32s[1] = c[i + 1], c_i32s[2] = c[i + 2], c_i32s[3] = c[i + 3], //
|
|
781
|
+
c_i32s[4] = c[i + 4], c_i32s[5] = c[i + 5], c_i32s[6] = c[i + 6], c_i32s[7] = c[i + 7];
|
|
782
|
+
//! This can be done at least 50% faster if we convert 8-bit integers to floats instead
|
|
783
|
+
//! of relying on `_mm256_cvtepi32_ps`: 4cy (1/cy) @ p01.
|
|
784
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)a_i32s));
|
|
785
|
+
__m256 b_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)b_i32s));
|
|
786
|
+
__m256 c_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)c_i32s));
|
|
787
|
+
// The normal part.
|
|
788
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
789
|
+
__m256 abc_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
790
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, abc_f32x8);
|
|
791
|
+
// Instead of serial calls to expensive `nk_f32_to_u8_serial`, convert and clip with SIMD.
|
|
792
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
793
|
+
result_i32x8 = _mm256_max_epi32(result_i32x8, _mm256_set1_epi32(-128));
|
|
794
|
+
result_i32x8 = _mm256_min_epi32(result_i32x8, _mm256_set1_epi32(127));
|
|
795
|
+
// Export into a serial buffer.
|
|
796
|
+
_mm256_storeu_si256((__m256i *)sum_i32s, result_i32x8);
|
|
797
|
+
result[i + 0] = (nk_i8_t)sum_i32s[0];
|
|
798
|
+
result[i + 1] = (nk_i8_t)sum_i32s[1];
|
|
799
|
+
result[i + 2] = (nk_i8_t)sum_i32s[2];
|
|
800
|
+
result[i + 3] = (nk_i8_t)sum_i32s[3];
|
|
801
|
+
result[i + 4] = (nk_i8_t)sum_i32s[4];
|
|
802
|
+
result[i + 5] = (nk_i8_t)sum_i32s[5];
|
|
803
|
+
result[i + 6] = (nk_i8_t)sum_i32s[6];
|
|
804
|
+
result[i + 7] = (nk_i8_t)sum_i32s[7];
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
// The tail:
|
|
808
|
+
for (; i < n; ++i) {
|
|
809
|
+
nk_f32_t ai = a[i], bi = b[i], ci = c[i];
|
|
810
|
+
nk_f32_t sum = alpha_val * ai * bi + beta_val * ci;
|
|
811
|
+
nk_f32_to_i8_serial(&sum, result + i);
|
|
812
|
+
}
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
NK_PUBLIC void nk_each_fma_u8_haswell( //
|
|
816
|
+
nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, //
|
|
817
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
|
|
818
|
+
nk_f32_t alpha_val = *alpha;
|
|
819
|
+
nk_f32_t beta_val = *beta;
|
|
820
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_val);
|
|
821
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_val);
|
|
822
|
+
int sum_i32s[8], a_i32s[8], b_i32s[8], c_i32s[8];
|
|
823
|
+
|
|
824
|
+
// The main loop:
|
|
825
|
+
nk_size_t i = 0;
|
|
826
|
+
for (; i + 8 <= n; i += 8) {
|
|
827
|
+
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
|
|
828
|
+
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
|
|
829
|
+
a_i32s[0] = a[i + 0], a_i32s[1] = a[i + 1], a_i32s[2] = a[i + 2], a_i32s[3] = a[i + 3], //
|
|
830
|
+
a_i32s[4] = a[i + 4], a_i32s[5] = a[i + 5], a_i32s[6] = a[i + 6], a_i32s[7] = a[i + 7];
|
|
831
|
+
b_i32s[0] = b[i + 0], b_i32s[1] = b[i + 1], b_i32s[2] = b[i + 2], b_i32s[3] = b[i + 3], //
|
|
832
|
+
b_i32s[4] = b[i + 4], b_i32s[5] = b[i + 5], b_i32s[6] = b[i + 6], b_i32s[7] = b[i + 7];
|
|
833
|
+
c_i32s[0] = c[i + 0], c_i32s[1] = c[i + 1], c_i32s[2] = c[i + 2], c_i32s[3] = c[i + 3], //
|
|
834
|
+
c_i32s[4] = c[i + 4], c_i32s[5] = c[i + 5], c_i32s[6] = c[i + 6], c_i32s[7] = c[i + 7];
|
|
835
|
+
//! This can be done at least 50% faster if we convert 8-bit integers to floats instead
|
|
836
|
+
//! of relying on `_mm256_cvtepi32_ps`: 4cy (1/cy) @ p01.
|
|
837
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)a_i32s));
|
|
838
|
+
__m256 b_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)b_i32s));
|
|
839
|
+
__m256 c_f32x8 = _mm256_cvtepi32_ps(_mm256_loadu_si256((__m256i *)c_i32s));
|
|
840
|
+
// The normal part.
|
|
841
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
842
|
+
__m256 abc_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
843
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, abc_f32x8);
|
|
844
|
+
// Instead of serial calls to expensive `nk_f32_to_u8_serial`, convert and clip with SIMD.
|
|
845
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
846
|
+
result_i32x8 = _mm256_max_epi32(result_i32x8, _mm256_set1_epi32(0));
|
|
847
|
+
result_i32x8 = _mm256_min_epi32(result_i32x8, _mm256_set1_epi32(255));
|
|
848
|
+
// Export into a serial buffer.
|
|
849
|
+
_mm256_storeu_si256((__m256i *)sum_i32s, result_i32x8);
|
|
850
|
+
result[i + 0] = (nk_u8_t)sum_i32s[0];
|
|
851
|
+
result[i + 1] = (nk_u8_t)sum_i32s[1];
|
|
852
|
+
result[i + 2] = (nk_u8_t)sum_i32s[2];
|
|
853
|
+
result[i + 3] = (nk_u8_t)sum_i32s[3];
|
|
854
|
+
result[i + 4] = (nk_u8_t)sum_i32s[4];
|
|
855
|
+
result[i + 5] = (nk_u8_t)sum_i32s[5];
|
|
856
|
+
result[i + 6] = (nk_u8_t)sum_i32s[6];
|
|
857
|
+
result[i + 7] = (nk_u8_t)sum_i32s[7];
|
|
858
|
+
}
|
|
859
|
+
|
|
860
|
+
// The tail:
|
|
861
|
+
for (; i < n; ++i) {
|
|
862
|
+
nk_f32_t ai = a[i], bi = b[i], ci = c[i];
|
|
863
|
+
nk_f32_t sum = alpha_val * ai * bi + beta_val * ci;
|
|
864
|
+
nk_f32_to_u8_serial(&sum, result + i);
|
|
865
|
+
}
|
|
866
|
+
}
|
|
867
|
+
|
|
868
|
+
NK_PUBLIC void nk_each_sum_i16_haswell(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result) {
|
|
869
|
+
// The main loop:
|
|
870
|
+
nk_size_t i = 0;
|
|
871
|
+
for (; i + 16 <= n; i += 16) {
|
|
872
|
+
__m256i a_vec = _mm256_loadu_si256((__m256i *)(a + i));
|
|
873
|
+
__m256i b_vec = _mm256_loadu_si256((__m256i *)(b + i));
|
|
874
|
+
__m256i sum_vec = _mm256_adds_epi16(a_vec, b_vec);
|
|
875
|
+
_mm256_storeu_si256((__m256i *)(result + i), sum_vec);
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
// The tail:
|
|
879
|
+
for (; i < n; ++i) {
|
|
880
|
+
nk_i64_t ai = a[i], bi = b[i];
|
|
881
|
+
nk_i64_t sum = ai + bi;
|
|
882
|
+
nk_i64_to_i16_serial(&sum, result + i);
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
NK_PUBLIC void nk_each_scale_i16_haswell(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
887
|
+
nk_i16_t *result) {
|
|
888
|
+
nk_f32_t alpha_f32 = *alpha;
|
|
889
|
+
nk_f32_t beta_f32 = *beta;
|
|
890
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_f32);
|
|
891
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_f32);
|
|
892
|
+
__m256 min_f32x8 = _mm256_set1_ps(-32768.0f);
|
|
893
|
+
__m256 max_f32x8 = _mm256_set1_ps(32767.0f);
|
|
894
|
+
|
|
895
|
+
// The main loop:
|
|
896
|
+
nk_size_t i = 0;
|
|
897
|
+
for (; i + 8 <= n; i += 8) {
|
|
898
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i *)(a + i))));
|
|
899
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
900
|
+
result_f32x8 = _mm256_max_ps(result_f32x8, min_f32x8);
|
|
901
|
+
result_f32x8 = _mm256_min_ps(result_f32x8, max_f32x8);
|
|
902
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
903
|
+
// Casting down to 16-bit integers is tricky!
|
|
904
|
+
__m128i result_i16x8 = _mm_packs_epi32(_mm256_castsi256_si128(result_i32x8),
|
|
905
|
+
_mm256_extracti128_si256(result_i32x8, 1));
|
|
906
|
+
_mm_storeu_si128((__m128i *)(result + i), result_i16x8);
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
// The tail:
|
|
910
|
+
for (; i < n; ++i) {
|
|
911
|
+
nk_f32_t ai = a[i];
|
|
912
|
+
nk_f32_t sum = alpha_f32 * ai + beta_f32;
|
|
913
|
+
nk_f32_to_i16_serial(&sum, result + i);
|
|
914
|
+
}
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
NK_PUBLIC void nk_each_fma_i16_haswell( //
|
|
918
|
+
nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n, //
|
|
919
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result) {
|
|
920
|
+
nk_f32_t alpha_f32 = *alpha;
|
|
921
|
+
nk_f32_t beta_f32 = *beta;
|
|
922
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_f32);
|
|
923
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_f32);
|
|
924
|
+
__m256 min_f32x8 = _mm256_set1_ps(-32768.0f);
|
|
925
|
+
__m256 max_f32x8 = _mm256_set1_ps(32767.0f);
|
|
926
|
+
|
|
927
|
+
// The main loop:
|
|
928
|
+
nk_size_t i = 0;
|
|
929
|
+
for (; i + 8 <= n; i += 8) {
|
|
930
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i *)(a + i))));
|
|
931
|
+
__m256 b_f32x8 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i *)(b + i))));
|
|
932
|
+
__m256 c_f32x8 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i *)(c + i))));
|
|
933
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
934
|
+
__m256 abc_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
935
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, abc_f32x8);
|
|
936
|
+
result_f32x8 = _mm256_max_ps(result_f32x8, min_f32x8);
|
|
937
|
+
result_f32x8 = _mm256_min_ps(result_f32x8, max_f32x8);
|
|
938
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
939
|
+
// Casting down to 16-bit integers is tricky!
|
|
940
|
+
__m128i result_i16x8 = _mm_packs_epi32(_mm256_castsi256_si128(result_i32x8),
|
|
941
|
+
_mm256_extracti128_si256(result_i32x8, 1));
|
|
942
|
+
_mm_storeu_si128((__m128i *)(result + i), result_i16x8);
|
|
943
|
+
}
|
|
944
|
+
|
|
945
|
+
// The tail:
|
|
946
|
+
for (; i < n; ++i) {
|
|
947
|
+
nk_f32_t ai = a[i], bi = b[i], ci = c[i];
|
|
948
|
+
nk_f32_t sum = alpha_f32 * ai * bi + beta_f32 * ci;
|
|
949
|
+
nk_f32_to_i16_serial(&sum, result + i);
|
|
950
|
+
}
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
NK_PUBLIC void nk_each_sum_u16_haswell(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result) {
|
|
954
|
+
// The main loop:
|
|
955
|
+
nk_size_t i = 0;
|
|
956
|
+
for (; i + 16 <= n; i += 16) {
|
|
957
|
+
__m256i a_vec = _mm256_loadu_si256((__m256i *)(a + i));
|
|
958
|
+
__m256i b_vec = _mm256_loadu_si256((__m256i *)(b + i));
|
|
959
|
+
__m256i sum_vec = _mm256_adds_epu16(a_vec, b_vec);
|
|
960
|
+
_mm256_storeu_si256((__m256i *)(result + i), sum_vec);
|
|
961
|
+
}
|
|
962
|
+
|
|
963
|
+
// The tail:
|
|
964
|
+
for (; i < n; ++i) {
|
|
965
|
+
nk_u64_t ai = a[i], bi = b[i];
|
|
966
|
+
nk_u64_t sum = ai + bi;
|
|
967
|
+
nk_u64_to_u16_serial(&sum, result + i);
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
|
|
971
|
+
NK_PUBLIC void nk_each_scale_u16_haswell(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
972
|
+
nk_u16_t *result) {
|
|
973
|
+
nk_f32_t alpha_f32 = *alpha;
|
|
974
|
+
nk_f32_t beta_f32 = *beta;
|
|
975
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_f32);
|
|
976
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_f32);
|
|
977
|
+
__m256 min_f32x8 = _mm256_setzero_ps();
|
|
978
|
+
__m256 max_f32x8 = _mm256_set1_ps(65535.0f);
|
|
979
|
+
|
|
980
|
+
// The main loop:
|
|
981
|
+
nk_size_t i = 0;
|
|
982
|
+
for (; i + 8 <= n; i += 8) {
|
|
983
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i *)(a + i))));
|
|
984
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
985
|
+
result_f32x8 = _mm256_max_ps(result_f32x8, min_f32x8);
|
|
986
|
+
result_f32x8 = _mm256_min_ps(result_f32x8, max_f32x8);
|
|
987
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
988
|
+
// Casting down to 16-bit integers is tricky!
|
|
989
|
+
__m128i result_u16x8 = _mm_packus_epi32(_mm256_castsi256_si128(result_i32x8),
|
|
990
|
+
_mm256_extracti128_si256(result_i32x8, 1));
|
|
991
|
+
_mm_storeu_si128((__m128i *)(result + i), result_u16x8);
|
|
992
|
+
}
|
|
993
|
+
|
|
994
|
+
// The tail:
|
|
995
|
+
for (; i < n; ++i) {
|
|
996
|
+
nk_f32_t ai = a[i];
|
|
997
|
+
nk_f32_t sum = alpha_f32 * ai + beta_f32;
|
|
998
|
+
nk_f32_to_u16_serial(&sum, result + i);
|
|
999
|
+
}
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
NK_PUBLIC void nk_each_fma_u16_haswell( //
|
|
1003
|
+
nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n, //
|
|
1004
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result) {
|
|
1005
|
+
nk_f32_t alpha_f32 = *alpha;
|
|
1006
|
+
nk_f32_t beta_f32 = *beta;
|
|
1007
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(alpha_f32);
|
|
1008
|
+
__m256 beta_f32x8 = _mm256_set1_ps(beta_f32);
|
|
1009
|
+
__m256 min_f32x8 = _mm256_setzero_ps();
|
|
1010
|
+
__m256 max_f32x8 = _mm256_set1_ps(65535.0f);
|
|
1011
|
+
|
|
1012
|
+
// The main loop:
|
|
1013
|
+
nk_size_t i = 0;
|
|
1014
|
+
for (; i + 8 <= n; i += 8) {
|
|
1015
|
+
__m256 a_f32x8 = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i *)(a + i))));
|
|
1016
|
+
__m256 b_f32x8 = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i *)(b + i))));
|
|
1017
|
+
__m256 c_f32x8 = _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i *)(c + i))));
|
|
1018
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
1019
|
+
__m256 abc_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
1020
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, abc_f32x8);
|
|
1021
|
+
result_f32x8 = _mm256_max_ps(result_f32x8, min_f32x8);
|
|
1022
|
+
result_f32x8 = _mm256_min_ps(result_f32x8, max_f32x8);
|
|
1023
|
+
__m256i result_i32x8 = _mm256_cvtps_epi32(result_f32x8);
|
|
1024
|
+
// Casting down to 16-bit integers is tricky!
|
|
1025
|
+
__m128i result_u16x8 = _mm_packus_epi32(_mm256_castsi256_si128(result_i32x8),
|
|
1026
|
+
_mm256_extracti128_si256(result_i32x8, 1));
|
|
1027
|
+
_mm_storeu_si128((__m128i *)(result + i), result_u16x8);
|
|
1028
|
+
}
|
|
1029
|
+
|
|
1030
|
+
// The tail:
|
|
1031
|
+
for (; i < n; ++i) {
|
|
1032
|
+
nk_f32_t ai = a[i], bi = b[i], ci = c[i];
|
|
1033
|
+
nk_f32_t sum = alpha_f32 * ai * bi + beta_f32 * ci;
|
|
1034
|
+
nk_f32_to_u16_serial(&sum, result + i);
|
|
1035
|
+
}
|
|
1036
|
+
}
|
|
1037
|
+
|
|
1038
|
+
NK_INTERNAL __m256i _mm256_adds_epi32_haswell(__m256i a, __m256i b) {
|
|
1039
|
+
__m256i sum_i32x8 = _mm256_add_epi32(a, b);
|
|
1040
|
+
__m256i a_xor_b_i32x8 = _mm256_xor_si256(a, b);
|
|
1041
|
+
__m256i sum_xor_a_i32x8 = _mm256_xor_si256(sum_i32x8, a);
|
|
1042
|
+
// ~(a^b) & (sum^a): overflow iff same-sign inputs produce different-sign result
|
|
1043
|
+
__m256i overflow_i32x8 = _mm256_srai_epi32(_mm256_andnot_si256(a_xor_b_i32x8, sum_xor_a_i32x8), 31);
|
|
1044
|
+
// Positive overflow → INT32_MAX, negative overflow → INT32_MIN
|
|
1045
|
+
__m256i max_i32x8 = _mm256_set1_epi32(0x7FFFFFFF);
|
|
1046
|
+
__m256i min_i32x8 = _mm256_set1_epi32((int)0x80000000);
|
|
1047
|
+
__m256i saturated_i32x8 = _mm256_blendv_epi8(max_i32x8, min_i32x8, _mm256_srai_epi32(a, 31));
|
|
1048
|
+
return _mm256_blendv_epi8(sum_i32x8, saturated_i32x8, overflow_i32x8);
|
|
1049
|
+
}
|
|
1050
|
+
|
|
1051
|
+
NK_PUBLIC void nk_each_sum_i32_haswell(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result) {
|
|
1052
|
+
// The main loop:
|
|
1053
|
+
nk_size_t i = 0;
|
|
1054
|
+
for (; i + 8 <= n; i += 8) {
|
|
1055
|
+
__m256i a_vec = _mm256_loadu_si256((__m256i *)(a + i));
|
|
1056
|
+
__m256i b_vec = _mm256_loadu_si256((__m256i *)(b + i));
|
|
1057
|
+
__m256i sum_vec = _mm256_adds_epi32_haswell(a_vec, b_vec);
|
|
1058
|
+
_mm256_storeu_si256((__m256i *)(result + i), sum_vec);
|
|
1059
|
+
}
|
|
1060
|
+
|
|
1061
|
+
// The tail:
|
|
1062
|
+
for (; i < n; ++i) {
|
|
1063
|
+
nk_i64_t ai = a[i], bi = b[i];
|
|
1064
|
+
nk_i64_t sum = ai + bi;
|
|
1065
|
+
nk_i64_to_i32_serial(&sum, result + i);
|
|
1066
|
+
}
|
|
1067
|
+
}
|
|
1068
|
+
|
|
1069
|
+
NK_PUBLIC void nk_each_scale_i32_haswell(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1070
|
+
nk_i32_t *result) {
|
|
1071
|
+
nk_f64_t alpha_val = *alpha;
|
|
1072
|
+
nk_f64_t beta_val = *beta;
|
|
1073
|
+
__m256d alpha_f64x4 = _mm256_set1_pd(alpha_val);
|
|
1074
|
+
__m256d beta_f64x4 = _mm256_set1_pd(beta_val);
|
|
1075
|
+
__m256d min_f64x4 = _mm256_set1_pd(-2147483648.0);
|
|
1076
|
+
__m256d max_f64x4 = _mm256_set1_pd(2147483647.0);
|
|
1077
|
+
|
|
1078
|
+
// The main loop:
|
|
1079
|
+
nk_size_t i = 0;
|
|
1080
|
+
for (; i + 4 <= n; i += 4) {
|
|
1081
|
+
__m256d a_f64x4 = _mm256_cvtepi32_pd(_mm_loadu_si128((__m128i *)(a + i)));
|
|
1082
|
+
__m256d result_f64x4 = _mm256_fmadd_pd(a_f64x4, alpha_f64x4, beta_f64x4);
|
|
1083
|
+
// Clip to the largest values representable by 32-bit integers.
|
|
1084
|
+
result_f64x4 = _mm256_max_pd(result_f64x4, min_f64x4);
|
|
1085
|
+
result_f64x4 = _mm256_min_pd(result_f64x4, max_f64x4);
|
|
1086
|
+
__m128i result_i32x4 = _mm256_cvtpd_epi32(result_f64x4);
|
|
1087
|
+
_mm_storeu_si128((__m128i *)(result + i), result_i32x4);
|
|
1088
|
+
}
|
|
1089
|
+
|
|
1090
|
+
// The tail:
|
|
1091
|
+
for (; i < n; ++i) {
|
|
1092
|
+
nk_f64_t ai = a[i];
|
|
1093
|
+
nk_f64_t sum = alpha_val * ai + beta_val;
|
|
1094
|
+
nk_f64_to_i32_serial(&sum, result + i);
|
|
1095
|
+
}
|
|
1096
|
+
}
|
|
1097
|
+
|
|
1098
|
+
NK_PUBLIC void nk_each_fma_i32_haswell( //
|
|
1099
|
+
nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n, //
|
|
1100
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result) {
|
|
1101
|
+
nk_f64_t alpha_val = *alpha;
|
|
1102
|
+
nk_f64_t beta_val = *beta;
|
|
1103
|
+
__m256d alpha_f64x4 = _mm256_set1_pd(alpha_val);
|
|
1104
|
+
__m256d beta_f64x4 = _mm256_set1_pd(beta_val);
|
|
1105
|
+
__m256d min_f64x4 = _mm256_set1_pd(-2147483648.0);
|
|
1106
|
+
__m256d max_f64x4 = _mm256_set1_pd(2147483647.0);
|
|
1107
|
+
|
|
1108
|
+
// The main loop:
|
|
1109
|
+
nk_size_t i = 0;
|
|
1110
|
+
for (; i + 4 <= n; i += 4) {
|
|
1111
|
+
__m256d a_f64x4 = _mm256_cvtepi32_pd(_mm_loadu_si128((__m128i *)(a + i)));
|
|
1112
|
+
__m256d b_f64x4 = _mm256_cvtepi32_pd(_mm_loadu_si128((__m128i *)(b + i)));
|
|
1113
|
+
__m256d c_f64x4 = _mm256_cvtepi32_pd(_mm_loadu_si128((__m128i *)(c + i)));
|
|
1114
|
+
__m256d ab_f64x4 = _mm256_mul_pd(a_f64x4, b_f64x4);
|
|
1115
|
+
__m256d ab_scaled_f64x4 = _mm256_mul_pd(ab_f64x4, alpha_f64x4);
|
|
1116
|
+
__m256d result_f64x4 = _mm256_fmadd_pd(c_f64x4, beta_f64x4, ab_scaled_f64x4);
|
|
1117
|
+
// Clip to the largest values representable by 32-bit integers.
|
|
1118
|
+
result_f64x4 = _mm256_max_pd(result_f64x4, min_f64x4);
|
|
1119
|
+
result_f64x4 = _mm256_min_pd(result_f64x4, max_f64x4);
|
|
1120
|
+
__m128i result_i32x4 = _mm256_cvtpd_epi32(result_f64x4);
|
|
1121
|
+
_mm_storeu_si128((__m128i *)(result + i), result_i32x4);
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
// The tail:
|
|
1125
|
+
for (; i < n; ++i) {
|
|
1126
|
+
nk_f64_t ai = a[i], bi = b[i], ci = c[i];
|
|
1127
|
+
nk_f64_t sum = alpha_val * ai * bi + beta_val * ci;
|
|
1128
|
+
nk_f64_to_i32_serial(&sum, result + i);
|
|
1129
|
+
}
|
|
1130
|
+
}
|
|
1131
|
+
|
|
1132
|
+
NK_INTERNAL __m256i _mm256_adds_epu32_haswell(__m256i a, __m256i b) {
|
|
1133
|
+
__m256i sum_u32x8 = _mm256_add_epi32(a, b);
|
|
1134
|
+
__m256i max_u32x8 = _mm256_set1_epi32((int)0xFFFFFFFF);
|
|
1135
|
+
// Overflow iff sum < a (unsigned wrapping). max_epu32(sum, a) != sum means overflow.
|
|
1136
|
+
__m256i no_overflow_u32x8 = _mm256_cmpeq_epi32(_mm256_max_epu32(sum_u32x8, a), sum_u32x8);
|
|
1137
|
+
return _mm256_blendv_epi8(max_u32x8, sum_u32x8, no_overflow_u32x8);
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
NK_INTERNAL __m256d _mm256_cvtepu32_pd_haswell(__m128i a) {
|
|
1141
|
+
// TODO: Converting unsigned 32-bit integers to double-precision floats isn't trivial in AVX2.
|
|
1142
|
+
// Let's convert the lower 31 bits to a double-precision float.
|
|
1143
|
+
// And then conditionally add 2³¹ to the result if the MSB is set.
|
|
1144
|
+
//
|
|
1145
|
+
// __m256d result = _mm256_cvtepi32_pd(_mm_and_si128(a, _mm_set1_epi32(0x7FFFFFFF)));
|
|
1146
|
+
// int should_increment = (_mm_movemask_epi8(a) & 0x8888);
|
|
1147
|
+
// should_increment = should_increment / 0x8888; // Transform something like 0b1000100010001000 to 0b1111
|
|
1148
|
+
// __m256d incremented = _mm256_add_pd(result, _mm256_set1_pd(2147483648.0));
|
|
1149
|
+
// result = _mm256_blend_pd(result, incremented, should_increment);
|
|
1150
|
+
nk_u32_t from[4];
|
|
1151
|
+
nk_f64_t to[4];
|
|
1152
|
+
_mm_storeu_si128((__m128i *)from, a);
|
|
1153
|
+
to[0] = (nk_f64_t)from[0];
|
|
1154
|
+
to[1] = (nk_f64_t)from[1];
|
|
1155
|
+
to[2] = (nk_f64_t)from[2];
|
|
1156
|
+
to[3] = (nk_f64_t)from[3];
|
|
1157
|
+
return _mm256_loadu_pd(to);
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
NK_INTERNAL __m128i _mm256_cvtpd_epu32_haswell(__m256d a) {
|
|
1161
|
+
//? For now let's avoid SIMD and just use serial conversion.
|
|
1162
|
+
nk_f64_t from[4];
|
|
1163
|
+
nk_u32_t to[4];
|
|
1164
|
+
_mm256_storeu_pd(from, a);
|
|
1165
|
+
to[0] = (nk_u32_t)from[0];
|
|
1166
|
+
to[1] = (nk_u32_t)from[1];
|
|
1167
|
+
to[2] = (nk_u32_t)from[2];
|
|
1168
|
+
to[3] = (nk_u32_t)from[3];
|
|
1169
|
+
return _mm_loadu_si128((__m128i *)to);
|
|
1170
|
+
}
|
|
1171
|
+
|
|
1172
|
+
NK_PUBLIC void nk_each_sum_u32_haswell(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
1173
|
+
// The main loop:
|
|
1174
|
+
nk_size_t i = 0;
|
|
1175
|
+
for (; i + 8 <= n; i += 8) {
|
|
1176
|
+
__m256i a_vec = _mm256_loadu_si256((__m256i *)(a + i));
|
|
1177
|
+
__m256i b_vec = _mm256_loadu_si256((__m256i *)(b + i));
|
|
1178
|
+
__m256i sum_vec = _mm256_adds_epu32_haswell(a_vec, b_vec);
|
|
1179
|
+
_mm256_storeu_si256((__m256i *)(result + i), sum_vec);
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1182
|
+
// The tail:
|
|
1183
|
+
for (; i < n; ++i) {
|
|
1184
|
+
nk_i64_t ai = a[i], bi = b[i];
|
|
1185
|
+
nk_i64_t sum = ai + bi;
|
|
1186
|
+
nk_i64_to_u32_serial(&sum, result + i);
|
|
1187
|
+
}
|
|
1188
|
+
}
|
|
1189
|
+
|
|
1190
|
+
NK_PUBLIC void nk_each_scale_u32_haswell(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1191
|
+
nk_u32_t *result) {
|
|
1192
|
+
nk_f64_t alpha_val = *alpha;
|
|
1193
|
+
nk_f64_t beta_val = *beta;
|
|
1194
|
+
__m256d alpha_f64x4 = _mm256_set1_pd(alpha_val);
|
|
1195
|
+
__m256d beta_f64x4 = _mm256_set1_pd(beta_val);
|
|
1196
|
+
__m256d min_f64x4 = _mm256_set1_pd(0);
|
|
1197
|
+
__m256d max_f64x4 = _mm256_set1_pd(4294967295.0);
|
|
1198
|
+
|
|
1199
|
+
// The main loop:
|
|
1200
|
+
nk_size_t i = 0;
|
|
1201
|
+
for (; i + 4 <= n; i += 4) {
|
|
1202
|
+
__m256d a_f64x4 = _mm256_cvtepu32_pd_haswell(_mm_loadu_si128((__m128i *)(a + i)));
|
|
1203
|
+
__m256d result_f64x4 = _mm256_fmadd_pd(a_f64x4, alpha_f64x4, beta_f64x4);
|
|
1204
|
+
// Clip to the largest values representable by 32-bit integers.
|
|
1205
|
+
result_f64x4 = _mm256_max_pd(result_f64x4, min_f64x4);
|
|
1206
|
+
result_f64x4 = _mm256_min_pd(result_f64x4, max_f64x4);
|
|
1207
|
+
__m128i result_u32x4 = _mm256_cvtpd_epu32_haswell(result_f64x4);
|
|
1208
|
+
_mm_storeu_si128((__m128i *)(result + i), result_u32x4);
|
|
1209
|
+
}
|
|
1210
|
+
|
|
1211
|
+
// The tail:
|
|
1212
|
+
for (; i < n; ++i) {
|
|
1213
|
+
nk_f64_t ai = a[i];
|
|
1214
|
+
nk_f64_t sum = alpha_val * ai + beta_val;
|
|
1215
|
+
nk_f64_to_u32_serial(&sum, result + i);
|
|
1216
|
+
}
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
NK_PUBLIC void nk_each_fma_u32_haswell( //
|
|
1220
|
+
nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n, //
|
|
1221
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result) {
|
|
1222
|
+
nk_f64_t alpha_val = *alpha;
|
|
1223
|
+
nk_f64_t beta_val = *beta;
|
|
1224
|
+
__m256d alpha_f64x4 = _mm256_set1_pd(alpha_val);
|
|
1225
|
+
__m256d beta_f64x4 = _mm256_set1_pd(beta_val);
|
|
1226
|
+
__m256d min_f64x4 = _mm256_set1_pd(0);
|
|
1227
|
+
__m256d max_f64x4 = _mm256_set1_pd(4294967295.0);
|
|
1228
|
+
|
|
1229
|
+
// The main loop:
|
|
1230
|
+
nk_size_t i = 0;
|
|
1231
|
+
for (; i + 4 <= n; i += 4) {
|
|
1232
|
+
__m256d a_f64x4 = _mm256_cvtepu32_pd_haswell(_mm_loadu_si128((__m128i *)(a + i)));
|
|
1233
|
+
__m256d b_f64x4 = _mm256_cvtepu32_pd_haswell(_mm_loadu_si128((__m128i *)(b + i)));
|
|
1234
|
+
__m256d c_f64x4 = _mm256_cvtepu32_pd_haswell(_mm_loadu_si128((__m128i *)(c + i)));
|
|
1235
|
+
__m256d ab_f64x4 = _mm256_mul_pd(a_f64x4, b_f64x4);
|
|
1236
|
+
__m256d ab_scaled_f64x4 = _mm256_mul_pd(ab_f64x4, alpha_f64x4);
|
|
1237
|
+
__m256d result_f64x4 = _mm256_fmadd_pd(c_f64x4, beta_f64x4, ab_scaled_f64x4);
|
|
1238
|
+
// Clip to the largest values representable by 32-bit integers.
|
|
1239
|
+
result_f64x4 = _mm256_max_pd(result_f64x4, min_f64x4);
|
|
1240
|
+
result_f64x4 = _mm256_min_pd(result_f64x4, max_f64x4);
|
|
1241
|
+
__m128i result_u32x4 = _mm256_cvtpd_epu32_haswell(result_f64x4);
|
|
1242
|
+
_mm_storeu_si128((__m128i *)(result + i), result_u32x4);
|
|
1243
|
+
}
|
|
1244
|
+
|
|
1245
|
+
// The tail:
|
|
1246
|
+
for (; i < n; ++i) {
|
|
1247
|
+
nk_f64_t ai = a[i], bi = b[i], ci = c[i];
|
|
1248
|
+
nk_f64_t sum = alpha_val * ai * bi + beta_val * ci;
|
|
1249
|
+
nk_f64_to_u32_serial(&sum, result + i);
|
|
1250
|
+
}
|
|
1251
|
+
}
|
|
1252
|
+
|
|
1253
|
+
NK_PUBLIC void nk_each_sum_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
|
|
1254
|
+
nk_size_t i = 0;
|
|
1255
|
+
for (; i + 8 <= n; i += 8) {
|
|
1256
|
+
__m128i a_e4m3x8 = _mm_loadl_epi64((__m128i const *)(a + i));
|
|
1257
|
+
__m128i b_e4m3x8 = _mm_loadl_epi64((__m128i const *)(b + i));
|
|
1258
|
+
__m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_e4m3x8);
|
|
1259
|
+
__m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_e4m3x8);
|
|
1260
|
+
__m256 result_f32x8 = _mm256_add_ps(a_f32x8, b_f32x8);
|
|
1261
|
+
__m128i result_e4m3x8 = nk_f32x8_to_e4m3x8_haswell_(result_f32x8);
|
|
1262
|
+
_mm_storel_epi64((__m128i *)(result + i), result_e4m3x8);
|
|
1263
|
+
}
|
|
1264
|
+
for (; i < n; ++i) {
|
|
1265
|
+
nk_f32_t ai, bi;
|
|
1266
|
+
nk_e4m3_to_f32_serial(a + i, &ai);
|
|
1267
|
+
nk_e4m3_to_f32_serial(b + i, &bi);
|
|
1268
|
+
nk_f32_t sum = ai + bi;
|
|
1269
|
+
nk_f32_to_e4m3_serial(&sum, result + i);
|
|
1270
|
+
}
|
|
1271
|
+
}
|
|
1272
|
+
|
|
1273
|
+
NK_PUBLIC void nk_each_sum_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result) {
|
|
1274
|
+
nk_size_t i = 0;
|
|
1275
|
+
for (; i + 8 <= n; i += 8) {
|
|
1276
|
+
__m128i a_e5m2x8 = _mm_loadl_epi64((__m128i const *)(a + i));
|
|
1277
|
+
__m128i b_e5m2x8 = _mm_loadl_epi64((__m128i const *)(b + i));
|
|
1278
|
+
__m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(a_e5m2x8);
|
|
1279
|
+
__m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(b_e5m2x8);
|
|
1280
|
+
__m256 result_f32x8 = _mm256_add_ps(a_f32x8, b_f32x8);
|
|
1281
|
+
__m128i result_e5m2x8 = nk_f32x8_to_e5m2x8_haswell_(result_f32x8);
|
|
1282
|
+
_mm_storel_epi64((__m128i *)(result + i), result_e5m2x8);
|
|
1283
|
+
}
|
|
1284
|
+
for (; i < n; ++i) {
|
|
1285
|
+
nk_f32_t ai, bi;
|
|
1286
|
+
nk_e5m2_to_f32_serial(a + i, &ai);
|
|
1287
|
+
nk_e5m2_to_f32_serial(b + i, &bi);
|
|
1288
|
+
nk_f32_t sum = ai + bi;
|
|
1289
|
+
nk_f32_to_e5m2_serial(&sum, result + i);
|
|
1290
|
+
}
|
|
1291
|
+
}
|
|
1292
|
+
|
|
1293
|
+
NK_PUBLIC void nk_each_scale_e4m3_haswell(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1294
|
+
nk_e4m3_t *result) {
|
|
1295
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(*alpha);
|
|
1296
|
+
__m256 beta_f32x8 = _mm256_set1_ps(*beta);
|
|
1297
|
+
nk_size_t i = 0;
|
|
1298
|
+
for (; i + 8 <= n; i += 8) {
|
|
1299
|
+
__m128i a_e4m3x8 = _mm_loadl_epi64((__m128i const *)(a + i));
|
|
1300
|
+
__m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_e4m3x8);
|
|
1301
|
+
// FP8 rounding note: FMA is acceptable here because scale computes (α × a + β),
|
|
1302
|
+
// a single multiply-add operation where single-rounding preserves accuracy.
|
|
1303
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
1304
|
+
__m128i result_e4m3x8 = nk_f32x8_to_e4m3x8_haswell_(result_f32x8);
|
|
1305
|
+
_mm_storel_epi64((__m128i *)(result + i), result_e4m3x8);
|
|
1306
|
+
}
|
|
1307
|
+
for (; i < n; ++i) {
|
|
1308
|
+
nk_f32_t ai;
|
|
1309
|
+
nk_e4m3_to_f32_serial(a + i, &ai);
|
|
1310
|
+
nk_f32_t scaled = *alpha * ai + *beta;
|
|
1311
|
+
nk_f32_to_e4m3_serial(&scaled, result + i);
|
|
1312
|
+
}
|
|
1313
|
+
}
|
|
1314
|
+
|
|
1315
|
+
NK_PUBLIC void nk_each_scale_e5m2_haswell(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1316
|
+
nk_e5m2_t *result) {
|
|
1317
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(*alpha);
|
|
1318
|
+
__m256 beta_f32x8 = _mm256_set1_ps(*beta);
|
|
1319
|
+
nk_size_t i = 0;
|
|
1320
|
+
for (; i + 8 <= n; i += 8) {
|
|
1321
|
+
__m128i a_e5m2x8 = _mm_loadl_epi64((__m128i const *)(a + i));
|
|
1322
|
+
__m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(a_e5m2x8);
|
|
1323
|
+
// FP8 rounding note: FMA is acceptable here because scale computes (α × a + β),
|
|
1324
|
+
// a single multiply-add operation where single-rounding preserves accuracy.
|
|
1325
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
1326
|
+
__m128i result_e5m2x8 = nk_f32x8_to_e5m2x8_haswell_(result_f32x8);
|
|
1327
|
+
_mm_storel_epi64((__m128i *)(result + i), result_e5m2x8);
|
|
1328
|
+
}
|
|
1329
|
+
for (; i < n; ++i) {
|
|
1330
|
+
nk_f32_t ai;
|
|
1331
|
+
nk_e5m2_to_f32_serial(a + i, &ai);
|
|
1332
|
+
nk_f32_t scaled = *alpha * ai + *beta;
|
|
1333
|
+
nk_f32_to_e5m2_serial(&scaled, result + i);
|
|
1334
|
+
}
|
|
1335
|
+
}
|
|
1336
|
+
|
|
1337
|
+
NK_PUBLIC void nk_each_blend_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1338
|
+
nk_f32_t const *beta, nk_e4m3_t *result) {
|
|
1339
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(*alpha);
|
|
1340
|
+
__m256 beta_f32x8 = _mm256_set1_ps(*beta);
|
|
1341
|
+
nk_size_t i = 0;
|
|
1342
|
+
for (; i + 8 <= n; i += 8) {
|
|
1343
|
+
__m128i a_e4m3x8 = _mm_loadl_epi64((__m128i const *)(a + i));
|
|
1344
|
+
__m128i b_e4m3x8 = _mm_loadl_epi64((__m128i const *)(b + i));
|
|
1345
|
+
__m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_e4m3x8);
|
|
1346
|
+
__m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_e4m3x8);
|
|
1347
|
+
__m256 a_scaled_f32x8 = _mm256_mul_ps(a_f32x8, alpha_f32x8);
|
|
1348
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(b_f32x8, beta_f32x8, a_scaled_f32x8);
|
|
1349
|
+
__m128i result_e4m3x8 = nk_f32x8_to_e4m3x8_haswell_(result_f32x8);
|
|
1350
|
+
_mm_storel_epi64((__m128i *)(result + i), result_e4m3x8);
|
|
1351
|
+
}
|
|
1352
|
+
for (; i < n; ++i) {
|
|
1353
|
+
nk_f32_t ai, bi;
|
|
1354
|
+
nk_e4m3_to_f32_serial(a + i, &ai);
|
|
1355
|
+
nk_e4m3_to_f32_serial(b + i, &bi);
|
|
1356
|
+
nk_f32_t blended = *alpha * ai + *beta * bi;
|
|
1357
|
+
nk_f32_to_e4m3_serial(&blended, result + i);
|
|
1358
|
+
}
|
|
1359
|
+
}
|
|
1360
|
+
|
|
1361
|
+
NK_PUBLIC void nk_each_blend_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1362
|
+
nk_f32_t const *beta, nk_e5m2_t *result) {
|
|
1363
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(*alpha);
|
|
1364
|
+
__m256 beta_f32x8 = _mm256_set1_ps(*beta);
|
|
1365
|
+
nk_size_t i = 0;
|
|
1366
|
+
for (; i + 8 <= n; i += 8) {
|
|
1367
|
+
__m128i a_e5m2x8 = _mm_loadl_epi64((__m128i const *)(a + i));
|
|
1368
|
+
__m128i b_e5m2x8 = _mm_loadl_epi64((__m128i const *)(b + i));
|
|
1369
|
+
__m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(a_e5m2x8);
|
|
1370
|
+
__m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(b_e5m2x8);
|
|
1371
|
+
__m256 a_scaled_f32x8 = _mm256_mul_ps(a_f32x8, alpha_f32x8);
|
|
1372
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(b_f32x8, beta_f32x8, a_scaled_f32x8);
|
|
1373
|
+
__m128i result_e5m2x8 = nk_f32x8_to_e5m2x8_haswell_(result_f32x8);
|
|
1374
|
+
_mm_storel_epi64((__m128i *)(result + i), result_e5m2x8);
|
|
1375
|
+
}
|
|
1376
|
+
for (; i < n; ++i) {
|
|
1377
|
+
nk_f32_t ai, bi;
|
|
1378
|
+
nk_e5m2_to_f32_serial(a + i, &ai);
|
|
1379
|
+
nk_e5m2_to_f32_serial(b + i, &bi);
|
|
1380
|
+
nk_f32_t blended = *alpha * ai + *beta * bi;
|
|
1381
|
+
nk_f32_to_e5m2_serial(&blended, result + i);
|
|
1382
|
+
}
|
|
1383
|
+
}
|
|
1384
|
+
|
|
1385
|
+
NK_PUBLIC void nk_each_fma_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
1386
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result) {
|
|
1387
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(*alpha);
|
|
1388
|
+
__m256 beta_f32x8 = _mm256_set1_ps(*beta);
|
|
1389
|
+
nk_size_t i = 0;
|
|
1390
|
+
for (; i + 8 <= n; i += 8) {
|
|
1391
|
+
__m128i a_e4m3x8 = _mm_loadl_epi64((__m128i const *)(a + i));
|
|
1392
|
+
__m128i b_e4m3x8 = _mm_loadl_epi64((__m128i const *)(b + i));
|
|
1393
|
+
__m128i c_e4m3x8 = _mm_loadl_epi64((__m128i const *)(c + i));
|
|
1394
|
+
__m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_e4m3x8);
|
|
1395
|
+
__m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_e4m3x8);
|
|
1396
|
+
__m256 c_f32x8 = nk_e4m3x8_to_f32x8_haswell_(c_e4m3x8);
|
|
1397
|
+
// FP8 rounding note: Hybrid approach - use separate MUL for (a × b) and (α × a × b) to
|
|
1398
|
+
// preserve intermediate rounding, then FMA for final addition since it matches scalar
|
|
1399
|
+
// semantics of (α × a × b + β × c) when the multiply term is already computed.
|
|
1400
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
1401
|
+
__m256 ab_scaled_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
1402
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, ab_scaled_f32x8);
|
|
1403
|
+
__m128i result_e4m3x8 = nk_f32x8_to_e4m3x8_haswell_(result_f32x8);
|
|
1404
|
+
_mm_storel_epi64((__m128i *)(result + i), result_e4m3x8);
|
|
1405
|
+
}
|
|
1406
|
+
for (; i < n; ++i) {
|
|
1407
|
+
nk_f32_t ai, bi, ci;
|
|
1408
|
+
nk_e4m3_to_f32_serial(a + i, &ai);
|
|
1409
|
+
nk_e4m3_to_f32_serial(b + i, &bi);
|
|
1410
|
+
nk_e4m3_to_f32_serial(c + i, &ci);
|
|
1411
|
+
nk_f32_t fma = *alpha * ai * bi + *beta * ci;
|
|
1412
|
+
nk_f32_to_e4m3_serial(&fma, result + i);
|
|
1413
|
+
}
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
NK_PUBLIC void nk_each_fma_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
1417
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result) {
|
|
1418
|
+
__m256 alpha_f32x8 = _mm256_set1_ps(*alpha);
|
|
1419
|
+
__m256 beta_f32x8 = _mm256_set1_ps(*beta);
|
|
1420
|
+
nk_size_t i = 0;
|
|
1421
|
+
for (; i + 8 <= n; i += 8) {
|
|
1422
|
+
__m128i a_e5m2x8 = _mm_loadl_epi64((__m128i const *)(a + i));
|
|
1423
|
+
__m128i b_e5m2x8 = _mm_loadl_epi64((__m128i const *)(b + i));
|
|
1424
|
+
__m128i c_e5m2x8 = _mm_loadl_epi64((__m128i const *)(c + i));
|
|
1425
|
+
__m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(a_e5m2x8);
|
|
1426
|
+
__m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(b_e5m2x8);
|
|
1427
|
+
__m256 c_f32x8 = nk_e5m2x8_to_f32x8_haswell_(c_e5m2x8);
|
|
1428
|
+
// FP8 rounding note: Hybrid approach - use separate MUL for (a × b) and (α × a × b) to
|
|
1429
|
+
// preserve intermediate rounding, then FMA for final addition since it matches scalar
|
|
1430
|
+
// semantics of (α × a × b + β × c) when the multiply term is already computed.
|
|
1431
|
+
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
1432
|
+
__m256 ab_scaled_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
1433
|
+
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, ab_scaled_f32x8);
|
|
1434
|
+
__m128i result_e5m2x8 = nk_f32x8_to_e5m2x8_haswell_(result_f32x8);
|
|
1435
|
+
_mm_storel_epi64((__m128i *)(result + i), result_e5m2x8);
|
|
1436
|
+
}
|
|
1437
|
+
for (; i < n; ++i) {
|
|
1438
|
+
nk_f32_t ai, bi, ci;
|
|
1439
|
+
nk_e5m2_to_f32_serial(a + i, &ai);
|
|
1440
|
+
nk_e5m2_to_f32_serial(b + i, &bi);
|
|
1441
|
+
nk_e5m2_to_f32_serial(c + i, &ci);
|
|
1442
|
+
nk_f32_t fma = *alpha * ai * bi + *beta * ci;
|
|
1443
|
+
nk_f32_to_e5m2_serial(&fma, result + i);
|
|
1444
|
+
}
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
NK_PUBLIC void nk_each_scale_f32c_haswell(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha,
|
|
1448
|
+
nk_f32c_t const *beta, nk_f32c_t *result) {
|
|
1449
|
+
nk_f32_t const *a_f32 = (nk_f32_t const *)a;
|
|
1450
|
+
nk_f32_t *result_f32 = (nk_f32_t *)result;
|
|
1451
|
+
__m256 alpha_real_f32x8 = _mm256_set1_ps(alpha->real);
|
|
1452
|
+
__m256 alpha_imag_f32x8 = _mm256_set1_ps(alpha->imag);
|
|
1453
|
+
__m256 beta_f32x8 = _mm256_setr_ps(beta->real, beta->imag, beta->real, beta->imag, beta->real, beta->imag,
|
|
1454
|
+
beta->real, beta->imag);
|
|
1455
|
+
nk_size_t i = 0;
|
|
1456
|
+
for (; i + 4 <= n; i += 4) {
|
|
1457
|
+
__m256 a_f32x8 = _mm256_loadu_ps(a_f32 + 2 * i);
|
|
1458
|
+
__m256 a_swapped_f32x8 = _mm256_permute_ps(a_f32x8, 0xB1);
|
|
1459
|
+
__m256 temp_f32x8 = _mm256_mul_ps(alpha_imag_f32x8, a_swapped_f32x8);
|
|
1460
|
+
__m256 y_f32x8 = _mm256_fmaddsub_ps(alpha_real_f32x8, a_f32x8, temp_f32x8);
|
|
1461
|
+
y_f32x8 = _mm256_add_ps(y_f32x8, beta_f32x8);
|
|
1462
|
+
_mm256_storeu_ps(result_f32 + 2 * i, y_f32x8);
|
|
1463
|
+
}
|
|
1464
|
+
for (; i < n; i++) {
|
|
1465
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1466
|
+
result[i].real = alpha->real * a_real - alpha->imag * a_imag + beta->real;
|
|
1467
|
+
result[i].imag = alpha->real * a_imag + alpha->imag * a_real + beta->imag;
|
|
1468
|
+
}
|
|
1469
|
+
}
|
|
1470
|
+
|
|
1471
|
+
NK_PUBLIC void nk_each_scale_f64c_haswell(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha,
|
|
1472
|
+
nk_f64c_t const *beta, nk_f64c_t *result) {
|
|
1473
|
+
nk_f64_t const *a_f64 = (nk_f64_t const *)a;
|
|
1474
|
+
nk_f64_t *result_f64 = (nk_f64_t *)result;
|
|
1475
|
+
__m256d alpha_real_f64x4 = _mm256_set1_pd(alpha->real);
|
|
1476
|
+
__m256d alpha_imag_f64x4 = _mm256_set1_pd(alpha->imag);
|
|
1477
|
+
__m256d beta_f64x4 = _mm256_setr_pd(beta->real, beta->imag, beta->real, beta->imag);
|
|
1478
|
+
nk_size_t i = 0;
|
|
1479
|
+
for (; i + 2 <= n; i += 2) {
|
|
1480
|
+
__m256d a_f64x4 = _mm256_loadu_pd(a_f64 + 2 * i);
|
|
1481
|
+
__m256d a_swapped_f64x4 = _mm256_permute_pd(a_f64x4, 0x5);
|
|
1482
|
+
__m256d temp_f64x4 = _mm256_mul_pd(alpha_imag_f64x4, a_swapped_f64x4);
|
|
1483
|
+
__m256d y_f64x4 = _mm256_fmaddsub_pd(alpha_real_f64x4, a_f64x4, temp_f64x4);
|
|
1484
|
+
y_f64x4 = _mm256_add_pd(y_f64x4, beta_f64x4);
|
|
1485
|
+
_mm256_storeu_pd(result_f64 + 2 * i, y_f64x4);
|
|
1486
|
+
}
|
|
1487
|
+
for (; i < n; i++) {
|
|
1488
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1489
|
+
result[i].real = alpha->real * a_real - alpha->imag * a_imag + beta->real;
|
|
1490
|
+
result[i].imag = alpha->real * a_imag + alpha->imag * a_real + beta->imag;
|
|
1491
|
+
}
|
|
1492
|
+
}
|
|
1493
|
+
|
|
1494
|
+
NK_PUBLIC void nk_each_blend_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
1495
|
+
nk_f32c_t const *beta, nk_f32c_t *result) {
|
|
1496
|
+
nk_f32_t const *a_f32 = (nk_f32_t const *)a;
|
|
1497
|
+
nk_f32_t const *b_f32 = (nk_f32_t const *)b;
|
|
1498
|
+
nk_f32_t *result_f32 = (nk_f32_t *)result;
|
|
1499
|
+
__m256 alpha_real_f32x8 = _mm256_set1_ps(alpha->real);
|
|
1500
|
+
__m256 alpha_imag_f32x8 = _mm256_set1_ps(alpha->imag);
|
|
1501
|
+
__m256 beta_real_f32x8 = _mm256_set1_ps(beta->real);
|
|
1502
|
+
__m256 beta_imag_f32x8 = _mm256_set1_ps(beta->imag);
|
|
1503
|
+
nk_size_t i = 0;
|
|
1504
|
+
for (; i + 4 <= n; i += 4) {
|
|
1505
|
+
__m256 a_f32x8 = _mm256_loadu_ps(a_f32 + 2 * i);
|
|
1506
|
+
__m256 b_f32x8 = _mm256_loadu_ps(b_f32 + 2 * i);
|
|
1507
|
+
__m256 a_swapped_f32x8 = _mm256_permute_ps(a_f32x8, 0xB1);
|
|
1508
|
+
__m256 ta_f32x8 = _mm256_mul_ps(alpha_imag_f32x8, a_swapped_f32x8);
|
|
1509
|
+
__m256 ya_f32x8 = _mm256_fmaddsub_ps(alpha_real_f32x8, a_f32x8, ta_f32x8);
|
|
1510
|
+
__m256 b_swapped_f32x8 = _mm256_permute_ps(b_f32x8, 0xB1);
|
|
1511
|
+
__m256 tb_f32x8 = _mm256_mul_ps(beta_imag_f32x8, b_swapped_f32x8);
|
|
1512
|
+
__m256 yb_f32x8 = _mm256_fmaddsub_ps(beta_real_f32x8, b_f32x8, tb_f32x8);
|
|
1513
|
+
_mm256_storeu_ps(result_f32 + 2 * i, _mm256_add_ps(ya_f32x8, yb_f32x8));
|
|
1514
|
+
}
|
|
1515
|
+
for (; i < n; i++) {
|
|
1516
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1517
|
+
nk_f32_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1518
|
+
nk_f32_t ar = alpha->real * a_real - alpha->imag * a_imag;
|
|
1519
|
+
nk_f32_t ai = alpha->real * a_imag + alpha->imag * a_real;
|
|
1520
|
+
nk_f32_t br = beta->real * b_real - beta->imag * b_imag;
|
|
1521
|
+
nk_f32_t bi = beta->real * b_imag + beta->imag * b_real;
|
|
1522
|
+
result[i].real = ar + br;
|
|
1523
|
+
result[i].imag = ai + bi;
|
|
1524
|
+
}
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
NK_PUBLIC void nk_each_blend_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
1528
|
+
nk_f64c_t const *beta, nk_f64c_t *result) {
|
|
1529
|
+
nk_f64_t const *a_f64 = (nk_f64_t const *)a;
|
|
1530
|
+
nk_f64_t const *b_f64 = (nk_f64_t const *)b;
|
|
1531
|
+
nk_f64_t *result_f64 = (nk_f64_t *)result;
|
|
1532
|
+
__m256d alpha_real_f64x4 = _mm256_set1_pd(alpha->real);
|
|
1533
|
+
__m256d alpha_imag_f64x4 = _mm256_set1_pd(alpha->imag);
|
|
1534
|
+
__m256d beta_real_f64x4 = _mm256_set1_pd(beta->real);
|
|
1535
|
+
__m256d beta_imag_f64x4 = _mm256_set1_pd(beta->imag);
|
|
1536
|
+
nk_size_t i = 0;
|
|
1537
|
+
for (; i + 2 <= n; i += 2) {
|
|
1538
|
+
__m256d a_f64x4 = _mm256_loadu_pd(a_f64 + 2 * i);
|
|
1539
|
+
__m256d b_f64x4 = _mm256_loadu_pd(b_f64 + 2 * i);
|
|
1540
|
+
__m256d a_swapped_f64x4 = _mm256_permute_pd(a_f64x4, 0x5);
|
|
1541
|
+
__m256d ta_f64x4 = _mm256_mul_pd(alpha_imag_f64x4, a_swapped_f64x4);
|
|
1542
|
+
__m256d ya_f64x4 = _mm256_fmaddsub_pd(alpha_real_f64x4, a_f64x4, ta_f64x4);
|
|
1543
|
+
__m256d b_swapped_f64x4 = _mm256_permute_pd(b_f64x4, 0x5);
|
|
1544
|
+
__m256d tb_f64x4 = _mm256_mul_pd(beta_imag_f64x4, b_swapped_f64x4);
|
|
1545
|
+
__m256d yb_f64x4 = _mm256_fmaddsub_pd(beta_real_f64x4, b_f64x4, tb_f64x4);
|
|
1546
|
+
_mm256_storeu_pd(result_f64 + 2 * i, _mm256_add_pd(ya_f64x4, yb_f64x4));
|
|
1547
|
+
}
|
|
1548
|
+
for (; i < n; i++) {
|
|
1549
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1550
|
+
nk_f64_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1551
|
+
nk_f64_t ar = alpha->real * a_real - alpha->imag * a_imag;
|
|
1552
|
+
nk_f64_t ai = alpha->real * a_imag + alpha->imag * a_real;
|
|
1553
|
+
nk_f64_t br = beta->real * b_real - beta->imag * b_imag;
|
|
1554
|
+
nk_f64_t bi = beta->real * b_imag + beta->imag * b_real;
|
|
1555
|
+
result[i].real = ar + br;
|
|
1556
|
+
result[i].imag = ai + bi;
|
|
1557
|
+
}
|
|
1558
|
+
}
|
|
1559
|
+
|
|
1560
|
+
NK_PUBLIC void nk_each_fma_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
1561
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result) {
|
|
1562
|
+
nk_f32_t const *a_f32 = (nk_f32_t const *)a;
|
|
1563
|
+
nk_f32_t const *b_f32 = (nk_f32_t const *)b;
|
|
1564
|
+
nk_f32_t const *c_f32 = (nk_f32_t const *)c;
|
|
1565
|
+
nk_f32_t *result_f32 = (nk_f32_t *)result;
|
|
1566
|
+
__m256 alpha_real_f32x8 = _mm256_set1_ps(alpha->real);
|
|
1567
|
+
__m256 alpha_imag_f32x8 = _mm256_set1_ps(alpha->imag);
|
|
1568
|
+
__m256 beta_real_f32x8 = _mm256_set1_ps(beta->real);
|
|
1569
|
+
__m256 beta_imag_f32x8 = _mm256_set1_ps(beta->imag);
|
|
1570
|
+
nk_size_t i = 0;
|
|
1571
|
+
for (; i + 4 <= n; i += 4) {
|
|
1572
|
+
__m256 a_f32x8 = _mm256_loadu_ps(a_f32 + 2 * i);
|
|
1573
|
+
__m256 b_f32x8 = _mm256_loadu_ps(b_f32 + 2 * i);
|
|
1574
|
+
__m256 c_f32x8 = _mm256_loadu_ps(c_f32 + 2 * i);
|
|
1575
|
+
__m256 b_swapped_f32x8 = _mm256_permute_ps(b_f32x8, 0xB1);
|
|
1576
|
+
__m256 a_real_f32x8 = _mm256_moveldup_ps(a_f32x8);
|
|
1577
|
+
__m256 a_imag_f32x8 = _mm256_movehdup_ps(a_f32x8);
|
|
1578
|
+
__m256 tab_f32x8 = _mm256_mul_ps(a_imag_f32x8, b_swapped_f32x8);
|
|
1579
|
+
__m256 ab_f32x8 = _mm256_fmaddsub_ps(a_real_f32x8, b_f32x8, tab_f32x8);
|
|
1580
|
+
__m256 ab_swapped_f32x8 = _mm256_permute_ps(ab_f32x8, 0xB1);
|
|
1581
|
+
__m256 taa_f32x8 = _mm256_mul_ps(alpha_imag_f32x8, ab_swapped_f32x8);
|
|
1582
|
+
__m256 ya_f32x8 = _mm256_fmaddsub_ps(alpha_real_f32x8, ab_f32x8, taa_f32x8);
|
|
1583
|
+
__m256 c_swapped_f32x8 = _mm256_permute_ps(c_f32x8, 0xB1);
|
|
1584
|
+
__m256 tbc_f32x8 = _mm256_mul_ps(beta_imag_f32x8, c_swapped_f32x8);
|
|
1585
|
+
__m256 yb_f32x8 = _mm256_fmaddsub_ps(beta_real_f32x8, c_f32x8, tbc_f32x8);
|
|
1586
|
+
_mm256_storeu_ps(result_f32 + 2 * i, _mm256_add_ps(ya_f32x8, yb_f32x8));
|
|
1587
|
+
}
|
|
1588
|
+
for (; i < n; i++) {
|
|
1589
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1590
|
+
nk_f32_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1591
|
+
nk_f32_t c_real = c[i].real, c_imag = c[i].imag;
|
|
1592
|
+
nk_f32_t ab_real = a_real * b_real - a_imag * b_imag;
|
|
1593
|
+
nk_f32_t ab_imag = a_real * b_imag + a_imag * b_real;
|
|
1594
|
+
nk_f32_t aab_real = alpha->real * ab_real - alpha->imag * ab_imag;
|
|
1595
|
+
nk_f32_t aab_imag = alpha->real * ab_imag + alpha->imag * ab_real;
|
|
1596
|
+
nk_f32_t bc_real = beta->real * c_real - beta->imag * c_imag;
|
|
1597
|
+
nk_f32_t bc_imag = beta->real * c_imag + beta->imag * c_real;
|
|
1598
|
+
result[i].real = aab_real + bc_real;
|
|
1599
|
+
result[i].imag = aab_imag + bc_imag;
|
|
1600
|
+
}
|
|
1601
|
+
}
|
|
1602
|
+
|
|
1603
|
+
NK_PUBLIC void nk_each_fma_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
1604
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result) {
|
|
1605
|
+
nk_f64_t const *a_f64 = (nk_f64_t const *)a;
|
|
1606
|
+
nk_f64_t const *b_f64 = (nk_f64_t const *)b;
|
|
1607
|
+
nk_f64_t const *c_f64 = (nk_f64_t const *)c;
|
|
1608
|
+
nk_f64_t *result_f64 = (nk_f64_t *)result;
|
|
1609
|
+
__m256d alpha_real_f64x4 = _mm256_set1_pd(alpha->real);
|
|
1610
|
+
__m256d alpha_imag_f64x4 = _mm256_set1_pd(alpha->imag);
|
|
1611
|
+
__m256d beta_real_f64x4 = _mm256_set1_pd(beta->real);
|
|
1612
|
+
__m256d beta_imag_f64x4 = _mm256_set1_pd(beta->imag);
|
|
1613
|
+
nk_size_t i = 0;
|
|
1614
|
+
for (; i + 2 <= n; i += 2) {
|
|
1615
|
+
__m256d a_f64x4 = _mm256_loadu_pd(a_f64 + 2 * i);
|
|
1616
|
+
__m256d b_f64x4 = _mm256_loadu_pd(b_f64 + 2 * i);
|
|
1617
|
+
__m256d c_f64x4 = _mm256_loadu_pd(c_f64 + 2 * i);
|
|
1618
|
+
__m256d b_swapped_f64x4 = _mm256_permute_pd(b_f64x4, 0x5);
|
|
1619
|
+
__m256d a_real_f64x4 = _mm256_movedup_pd(a_f64x4);
|
|
1620
|
+
__m256d a_imag_f64x4 = _mm256_permute_pd(a_f64x4, 0xF);
|
|
1621
|
+
__m256d tab_f64x4 = _mm256_mul_pd(a_imag_f64x4, b_swapped_f64x4);
|
|
1622
|
+
__m256d ab_f64x4 = _mm256_fmaddsub_pd(a_real_f64x4, b_f64x4, tab_f64x4);
|
|
1623
|
+
__m256d ab_swapped_f64x4 = _mm256_permute_pd(ab_f64x4, 0x5);
|
|
1624
|
+
__m256d taa_f64x4 = _mm256_mul_pd(alpha_imag_f64x4, ab_swapped_f64x4);
|
|
1625
|
+
__m256d ya_f64x4 = _mm256_fmaddsub_pd(alpha_real_f64x4, ab_f64x4, taa_f64x4);
|
|
1626
|
+
__m256d c_swapped_f64x4 = _mm256_permute_pd(c_f64x4, 0x5);
|
|
1627
|
+
__m256d tbc_f64x4 = _mm256_mul_pd(beta_imag_f64x4, c_swapped_f64x4);
|
|
1628
|
+
__m256d yb_f64x4 = _mm256_fmaddsub_pd(beta_real_f64x4, c_f64x4, tbc_f64x4);
|
|
1629
|
+
_mm256_storeu_pd(result_f64 + 2 * i, _mm256_add_pd(ya_f64x4, yb_f64x4));
|
|
1630
|
+
}
|
|
1631
|
+
for (; i < n; i++) {
|
|
1632
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1633
|
+
nk_f64_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1634
|
+
nk_f64_t c_real = c[i].real, c_imag = c[i].imag;
|
|
1635
|
+
nk_f64_t ab_real = a_real * b_real - a_imag * b_imag;
|
|
1636
|
+
nk_f64_t ab_imag = a_real * b_imag + a_imag * b_real;
|
|
1637
|
+
nk_f64_t aab_real = alpha->real * ab_real - alpha->imag * ab_imag;
|
|
1638
|
+
nk_f64_t aab_imag = alpha->real * ab_imag + alpha->imag * ab_real;
|
|
1639
|
+
nk_f64_t bc_real = beta->real * c_real - beta->imag * c_imag;
|
|
1640
|
+
nk_f64_t bc_imag = beta->real * c_imag + beta->imag * c_real;
|
|
1641
|
+
result[i].real = aab_real + bc_real;
|
|
1642
|
+
result[i].imag = aab_imag + bc_imag;
|
|
1643
|
+
}
|
|
1644
|
+
}
|
|
1645
|
+
|
|
1646
|
+
#if defined(__clang__)
|
|
1647
|
+
#pragma clang attribute pop
|
|
1648
|
+
#elif defined(__GNUC__)
|
|
1649
|
+
#pragma GCC pop_options
|
|
1650
|
+
#endif
|
|
1651
|
+
|
|
1652
|
+
#if defined(__cplusplus)
|
|
1653
|
+
} // extern "C"
|
|
1654
|
+
#endif
|
|
1655
|
+
|
|
1656
|
+
#endif // NK_TARGET_HASWELL
|
|
1657
|
+
#endif // NK_TARGET_X86_
|
|
1658
|
+
#endif // NK_EACH_HASWELL_H
|