numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Curved Space Similarity for Skylake.
|
|
3
|
+
* @file include/numkong/curved/skylake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 14, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements f32 and f64 bilinear forms and Mahalanobis distance using AVX-512:
|
|
10
|
+
* - f32 inputs accumulate in f64 to avoid catastrophic cancellation
|
|
11
|
+
* - f64 inputs use Dot2 algorithm (Ogita-Rump-Oishi 2005) for error compensation
|
|
12
|
+
*/
|
|
13
|
+
#ifndef NK_CURVED_SKYLAKE_H
|
|
14
|
+
#define NK_CURVED_SKYLAKE_H
|
|
15
|
+
|
|
16
|
+
#if NK_TARGET_X86_
|
|
17
|
+
#if NK_TARGET_SKYLAKE
|
|
18
|
+
|
|
19
|
+
#include "numkong/types.h"
|
|
20
|
+
#include "numkong/spatial/haswell.h" // `nk_f64_sqrt_haswell`
|
|
21
|
+
|
|
22
|
+
#if defined(__cplusplus)
|
|
23
|
+
extern "C" {
|
|
24
|
+
#endif
|
|
25
|
+
|
|
26
|
+
#if defined(__clang__)
|
|
27
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
|
|
28
|
+
apply_to = function)
|
|
29
|
+
#elif defined(__GNUC__)
|
|
30
|
+
#pragma GCC push_options
|
|
31
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
NK_PUBLIC void nk_bilinear_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
35
|
+
nk_f64_t *result) {
|
|
36
|
+
|
|
37
|
+
// Default case for arbitrary size `n`
|
|
38
|
+
nk_size_t const tail_length = n % 8;
|
|
39
|
+
nk_size_t const tail_start = n - tail_length;
|
|
40
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
41
|
+
__mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length);
|
|
42
|
+
|
|
43
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
44
|
+
__m512d a_f64x8 = _mm512_set1_pd((nk_f64_t)a[i]);
|
|
45
|
+
__m512d cb_j_f64x8 = _mm512_setzero_pd();
|
|
46
|
+
__m256 b_f32x8, c_f32x8;
|
|
47
|
+
nk_size_t j = 0;
|
|
48
|
+
|
|
49
|
+
nk_bilinear_f32_skylake_cycle:
|
|
50
|
+
if (j + 8 <= n) {
|
|
51
|
+
b_f32x8 = _mm256_loadu_ps(b + j);
|
|
52
|
+
c_f32x8 = _mm256_loadu_ps(c + i * n + j);
|
|
53
|
+
}
|
|
54
|
+
else {
|
|
55
|
+
b_f32x8 = _mm256_maskz_loadu_ps(tail_mask, b + tail_start);
|
|
56
|
+
c_f32x8 = _mm256_maskz_loadu_ps(tail_mask, c + i * n + tail_start);
|
|
57
|
+
}
|
|
58
|
+
cb_j_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(b_f32x8), _mm512_cvtps_pd(c_f32x8), cb_j_f64x8);
|
|
59
|
+
j += 8;
|
|
60
|
+
if (j < n) goto nk_bilinear_f32_skylake_cycle;
|
|
61
|
+
sum_f64x8 = _mm512_fmadd_pd(a_f64x8, cb_j_f64x8, sum_f64x8);
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
*result = _mm512_reduce_add_pd(sum_f64x8);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
NK_PUBLIC void nk_mahalanobis_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
68
|
+
nk_f64_t *result) {
|
|
69
|
+
// We use f64 accumulators to prevent catastrophic cancellation.
|
|
70
|
+
nk_size_t const tail_length = n % 8;
|
|
71
|
+
nk_size_t const tail_start = n - tail_length;
|
|
72
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
73
|
+
__mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length);
|
|
74
|
+
|
|
75
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
76
|
+
__m512d diff_i_f64x8 = _mm512_set1_pd((nk_f64_t)a[i] - (nk_f64_t)b[i]);
|
|
77
|
+
__m512d cdiff_j_f64x8 = _mm512_setzero_pd();
|
|
78
|
+
__m256 a_j_f32x8, b_j_f32x8, c_f32x8;
|
|
79
|
+
nk_size_t j = 0;
|
|
80
|
+
|
|
81
|
+
// The nested loop is cleaner to implement with a `goto` in this case:
|
|
82
|
+
nk_mahalanobis_f32_skylake_cycle:
|
|
83
|
+
if (j + 8 <= n) {
|
|
84
|
+
a_j_f32x8 = _mm256_loadu_ps(a + j);
|
|
85
|
+
b_j_f32x8 = _mm256_loadu_ps(b + j);
|
|
86
|
+
c_f32x8 = _mm256_loadu_ps(c + i * n + j);
|
|
87
|
+
}
|
|
88
|
+
else {
|
|
89
|
+
a_j_f32x8 = _mm256_maskz_loadu_ps(tail_mask, a + tail_start);
|
|
90
|
+
b_j_f32x8 = _mm256_maskz_loadu_ps(tail_mask, b + tail_start);
|
|
91
|
+
c_f32x8 = _mm256_maskz_loadu_ps(tail_mask, c + i * n + tail_start);
|
|
92
|
+
}
|
|
93
|
+
__m512d diff_j_f64x8 = _mm512_sub_pd(_mm512_cvtps_pd(a_j_f32x8), _mm512_cvtps_pd(b_j_f32x8));
|
|
94
|
+
cdiff_j_f64x8 = _mm512_fmadd_pd(diff_j_f64x8, _mm512_cvtps_pd(c_f32x8), cdiff_j_f64x8);
|
|
95
|
+
j += 8;
|
|
96
|
+
if (j < n) goto nk_mahalanobis_f32_skylake_cycle;
|
|
97
|
+
sum_f64x8 = _mm512_fmadd_pd(diff_i_f64x8, cdiff_j_f64x8, sum_f64x8);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
nk_f64_t quadratic = _mm512_reduce_add_pd(sum_f64x8);
|
|
101
|
+
*result = nk_f64_sqrt_haswell(quadratic > 0 ? quadratic : 0);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
NK_PUBLIC void nk_bilinear_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
105
|
+
nk_f64c_t *results) {
|
|
106
|
+
|
|
107
|
+
// We take into account, that FMS is the same as FMA with a negative multiplier.
|
|
108
|
+
// To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
|
|
109
|
+
// This way we can avoid the shuffling and the need for separate real and imaginary parts.
|
|
110
|
+
// For the imaginary part of the product, we would need to swap the real and imaginary parts of
|
|
111
|
+
// one of the vectors. We use f64 accumulators to prevent catastrophic cancellation.
|
|
112
|
+
__m512i const sign_flip_i64x8 = _mm512_set_epi64( //
|
|
113
|
+
0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, //
|
|
114
|
+
0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 //
|
|
115
|
+
);
|
|
116
|
+
|
|
117
|
+
// Default case for arbitrary size `n`
|
|
118
|
+
nk_size_t const tail_length = n % 4;
|
|
119
|
+
nk_size_t const tail_start = n - tail_length;
|
|
120
|
+
__mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length * 2);
|
|
121
|
+
nk_f64_t sum_real = 0;
|
|
122
|
+
nk_f64_t sum_imag = 0;
|
|
123
|
+
|
|
124
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
125
|
+
nk_f64_t const a_i_real = (nk_f64_t)a[i].real;
|
|
126
|
+
nk_f64_t const a_i_imag = (nk_f64_t)a[i].imag;
|
|
127
|
+
__m512d cb_j_real_f64x8 = _mm512_setzero_pd();
|
|
128
|
+
__m512d cb_j_imag_f64x8 = _mm512_setzero_pd();
|
|
129
|
+
__m256 b_f32x8, c_f32x8;
|
|
130
|
+
nk_size_t j = 0;
|
|
131
|
+
|
|
132
|
+
nk_bilinear_f32c_skylake_cycle:
|
|
133
|
+
if (j + 4 <= n) {
|
|
134
|
+
b_f32x8 = _mm256_loadu_ps((nk_f32_t const *)(b + j));
|
|
135
|
+
c_f32x8 = _mm256_loadu_ps((nk_f32_t const *)(c + i * n + j));
|
|
136
|
+
}
|
|
137
|
+
else {
|
|
138
|
+
b_f32x8 = _mm256_maskz_loadu_ps(tail_mask, (nk_f32_t const *)(b + tail_start));
|
|
139
|
+
c_f32x8 = _mm256_maskz_loadu_ps(tail_mask, (nk_f32_t const *)(c + i * n + tail_start));
|
|
140
|
+
}
|
|
141
|
+
__m512d b_f64x8 = _mm512_cvtps_pd(b_f32x8);
|
|
142
|
+
__m512d c_f64x8 = _mm512_cvtps_pd(c_f32x8);
|
|
143
|
+
// The real part of the product: b.real * c.real - b.imag * c.imag.
|
|
144
|
+
// The subtraction will be performed later with a sign flip.
|
|
145
|
+
cb_j_real_f64x8 = _mm512_fmadd_pd(c_f64x8, b_f64x8, cb_j_real_f64x8);
|
|
146
|
+
// The imaginary part of the product: b.real * c.imag + b.imag * c.real.
|
|
147
|
+
// Swap the imaginary and real parts of `c` before multiplication:
|
|
148
|
+
c_f64x8 = _mm512_permute_pd(c_f64x8, 0x55); //? Same as 0b01010101. Swap adjacent entries within each pair
|
|
149
|
+
cb_j_imag_f64x8 = _mm512_fmadd_pd(c_f64x8, b_f64x8, cb_j_imag_f64x8);
|
|
150
|
+
j += 4;
|
|
151
|
+
if (j < n) goto nk_bilinear_f32c_skylake_cycle;
|
|
152
|
+
// Flip the sign bit in every second scalar before accumulation:
|
|
153
|
+
cb_j_real_f64x8 = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(cb_j_real_f64x8), sign_flip_i64x8));
|
|
154
|
+
// Horizontal sums are the expensive part of the computation:
|
|
155
|
+
nk_f64_t const cb_j_real = _mm512_reduce_add_pd(cb_j_real_f64x8);
|
|
156
|
+
nk_f64_t const cb_j_imag = _mm512_reduce_add_pd(cb_j_imag_f64x8);
|
|
157
|
+
sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag;
|
|
158
|
+
sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
// Reduce horizontal sums:
|
|
162
|
+
results->real = sum_real;
|
|
163
|
+
results->imag = sum_imag;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
NK_PUBLIC void nk_bilinear_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
167
|
+
nk_f64_t *result) {
|
|
168
|
+
|
|
169
|
+
// Default case for arbitrary size `n`
|
|
170
|
+
// Using Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated summation.
|
|
171
|
+
nk_size_t const tail_length = n % 8;
|
|
172
|
+
nk_size_t const tail_start = n - tail_length;
|
|
173
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
174
|
+
__m512d compensation_f64x8 = _mm512_setzero_pd();
|
|
175
|
+
__mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length);
|
|
176
|
+
|
|
177
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
178
|
+
__m512d a_f64x8 = _mm512_set1_pd(a[i]);
|
|
179
|
+
__m512d cb_j_f64x8 = _mm512_setzero_pd();
|
|
180
|
+
__m512d inner_compensation_f64x8 = _mm512_setzero_pd();
|
|
181
|
+
__m512d b_f64x8, c_f64x8;
|
|
182
|
+
nk_size_t j = 0;
|
|
183
|
+
|
|
184
|
+
nk_bilinear_f64_skylake_cycle:
|
|
185
|
+
if (j + 8 <= n) {
|
|
186
|
+
b_f64x8 = _mm512_loadu_pd(b + j);
|
|
187
|
+
c_f64x8 = _mm512_loadu_pd(c + i * n + j);
|
|
188
|
+
}
|
|
189
|
+
else {
|
|
190
|
+
b_f64x8 = _mm512_maskz_loadu_pd(tail_mask, b + tail_start);
|
|
191
|
+
c_f64x8 = _mm512_maskz_loadu_pd(tail_mask, c + i * n + tail_start);
|
|
192
|
+
}
|
|
193
|
+
// Inner loop Dot2: accumulate cb_j = sum(b[j] * c[i,j])
|
|
194
|
+
// TwoProd: product = b * c, product_error = fma(b, c, -product)
|
|
195
|
+
{
|
|
196
|
+
__m512d product_f64x8 = _mm512_mul_pd(b_f64x8, c_f64x8);
|
|
197
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(b_f64x8, c_f64x8, product_f64x8);
|
|
198
|
+
// TwoSum: t = cb_j + product
|
|
199
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(cb_j_f64x8, product_f64x8);
|
|
200
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, cb_j_f64x8);
|
|
201
|
+
__m512d sum_error_f64x8 = _mm512_add_pd(
|
|
202
|
+
_mm512_sub_pd(cb_j_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
203
|
+
_mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
|
|
204
|
+
cb_j_f64x8 = tentative_sum_f64x8;
|
|
205
|
+
inner_compensation_f64x8 = _mm512_add_pd(inner_compensation_f64x8,
|
|
206
|
+
_mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
207
|
+
}
|
|
208
|
+
j += 8;
|
|
209
|
+
if (j < n) goto nk_bilinear_f64_skylake_cycle;
|
|
210
|
+
|
|
211
|
+
// Combine inner sum with compensation before outer accumulation
|
|
212
|
+
cb_j_f64x8 = _mm512_add_pd(cb_j_f64x8, inner_compensation_f64x8);
|
|
213
|
+
|
|
214
|
+
// Outer loop Dot2: accumulate sum += a[i] * cb_j
|
|
215
|
+
// TwoProd: product = a * cb_j, product_error = fma(a, cb_j, -product)
|
|
216
|
+
{
|
|
217
|
+
__m512d product_f64x8 = _mm512_mul_pd(a_f64x8, cb_j_f64x8);
|
|
218
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(a_f64x8, cb_j_f64x8, product_f64x8);
|
|
219
|
+
// TwoSum: t = sum + product
|
|
220
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(sum_f64x8, product_f64x8);
|
|
221
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, sum_f64x8);
|
|
222
|
+
__m512d sum_error_f64x8 = _mm512_add_pd(
|
|
223
|
+
_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
224
|
+
_mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
|
|
225
|
+
sum_f64x8 = tentative_sum_f64x8;
|
|
226
|
+
compensation_f64x8 = _mm512_add_pd(compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// Final: combine sum + compensation before reduce
|
|
231
|
+
*result = _mm512_reduce_add_pd(_mm512_add_pd(sum_f64x8, compensation_f64x8));
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
NK_PUBLIC void nk_mahalanobis_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
235
|
+
nk_f64_t *result) {
|
|
236
|
+
// Using Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated summation.
|
|
237
|
+
nk_size_t const tail_length = n % 8;
|
|
238
|
+
nk_size_t const tail_start = n - tail_length;
|
|
239
|
+
__mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length);
|
|
240
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
241
|
+
__m512d compensation_f64x8 = _mm512_setzero_pd();
|
|
242
|
+
|
|
243
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
244
|
+
__m512d diff_i_f64x8 = _mm512_set1_pd(a[i] - b[i]);
|
|
245
|
+
__m512d cdiff_j_f64x8 = _mm512_setzero_pd();
|
|
246
|
+
__m512d inner_compensation_f64x8 = _mm512_setzero_pd();
|
|
247
|
+
__m512d a_j_f64x8, b_j_f64x8, diff_j_f64x8, c_f64x8;
|
|
248
|
+
nk_size_t j = 0;
|
|
249
|
+
|
|
250
|
+
// The nested loop is cleaner to implement with a `goto` in this case:
|
|
251
|
+
nk_mahalanobis_f64_skylake_cycle:
|
|
252
|
+
if (j + 8 <= n) {
|
|
253
|
+
a_j_f64x8 = _mm512_loadu_pd(a + j);
|
|
254
|
+
b_j_f64x8 = _mm512_loadu_pd(b + j);
|
|
255
|
+
c_f64x8 = _mm512_loadu_pd(c + i * n + j);
|
|
256
|
+
}
|
|
257
|
+
else {
|
|
258
|
+
a_j_f64x8 = _mm512_maskz_loadu_pd(tail_mask, a + tail_start);
|
|
259
|
+
b_j_f64x8 = _mm512_maskz_loadu_pd(tail_mask, b + tail_start);
|
|
260
|
+
c_f64x8 = _mm512_maskz_loadu_pd(tail_mask, c + i * n + tail_start);
|
|
261
|
+
}
|
|
262
|
+
diff_j_f64x8 = _mm512_sub_pd(a_j_f64x8, b_j_f64x8);
|
|
263
|
+
|
|
264
|
+
// Inner loop Dot2: accumulate cdiff_j = sum(diff_j * c[i,j])
|
|
265
|
+
// TwoProd: product = diff_j * c, product_error = fma(diff_j, c, -product)
|
|
266
|
+
{
|
|
267
|
+
__m512d product_f64x8 = _mm512_mul_pd(diff_j_f64x8, c_f64x8);
|
|
268
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(diff_j_f64x8, c_f64x8, product_f64x8);
|
|
269
|
+
// TwoSum: t = cdiff_j + product
|
|
270
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(cdiff_j_f64x8, product_f64x8);
|
|
271
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, cdiff_j_f64x8);
|
|
272
|
+
__m512d sum_error_f64x8 = _mm512_add_pd(
|
|
273
|
+
_mm512_sub_pd(cdiff_j_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
274
|
+
_mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
|
|
275
|
+
cdiff_j_f64x8 = tentative_sum_f64x8;
|
|
276
|
+
inner_compensation_f64x8 = _mm512_add_pd(inner_compensation_f64x8,
|
|
277
|
+
_mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
278
|
+
}
|
|
279
|
+
j += 8;
|
|
280
|
+
if (j < n) goto nk_mahalanobis_f64_skylake_cycle;
|
|
281
|
+
|
|
282
|
+
// Combine inner sum with compensation before outer accumulation
|
|
283
|
+
cdiff_j_f64x8 = _mm512_add_pd(cdiff_j_f64x8, inner_compensation_f64x8);
|
|
284
|
+
|
|
285
|
+
// Outer loop Dot2: accumulate sum += diff_i * cdiff_j
|
|
286
|
+
// TwoProd: product = diff_i * cdiff_j, product_error = fma(diff_i, cdiff_j, -product)
|
|
287
|
+
{
|
|
288
|
+
__m512d product_f64x8 = _mm512_mul_pd(diff_i_f64x8, cdiff_j_f64x8);
|
|
289
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(diff_i_f64x8, cdiff_j_f64x8, product_f64x8);
|
|
290
|
+
// TwoSum: t = sum + product
|
|
291
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(sum_f64x8, product_f64x8);
|
|
292
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, sum_f64x8);
|
|
293
|
+
__m512d sum_error_f64x8 = _mm512_add_pd(
|
|
294
|
+
_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
295
|
+
_mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
|
|
296
|
+
sum_f64x8 = tentative_sum_f64x8;
|
|
297
|
+
compensation_f64x8 = _mm512_add_pd(compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
298
|
+
}
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
// Final: combine sum + compensation before reduce
|
|
302
|
+
nk_f64_t quadratic = _mm512_reduce_add_pd(_mm512_add_pd(sum_f64x8, compensation_f64x8));
|
|
303
|
+
*result = nk_f64_sqrt_haswell(quadratic > 0 ? quadratic : 0);
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
NK_PUBLIC void nk_bilinear_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
307
|
+
nk_f64c_t *results) {
|
|
308
|
+
|
|
309
|
+
// We take into account, that FMS is the same as FMA with a negative multiplier.
|
|
310
|
+
// To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
|
|
311
|
+
// This way we can avoid the shuffling and the need for separate real and imaginary parts.
|
|
312
|
+
// For the imaginary part of the product, we would need to swap the real and imaginary parts of
|
|
313
|
+
// one of the vectors.
|
|
314
|
+
// Using Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated summation.
|
|
315
|
+
__m512i const sign_flip_i64x8 = _mm512_set_epi64( //
|
|
316
|
+
0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, //
|
|
317
|
+
0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 //
|
|
318
|
+
);
|
|
319
|
+
|
|
320
|
+
// Default case for arbitrary size `n`
|
|
321
|
+
nk_size_t const tail_length = n % 4;
|
|
322
|
+
nk_size_t const tail_start = n - tail_length;
|
|
323
|
+
__mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length * 2);
|
|
324
|
+
nk_f64_t sum_real = 0;
|
|
325
|
+
nk_f64_t sum_imag = 0;
|
|
326
|
+
nk_f64_t compensation_real = 0;
|
|
327
|
+
nk_f64_t compensation_imag = 0;
|
|
328
|
+
|
|
329
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
330
|
+
nk_f64_t const a_i_real = a[i].real;
|
|
331
|
+
nk_f64_t const a_i_imag = a[i].imag;
|
|
332
|
+
__m512d cb_j_real_f64x8 = _mm512_setzero_pd();
|
|
333
|
+
__m512d cb_j_imag_f64x8 = _mm512_setzero_pd();
|
|
334
|
+
__m512d compensation_real_f64x8 = _mm512_setzero_pd();
|
|
335
|
+
__m512d compensation_imag_f64x8 = _mm512_setzero_pd();
|
|
336
|
+
__m512d b_f64x8, c_f64x8;
|
|
337
|
+
nk_size_t j = 0;
|
|
338
|
+
|
|
339
|
+
nk_bilinear_f64c_skylake_cycle:
|
|
340
|
+
if (j + 4 <= n) {
|
|
341
|
+
b_f64x8 = _mm512_loadu_pd((nk_f64_t const *)(b + j));
|
|
342
|
+
c_f64x8 = _mm512_loadu_pd((nk_f64_t const *)(c + i * n + j));
|
|
343
|
+
}
|
|
344
|
+
else {
|
|
345
|
+
b_f64x8 = _mm512_maskz_loadu_pd(tail_mask, (nk_f64_t const *)(b + tail_start));
|
|
346
|
+
c_f64x8 = _mm512_maskz_loadu_pd(tail_mask, (nk_f64_t const *)(c + i * n + tail_start));
|
|
347
|
+
}
|
|
348
|
+
// The real part of the product: b.real * c.real - b.imag * c.imag.
|
|
349
|
+
// The subtraction will be performed later with a sign flip.
|
|
350
|
+
// Inner loop Dot2 for real accumulator
|
|
351
|
+
{
|
|
352
|
+
__m512d product_f64x8 = _mm512_mul_pd(c_f64x8, b_f64x8);
|
|
353
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(c_f64x8, b_f64x8, product_f64x8);
|
|
354
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(cb_j_real_f64x8, product_f64x8);
|
|
355
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, cb_j_real_f64x8);
|
|
356
|
+
__m512d sum_error_f64x8 = _mm512_add_pd(
|
|
357
|
+
_mm512_sub_pd(cb_j_real_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
358
|
+
_mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
|
|
359
|
+
cb_j_real_f64x8 = tentative_sum_f64x8;
|
|
360
|
+
compensation_real_f64x8 = _mm512_add_pd(compensation_real_f64x8,
|
|
361
|
+
_mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
362
|
+
}
|
|
363
|
+
// The imaginary part of the product: b.real * c.imag + b.imag * c.real.
|
|
364
|
+
// Swap the imaginary and real parts of `c` before multiplication:
|
|
365
|
+
c_f64x8 = _mm512_permute_pd(c_f64x8, 0x55); //? Same as 0b01010101.
|
|
366
|
+
// Inner loop Dot2 for imaginary accumulator
|
|
367
|
+
{
|
|
368
|
+
__m512d product_f64x8 = _mm512_mul_pd(c_f64x8, b_f64x8);
|
|
369
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(c_f64x8, b_f64x8, product_f64x8);
|
|
370
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(cb_j_imag_f64x8, product_f64x8);
|
|
371
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, cb_j_imag_f64x8);
|
|
372
|
+
__m512d sum_error_f64x8 = _mm512_add_pd(
|
|
373
|
+
_mm512_sub_pd(cb_j_imag_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
|
|
374
|
+
_mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
|
|
375
|
+
cb_j_imag_f64x8 = tentative_sum_f64x8;
|
|
376
|
+
compensation_imag_f64x8 = _mm512_add_pd(compensation_imag_f64x8,
|
|
377
|
+
_mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
378
|
+
}
|
|
379
|
+
j += 4;
|
|
380
|
+
if (j < n) goto nk_bilinear_f64c_skylake_cycle;
|
|
381
|
+
|
|
382
|
+
// Flip the sign bit in every second scalar before accumulation:
|
|
383
|
+
cb_j_real_f64x8 = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(cb_j_real_f64x8), sign_flip_i64x8));
|
|
384
|
+
compensation_real_f64x8 = _mm512_castsi512_pd(
|
|
385
|
+
_mm512_xor_si512(_mm512_castpd_si512(compensation_real_f64x8), sign_flip_i64x8));
|
|
386
|
+
|
|
387
|
+
// Combine inner sums with compensation before horizontal reduce
|
|
388
|
+
cb_j_real_f64x8 = _mm512_add_pd(cb_j_real_f64x8, compensation_real_f64x8);
|
|
389
|
+
cb_j_imag_f64x8 = _mm512_add_pd(cb_j_imag_f64x8, compensation_imag_f64x8);
|
|
390
|
+
|
|
391
|
+
// Horizontal sums are the expensive part of the computation:
|
|
392
|
+
nk_f64_t const cb_j_real = _mm512_reduce_add_pd(cb_j_real_f64x8);
|
|
393
|
+
nk_f64_t const cb_j_imag = _mm512_reduce_add_pd(cb_j_imag_f64x8);
|
|
394
|
+
|
|
395
|
+
// Outer loop Dot2 for real part: sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag
|
|
396
|
+
{
|
|
397
|
+
// First term: a_i_real * cb_j_real
|
|
398
|
+
nk_f64_t product1 = a_i_real * cb_j_real;
|
|
399
|
+
nk_f64_t product_error1 = (a_i_real * cb_j_real) - product1;
|
|
400
|
+
// Second term: -a_i_imag * cb_j_imag
|
|
401
|
+
nk_f64_t product2 = a_i_imag * cb_j_imag;
|
|
402
|
+
nk_f64_t product_error2 = (a_i_imag * cb_j_imag) - product2;
|
|
403
|
+
// TwoSum for first addition: t = sum_real + product1
|
|
404
|
+
nk_f64_t t1 = sum_real + product1;
|
|
405
|
+
nk_f64_t z1 = t1 - sum_real;
|
|
406
|
+
nk_f64_t sum_error1 = (sum_real - (t1 - z1)) + (product1 - z1);
|
|
407
|
+
sum_real = t1;
|
|
408
|
+
compensation_real += sum_error1 + product_error1;
|
|
409
|
+
// TwoSum for subtraction: t = sum_real - product2
|
|
410
|
+
nk_f64_t t2 = sum_real - product2;
|
|
411
|
+
nk_f64_t z2 = t2 - sum_real;
|
|
412
|
+
nk_f64_t sum_error2 = (sum_real - (t2 - z2)) + (-product2 - z2);
|
|
413
|
+
sum_real = t2;
|
|
414
|
+
compensation_real += sum_error2 - product_error2;
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
// Outer loop Dot2 for imaginary part: sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real
|
|
418
|
+
{
|
|
419
|
+
// First term: a_i_real * cb_j_imag
|
|
420
|
+
nk_f64_t product1 = a_i_real * cb_j_imag;
|
|
421
|
+
nk_f64_t product_error1 = (a_i_real * cb_j_imag) - product1;
|
|
422
|
+
// Second term: a_i_imag * cb_j_real
|
|
423
|
+
nk_f64_t product2 = a_i_imag * cb_j_real;
|
|
424
|
+
nk_f64_t product_error2 = (a_i_imag * cb_j_real) - product2;
|
|
425
|
+
// TwoSum for first addition: t = sum_imag + product1
|
|
426
|
+
nk_f64_t t1 = sum_imag + product1;
|
|
427
|
+
nk_f64_t z1 = t1 - sum_imag;
|
|
428
|
+
nk_f64_t sum_error1 = (sum_imag - (t1 - z1)) + (product1 - z1);
|
|
429
|
+
sum_imag = t1;
|
|
430
|
+
compensation_imag += sum_error1 + product_error1;
|
|
431
|
+
// TwoSum for second addition: t = sum_imag + product2
|
|
432
|
+
nk_f64_t t2 = sum_imag + product2;
|
|
433
|
+
nk_f64_t z2 = t2 - sum_imag;
|
|
434
|
+
nk_f64_t sum_error2 = (sum_imag - (t2 - z2)) + (product2 - z2);
|
|
435
|
+
sum_imag = t2;
|
|
436
|
+
compensation_imag += sum_error2 + product_error2;
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
// Final: combine sum + compensation
|
|
441
|
+
results->real = sum_real + compensation_real;
|
|
442
|
+
results->imag = sum_imag + compensation_imag;
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
#if defined(__clang__)
|
|
446
|
+
#pragma clang attribute pop
|
|
447
|
+
#elif defined(__GNUC__)
|
|
448
|
+
#pragma GCC pop_options
|
|
449
|
+
#endif
|
|
450
|
+
|
|
451
|
+
#if defined(__cplusplus)
|
|
452
|
+
} // extern "C"
|
|
453
|
+
#endif
|
|
454
|
+
|
|
455
|
+
#endif // NK_TARGET_SKYLAKE
|
|
456
|
+
#endif // NK_TARGET_X86_
|
|
457
|
+
#endif // NK_CURVED_SKYLAKE_H
|