numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,607 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for Alder Lake.
|
|
3
|
+
* @file include/numkong/spatial/alder.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 4, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_alder_instructions AVX-VNNI Instructions Performance
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Alder Lake Raptor Lake
|
|
12
|
+
* _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
|
|
13
|
+
* _mm256_sad_epu8 VPSADBW (YMM, YMM, YMM) 3cy @ p5 3cy @ p5
|
|
14
|
+
* _mm256_xor_si256 VPXOR (YMM, YMM, YMM) 1cy @ p015 1cy @ p015
|
|
15
|
+
* _mm256_add_epi64 VPADDQ (YMM, YMM, YMM) 1cy @ p015 1cy @ p015
|
|
16
|
+
* _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0 5cy @ p0
|
|
17
|
+
* _mm_sqrt_ss VSQRTSS (XMM, XMM, XMM) 12cy @ p0 12cy @ p0
|
|
18
|
+
*
|
|
19
|
+
* All spatial kernels use the dpbusd norm-decomposition approach:
|
|
20
|
+
* ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
21
|
+
* This avoids the p5 bottleneck from unpack operations, achieving ~2x throughput
|
|
22
|
+
* over Haswell's subs+unpack+madd approach (16 elem/cy vs 8 elem/cy).
|
|
23
|
+
*/
|
|
24
|
+
#ifndef NK_SPATIAL_ALDER_H
|
|
25
|
+
#define NK_SPATIAL_ALDER_H
|
|
26
|
+
|
|
27
|
+
#if NK_TARGET_X86_
|
|
28
|
+
#if NK_TARGET_ALDER
|
|
29
|
+
|
|
30
|
+
#include "numkong/types.h"
|
|
31
|
+
#include "numkong/dot/alder.h" // VEX compat macros + dpbusd helpers
|
|
32
|
+
#include "numkong/scalar/haswell.h" // `nk_f32_sqrt_haswell`
|
|
33
|
+
#include "numkong/reduce/haswell.h"
|
|
34
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b8x32_serial_`
|
|
35
|
+
|
|
36
|
+
#if defined(__cplusplus)
|
|
37
|
+
extern "C" {
|
|
38
|
+
#endif
|
|
39
|
+
|
|
40
|
+
#if defined(__clang__)
|
|
41
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni"))), apply_to = function)
|
|
42
|
+
#elif defined(__GNUC__)
|
|
43
|
+
#pragma GCC push_options
|
|
44
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni")
|
|
45
|
+
#endif
|
|
46
|
+
|
|
47
|
+
NK_PUBLIC void nk_angular_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
48
|
+
// Angular distance using DPBUSD with algebraic transformation for signed x signed.
|
|
49
|
+
//
|
|
50
|
+
// For angular distance we need: dot(a,b), ||a||^2, ||b||^2
|
|
51
|
+
// Using dpbusd(u8, i8) for asymmetric unsigned x signed:
|
|
52
|
+
// a' = a XOR 0x80 (signed -> unsigned), then dpbusd(a', b) = (a+128)*b
|
|
53
|
+
// a*b = dpbusd(a',b) - 128*sum(b)
|
|
54
|
+
//
|
|
55
|
+
// For norms: dpbusd(a', a) = (a+128)*a, so a^2 = dpbusd(a',a) - 128*sum(a)
|
|
56
|
+
// Similarly for b: dpbusd(b', b) = (b+128)*b
|
|
57
|
+
//
|
|
58
|
+
__m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
|
|
59
|
+
__m256i const zeros_u8x32 = _mm256_setzero_si256();
|
|
60
|
+
__m256i dot_product_i32x8 = _mm256_setzero_si256();
|
|
61
|
+
__m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
62
|
+
__m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
63
|
+
__m256i sum_a_biased_i64x4 = _mm256_setzero_si256();
|
|
64
|
+
__m256i sum_b_biased_i64x4 = _mm256_setzero_si256();
|
|
65
|
+
|
|
66
|
+
nk_size_t i = 0;
|
|
67
|
+
for (; i + 32 <= n; i += 32) {
|
|
68
|
+
__m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
69
|
+
__m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
70
|
+
|
|
71
|
+
// Convert to unsigned for dpbusd
|
|
72
|
+
__m256i a_unsigned_u8x32 = _mm256_xor_si256(a_i8x32, xor_mask_u8x32);
|
|
73
|
+
__m256i b_unsigned_u8x32 = _mm256_xor_si256(b_i8x32, xor_mask_u8x32);
|
|
74
|
+
|
|
75
|
+
// dpbusd: (a+128)*b, (a+128)*a, (b+128)*b
|
|
76
|
+
dot_product_i32x8 = _mm256_dpbusd_avx_epi32(dot_product_i32x8, a_unsigned_u8x32, b_i8x32);
|
|
77
|
+
a_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_sq_i32x8, a_unsigned_u8x32, a_i8x32);
|
|
78
|
+
b_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_sq_i32x8, b_unsigned_u8x32, b_i8x32);
|
|
79
|
+
|
|
80
|
+
// Accumulate biased sums for correction: sum(a+128), sum(b+128) via SAD
|
|
81
|
+
sum_a_biased_i64x4 = _mm256_add_epi64(sum_a_biased_i64x4, _mm256_sad_epu8(a_unsigned_u8x32, zeros_u8x32));
|
|
82
|
+
sum_b_biased_i64x4 = _mm256_add_epi64(sum_b_biased_i64x4, _mm256_sad_epu8(b_unsigned_u8x32, zeros_u8x32));
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// Reduce and apply corrections inline:
|
|
86
|
+
// correction_x = 128 * sum_x_biased - 16384 * elements_processed
|
|
87
|
+
// value = reduce(accumulator) - correction
|
|
88
|
+
nk_i64_t sum_a_biased = nk_reduce_add_i64x4_haswell_(sum_a_biased_i64x4);
|
|
89
|
+
nk_i64_t sum_b_biased = nk_reduce_add_i64x4_haswell_(sum_b_biased_i64x4);
|
|
90
|
+
nk_i64_t correction_a = 128LL * sum_a_biased - 16384LL * (nk_i64_t)i;
|
|
91
|
+
nk_i64_t correction_b = 128LL * sum_b_biased - 16384LL * (nk_i64_t)i;
|
|
92
|
+
|
|
93
|
+
nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8) - (nk_i32_t)correction_b;
|
|
94
|
+
nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8) - (nk_i32_t)correction_a;
|
|
95
|
+
nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8) - (nk_i32_t)correction_b;
|
|
96
|
+
|
|
97
|
+
// Scalar tail
|
|
98
|
+
for (; i < n; ++i) {
|
|
99
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
100
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
101
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
102
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
NK_PUBLIC void nk_sqeuclidean_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
109
|
+
// Squared Euclidean distance for i8 using DPBUSD with norm decomposition.
|
|
110
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
111
|
+
//
|
|
112
|
+
__m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
|
|
113
|
+
__m256i const zeros_u8x32 = _mm256_setzero_si256();
|
|
114
|
+
__m256i dot_product_i32x8 = _mm256_setzero_si256();
|
|
115
|
+
__m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
116
|
+
__m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
117
|
+
__m256i sum_a_biased_i64x4 = _mm256_setzero_si256();
|
|
118
|
+
__m256i sum_b_biased_i64x4 = _mm256_setzero_si256();
|
|
119
|
+
|
|
120
|
+
nk_size_t i = 0;
|
|
121
|
+
for (; i + 32 <= n; i += 32) {
|
|
122
|
+
__m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
123
|
+
__m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
124
|
+
__m256i a_unsigned_u8x32 = _mm256_xor_si256(a_i8x32, xor_mask_u8x32);
|
|
125
|
+
__m256i b_unsigned_u8x32 = _mm256_xor_si256(b_i8x32, xor_mask_u8x32);
|
|
126
|
+
|
|
127
|
+
dot_product_i32x8 = _mm256_dpbusd_avx_epi32(dot_product_i32x8, a_unsigned_u8x32, b_i8x32);
|
|
128
|
+
a_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_sq_i32x8, a_unsigned_u8x32, a_i8x32);
|
|
129
|
+
b_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_sq_i32x8, b_unsigned_u8x32, b_i8x32);
|
|
130
|
+
|
|
131
|
+
sum_a_biased_i64x4 = _mm256_add_epi64(sum_a_biased_i64x4, _mm256_sad_epu8(a_unsigned_u8x32, zeros_u8x32));
|
|
132
|
+
sum_b_biased_i64x4 = _mm256_add_epi64(sum_b_biased_i64x4, _mm256_sad_epu8(b_unsigned_u8x32, zeros_u8x32));
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
nk_i64_t sum_a_biased = nk_reduce_add_i64x4_haswell_(sum_a_biased_i64x4);
|
|
136
|
+
nk_i64_t sum_b_biased = nk_reduce_add_i64x4_haswell_(sum_b_biased_i64x4);
|
|
137
|
+
nk_i64_t correction_a = 128LL * sum_a_biased - 16384LL * (nk_i64_t)i;
|
|
138
|
+
nk_i64_t correction_b = 128LL * sum_b_biased - 16384LL * (nk_i64_t)i;
|
|
139
|
+
|
|
140
|
+
nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8) - (nk_i32_t)correction_b;
|
|
141
|
+
nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8) - (nk_i32_t)correction_a;
|
|
142
|
+
nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8) - (nk_i32_t)correction_b;
|
|
143
|
+
|
|
144
|
+
for (; i < n; ++i) {
|
|
145
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
146
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
147
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
148
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
152
|
+
*result = (nk_u32_t)(a_norm_sq_i32 + b_norm_sq_i32 - 2 * dot_product_i32);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
NK_PUBLIC void nk_euclidean_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
156
|
+
nk_u32_t distance_sq_u32;
|
|
157
|
+
nk_sqeuclidean_i8_alder(a, b, n, &distance_sq_u32);
|
|
158
|
+
*result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
NK_PUBLIC void nk_sqeuclidean_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
162
|
+
// Squared Euclidean distance for u8 using DPBUSD with norm decomposition.
|
|
163
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
164
|
+
//
|
|
165
|
+
// For u8 x u8: dpbusd(a, b'^0x80) = a*(b-128), so dot(a,b) = dpbusd(a,b') + 128*sum(a)
|
|
166
|
+
// For norms: dpbusd(a, a'^0x80) = a*(a-128), so ||a||^2 = dpbusd(a,a') + 128*sum(a)
|
|
167
|
+
//
|
|
168
|
+
__m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
|
|
169
|
+
__m256i const zeros_u8x32 = _mm256_setzero_si256();
|
|
170
|
+
__m256i dot_product_i32x8 = _mm256_setzero_si256();
|
|
171
|
+
__m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
172
|
+
__m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
173
|
+
__m256i sum_a_u64x4 = _mm256_setzero_si256();
|
|
174
|
+
__m256i sum_b_u64x4 = _mm256_setzero_si256();
|
|
175
|
+
|
|
176
|
+
nk_size_t i = 0;
|
|
177
|
+
for (; i + 32 <= n; i += 32) {
|
|
178
|
+
__m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
179
|
+
__m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
180
|
+
__m256i a_signed_i8x32 = _mm256_xor_si256(a_u8x32, xor_mask_u8x32);
|
|
181
|
+
__m256i b_signed_i8x32 = _mm256_xor_si256(b_u8x32, xor_mask_u8x32);
|
|
182
|
+
|
|
183
|
+
// dpbusd(a, b-128) = a*(b-128), dpbusd(a, a-128) = a*(a-128), etc.
|
|
184
|
+
dot_product_i32x8 = _mm256_dpbusd_avx_epi32(dot_product_i32x8, a_u8x32, b_signed_i8x32);
|
|
185
|
+
a_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_sq_i32x8, a_u8x32, a_signed_i8x32);
|
|
186
|
+
b_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_sq_i32x8, b_u8x32, b_signed_i8x32);
|
|
187
|
+
|
|
188
|
+
sum_a_u64x4 = _mm256_add_epi64(sum_a_u64x4, _mm256_sad_epu8(a_u8x32, zeros_u8x32));
|
|
189
|
+
sum_b_u64x4 = _mm256_add_epi64(sum_b_u64x4, _mm256_sad_epu8(b_u8x32, zeros_u8x32));
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// Corrections: x*(y-128) + 128*sum(x) = x*y
|
|
193
|
+
nk_i64_t sum_a_i64 = nk_reduce_add_i64x4_haswell_(sum_a_u64x4);
|
|
194
|
+
nk_i64_t sum_b_i64 = nk_reduce_add_i64x4_haswell_(sum_b_u64x4);
|
|
195
|
+
nk_i32_t dot_product_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(dot_product_i32x8) +
|
|
196
|
+
128LL * sum_a_i64);
|
|
197
|
+
nk_i32_t a_norm_sq_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8) + 128LL * sum_a_i64);
|
|
198
|
+
nk_i32_t b_norm_sq_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8) + 128LL * sum_b_i64);
|
|
199
|
+
|
|
200
|
+
for (; i < n; ++i) {
|
|
201
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
202
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
203
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
204
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
*result = (nk_u32_t)(a_norm_sq_i32 + b_norm_sq_i32 - 2 * dot_product_i32);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
NK_PUBLIC void nk_euclidean_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
211
|
+
nk_u32_t distance_sq_u32;
|
|
212
|
+
nk_sqeuclidean_u8_alder(a, b, n, &distance_sq_u32);
|
|
213
|
+
*result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
NK_PUBLIC void nk_angular_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
217
|
+
// Angular distance for u8 using DPBUSD with algebraic transformation.
|
|
218
|
+
// dpbusd(a, b'^0x80) = a*(b-128), so dot(a,b) = dpbusd(a,b') + 128*sum(a)
|
|
219
|
+
//
|
|
220
|
+
__m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
|
|
221
|
+
__m256i const zeros_u8x32 = _mm256_setzero_si256();
|
|
222
|
+
__m256i dot_product_i32x8 = _mm256_setzero_si256();
|
|
223
|
+
__m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
224
|
+
__m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
225
|
+
__m256i sum_a_u64x4 = _mm256_setzero_si256();
|
|
226
|
+
__m256i sum_b_u64x4 = _mm256_setzero_si256();
|
|
227
|
+
|
|
228
|
+
nk_size_t i = 0;
|
|
229
|
+
for (; i + 32 <= n; i += 32) {
|
|
230
|
+
__m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
231
|
+
__m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
232
|
+
__m256i a_signed_i8x32 = _mm256_xor_si256(a_u8x32, xor_mask_u8x32);
|
|
233
|
+
__m256i b_signed_i8x32 = _mm256_xor_si256(b_u8x32, xor_mask_u8x32);
|
|
234
|
+
|
|
235
|
+
dot_product_i32x8 = _mm256_dpbusd_avx_epi32(dot_product_i32x8, a_u8x32, b_signed_i8x32);
|
|
236
|
+
a_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_sq_i32x8, a_u8x32, a_signed_i8x32);
|
|
237
|
+
b_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_sq_i32x8, b_u8x32, b_signed_i8x32);
|
|
238
|
+
|
|
239
|
+
sum_a_u64x4 = _mm256_add_epi64(sum_a_u64x4, _mm256_sad_epu8(a_u8x32, zeros_u8x32));
|
|
240
|
+
sum_b_u64x4 = _mm256_add_epi64(sum_b_u64x4, _mm256_sad_epu8(b_u8x32, zeros_u8x32));
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
nk_i64_t sum_a_i64 = nk_reduce_add_i64x4_haswell_(sum_a_u64x4);
|
|
244
|
+
nk_i64_t sum_b_i64 = nk_reduce_add_i64x4_haswell_(sum_b_u64x4);
|
|
245
|
+
nk_i32_t dot_product_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(dot_product_i32x8) +
|
|
246
|
+
128LL * sum_a_i64);
|
|
247
|
+
nk_i32_t a_norm_sq_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8) + 128LL * sum_a_i64);
|
|
248
|
+
nk_i32_t b_norm_sq_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8) + 128LL * sum_b_i64);
|
|
249
|
+
|
|
250
|
+
for (; i < n; ++i) {
|
|
251
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
252
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
253
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
254
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
NK_PUBLIC void nk_angular_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
261
|
+
nk_f32_t *result) {
|
|
262
|
+
// Angular distance for e2m3 using dual-VPSHUFB LUT + VPDPBUSD norm decomposition.
|
|
263
|
+
// Every e2m3 value × 16 is an exact integer in [-120, +120].
|
|
264
|
+
// We compute dot(a,b), ||a||^2, ||b||^2 in scaled integer domain,
|
|
265
|
+
// then normalize: angular = 1 - dot / sqrt(||a||^2 * ||b||^2).
|
|
266
|
+
// Final division by 256.0f for dot and norms cancels in the ratio.
|
|
267
|
+
//
|
|
268
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
|
|
269
|
+
26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
270
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
271
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
272
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
273
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
274
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
275
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
276
|
+
__m256i dot_i32x8 = _mm256_setzero_si256();
|
|
277
|
+
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
278
|
+
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
279
|
+
__m256i a_e2m3_u8x32, b_e2m3_u8x32;
|
|
280
|
+
|
|
281
|
+
nk_angular_e2m3_alder_cycle:
|
|
282
|
+
if (count_scalars < 32) {
|
|
283
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
284
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
285
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
286
|
+
a_e2m3_u8x32 = a_vec.ymm;
|
|
287
|
+
b_e2m3_u8x32 = b_vec.ymm;
|
|
288
|
+
count_scalars = 0;
|
|
289
|
+
}
|
|
290
|
+
else {
|
|
291
|
+
a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
292
|
+
b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
293
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
// Decode a: extract magnitude, dual-VPSHUFB LUT
|
|
297
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
298
|
+
__m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
299
|
+
__m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
300
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
|
|
301
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
|
|
302
|
+
|
|
303
|
+
// Decode b: same LUT decode
|
|
304
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
305
|
+
__m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
306
|
+
__m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
307
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
|
|
308
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
|
|
309
|
+
|
|
310
|
+
// Dot product with sign: combined sign from (a XOR b) & 0x20
|
|
311
|
+
__m256i sign_combined = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
312
|
+
__m256i negate_mask = _mm256_cmpeq_epi8(sign_combined, sign_mask_u8x32);
|
|
313
|
+
__m256i b_negated = _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32);
|
|
314
|
+
__m256i b_dot_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32, b_negated, negate_mask);
|
|
315
|
+
|
|
316
|
+
// DPBUSD: a_unsigned[u8] × b_signed[i8] → i32 for dot product
|
|
317
|
+
dot_i32x8 = _mm256_dpbusd_avx_epi32(dot_i32x8, a_unsigned_u8x32, b_dot_i8x32);
|
|
318
|
+
// Norms: magnitude² is always positive, DPBUSD(unsigned, unsigned-as-signed) works since max=120 < 127
|
|
319
|
+
a_norm_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_i32x8, a_unsigned_u8x32, a_unsigned_u8x32);
|
|
320
|
+
b_norm_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_i32x8, b_unsigned_u8x32, b_unsigned_u8x32);
|
|
321
|
+
|
|
322
|
+
if (count_scalars) goto nk_angular_e2m3_alder_cycle;
|
|
323
|
+
|
|
324
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
|
|
325
|
+
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
326
|
+
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
327
|
+
// The 256.0f factor cancels in the angular normalization ratio
|
|
328
|
+
*result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
|
|
332
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
333
|
+
// Squared Euclidean distance for e2m3 using norm decomposition:
|
|
334
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
335
|
+
// Each value × 16 is exact integer, so result = integer_result / 256.0f
|
|
336
|
+
//
|
|
337
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
|
|
338
|
+
26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
339
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
340
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
341
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
342
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
343
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
344
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
345
|
+
__m256i dot_i32x8 = _mm256_setzero_si256();
|
|
346
|
+
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
347
|
+
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
348
|
+
__m256i a_e2m3_u8x32, b_e2m3_u8x32;
|
|
349
|
+
|
|
350
|
+
nk_sqeuclidean_e2m3_alder_cycle:
|
|
351
|
+
if (count_scalars < 32) {
|
|
352
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
353
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
354
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
355
|
+
a_e2m3_u8x32 = a_vec.ymm;
|
|
356
|
+
b_e2m3_u8x32 = b_vec.ymm;
|
|
357
|
+
count_scalars = 0;
|
|
358
|
+
}
|
|
359
|
+
else {
|
|
360
|
+
a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
361
|
+
b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
362
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
// Decode a and b magnitudes via LUT
|
|
366
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
367
|
+
__m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
368
|
+
__m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
369
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
|
|
370
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
|
|
371
|
+
|
|
372
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
373
|
+
__m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
374
|
+
__m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
375
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
|
|
376
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
|
|
377
|
+
|
|
378
|
+
// Signed dot product: combined sign from (a XOR b) & 0x20
|
|
379
|
+
__m256i sign_combined = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
380
|
+
__m256i negate_mask = _mm256_cmpeq_epi8(sign_combined, sign_mask_u8x32);
|
|
381
|
+
__m256i b_negated = _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32);
|
|
382
|
+
__m256i b_dot_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32, b_negated, negate_mask);
|
|
383
|
+
|
|
384
|
+
dot_i32x8 = _mm256_dpbusd_avx_epi32(dot_i32x8, a_unsigned_u8x32, b_dot_i8x32);
|
|
385
|
+
a_norm_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_i32x8, a_unsigned_u8x32, a_unsigned_u8x32);
|
|
386
|
+
b_norm_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_i32x8, b_unsigned_u8x32, b_unsigned_u8x32);
|
|
387
|
+
|
|
388
|
+
if (count_scalars) goto nk_sqeuclidean_e2m3_alder_cycle;
|
|
389
|
+
|
|
390
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
|
|
391
|
+
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
392
|
+
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
393
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b), scaled by 256
|
|
394
|
+
*result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
NK_PUBLIC void nk_euclidean_e2m3_alder(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
398
|
+
nk_sqeuclidean_e2m3_alder(a, b, n, result);
|
|
399
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
NK_PUBLIC void nk_angular_e3m2_alder(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
403
|
+
nk_f32_t *result) {
|
|
404
|
+
// Angular distance for e3m2 using dual-VPSHUFB LUT decode to i16 + VPDPWSSD norm decomposition.
|
|
405
|
+
// Every e3m2 value × 16 is an exact integer (max magnitude 448), requiring i16.
|
|
406
|
+
// VPDPWSSD replaces Haswell's VPMADDWD + VPADDD, saving one instruction per accumulation.
|
|
407
|
+
//
|
|
408
|
+
__m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
|
|
409
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
410
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
411
|
+
__m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
|
|
412
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
413
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
414
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
415
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
416
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
417
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
418
|
+
__m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
|
|
419
|
+
__m256i const ones_u8x32 = _mm256_set1_epi8(1);
|
|
420
|
+
__m256i const ones_i16x16 = _mm256_set1_epi16(1);
|
|
421
|
+
__m256i dot_i32x8 = _mm256_setzero_si256();
|
|
422
|
+
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
423
|
+
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
424
|
+
__m256i a_e3m2_u8x32, b_e3m2_u8x32;
|
|
425
|
+
|
|
426
|
+
nk_angular_e3m2_alder_cycle:
|
|
427
|
+
if (count_scalars < 32) {
|
|
428
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
429
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
430
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
431
|
+
a_e3m2_u8x32 = a_vec.ymm;
|
|
432
|
+
b_e3m2_u8x32 = b_vec.ymm;
|
|
433
|
+
count_scalars = 0;
|
|
434
|
+
}
|
|
435
|
+
else {
|
|
436
|
+
a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
437
|
+
b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
438
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
// Extract 5-bit magnitude, split into low 4 bits and bit 4
|
|
442
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
|
|
443
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
|
|
444
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
445
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
446
|
+
__m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
447
|
+
half_select_u8x32);
|
|
448
|
+
__m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
449
|
+
half_select_u8x32);
|
|
450
|
+
|
|
451
|
+
// Dual VPSHUFB: lookup low bytes in both halves, blend based on bit 4
|
|
452
|
+
__m256i a_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, a_shuffle_index_u8x32),
|
|
453
|
+
_mm256_shuffle_epi8(lut_lo_upper_u8x32, a_shuffle_index_u8x32),
|
|
454
|
+
a_upper_select_u8x32);
|
|
455
|
+
__m256i b_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, b_shuffle_index_u8x32),
|
|
456
|
+
_mm256_shuffle_epi8(lut_lo_upper_u8x32, b_shuffle_index_u8x32),
|
|
457
|
+
b_upper_select_u8x32);
|
|
458
|
+
|
|
459
|
+
// High byte: 1 iff magnitude >= 28 (signed compare safe: 27 < 128)
|
|
460
|
+
__m256i a_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
|
|
461
|
+
__m256i b_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
|
|
462
|
+
|
|
463
|
+
// Interleave low and high bytes into i16 (little-endian: low byte first)
|
|
464
|
+
__m256i a_lo_i16x16 = _mm256_unpacklo_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
|
|
465
|
+
__m256i a_hi_i16x16 = _mm256_unpackhi_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
|
|
466
|
+
__m256i b_lo_i16x16 = _mm256_unpacklo_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
|
|
467
|
+
__m256i b_hi_i16x16 = _mm256_unpackhi_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
|
|
468
|
+
|
|
469
|
+
// Combined sign: (a ^ b) & 0x20, widen to i16 via unpack, create +1/-1 sign vector
|
|
470
|
+
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
|
|
471
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
|
|
472
|
+
__m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
473
|
+
__m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
474
|
+
__m256i sign_lo_i16x16 = _mm256_or_si256(negate_lo_i16x16, ones_i16x16);
|
|
475
|
+
__m256i sign_hi_i16x16 = _mm256_or_si256(negate_hi_i16x16, ones_i16x16);
|
|
476
|
+
__m256i b_signed_lo_i16x16 = _mm256_sign_epi16(b_lo_i16x16, sign_lo_i16x16);
|
|
477
|
+
__m256i b_signed_hi_i16x16 = _mm256_sign_epi16(b_hi_i16x16, sign_hi_i16x16);
|
|
478
|
+
|
|
479
|
+
// VPDPWSSD: i16×i16→i32 fused dot-product-accumulate (replaces VPMADDWD + VPADDD)
|
|
480
|
+
dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_lo_i16x16, b_signed_lo_i16x16);
|
|
481
|
+
dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_hi_i16x16, b_signed_hi_i16x16);
|
|
482
|
+
a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_lo_i16x16, a_lo_i16x16);
|
|
483
|
+
a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_hi_i16x16, a_hi_i16x16);
|
|
484
|
+
b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_lo_i16x16, b_lo_i16x16);
|
|
485
|
+
b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_hi_i16x16, b_hi_i16x16);
|
|
486
|
+
|
|
487
|
+
if (count_scalars) goto nk_angular_e3m2_alder_cycle;
|
|
488
|
+
|
|
489
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
|
|
490
|
+
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
491
|
+
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
492
|
+
// The 256.0f factor cancels in the angular normalization ratio
|
|
493
|
+
*result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_alder(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
|
|
497
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
498
|
+
// Squared Euclidean distance for e3m2 using norm decomposition + VPDPWSSD:
|
|
499
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
500
|
+
// Each value × 16 is exact integer, so result = integer_result / 256.0f
|
|
501
|
+
//
|
|
502
|
+
__m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
|
|
503
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
504
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
505
|
+
__m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
|
|
506
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
507
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
508
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
509
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
510
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
511
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
512
|
+
__m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
|
|
513
|
+
__m256i const ones_u8x32 = _mm256_set1_epi8(1);
|
|
514
|
+
__m256i const ones_i16x16 = _mm256_set1_epi16(1);
|
|
515
|
+
__m256i dot_i32x8 = _mm256_setzero_si256();
|
|
516
|
+
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
517
|
+
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
518
|
+
__m256i a_e3m2_u8x32, b_e3m2_u8x32;
|
|
519
|
+
|
|
520
|
+
nk_sqeuclidean_e3m2_alder_cycle:
|
|
521
|
+
if (count_scalars < 32) {
|
|
522
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
523
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
524
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
525
|
+
a_e3m2_u8x32 = a_vec.ymm;
|
|
526
|
+
b_e3m2_u8x32 = b_vec.ymm;
|
|
527
|
+
count_scalars = 0;
|
|
528
|
+
}
|
|
529
|
+
else {
|
|
530
|
+
a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
531
|
+
b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
532
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
// Extract 5-bit magnitude, split into low 4 bits and bit 4
|
|
536
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
|
|
537
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
|
|
538
|
+
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
539
|
+
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
540
|
+
__m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
541
|
+
half_select_u8x32);
|
|
542
|
+
__m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
543
|
+
half_select_u8x32);
|
|
544
|
+
|
|
545
|
+
// Dual VPSHUFB: lookup low bytes in both halves, blend based on bit 4
|
|
546
|
+
__m256i a_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, a_shuffle_index_u8x32),
|
|
547
|
+
_mm256_shuffle_epi8(lut_lo_upper_u8x32, a_shuffle_index_u8x32),
|
|
548
|
+
a_upper_select_u8x32);
|
|
549
|
+
__m256i b_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, b_shuffle_index_u8x32),
|
|
550
|
+
_mm256_shuffle_epi8(lut_lo_upper_u8x32, b_shuffle_index_u8x32),
|
|
551
|
+
b_upper_select_u8x32);
|
|
552
|
+
|
|
553
|
+
// High byte: 1 iff magnitude >= 28
|
|
554
|
+
__m256i a_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
|
|
555
|
+
__m256i b_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
|
|
556
|
+
|
|
557
|
+
// Interleave low and high bytes into i16
|
|
558
|
+
__m256i a_lo_i16x16 = _mm256_unpacklo_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
|
|
559
|
+
__m256i a_hi_i16x16 = _mm256_unpackhi_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
|
|
560
|
+
__m256i b_lo_i16x16 = _mm256_unpacklo_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
|
|
561
|
+
__m256i b_hi_i16x16 = _mm256_unpackhi_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
|
|
562
|
+
|
|
563
|
+
// Combined sign for dot product
|
|
564
|
+
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
|
|
565
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
|
|
566
|
+
__m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
567
|
+
__m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
568
|
+
__m256i sign_lo_i16x16 = _mm256_or_si256(negate_lo_i16x16, ones_i16x16);
|
|
569
|
+
__m256i sign_hi_i16x16 = _mm256_or_si256(negate_hi_i16x16, ones_i16x16);
|
|
570
|
+
__m256i b_signed_lo_i16x16 = _mm256_sign_epi16(b_lo_i16x16, sign_lo_i16x16);
|
|
571
|
+
__m256i b_signed_hi_i16x16 = _mm256_sign_epi16(b_hi_i16x16, sign_hi_i16x16);
|
|
572
|
+
|
|
573
|
+
// VPDPWSSD: i16×i16→i32 fused dot-product-accumulate
|
|
574
|
+
dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_lo_i16x16, b_signed_lo_i16x16);
|
|
575
|
+
dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_hi_i16x16, b_signed_hi_i16x16);
|
|
576
|
+
a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_lo_i16x16, a_lo_i16x16);
|
|
577
|
+
a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_hi_i16x16, a_hi_i16x16);
|
|
578
|
+
b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_lo_i16x16, b_lo_i16x16);
|
|
579
|
+
b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_hi_i16x16, b_hi_i16x16);
|
|
580
|
+
|
|
581
|
+
if (count_scalars) goto nk_sqeuclidean_e3m2_alder_cycle;
|
|
582
|
+
|
|
583
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
|
|
584
|
+
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
585
|
+
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
586
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b), scaled by 256
|
|
587
|
+
*result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
NK_PUBLIC void nk_euclidean_e3m2_alder(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
591
|
+
nk_sqeuclidean_e3m2_alder(a, b, n, result);
|
|
592
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
#if defined(__clang__)
|
|
596
|
+
#pragma clang attribute pop
|
|
597
|
+
#elif defined(__GNUC__)
|
|
598
|
+
#pragma GCC pop_options
|
|
599
|
+
#endif
|
|
600
|
+
|
|
601
|
+
#if defined(__cplusplus)
|
|
602
|
+
} // extern "C"
|
|
603
|
+
#endif
|
|
604
|
+
|
|
605
|
+
#endif // NK_TARGET_ALDER
|
|
606
|
+
#endif // NK_TARGET_X86_
|
|
607
|
+
#endif // NK_SPATIAL_ALDER_H
|