numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Curved Space Similarity for Genoa.
|
|
3
|
+
* @file include/numkong/curved/genoa.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 14, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements bf16 bilinear forms using AVX-512 with BF16 extensions.
|
|
10
|
+
*/
|
|
11
|
+
#ifndef NK_CURVED_GENOA_H
|
|
12
|
+
#define NK_CURVED_GENOA_H
|
|
13
|
+
|
|
14
|
+
#if NK_TARGET_X86_
|
|
15
|
+
#if NK_TARGET_GENOA
|
|
16
|
+
|
|
17
|
+
#include "numkong/types.h"
|
|
18
|
+
#include "numkong/spatial/genoa.h" // `nk_substract_bf16x32_genoa_`
|
|
19
|
+
#include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
|
|
20
|
+
|
|
21
|
+
#if defined(__cplusplus)
|
|
22
|
+
extern "C" {
|
|
23
|
+
#endif
|
|
24
|
+
|
|
25
|
+
#if defined(__clang__)
|
|
26
|
+
#pragma clang attribute push( \
|
|
27
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512bf16,f16c,fma,bmi,bmi2"))), \
|
|
28
|
+
apply_to = function)
|
|
29
|
+
#elif defined(__GNUC__)
|
|
30
|
+
#pragma GCC push_options
|
|
31
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512bf16", "f16c", "fma", "bmi", "bmi2")
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
NK_PUBLIC void nk_bilinear_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
35
|
+
nk_f32_t *result) {
|
|
36
|
+
nk_size_t const tail_length = n % 32;
|
|
37
|
+
nk_size_t const tail_start = n - tail_length;
|
|
38
|
+
__mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length);
|
|
39
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
40
|
+
|
|
41
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
42
|
+
nk_f32_t a_f32;
|
|
43
|
+
nk_bf16_to_f32_serial(a + i, &a_f32);
|
|
44
|
+
__m512 a_f32x16 = _mm512_set1_ps(a_f32);
|
|
45
|
+
__m512 cb_j_f32x16 = _mm512_setzero_ps();
|
|
46
|
+
__m512i b_bf16x32, c_bf16x32;
|
|
47
|
+
nk_size_t j = 0;
|
|
48
|
+
|
|
49
|
+
nk_bilinear_bf16_genoa_cycle:
|
|
50
|
+
if (j + 32 <= n) {
|
|
51
|
+
b_bf16x32 = _mm512_loadu_epi16(b + j);
|
|
52
|
+
c_bf16x32 = _mm512_loadu_epi16(c + i * n + j);
|
|
53
|
+
}
|
|
54
|
+
else {
|
|
55
|
+
b_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start);
|
|
56
|
+
c_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start);
|
|
57
|
+
}
|
|
58
|
+
cb_j_f32x16 = _mm512_dpbf16_ps(cb_j_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(c_bf16x32));
|
|
59
|
+
j += 32;
|
|
60
|
+
if (j < n) goto nk_bilinear_bf16_genoa_cycle;
|
|
61
|
+
sum_f32x16 = _mm512_fmadd_ps(a_f32x16, cb_j_f32x16, sum_f32x16);
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
*result = _mm512_reduce_add_ps(sum_f32x16);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
NK_PUBLIC void nk_mahalanobis_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
68
|
+
nk_f32_t *result) {
|
|
69
|
+
nk_size_t const tail_length = n % 32;
|
|
70
|
+
nk_size_t const tail_start = n - tail_length;
|
|
71
|
+
__mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length);
|
|
72
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
73
|
+
|
|
74
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
75
|
+
nk_f32_t a_i, b_i;
|
|
76
|
+
nk_bf16_to_f32_serial(a + i, &a_i);
|
|
77
|
+
nk_bf16_to_f32_serial(b + i, &b_i);
|
|
78
|
+
__m512 diff_i_f32x16 = _mm512_set1_ps(a_i - b_i);
|
|
79
|
+
__m512 cdiff_j_f32x16 = _mm512_setzero_ps();
|
|
80
|
+
__m512i a_j_bf16x32, b_j_bf16x32, diff_j_bf16x32, c_bf16x32;
|
|
81
|
+
nk_size_t j = 0;
|
|
82
|
+
|
|
83
|
+
// The nested loop is cleaner to implement with a `goto` in this case:
|
|
84
|
+
nk_mahalanobis_bf16_genoa_cycle:
|
|
85
|
+
if (j + 32 <= n) {
|
|
86
|
+
a_j_bf16x32 = _mm512_loadu_epi16(a + j);
|
|
87
|
+
b_j_bf16x32 = _mm512_loadu_epi16(b + j);
|
|
88
|
+
c_bf16x32 = _mm512_loadu_epi16(c + i * n + j);
|
|
89
|
+
}
|
|
90
|
+
else {
|
|
91
|
+
a_j_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, a + tail_start);
|
|
92
|
+
b_j_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start);
|
|
93
|
+
c_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start);
|
|
94
|
+
}
|
|
95
|
+
diff_j_bf16x32 = nk_substract_bf16x32_genoa_(a_j_bf16x32, b_j_bf16x32);
|
|
96
|
+
cdiff_j_f32x16 = _mm512_dpbf16_ps(cdiff_j_f32x16, nk_m512bh_from_m512i_(diff_j_bf16x32),
|
|
97
|
+
nk_m512bh_from_m512i_(c_bf16x32));
|
|
98
|
+
j += 32;
|
|
99
|
+
if (j < n) goto nk_mahalanobis_bf16_genoa_cycle;
|
|
100
|
+
sum_f32x16 = _mm512_fmadd_ps(diff_i_f32x16, cdiff_j_f32x16, sum_f32x16);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
nk_f32_t quadratic = _mm512_reduce_add_ps(sum_f32x16);
|
|
104
|
+
*result = nk_f32_sqrt_haswell(quadratic > 0 ? quadratic : 0);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
NK_PUBLIC void nk_bilinear_bf16c_genoa(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
|
|
108
|
+
nk_f32c_t *results) {
|
|
109
|
+
|
|
110
|
+
// We take into account, that FMS is the same as FMA with a negative multiplier.
|
|
111
|
+
// To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
|
|
112
|
+
// This way we can avoid the shuffling and the need for separate real and imaginary parts.
|
|
113
|
+
// For the imaginary part of the product, we would need to swap the real and imaginary parts of
|
|
114
|
+
// one of the vectors.
|
|
115
|
+
__m512i const sign_flip_i32x16 = _mm512_set1_epi32(0x80000000);
|
|
116
|
+
__m512i const swap_adjacent_i8x64 = _mm512_set_epi8( //
|
|
117
|
+
61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane
|
|
118
|
+
45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane
|
|
119
|
+
29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane
|
|
120
|
+
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane
|
|
121
|
+
);
|
|
122
|
+
|
|
123
|
+
// Default case for arbitrary size `n`
|
|
124
|
+
nk_size_t const tail_length = n % 16;
|
|
125
|
+
nk_size_t const tail_start = n - tail_length;
|
|
126
|
+
__mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length * 2);
|
|
127
|
+
nk_f32_t sum_real = 0;
|
|
128
|
+
nk_f32_t sum_imag = 0;
|
|
129
|
+
|
|
130
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
131
|
+
nk_f32_t a_i_real, a_i_imag;
|
|
132
|
+
nk_bf16_to_f32_serial(&a[i].real, &a_i_real);
|
|
133
|
+
nk_bf16_to_f32_serial(&a[i].imag, &a_i_imag);
|
|
134
|
+
__m512 cb_j_real_f32x16 = _mm512_setzero_ps();
|
|
135
|
+
__m512 cb_j_imag_f32x16 = _mm512_setzero_ps();
|
|
136
|
+
__m512i b_bf16x32, c_bf16x32;
|
|
137
|
+
nk_size_t j = 0;
|
|
138
|
+
|
|
139
|
+
nk_bilinear_bf16c_skylake_cycle:
|
|
140
|
+
if (j + 16 <= n) {
|
|
141
|
+
b_bf16x32 = _mm512_loadu_epi16((nk_i16_t const *)(b + j));
|
|
142
|
+
c_bf16x32 = _mm512_loadu_epi16((nk_i16_t const *)(c + i * n + j));
|
|
143
|
+
}
|
|
144
|
+
else {
|
|
145
|
+
b_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (nk_i16_t const *)(b + tail_start));
|
|
146
|
+
c_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (nk_i16_t const *)(c + i * n + tail_start));
|
|
147
|
+
}
|
|
148
|
+
cb_j_real_f32x16 = _mm512_dpbf16_ps( //
|
|
149
|
+
cb_j_real_f32x16, //
|
|
150
|
+
nk_m512bh_from_m512i_(_mm512_xor_si512(c_bf16x32, sign_flip_i32x16)), //
|
|
151
|
+
nk_m512bh_from_m512i_(b_bf16x32));
|
|
152
|
+
cb_j_imag_f32x16 = _mm512_dpbf16_ps( //
|
|
153
|
+
cb_j_imag_f32x16, //
|
|
154
|
+
nk_m512bh_from_m512i_(_mm512_shuffle_epi8(c_bf16x32, swap_adjacent_i8x64)), //
|
|
155
|
+
nk_m512bh_from_m512i_(b_bf16x32));
|
|
156
|
+
j += 16;
|
|
157
|
+
if (j < n) goto nk_bilinear_bf16c_skylake_cycle;
|
|
158
|
+
// Horizontal sums are the expensive part of the computation:
|
|
159
|
+
nk_f32_t const cb_j_real = nk_reduce_add_f32x16_skylake_(cb_j_real_f32x16);
|
|
160
|
+
nk_f32_t const cb_j_imag = nk_reduce_add_f32x16_skylake_(cb_j_imag_f32x16);
|
|
161
|
+
sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag;
|
|
162
|
+
sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
// Reduce horizontal sums:
|
|
166
|
+
results->real = sum_real;
|
|
167
|
+
results->imag = sum_imag;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
#if defined(__clang__)
|
|
171
|
+
#pragma clang attribute pop
|
|
172
|
+
#elif defined(__GNUC__)
|
|
173
|
+
#pragma GCC pop_options
|
|
174
|
+
#endif
|
|
175
|
+
|
|
176
|
+
#if defined(__cplusplus)
|
|
177
|
+
} // extern "C"
|
|
178
|
+
#endif
|
|
179
|
+
|
|
180
|
+
#endif // NK_TARGET_GENOA
|
|
181
|
+
#endif // NK_TARGET_X86_
|
|
182
|
+
#endif // NK_CURVED_GENOA_H
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Curved Space Similarity for Haswell.
|
|
3
|
+
* @file include/numkong/curved/haswell.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 14, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements f16 and bf16 bilinear forms using AVX2 with F16C conversion.
|
|
10
|
+
*/
|
|
11
|
+
#ifndef NK_CURVED_HASWELL_H
|
|
12
|
+
#define NK_CURVED_HASWELL_H
|
|
13
|
+
|
|
14
|
+
#if NK_TARGET_X86_
|
|
15
|
+
#if NK_TARGET_HASWELL
|
|
16
|
+
|
|
17
|
+
#include "numkong/types.h"
|
|
18
|
+
#include "numkong/reduce/haswell.h" // `nk_reduce_add_f32x8_haswell_`
|
|
19
|
+
#include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`
|
|
20
|
+
|
|
21
|
+
#if defined(__cplusplus)
|
|
22
|
+
extern "C" {
|
|
23
|
+
#endif
|
|
24
|
+
|
|
25
|
+
#if defined(__clang__)
|
|
26
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
|
|
27
|
+
#elif defined(__GNUC__)
|
|
28
|
+
#pragma GCC push_options
|
|
29
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
|
|
30
|
+
#endif
|
|
31
|
+
|
|
32
|
+
NK_PUBLIC void nk_bilinear_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
33
|
+
nk_f64_t *result) {
|
|
34
|
+
nk_size_t const tail_length = n % 4;
|
|
35
|
+
nk_size_t const tail_start = n - tail_length;
|
|
36
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
37
|
+
|
|
38
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
39
|
+
__m256d a_f64x4 = _mm256_set1_pd((nk_f64_t)a[i]);
|
|
40
|
+
__m256d cb_j_f64x4 = _mm256_setzero_pd();
|
|
41
|
+
for (nk_size_t j = 0; j + 4 <= n; j += 4) {
|
|
42
|
+
__m256d b_f64x4 = _mm256_cvtps_pd(_mm_loadu_ps(b + j));
|
|
43
|
+
__m256d c_f64x4 = _mm256_cvtps_pd(_mm_loadu_ps(c + i * n + j));
|
|
44
|
+
cb_j_f64x4 = _mm256_fmadd_pd(b_f64x4, c_f64x4, cb_j_f64x4);
|
|
45
|
+
}
|
|
46
|
+
sum_f64x4 = _mm256_fmadd_pd(a_f64x4, cb_j_f64x4, sum_f64x4);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
50
|
+
if (tail_length) {
|
|
51
|
+
nk_b128_vec_t b_tail_vec;
|
|
52
|
+
nk_partial_load_b32x4_haswell_(b + tail_start, &b_tail_vec, tail_length);
|
|
53
|
+
__m256d b_tail_f64x4 = _mm256_cvtps_pd(b_tail_vec.xmm_ps);
|
|
54
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
55
|
+
nk_f64_t a_i = (nk_f64_t)a[i];
|
|
56
|
+
nk_b128_vec_t c_tail_vec;
|
|
57
|
+
nk_partial_load_b32x4_haswell_(c + i * n + tail_start, &c_tail_vec, tail_length);
|
|
58
|
+
__m256d c_tail_f64x4 = _mm256_cvtps_pd(c_tail_vec.xmm_ps);
|
|
59
|
+
sum += a_i * nk_reduce_add_f64x4_haswell_(_mm256_mul_pd(b_tail_f64x4, c_tail_f64x4));
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
*result = sum;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
NK_PUBLIC void nk_mahalanobis_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
67
|
+
nk_f64_t *result) {
|
|
68
|
+
nk_size_t const tail_length = n % 4;
|
|
69
|
+
nk_size_t const tail_start = n - tail_length;
|
|
70
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
71
|
+
|
|
72
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
73
|
+
__m256d diff_i_f64x4 = _mm256_set1_pd((nk_f64_t)a[i] - (nk_f64_t)b[i]);
|
|
74
|
+
__m256d cdiff_j_f64x4 = _mm256_setzero_pd();
|
|
75
|
+
for (nk_size_t j = 0; j + 4 <= n; j += 4) {
|
|
76
|
+
__m256d diff_j_f64x4 = _mm256_sub_pd( //
|
|
77
|
+
_mm256_cvtps_pd(_mm_loadu_ps(a + j)), _mm256_cvtps_pd(_mm_loadu_ps(b + j)));
|
|
78
|
+
__m256d c_f64x4 = _mm256_cvtps_pd(_mm_loadu_ps(c + i * n + j));
|
|
79
|
+
cdiff_j_f64x4 = _mm256_fmadd_pd(diff_j_f64x4, c_f64x4, cdiff_j_f64x4);
|
|
80
|
+
}
|
|
81
|
+
sum_f64x4 = _mm256_fmadd_pd(diff_i_f64x4, cdiff_j_f64x4, sum_f64x4);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
85
|
+
if (tail_length) {
|
|
86
|
+
nk_b128_vec_t a_tail_vec, b_tail_vec;
|
|
87
|
+
nk_partial_load_b32x4_haswell_(a + tail_start, &a_tail_vec, tail_length);
|
|
88
|
+
nk_partial_load_b32x4_haswell_(b + tail_start, &b_tail_vec, tail_length);
|
|
89
|
+
__m256d diff_tail_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(a_tail_vec.xmm_ps), _mm256_cvtps_pd(b_tail_vec.xmm_ps));
|
|
90
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
91
|
+
nk_f64_t diff_i = (nk_f64_t)a[i] - (nk_f64_t)b[i];
|
|
92
|
+
nk_b128_vec_t c_tail_vec;
|
|
93
|
+
nk_partial_load_b32x4_haswell_(c + i * n + tail_start, &c_tail_vec, tail_length);
|
|
94
|
+
__m256d c_tail_f64x4 = _mm256_cvtps_pd(c_tail_vec.xmm_ps);
|
|
95
|
+
sum += diff_i * nk_reduce_add_f64x4_haswell_(_mm256_mul_pd(diff_tail_f64x4, c_tail_f64x4));
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
*result = nk_f64_sqrt_haswell(sum > 0 ? sum : 0);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
NK_PUBLIC void nk_bilinear_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
103
|
+
nk_f32_t *result) {
|
|
104
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
105
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
106
|
+
__m256 a_f32x8 = _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i)));
|
|
107
|
+
__m256 cb_j_f32x8 = _mm256_setzero_ps();
|
|
108
|
+
for (nk_size_t j = 0; j + 8 <= n; j += 8) {
|
|
109
|
+
__m256 b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(b + j)));
|
|
110
|
+
__m256 c_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(c + i * n + j)));
|
|
111
|
+
cb_j_f32x8 = _mm256_fmadd_ps(b_f32x8, c_f32x8, cb_j_f32x8);
|
|
112
|
+
}
|
|
113
|
+
sum_f32x8 = _mm256_fmadd_ps(a_f32x8, cb_j_f32x8, sum_f32x8);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
// Handle the tail of every row
|
|
117
|
+
nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
|
|
118
|
+
nk_size_t const tail_length = n % 8;
|
|
119
|
+
nk_size_t const tail_start = n - tail_length;
|
|
120
|
+
if (tail_length) {
|
|
121
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
122
|
+
nk_f32_t a_i = _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))));
|
|
123
|
+
nk_b256_vec_t b_vec;
|
|
124
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(b + tail_start, &b_vec, tail_length);
|
|
125
|
+
__m256 b_f32x8 = b_vec.ymm_ps;
|
|
126
|
+
nk_b256_vec_t c_vec;
|
|
127
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(c + i * n + tail_start, &c_vec, tail_length);
|
|
128
|
+
__m256 c_f32x8 = c_vec.ymm_ps;
|
|
129
|
+
nk_f32_t cb_j = nk_reduce_add_f32x8_haswell_(_mm256_mul_ps(b_f32x8, c_f32x8));
|
|
130
|
+
sum += a_i * cb_j;
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
*result = sum;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
NK_PUBLIC void nk_mahalanobis_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
138
|
+
nk_f32_t *result) {
|
|
139
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
140
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
141
|
+
__m256 diff_i_f32x8 = _mm256_sub_ps( //
|
|
142
|
+
_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))), //
|
|
143
|
+
_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(b + i))));
|
|
144
|
+
__m256 cdiff_j_f32x8 = _mm256_setzero_ps();
|
|
145
|
+
for (nk_size_t j = 0; j + 8 <= n; j += 8) {
|
|
146
|
+
__m256 diff_j_f32x8 = _mm256_sub_ps( //
|
|
147
|
+
_mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(a + j))),
|
|
148
|
+
_mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(b + j))));
|
|
149
|
+
__m256 c_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(c + i * n + j)));
|
|
150
|
+
cdiff_j_f32x8 = _mm256_fmadd_ps(diff_j_f32x8, c_f32x8, cdiff_j_f32x8);
|
|
151
|
+
}
|
|
152
|
+
sum_f32x8 = _mm256_fmadd_ps(diff_i_f32x8, cdiff_j_f32x8, sum_f32x8);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
// Handle the tail of every row
|
|
156
|
+
nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
|
|
157
|
+
nk_size_t const tail_length = n % 8;
|
|
158
|
+
nk_size_t const tail_start = n - tail_length;
|
|
159
|
+
if (tail_length) {
|
|
160
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
161
|
+
nk_f32_t diff_i = _mm256_cvtss_f32(_mm256_sub_ps( //
|
|
162
|
+
_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))), //
|
|
163
|
+
_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(b + i)))));
|
|
164
|
+
nk_b256_vec_t a_tail_vec, b_tail_vec;
|
|
165
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(a + tail_start, &a_tail_vec, tail_length);
|
|
166
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(b + tail_start, &b_tail_vec, tail_length);
|
|
167
|
+
__m256 diff_j_f32x8 = _mm256_sub_ps(a_tail_vec.ymm_ps, b_tail_vec.ymm_ps);
|
|
168
|
+
nk_b256_vec_t c_vec;
|
|
169
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(c + i * n + tail_start, &c_vec, tail_length);
|
|
170
|
+
__m256 c_f32x8 = c_vec.ymm_ps;
|
|
171
|
+
nk_f32_t cdiff_j = nk_reduce_add_f32x8_haswell_(_mm256_mul_ps(diff_j_f32x8, c_f32x8));
|
|
172
|
+
sum += diff_i * cdiff_j;
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
*result = nk_f32_sqrt_haswell(sum > 0 ? sum : 0);
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
NK_PUBLIC void nk_bilinear_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
180
|
+
nk_f32_t *result) {
|
|
181
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
182
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
183
|
+
// The `nk_bf16_to_f32_serial` is cheaper than `nk_bf16x8_to_f32x8_haswell_`
|
|
184
|
+
nk_f32_t a_f32;
|
|
185
|
+
nk_bf16_to_f32_serial(a + i, &a_f32);
|
|
186
|
+
__m256 a_f32x8 = _mm256_set1_ps(a_f32);
|
|
187
|
+
__m256 cb_j_f32x8 = _mm256_setzero_ps();
|
|
188
|
+
for (nk_size_t j = 0; j + 8 <= n; j += 8) {
|
|
189
|
+
__m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(b + j)));
|
|
190
|
+
__m256 c_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(c + i * n + j)));
|
|
191
|
+
cb_j_f32x8 = _mm256_fmadd_ps(b_f32x8, c_f32x8, cb_j_f32x8);
|
|
192
|
+
}
|
|
193
|
+
sum_f32x8 = _mm256_fmadd_ps(a_f32x8, cb_j_f32x8, sum_f32x8);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// Handle the tail of every row
|
|
197
|
+
nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
|
|
198
|
+
nk_size_t const tail_length = n % 8;
|
|
199
|
+
nk_size_t const tail_start = n - tail_length;
|
|
200
|
+
if (tail_length) {
|
|
201
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
202
|
+
nk_f32_t a_i;
|
|
203
|
+
nk_bf16_to_f32_serial(a + i, &a_i);
|
|
204
|
+
nk_b256_vec_t b_vec;
|
|
205
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(b + tail_start, &b_vec, tail_length);
|
|
206
|
+
__m256 b_f32x8 = b_vec.ymm_ps;
|
|
207
|
+
nk_b256_vec_t c_vec;
|
|
208
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(c + i * n + tail_start, &c_vec, tail_length);
|
|
209
|
+
__m256 c_f32x8 = c_vec.ymm_ps;
|
|
210
|
+
nk_f32_t cb_j = nk_reduce_add_f32x8_haswell_(_mm256_mul_ps(b_f32x8, c_f32x8));
|
|
211
|
+
sum += a_i * cb_j;
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
*result = sum;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
NK_PUBLIC void nk_mahalanobis_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
219
|
+
nk_f32_t *result) {
|
|
220
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
221
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
222
|
+
nk_f32_t a_i, b_i;
|
|
223
|
+
nk_bf16_to_f32_serial(a + i, &a_i);
|
|
224
|
+
nk_bf16_to_f32_serial(b + i, &b_i);
|
|
225
|
+
__m256 diff_i_f32x8 = _mm256_sub_ps( //
|
|
226
|
+
_mm256_set1_ps(a_i), //
|
|
227
|
+
_mm256_set1_ps(b_i));
|
|
228
|
+
__m256 cdiff_j_f32x8 = _mm256_setzero_ps();
|
|
229
|
+
for (nk_size_t j = 0; j + 8 <= n; j += 8) {
|
|
230
|
+
__m256 diff_j_f32x8 = _mm256_sub_ps( //
|
|
231
|
+
nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(a + j))), //
|
|
232
|
+
nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(b + j))));
|
|
233
|
+
__m256 c_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(c + i * n + j)));
|
|
234
|
+
cdiff_j_f32x8 = _mm256_fmadd_ps(diff_j_f32x8, c_f32x8, cdiff_j_f32x8);
|
|
235
|
+
}
|
|
236
|
+
sum_f32x8 = _mm256_fmadd_ps(diff_i_f32x8, cdiff_j_f32x8, sum_f32x8);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
// Handle the tail of every row
|
|
240
|
+
nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
|
|
241
|
+
nk_size_t const tail_length = n % 8;
|
|
242
|
+
nk_size_t const tail_start = n - tail_length;
|
|
243
|
+
if (tail_length) {
|
|
244
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
245
|
+
nk_f32_t a_i, b_i;
|
|
246
|
+
nk_bf16_to_f32_serial(a + i, &a_i);
|
|
247
|
+
nk_bf16_to_f32_serial(b + i, &b_i);
|
|
248
|
+
nk_f32_t diff_i = a_i - b_i;
|
|
249
|
+
nk_b256_vec_t a_tail_vec, b_tail_vec;
|
|
250
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(a + tail_start, &a_tail_vec, tail_length);
|
|
251
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(b + tail_start, &b_tail_vec, tail_length);
|
|
252
|
+
__m256 diff_j_f32x8 = _mm256_sub_ps(a_tail_vec.ymm_ps, b_tail_vec.ymm_ps);
|
|
253
|
+
nk_b256_vec_t c_vec;
|
|
254
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(c + i * n + tail_start, &c_vec, tail_length);
|
|
255
|
+
__m256 c_f32x8 = c_vec.ymm_ps;
|
|
256
|
+
nk_f32_t cdiff_j = nk_reduce_add_f32x8_haswell_(_mm256_mul_ps(diff_j_f32x8, c_f32x8));
|
|
257
|
+
sum += diff_i * cdiff_j;
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
*result = nk_f32_sqrt_haswell(sum > 0 ? sum : 0);
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
#if defined(__clang__)
|
|
265
|
+
#pragma clang attribute pop
|
|
266
|
+
#elif defined(__GNUC__)
|
|
267
|
+
#pragma GCC pop_options
|
|
268
|
+
#endif
|
|
269
|
+
|
|
270
|
+
#if defined(__cplusplus)
|
|
271
|
+
} // extern "C"
|
|
272
|
+
#endif
|
|
273
|
+
|
|
274
|
+
#endif // NK_TARGET_HASWELL
|
|
275
|
+
#endif // NK_TARGET_X86_
|
|
276
|
+
#endif // NK_CURVED_HASWELL_H
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Curved Space Similarity for NEON.
|
|
3
|
+
* @file include/numkong/curved/neon.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 14, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements f32 bilinear forms and Mahalanobis distance using ARM NEON SIMD.
|
|
10
|
+
* Accumulates f32 inputs in f64 precision to avoid catastrophic cancellation.
|
|
11
|
+
*
|
|
12
|
+
* @section neon_curved_instructions Key NEON Instructions
|
|
13
|
+
*
|
|
14
|
+
* Intrinsic Instruction Latency Throughput
|
|
15
|
+
* A76 M4+/V1+/Oryon
|
|
16
|
+
* vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
|
|
17
|
+
* vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy 2/cy 2/cy
|
|
18
|
+
* vaddvq_f64 FADDP (V.2D to scalar) 3cy 1/cy 1/cy
|
|
19
|
+
* vld1_f32 LD1 ({Vt.2S}, [Xn]) 4cy 2/cy 2/cy
|
|
20
|
+
* vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy 1/cy 1/cy
|
|
21
|
+
*
|
|
22
|
+
* For f32 bilinear and Mahalanobis, we upcast to f64 for accumulation to preserve
|
|
23
|
+
* precision and avoid catastrophic cancellation in large-magnitude sums.
|
|
24
|
+
*/
|
|
25
|
+
#ifndef NK_CURVED_NEON_H
|
|
26
|
+
#define NK_CURVED_NEON_H
|
|
27
|
+
|
|
28
|
+
#if NK_TARGET_ARM_
|
|
29
|
+
#if NK_TARGET_NEON
|
|
30
|
+
|
|
31
|
+
#include "numkong/types.h"
|
|
32
|
+
#include "numkong/spatial/neon.h" // nk_f64_sqrt_neon
|
|
33
|
+
|
|
34
|
+
#if defined(__cplusplus)
|
|
35
|
+
extern "C" {
|
|
36
|
+
#endif
|
|
37
|
+
|
|
38
|
+
#if defined(__clang__)
|
|
39
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
|
|
40
|
+
#elif defined(__GNUC__)
|
|
41
|
+
#pragma GCC push_options
|
|
42
|
+
#pragma GCC target("arch=armv8-a+simd")
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
NK_PUBLIC void nk_bilinear_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
46
|
+
nk_f64_t *result) {
|
|
47
|
+
nk_f64_t outer_sum_f64 = 0;
|
|
48
|
+
|
|
49
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
50
|
+
// Convert a[i] to f64 for precision
|
|
51
|
+
nk_f64_t a_i_f64 = (nk_f64_t)a[i];
|
|
52
|
+
|
|
53
|
+
// Inner loop: accumulate Σⱼ cᵢⱼ × bⱼ in f64
|
|
54
|
+
float64x2_t inner_sum_f64x2 = vdupq_n_f64(0);
|
|
55
|
+
nk_size_t j = 0;
|
|
56
|
+
|
|
57
|
+
// Vectorized inner loop: process 2 elements at a time
|
|
58
|
+
for (; j + 2 <= n; j += 2) {
|
|
59
|
+
// Load b[j:j+2] as f32, upcast to f64
|
|
60
|
+
float32x2_t b_f32x2 = vld1_f32(b + j);
|
|
61
|
+
float64x2_t b_f64x2 = vcvt_f64_f32(b_f32x2);
|
|
62
|
+
|
|
63
|
+
// Load c[i*n+j : i*n+j+2] as f32, upcast to f64
|
|
64
|
+
float32x2_t c_f32x2 = vld1_f32(c + i * n + j);
|
|
65
|
+
float64x2_t c_f64x2 = vcvt_f64_f32(c_f32x2);
|
|
66
|
+
|
|
67
|
+
// FMA: inner_sum += c × b
|
|
68
|
+
inner_sum_f64x2 = vfmaq_f64(inner_sum_f64x2, c_f64x2, b_f64x2);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
// Reduce the f64x2 accumulator to scalar
|
|
72
|
+
nk_f64_t inner_sum_f64 = vaddvq_f64(inner_sum_f64x2);
|
|
73
|
+
|
|
74
|
+
// Handle tail elements
|
|
75
|
+
for (; j < n; ++j) { inner_sum_f64 += (nk_f64_t)c[i * n + j] * (nk_f64_t)b[j]; }
|
|
76
|
+
|
|
77
|
+
// Outer accumulation: outer_sum += aᵢ × inner_sum
|
|
78
|
+
outer_sum_f64 += a_i_f64 * inner_sum_f64;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
*result = outer_sum_f64;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
NK_PUBLIC void nk_mahalanobis_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
85
|
+
nk_f64_t *result) {
|
|
86
|
+
nk_f64_t outer_sum_f64 = 0;
|
|
87
|
+
|
|
88
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
89
|
+
// Compute difference (aᵢ - bᵢ) in f64 for precision
|
|
90
|
+
nk_f64_t diff_i_f64 = (nk_f64_t)a[i] - (nk_f64_t)b[i];
|
|
91
|
+
|
|
92
|
+
// Inner loop: accumulate Σⱼ cᵢⱼ × (aⱼ - bⱼ) in f64
|
|
93
|
+
float64x2_t inner_sum_f64x2 = vdupq_n_f64(0);
|
|
94
|
+
nk_size_t j = 0;
|
|
95
|
+
|
|
96
|
+
// Vectorized inner loop: process 2 elements at a time
|
|
97
|
+
for (; j + 2 <= n; j += 2) {
|
|
98
|
+
// Load a[j:j+2] and b[j:j+2] as f32
|
|
99
|
+
float32x2_t a_f32x2 = vld1_f32(a + j);
|
|
100
|
+
float32x2_t b_f32x2 = vld1_f32(b + j);
|
|
101
|
+
|
|
102
|
+
// Compute difference in f32, then upcast to f64
|
|
103
|
+
float32x2_t diff_f32x2 = vsub_f32(a_f32x2, b_f32x2);
|
|
104
|
+
float64x2_t diff_f64x2 = vcvt_f64_f32(diff_f32x2);
|
|
105
|
+
|
|
106
|
+
// Load c[i*n+j : i*n+j+2] as f32, upcast to f64
|
|
107
|
+
float32x2_t c_f32x2 = vld1_f32(c + i * n + j);
|
|
108
|
+
float64x2_t c_f64x2 = vcvt_f64_f32(c_f32x2);
|
|
109
|
+
|
|
110
|
+
// FMA: inner_sum += c × diff
|
|
111
|
+
inner_sum_f64x2 = vfmaq_f64(inner_sum_f64x2, c_f64x2, diff_f64x2);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
// Reduce the f64x2 accumulator to scalar
|
|
115
|
+
nk_f64_t inner_sum_f64 = vaddvq_f64(inner_sum_f64x2);
|
|
116
|
+
|
|
117
|
+
// Handle tail elements
|
|
118
|
+
for (; j < n; ++j) {
|
|
119
|
+
nk_f64_t diff_j_f64 = (nk_f64_t)a[j] - (nk_f64_t)b[j];
|
|
120
|
+
inner_sum_f64 += (nk_f64_t)c[i * n + j] * diff_j_f64;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// Outer accumulation: outer_sum += diff_i × inner_sum
|
|
124
|
+
outer_sum_f64 += diff_i_f64 * inner_sum_f64;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Take square root of the result (clamp to 0 for numerical stability)
|
|
128
|
+
*result = nk_f64_sqrt_neon(outer_sum_f64 > 0 ? outer_sum_f64 : 0);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
NK_PUBLIC void nk_bilinear_f32c_neon(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs,
|
|
132
|
+
nk_size_t n, nk_f64c_t *results) {
|
|
133
|
+
// ARMv8.3-A FCMLA (`vcmlaq_f32`) was benchmarked for this complex inner loop.
|
|
134
|
+
// The deinterleave+4FMA pattern is 2.3x faster on Apple M4 — see `dot/neon.h` comment.
|
|
135
|
+
nk_f64_t outer_sum_real_f64 = 0;
|
|
136
|
+
nk_f64_t outer_sum_imag_f64 = 0;
|
|
137
|
+
|
|
138
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
139
|
+
// Convert a[i] to f64 for precision
|
|
140
|
+
nk_f64_t a_real_f64 = (nk_f64_t)a_pairs[i].real;
|
|
141
|
+
nk_f64_t a_imag_f64 = (nk_f64_t)a_pairs[i].imag;
|
|
142
|
+
|
|
143
|
+
// Inner loop: accumulate Σⱼ cᵢⱼ × bⱼ in f64
|
|
144
|
+
float64x2_t inner_sum_real_f64x2 = vdupq_n_f64(0);
|
|
145
|
+
float64x2_t inner_sum_imag_f64x2 = vdupq_n_f64(0);
|
|
146
|
+
nk_size_t j = 0;
|
|
147
|
+
|
|
148
|
+
// Vectorized inner loop: process 2 complex elements at a time
|
|
149
|
+
for (; j + 2 <= n; j += 2) {
|
|
150
|
+
// Load b[j:j+2] as interleaved complex pairs (real, imag, real, imag)
|
|
151
|
+
float32x2x2_t b_f32x2x2 = vld2_f32((nk_f32_t const *)(b_pairs + j));
|
|
152
|
+
float64x2_t b_real_f64x2 = vcvt_f64_f32(b_f32x2x2.val[0]);
|
|
153
|
+
float64x2_t b_imag_f64x2 = vcvt_f64_f32(b_f32x2x2.val[1]);
|
|
154
|
+
|
|
155
|
+
// Load c[i*n+j : i*n+j+2] as interleaved complex pairs
|
|
156
|
+
float32x2x2_t c_f32x2x2 = vld2_f32((nk_f32_t const *)(c_pairs + i * n + j));
|
|
157
|
+
float64x2_t c_real_f64x2 = vcvt_f64_f32(c_f32x2x2.val[0]);
|
|
158
|
+
float64x2_t c_imag_f64x2 = vcvt_f64_f32(c_f32x2x2.val[1]);
|
|
159
|
+
|
|
160
|
+
// Complex multiply
|
|
161
|
+
inner_sum_real_f64x2 = vfmaq_f64(inner_sum_real_f64x2, c_real_f64x2, b_real_f64x2);
|
|
162
|
+
inner_sum_real_f64x2 = vfmsq_f64(inner_sum_real_f64x2, c_imag_f64x2, b_imag_f64x2);
|
|
163
|
+
|
|
164
|
+
// Imaginary part: c_real×b_imag + c_imag×b_real
|
|
165
|
+
inner_sum_imag_f64x2 = vfmaq_f64(inner_sum_imag_f64x2, c_real_f64x2, b_imag_f64x2);
|
|
166
|
+
inner_sum_imag_f64x2 = vfmaq_f64(inner_sum_imag_f64x2, c_imag_f64x2, b_real_f64x2);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
// Reduce the f64x2 accumulators to scalars
|
|
170
|
+
nk_f64_t inner_sum_real_f64 = vaddvq_f64(inner_sum_real_f64x2);
|
|
171
|
+
nk_f64_t inner_sum_imag_f64 = vaddvq_f64(inner_sum_imag_f64x2);
|
|
172
|
+
|
|
173
|
+
// Handle tail elements
|
|
174
|
+
for (; j < n; ++j) {
|
|
175
|
+
nk_f64_t b_real = (nk_f64_t)b_pairs[j].real;
|
|
176
|
+
nk_f64_t b_imag = (nk_f64_t)b_pairs[j].imag;
|
|
177
|
+
nk_f64_t c_real = (nk_f64_t)c_pairs[i * n + j].real;
|
|
178
|
+
nk_f64_t c_imag = (nk_f64_t)c_pairs[i * n + j].imag;
|
|
179
|
+
// Complex multiply: c × b
|
|
180
|
+
inner_sum_real_f64 += c_real * b_real - c_imag * b_imag;
|
|
181
|
+
inner_sum_imag_f64 += c_real * b_imag + c_imag * b_real;
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
// Outer accumulation
|
|
185
|
+
outer_sum_real_f64 += a_real_f64 * inner_sum_real_f64 - a_imag_f64 * inner_sum_imag_f64;
|
|
186
|
+
outer_sum_imag_f64 += a_real_f64 * inner_sum_imag_f64 + a_imag_f64 * inner_sum_real_f64;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
results->real = outer_sum_real_f64;
|
|
190
|
+
results->imag = outer_sum_imag_f64;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
#if defined(__clang__)
|
|
194
|
+
#pragma clang attribute pop
|
|
195
|
+
#elif defined(__GNUC__)
|
|
196
|
+
#pragma GCC pop_options
|
|
197
|
+
#endif
|
|
198
|
+
|
|
199
|
+
#if defined(__cplusplus)
|
|
200
|
+
} // extern "C"
|
|
201
|
+
#endif
|
|
202
|
+
|
|
203
|
+
#endif // NK_TARGET_NEON
|
|
204
|
+
#endif // NK_TARGET_ARM_
|
|
205
|
+
#endif // NK_CURVED_NEON_H
|