numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,960 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for Haswell.
|
|
3
|
+
* @file include/numkong/spatial/haswell.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_haswell_instructions Key AVX2 Spatial Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput Ports
|
|
12
|
+
* _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
|
|
13
|
+
* _mm256_mul_ps VMULPS (YMM, YMM, YMM) 5cy 0.5/cy p01
|
|
14
|
+
* _mm256_add_ps VADDPS (YMM, YMM, YMM) 3cy 1/cy p01
|
|
15
|
+
* _mm256_sub_ps VSUBPS (YMM, YMM, YMM) 3cy 1/cy p01
|
|
16
|
+
* _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy 1/cy p0
|
|
17
|
+
* _mm_sqrt_ps VSQRTPS (XMM, XMM) 11cy 7cy p0
|
|
18
|
+
* _mm256_sqrt_ps VSQRTPS (YMM, YMM) 12cy 14cy p0
|
|
19
|
+
*
|
|
20
|
+
* For angular distance normalization, `_mm_rsqrt_ps` provides ~12-bit precision (1.5 x 2⁻¹² error).
|
|
21
|
+
* Newton-Raphson refinement doubles precision to ~22-24 bits, sufficient for f32. For f64 we use
|
|
22
|
+
* the exact `_mm_sqrt_pd` instruction since fast rsqrt approximations lack f64 precision.
|
|
23
|
+
*/
|
|
24
|
+
#ifndef NK_SPATIAL_HASWELL_H
|
|
25
|
+
#define NK_SPATIAL_HASWELL_H
|
|
26
|
+
|
|
27
|
+
#if NK_TARGET_X86_
|
|
28
|
+
#if NK_TARGET_HASWELL
|
|
29
|
+
|
|
30
|
+
#include "numkong/types.h"
|
|
31
|
+
#include "numkong/scalar/haswell.h" // `nk_f32_sqrt_haswell`
|
|
32
|
+
#include "numkong/dot/haswell.h" // `nk_dot_f32x4_state_haswell_t`
|
|
33
|
+
#include "numkong/reduce/haswell.h" // `nk_reduce_add_f32x8_haswell_`
|
|
34
|
+
|
|
35
|
+
#if defined(__cplusplus)
|
|
36
|
+
extern "C" {
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#if defined(__clang__)
|
|
40
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
|
|
41
|
+
#elif defined(__GNUC__)
|
|
42
|
+
#pragma GCC push_options
|
|
43
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
/** @brief Reciprocal square root of 4 floats with Newton-Raphson refinement. */
|
|
47
|
+
NK_INTERNAL __m128 nk_rsqrt_f32x4_haswell_(__m128 x) {
|
|
48
|
+
__m128 rsqrt_f32x4 = _mm_rsqrt_ps(x);
|
|
49
|
+
__m128 nr_f32x4 = _mm_mul_ps(_mm_mul_ps(x, rsqrt_f32x4), rsqrt_f32x4);
|
|
50
|
+
nr_f32x4 = _mm_sub_ps(_mm_set1_ps(3.0f), nr_f32x4);
|
|
51
|
+
return _mm_mul_ps(_mm_mul_ps(_mm_set1_ps(0.5f), rsqrt_f32x4), nr_f32x4);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
/** @brief Safe square root of 4 floats with zero-clamping for numerical stability. */
|
|
55
|
+
NK_INTERNAL __m128 nk_safe_sqrt_f32x4_haswell_(__m128 x) { return _mm_sqrt_ps(_mm_max_ps(x, _mm_setzero_ps())); }
|
|
56
|
+
|
|
57
|
+
/** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs. */
|
|
58
|
+
NK_INTERNAL void nk_angular_through_f32_from_dot_haswell_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
59
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
60
|
+
__m128 dots_f32x4 = dots.xmm_ps;
|
|
61
|
+
__m128 query_sumsq_f32x4 = _mm_set1_ps(query_sumsq);
|
|
62
|
+
__m128 products_f32x4 = _mm_mul_ps(query_sumsq_f32x4, target_sumsqs.xmm_ps);
|
|
63
|
+
__m128 rsqrt_f32x4 = nk_rsqrt_f32x4_haswell_(products_f32x4);
|
|
64
|
+
__m128 normalized_f32x4 = _mm_mul_ps(dots_f32x4, rsqrt_f32x4);
|
|
65
|
+
__m128 angular_f32x4 = _mm_sub_ps(_mm_set1_ps(1.0f), normalized_f32x4);
|
|
66
|
+
results->xmm_ps = _mm_max_ps(angular_f32x4, _mm_setzero_ps());
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
/** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs. */
|
|
70
|
+
NK_INTERNAL void nk_euclidean_through_f32_from_dot_haswell_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
71
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
72
|
+
__m128 dots_f32x4 = dots.xmm_ps;
|
|
73
|
+
__m128 query_sumsq_f32x4 = _mm_set1_ps(query_sumsq);
|
|
74
|
+
__m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, target_sumsqs.xmm_ps);
|
|
75
|
+
__m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
|
|
76
|
+
results->xmm_ps = nk_safe_sqrt_f32x4_haswell_(dist_sq_f32x4);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/** @brief Angular from_dot for native f64: 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs. */
|
|
80
|
+
NK_INTERNAL void nk_angular_through_f64_from_dot_haswell_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
81
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
82
|
+
__m256d dots_f64x4 = dots.ymm_pd;
|
|
83
|
+
__m256d query_sumsq_f64x4 = _mm256_set1_pd(query_sumsq);
|
|
84
|
+
__m256d products_f64x4 = _mm256_mul_pd(query_sumsq_f64x4, target_sumsqs.ymm_pd);
|
|
85
|
+
__m256d sqrt_products_f64x4 = _mm256_sqrt_pd(products_f64x4);
|
|
86
|
+
__m256d normalized_f64x4 = _mm256_div_pd(dots_f64x4, sqrt_products_f64x4);
|
|
87
|
+
__m256d angular_f64x4 = _mm256_sub_pd(_mm256_set1_pd(1.0), normalized_f64x4);
|
|
88
|
+
results->ymm_pd = _mm256_max_pd(angular_f64x4, _mm256_setzero_pd());
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/** @brief Euclidean from_dot for native f64: √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs. */
|
|
92
|
+
NK_INTERNAL void nk_euclidean_through_f64_from_dot_haswell_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
93
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
94
|
+
__m256d dots_f64x4 = dots.ymm_pd;
|
|
95
|
+
__m256d query_sumsq_f64x4 = _mm256_set1_pd(query_sumsq);
|
|
96
|
+
__m256d sum_sq_f64x4 = _mm256_add_pd(query_sumsq_f64x4, target_sumsqs.ymm_pd);
|
|
97
|
+
__m256d dist_sq_f64x4 = _mm256_fnmadd_pd(_mm256_set1_pd(2.0), dots_f64x4, sum_sq_f64x4);
|
|
98
|
+
results->ymm_pd = _mm256_sqrt_pd(_mm256_max_pd(dist_sq_f64x4, _mm256_setzero_pd()));
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
/** @brief Angular from_dot for i32 accumulators: cast to f32, rsqrt+NR, clamp. 4 pairs. */
|
|
102
|
+
NK_INTERNAL void nk_angular_through_i32_from_dot_haswell_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
103
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
104
|
+
__m128 dots_f32x4 = _mm_cvtepi32_ps(dots.xmm);
|
|
105
|
+
__m128 query_sumsq_f32x4 = _mm_set1_ps((nk_f32_t)query_sumsq);
|
|
106
|
+
__m128 products_f32x4 = _mm_mul_ps(query_sumsq_f32x4, _mm_cvtepi32_ps(target_sumsqs.xmm));
|
|
107
|
+
__m128 rsqrt_f32x4 = nk_rsqrt_f32x4_haswell_(products_f32x4);
|
|
108
|
+
__m128 normalized_f32x4 = _mm_mul_ps(dots_f32x4, rsqrt_f32x4);
|
|
109
|
+
__m128 angular_f32x4 = _mm_sub_ps(_mm_set1_ps(1.0f), normalized_f32x4);
|
|
110
|
+
results->xmm_ps = _mm_max_ps(angular_f32x4, _mm_setzero_ps());
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
/** @brief Euclidean from_dot for i32 accumulators: cast to f32, then √(a² + b² − 2ab). 4 pairs. */
|
|
114
|
+
NK_INTERNAL void nk_euclidean_through_i32_from_dot_haswell_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
115
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
116
|
+
__m128 dots_f32x4 = _mm_cvtepi32_ps(dots.xmm);
|
|
117
|
+
__m128 query_sumsq_f32x4 = _mm_set1_ps((nk_f32_t)query_sumsq);
|
|
118
|
+
__m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, _mm_cvtepi32_ps(target_sumsqs.xmm));
|
|
119
|
+
__m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
|
|
120
|
+
results->xmm_ps = nk_safe_sqrt_f32x4_haswell_(dist_sq_f32x4);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
/** @brief Angular from_dot for u32 accumulators: cast to f32, rsqrt+NR, clamp. 4 pairs. */
|
|
124
|
+
NK_INTERNAL void nk_angular_through_u32_from_dot_haswell_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
125
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
126
|
+
__m128 dots_f32x4 = _mm_cvtepi32_ps(dots.xmm);
|
|
127
|
+
__m128 query_sumsq_f32x4 = _mm_set1_ps((nk_f32_t)query_sumsq);
|
|
128
|
+
__m128 products_f32x4 = _mm_mul_ps(query_sumsq_f32x4, _mm_cvtepi32_ps(target_sumsqs.xmm));
|
|
129
|
+
__m128 rsqrt_f32x4 = nk_rsqrt_f32x4_haswell_(products_f32x4);
|
|
130
|
+
__m128 normalized_f32x4 = _mm_mul_ps(dots_f32x4, rsqrt_f32x4);
|
|
131
|
+
__m128 angular_f32x4 = _mm_sub_ps(_mm_set1_ps(1.0f), normalized_f32x4);
|
|
132
|
+
results->xmm_ps = _mm_max_ps(angular_f32x4, _mm_setzero_ps());
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
/** @brief Euclidean from_dot for u32 accumulators: cast to f32, then √(a² + b² − 2ab). 4 pairs. */
|
|
136
|
+
NK_INTERNAL void nk_euclidean_through_u32_from_dot_haswell_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
137
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
138
|
+
__m128 dots_f32x4 = _mm_cvtepi32_ps(dots.xmm);
|
|
139
|
+
__m128 query_sumsq_f32x4 = _mm_set1_ps((nk_f32_t)query_sumsq);
|
|
140
|
+
__m128 sum_sq_f32x4 = _mm_add_ps(query_sumsq_f32x4, _mm_cvtepi32_ps(target_sumsqs.xmm));
|
|
141
|
+
__m128 dist_sq_f32x4 = _mm_fnmadd_ps(_mm_set1_ps(2.0f), dots_f32x4, sum_sq_f32x4);
|
|
142
|
+
results->xmm_ps = nk_safe_sqrt_f32x4_haswell_(dist_sq_f32x4);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
NK_INTERNAL nk_f64_t nk_angular_normalize_f64_haswell_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
|
|
146
|
+
|
|
147
|
+
// If both vectors have magnitude 0, the distance is 0.
|
|
148
|
+
if (a2 == 0 && b2 == 0) return 0;
|
|
149
|
+
// If any one of the vectors is 0, the square root of the product is 0,
|
|
150
|
+
// the division is illformed, and the result is 1.
|
|
151
|
+
else if (ab == 0) return 1;
|
|
152
|
+
|
|
153
|
+
// Design note: We use exact `_mm_sqrt_pd` instead of fast rsqrt approximation.
|
|
154
|
+
// The f32 `_mm_rsqrt_ps` has max relative error of 1.5 × 2⁻¹² (~11 bits precision).
|
|
155
|
+
// Even with Newton-Raphson refinement (doubles precision to ~22-24 bits), this is
|
|
156
|
+
// insufficient for f64's 52-bit mantissa, causing ULP errors in the hundreds of millions.
|
|
157
|
+
// The `_mm_sqrt_pd` instruction has ~13 cycle latency but provides full f64 precision.
|
|
158
|
+
// https://web.archive.org/web/20210208132927/http://assemblyrequired.crashworks.org/timing-square-root/
|
|
159
|
+
__m128d squares_f64x2 = _mm_set_pd(a2, b2);
|
|
160
|
+
__m128d sqrts_f64x2 = _mm_sqrt_pd(squares_f64x2);
|
|
161
|
+
nk_f64_t a_sqrt = _mm_cvtsd_f64(_mm_unpackhi_pd(sqrts_f64x2, sqrts_f64x2));
|
|
162
|
+
nk_f64_t b_sqrt = _mm_cvtsd_f64(sqrts_f64x2);
|
|
163
|
+
nk_f64_t result = 1 - ab / (a_sqrt * b_sqrt);
|
|
164
|
+
return result > 0 ? result : 0;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
NK_INTERNAL nk_f32_t nk_angular_normalize_f32_haswell_(nk_f32_t ab, nk_f32_t a2, nk_f32_t b2) {
|
|
168
|
+
|
|
169
|
+
// If both vectors have magnitude 0, the distance is 0.
|
|
170
|
+
if (a2 == 0.0f && b2 == 0.0f) return 0.0f;
|
|
171
|
+
// If any one of the vectors is 0, the square root of the product is 0,
|
|
172
|
+
// the division is illformed, and the result is 1.
|
|
173
|
+
else if (ab == 0.0f) return 1.0f;
|
|
174
|
+
|
|
175
|
+
// Load the squares into an __m128 register for single-precision floating-point operations
|
|
176
|
+
__m128 squares = _mm_set_ps(a2, b2, a2, b2); // We replicate to make use of full register
|
|
177
|
+
|
|
178
|
+
// Compute the reciprocal square root of the squares using `_mm_rsqrt_ps` (single-precision)
|
|
179
|
+
__m128 rsqrts = _mm_rsqrt_ps(squares);
|
|
180
|
+
|
|
181
|
+
// Perform one iteration of Newton-Raphson refinement to improve the precision of rsqrt:
|
|
182
|
+
// Formula: y' = y × (1.5 - 0.5 × x × y × y)
|
|
183
|
+
__m128 half = _mm_set1_ps(0.5f);
|
|
184
|
+
__m128 three_halves = _mm_set1_ps(1.5f);
|
|
185
|
+
rsqrts = _mm_mul_ps(rsqrts,
|
|
186
|
+
_mm_sub_ps(three_halves, _mm_mul_ps(half, _mm_mul_ps(squares, _mm_mul_ps(rsqrts, rsqrts)))));
|
|
187
|
+
|
|
188
|
+
// Extract the reciprocal square roots of a2 and b2 from the __m128 register
|
|
189
|
+
nk_f32_t a2_reciprocal = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts, rsqrts, _MM_SHUFFLE(0, 0, 0, 1)));
|
|
190
|
+
nk_f32_t b2_reciprocal = _mm_cvtss_f32(rsqrts);
|
|
191
|
+
|
|
192
|
+
// Calculate the angular distance: 1 - dot_product × a2_reciprocal × b2_reciprocal
|
|
193
|
+
nk_f32_t result = 1.0f - ab * a2_reciprocal * b2_reciprocal;
|
|
194
|
+
return result > 0 ? result : 0;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
#pragma region - Smaller Floats
|
|
198
|
+
|
|
199
|
+
NK_PUBLIC void nk_sqeuclidean_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
200
|
+
__m256 a_f32x8, b_f32x8;
|
|
201
|
+
__m256 distance_sq_f32x8 = _mm256_setzero_ps();
|
|
202
|
+
|
|
203
|
+
nk_sqeuclidean_f16_haswell_cycle:
|
|
204
|
+
if (n < 8) {
|
|
205
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
206
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(a, &a_vec, n);
|
|
207
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(b, &b_vec, n);
|
|
208
|
+
a_f32x8 = a_vec.ymm_ps;
|
|
209
|
+
b_f32x8 = b_vec.ymm_ps;
|
|
210
|
+
n = 0;
|
|
211
|
+
}
|
|
212
|
+
else {
|
|
213
|
+
a_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)a));
|
|
214
|
+
b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)b));
|
|
215
|
+
n -= 8, a += 8, b += 8;
|
|
216
|
+
}
|
|
217
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
218
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
219
|
+
if (n) goto nk_sqeuclidean_f16_haswell_cycle;
|
|
220
|
+
|
|
221
|
+
*result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
NK_PUBLIC void nk_euclidean_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
225
|
+
nk_sqeuclidean_f16_haswell(a, b, n, result);
|
|
226
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
NK_PUBLIC void nk_angular_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
230
|
+
__m256 a_f32x8, b_f32x8;
|
|
231
|
+
__m256 dot_product_f32x8 = _mm256_setzero_ps(), a_norm_sq_f32x8 = _mm256_setzero_ps(),
|
|
232
|
+
b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
233
|
+
|
|
234
|
+
nk_angular_f16_haswell_cycle:
|
|
235
|
+
if (n < 8) {
|
|
236
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
237
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(a, &a_vec, n);
|
|
238
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(b, &b_vec, n);
|
|
239
|
+
a_f32x8 = a_vec.ymm_ps;
|
|
240
|
+
b_f32x8 = b_vec.ymm_ps;
|
|
241
|
+
n = 0;
|
|
242
|
+
}
|
|
243
|
+
else {
|
|
244
|
+
a_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)a));
|
|
245
|
+
b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)b));
|
|
246
|
+
n -= 8, a += 8, b += 8;
|
|
247
|
+
}
|
|
248
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
249
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
250
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
251
|
+
if (n) goto nk_angular_f16_haswell_cycle;
|
|
252
|
+
|
|
253
|
+
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
254
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
255
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(b_norm_sq_f32x8);
|
|
256
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
260
|
+
__m256 a_f32x8, b_f32x8;
|
|
261
|
+
__m256 distance_sq_f32x8 = _mm256_setzero_ps();
|
|
262
|
+
|
|
263
|
+
nk_sqeuclidean_bf16_haswell_cycle:
|
|
264
|
+
if (n < 8) {
|
|
265
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
266
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(a, &a_vec, n);
|
|
267
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(b, &b_vec, n);
|
|
268
|
+
a_f32x8 = a_vec.ymm_ps;
|
|
269
|
+
b_f32x8 = b_vec.ymm_ps;
|
|
270
|
+
n = 0;
|
|
271
|
+
}
|
|
272
|
+
else {
|
|
273
|
+
a_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
274
|
+
b_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
275
|
+
n -= 8, a += 8, b += 8;
|
|
276
|
+
}
|
|
277
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
278
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
279
|
+
if (n) goto nk_sqeuclidean_bf16_haswell_cycle;
|
|
280
|
+
|
|
281
|
+
*result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
NK_PUBLIC void nk_euclidean_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
285
|
+
nk_sqeuclidean_bf16_haswell(a, b, n, result);
|
|
286
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
NK_PUBLIC void nk_angular_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
290
|
+
__m256 a_f32x8, b_f32x8;
|
|
291
|
+
__m256 dot_product_f32x8 = _mm256_setzero_ps(), a_norm_sq_f32x8 = _mm256_setzero_ps(),
|
|
292
|
+
b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
293
|
+
|
|
294
|
+
nk_angular_bf16_haswell_cycle:
|
|
295
|
+
if (n < 8) {
|
|
296
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
297
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(a, &a_vec, n);
|
|
298
|
+
nk_partial_load_bf16x8_to_f32x8_haswell_(b, &b_vec, n);
|
|
299
|
+
a_f32x8 = a_vec.ymm_ps;
|
|
300
|
+
b_f32x8 = b_vec.ymm_ps;
|
|
301
|
+
n = 0;
|
|
302
|
+
}
|
|
303
|
+
else {
|
|
304
|
+
a_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
305
|
+
b_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
306
|
+
n -= 8, a += 8, b += 8;
|
|
307
|
+
}
|
|
308
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
309
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
310
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
311
|
+
if (n) goto nk_angular_bf16_haswell_cycle;
|
|
312
|
+
|
|
313
|
+
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
314
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
315
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(b_norm_sq_f32x8);
|
|
316
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
#pragma endregion - Smaller Floats
|
|
320
|
+
#pragma region - Small Integers
|
|
321
|
+
|
|
322
|
+
NK_PUBLIC void nk_sqeuclidean_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
323
|
+
// Optimized i8 L2-squared using saturating subtract + VPMADDWD
|
|
324
|
+
//
|
|
325
|
+
// Approach:
|
|
326
|
+
// - XOR with 0x80 to reinterpret signed i8 as unsigned u8
|
|
327
|
+
// - Compute |a-b| using unsigned saturating subtraction: diff = (a ⊖ b) | (b ⊖ a)
|
|
328
|
+
// - Zero-extend u8→u16 using unpacking (1cy latency @ p5)
|
|
329
|
+
// - Square using vpmaddwd on u16 values (32 elements/iteration)
|
|
330
|
+
//
|
|
331
|
+
// The XOR bias is needed because subs_epu8 (unsigned) saturates to 0 when the result
|
|
332
|
+
// would be negative, so OR-ing both directions gives the true |a-b|.
|
|
333
|
+
// A naive subs_epi8 (signed) would saturate to -128, corrupting the OR trick.
|
|
334
|
+
//
|
|
335
|
+
// Correctness: For squared distance, |a-b|² = (a-b)², so unsigned absolute differences are valid.
|
|
336
|
+
// The XOR preserves distances: |a-b| = |(a^0x80) - (b^0x80)|.
|
|
337
|
+
//
|
|
338
|
+
__m256i distance_sq_low_i32x8 = _mm256_setzero_si256();
|
|
339
|
+
__m256i distance_sq_high_i32x8 = _mm256_setzero_si256();
|
|
340
|
+
__m256i const zeros_i8x32 = _mm256_setzero_si256();
|
|
341
|
+
__m256i const bias_i8x32 = _mm256_set1_epi8((char)0x80);
|
|
342
|
+
__m256i diff_low_i16x16, diff_high_i16x16;
|
|
343
|
+
__m256i a_i8x32, b_i8x32, diff_u8x32;
|
|
344
|
+
|
|
345
|
+
// Process 32 elements per iteration with 256-bit loads
|
|
346
|
+
nk_size_t i = 0;
|
|
347
|
+
for (; i + 32 <= n; i += 32) {
|
|
348
|
+
a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
349
|
+
b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
350
|
+
|
|
351
|
+
// Reinterpret signed i8 as unsigned u8 by flipping the sign bit
|
|
352
|
+
a_i8x32 = _mm256_xor_si256(a_i8x32, bias_i8x32);
|
|
353
|
+
b_i8x32 = _mm256_xor_si256(b_i8x32, bias_i8x32);
|
|
354
|
+
|
|
355
|
+
// Compute |a-b| using unsigned saturating subtraction
|
|
356
|
+
// subs_epu8 saturates to 0 if result would be negative
|
|
357
|
+
// OR-ing both directions gives absolute difference as unsigned
|
|
358
|
+
diff_u8x32 = _mm256_or_si256(_mm256_subs_epu8(a_i8x32, b_i8x32), _mm256_subs_epu8(b_i8x32, a_i8x32));
|
|
359
|
+
|
|
360
|
+
// Zero-extend to i16 using unpack (1cy @ p5, much faster than cvtepi8_epi16)
|
|
361
|
+
diff_low_i16x16 = _mm256_unpacklo_epi8(diff_u8x32, zeros_i8x32);
|
|
362
|
+
diff_high_i16x16 = _mm256_unpackhi_epi8(diff_u8x32, zeros_i8x32);
|
|
363
|
+
|
|
364
|
+
// Multiply and accumulate at i16 level, accumulate at i32 level
|
|
365
|
+
distance_sq_low_i32x8 = _mm256_add_epi32(distance_sq_low_i32x8,
|
|
366
|
+
_mm256_madd_epi16(diff_low_i16x16, diff_low_i16x16));
|
|
367
|
+
distance_sq_high_i32x8 = _mm256_add_epi32(distance_sq_high_i32x8,
|
|
368
|
+
_mm256_madd_epi16(diff_high_i16x16, diff_high_i16x16));
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
// Reduce to scalar
|
|
372
|
+
nk_i32_t distance_sq_i32 = nk_reduce_add_i32x8_haswell_(
|
|
373
|
+
_mm256_add_epi32(distance_sq_low_i32x8, distance_sq_high_i32x8));
|
|
374
|
+
|
|
375
|
+
// Take care of the tail:
|
|
376
|
+
for (; i < n; ++i) {
|
|
377
|
+
nk_i32_t diff_i32 = (nk_i32_t)(a[i]) - b[i];
|
|
378
|
+
distance_sq_i32 += diff_i32 * diff_i32;
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
*result = (nk_u32_t)distance_sq_i32;
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
NK_PUBLIC void nk_euclidean_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
385
|
+
nk_u32_t distance_sq_u32;
|
|
386
|
+
nk_sqeuclidean_i8_haswell(a, b, n, &distance_sq_u32);
|
|
387
|
+
*result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
NK_PUBLIC void nk_angular_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
391
|
+
|
|
392
|
+
__m256i dot_product_i32x8 = _mm256_setzero_si256();
|
|
393
|
+
__m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
394
|
+
__m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
395
|
+
|
|
396
|
+
// AVX2 has no instructions for 8-bit signed integer dot products,
|
|
397
|
+
// but it has a weird instruction for mixed signed-unsigned 8-bit dot product.
|
|
398
|
+
// So we need to normalize the first vector to its absolute value,
|
|
399
|
+
// and shift the product sign into the second vector.
|
|
400
|
+
//
|
|
401
|
+
// __m256i a_i8_abs_vec = _mm256_abs_epi8(a_i8_vec);
|
|
402
|
+
// __m256i b_i8_flipped_vec = _mm256_sign_epi8(b_i8_vec, a_i8_vec);
|
|
403
|
+
// __m256i ab_i16_vec = _mm256_maddubs_epi16(a_i8_abs_vec, b_i8_flipped_vec);
|
|
404
|
+
//
|
|
405
|
+
// The problem with this approach, however, is the `-128` value in the second vector.
|
|
406
|
+
// Flipping its sign will do nothing, and the result will be incorrect.
|
|
407
|
+
// This can easily lead to noticeable numerical errors in the final result.
|
|
408
|
+
nk_size_t i = 0;
|
|
409
|
+
for (; i + 16 <= n; i += 16) {
|
|
410
|
+
__m128i a_i8x16 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
411
|
+
__m128i b_i8x16 = _mm_loadu_si128((__m128i const *)(b + i));
|
|
412
|
+
|
|
413
|
+
// Sign extend i8 → i16 directly (128-bit → 256-bit, no port 5 pressure)
|
|
414
|
+
__m256i a_i16x16 = _mm256_cvtepi8_epi16(a_i8x16);
|
|
415
|
+
__m256i b_i16x16 = _mm256_cvtepi8_epi16(b_i8x16);
|
|
416
|
+
|
|
417
|
+
// Multiply and accumulate as i16 pairs, accumulate products as i32:
|
|
418
|
+
dot_product_i32x8 = _mm256_add_epi32(dot_product_i32x8, _mm256_madd_epi16(a_i16x16, b_i16x16));
|
|
419
|
+
a_norm_sq_i32x8 = _mm256_add_epi32(a_norm_sq_i32x8, _mm256_madd_epi16(a_i16x16, a_i16x16));
|
|
420
|
+
b_norm_sq_i32x8 = _mm256_add_epi32(b_norm_sq_i32x8, _mm256_madd_epi16(b_i16x16, b_i16x16));
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
// Reduce to scalar
|
|
424
|
+
nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8);
|
|
425
|
+
nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8);
|
|
426
|
+
nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8);
|
|
427
|
+
|
|
428
|
+
// Take care of the tail:
|
|
429
|
+
for (; i < n; ++i) {
|
|
430
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
431
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
432
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
433
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
NK_PUBLIC void nk_sqeuclidean_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
440
|
+
|
|
441
|
+
__m256i distance_sq_low_i32x8 = _mm256_setzero_si256();
|
|
442
|
+
__m256i distance_sq_high_i32x8 = _mm256_setzero_si256();
|
|
443
|
+
__m256i const zeros_i8x32 = _mm256_setzero_si256();
|
|
444
|
+
|
|
445
|
+
nk_size_t i = 0;
|
|
446
|
+
for (; i + 32 <= n; i += 32) {
|
|
447
|
+
__m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
448
|
+
__m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
449
|
+
|
|
450
|
+
// Subtracting unsigned vectors in AVX2 is done by saturating subtraction:
|
|
451
|
+
__m256i diff_u8x32 = _mm256_or_si256(_mm256_subs_epu8(a_u8x32, b_u8x32), _mm256_subs_epu8(b_u8x32, a_u8x32));
|
|
452
|
+
|
|
453
|
+
// Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking
|
|
454
|
+
// instructions instead of extracts, as they are much faster and more efficient.
|
|
455
|
+
__m256i diff_low_i16x16 = _mm256_unpacklo_epi8(diff_u8x32, zeros_i8x32);
|
|
456
|
+
__m256i diff_high_i16x16 = _mm256_unpackhi_epi8(diff_u8x32, zeros_i8x32);
|
|
457
|
+
|
|
458
|
+
// Multiply and accumulate at `int16` level, accumulate at `int32` level:
|
|
459
|
+
distance_sq_low_i32x8 = _mm256_add_epi32(distance_sq_low_i32x8,
|
|
460
|
+
_mm256_madd_epi16(diff_low_i16x16, diff_low_i16x16));
|
|
461
|
+
distance_sq_high_i32x8 = _mm256_add_epi32(distance_sq_high_i32x8,
|
|
462
|
+
_mm256_madd_epi16(diff_high_i16x16, diff_high_i16x16));
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
// Accumulate the 32-bit integers from `distance_sq_high_i32x8` and `distance_sq_low_i32x8`
|
|
466
|
+
nk_i32_t distance_sq_i32 = nk_reduce_add_i32x8_haswell_(
|
|
467
|
+
_mm256_add_epi32(distance_sq_low_i32x8, distance_sq_high_i32x8));
|
|
468
|
+
|
|
469
|
+
// Take care of the tail:
|
|
470
|
+
for (; i < n; ++i) {
|
|
471
|
+
nk_i32_t diff_i32 = (nk_i32_t)(a[i]) - b[i];
|
|
472
|
+
distance_sq_i32 += diff_i32 * diff_i32;
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
*result = (nk_u32_t)distance_sq_i32;
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
NK_PUBLIC void nk_euclidean_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
479
|
+
nk_u32_t distance_sq_u32;
|
|
480
|
+
nk_sqeuclidean_u8_haswell(a, b, n, &distance_sq_u32);
|
|
481
|
+
*result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
NK_PUBLIC void nk_angular_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
485
|
+
|
|
486
|
+
__m256i dot_product_low_i32x8 = _mm256_setzero_si256();
|
|
487
|
+
__m256i dot_product_high_i32x8 = _mm256_setzero_si256();
|
|
488
|
+
__m256i a_norm_sq_low_i32x8 = _mm256_setzero_si256();
|
|
489
|
+
__m256i a_norm_sq_high_i32x8 = _mm256_setzero_si256();
|
|
490
|
+
__m256i b_norm_sq_low_i32x8 = _mm256_setzero_si256();
|
|
491
|
+
__m256i b_norm_sq_high_i32x8 = _mm256_setzero_si256();
|
|
492
|
+
__m256i const zeros_i8x32 = _mm256_setzero_si256();
|
|
493
|
+
|
|
494
|
+
// AVX2 has no instructions for 8-bit signed integer dot products,
|
|
495
|
+
// but it has a weird instruction for mixed signed-unsigned 8-bit dot product.
|
|
496
|
+
// So we need to normalize the first vector to its absolute value,
|
|
497
|
+
// and shift the product sign into the second vector.
|
|
498
|
+
//
|
|
499
|
+
// __m256i a_i8_abs_vec = _mm256_abs_epi8(a_i8_vec);
|
|
500
|
+
// __m256i b_i8_flipped_vec = _mm256_sign_epi8(b_i8_vec, a_i8_vec);
|
|
501
|
+
// __m256i ab_i16_vec = _mm256_maddubs_epi16(a_i8_abs_vec, b_i8_flipped_vec);
|
|
502
|
+
//
|
|
503
|
+
// The problem with this approach, however, is the `-128` value in the second vector.
|
|
504
|
+
// Flipping its sign will do nothing, and the result will be incorrect.
|
|
505
|
+
// This can easily lead to noticeable numerical errors in the final result.
|
|
506
|
+
nk_size_t i = 0;
|
|
507
|
+
for (; i + 32 <= n; i += 32) {
|
|
508
|
+
__m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
509
|
+
__m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
510
|
+
|
|
511
|
+
// Upcast `uint8` to `int16`. Unlike the signed version, we can use the unpacking
|
|
512
|
+
// instructions instead of extracts, as they are much faster and more efficient.
|
|
513
|
+
__m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_u8x32, zeros_i8x32);
|
|
514
|
+
__m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_u8x32, zeros_i8x32);
|
|
515
|
+
__m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_u8x32, zeros_i8x32);
|
|
516
|
+
__m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_u8x32, zeros_i8x32);
|
|
517
|
+
|
|
518
|
+
// Multiply and accumulate as `int16`, accumulate products as `int32`
|
|
519
|
+
dot_product_low_i32x8 = _mm256_add_epi32(dot_product_low_i32x8, _mm256_madd_epi16(a_low_i16x16, b_low_i16x16));
|
|
520
|
+
dot_product_high_i32x8 = _mm256_add_epi32(dot_product_high_i32x8,
|
|
521
|
+
_mm256_madd_epi16(a_high_i16x16, b_high_i16x16));
|
|
522
|
+
a_norm_sq_low_i32x8 = _mm256_add_epi32(a_norm_sq_low_i32x8, _mm256_madd_epi16(a_low_i16x16, a_low_i16x16));
|
|
523
|
+
a_norm_sq_high_i32x8 = _mm256_add_epi32(a_norm_sq_high_i32x8, _mm256_madd_epi16(a_high_i16x16, a_high_i16x16));
|
|
524
|
+
b_norm_sq_low_i32x8 = _mm256_add_epi32(b_norm_sq_low_i32x8, _mm256_madd_epi16(b_low_i16x16, b_low_i16x16));
|
|
525
|
+
b_norm_sq_high_i32x8 = _mm256_add_epi32(b_norm_sq_high_i32x8, _mm256_madd_epi16(b_high_i16x16, b_high_i16x16));
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
// Further reduce to a single sum for each vector
|
|
529
|
+
nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(
|
|
530
|
+
_mm256_add_epi32(dot_product_low_i32x8, dot_product_high_i32x8));
|
|
531
|
+
nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(_mm256_add_epi32(a_norm_sq_low_i32x8, a_norm_sq_high_i32x8));
|
|
532
|
+
nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(_mm256_add_epi32(b_norm_sq_low_i32x8, b_norm_sq_high_i32x8));
|
|
533
|
+
|
|
534
|
+
// Take care of the tail:
|
|
535
|
+
for (; i < n; ++i) {
|
|
536
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
537
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
538
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
539
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
#pragma endregion - Small Integers
|
|
546
|
+
#pragma region - Traditional Floats
|
|
547
|
+
|
|
548
|
+
NK_PUBLIC void nk_sqeuclidean_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
549
|
+
// Upcast to f64 for higher precision accumulation
|
|
550
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
551
|
+
nk_size_t i = 0;
|
|
552
|
+
for (; i + 4 <= n; i += 4) {
|
|
553
|
+
__m128 a_f32x4 = _mm_loadu_ps(a + i);
|
|
554
|
+
__m128 b_f32x4 = _mm_loadu_ps(b + i);
|
|
555
|
+
__m256d a_f64x4 = _mm256_cvtps_pd(a_f32x4);
|
|
556
|
+
__m256d b_f64x4 = _mm256_cvtps_pd(b_f32x4);
|
|
557
|
+
__m256d diff_f64x4 = _mm256_sub_pd(a_f64x4, b_f64x4);
|
|
558
|
+
sum_f64x4 = _mm256_fmadd_pd(diff_f64x4, diff_f64x4, sum_f64x4);
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
nk_f64_t sum_f64 = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
562
|
+
for (; i < n; ++i) {
|
|
563
|
+
nk_f64_t diff_f64 = (nk_f64_t)a[i] - b[i];
|
|
564
|
+
sum_f64 += diff_f64 * diff_f64;
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
*result = sum_f64;
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
NK_PUBLIC void nk_euclidean_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
571
|
+
// Upcast to f64 for higher precision accumulation, use f64 sqrt before downcasting
|
|
572
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
573
|
+
nk_size_t i = 0;
|
|
574
|
+
for (; i + 4 <= n; i += 4) {
|
|
575
|
+
__m128 a_f32x4 = _mm_loadu_ps(a + i);
|
|
576
|
+
__m128 b_f32x4 = _mm_loadu_ps(b + i);
|
|
577
|
+
__m256d a_f64x4 = _mm256_cvtps_pd(a_f32x4);
|
|
578
|
+
__m256d b_f64x4 = _mm256_cvtps_pd(b_f32x4);
|
|
579
|
+
__m256d diff_f64x4 = _mm256_sub_pd(a_f64x4, b_f64x4);
|
|
580
|
+
sum_f64x4 = _mm256_fmadd_pd(diff_f64x4, diff_f64x4, sum_f64x4);
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
nk_f64_t sum_f64 = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
584
|
+
for (; i < n; ++i) {
|
|
585
|
+
nk_f64_t diff_f64 = (nk_f64_t)a[i] - b[i];
|
|
586
|
+
sum_f64 += diff_f64 * diff_f64;
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
*result = nk_f64_sqrt_haswell(sum_f64);
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
NK_PUBLIC void nk_angular_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
593
|
+
// Upcast to f64 for higher precision accumulation
|
|
594
|
+
__m256d dot_f64x4 = _mm256_setzero_pd();
|
|
595
|
+
__m256d a_norm_sq_f64x4 = _mm256_setzero_pd();
|
|
596
|
+
__m256d b_norm_sq_f64x4 = _mm256_setzero_pd();
|
|
597
|
+
nk_size_t i = 0;
|
|
598
|
+
for (; i + 4 <= n; i += 4) {
|
|
599
|
+
__m128 a_f32x4 = _mm_loadu_ps(a + i);
|
|
600
|
+
__m128 b_f32x4 = _mm_loadu_ps(b + i);
|
|
601
|
+
__m256d a_f64x4 = _mm256_cvtps_pd(a_f32x4);
|
|
602
|
+
__m256d b_f64x4 = _mm256_cvtps_pd(b_f32x4);
|
|
603
|
+
dot_f64x4 = _mm256_fmadd_pd(a_f64x4, b_f64x4, dot_f64x4);
|
|
604
|
+
a_norm_sq_f64x4 = _mm256_fmadd_pd(a_f64x4, a_f64x4, a_norm_sq_f64x4);
|
|
605
|
+
b_norm_sq_f64x4 = _mm256_fmadd_pd(b_f64x4, b_f64x4, b_norm_sq_f64x4);
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
nk_f64_t dot_f64 = nk_reduce_add_f64x4_haswell_(dot_f64x4);
|
|
609
|
+
nk_f64_t a_norm_sq_f64 = nk_reduce_add_f64x4_haswell_(a_norm_sq_f64x4);
|
|
610
|
+
nk_f64_t b_norm_sq_f64 = nk_reduce_add_f64x4_haswell_(b_norm_sq_f64x4);
|
|
611
|
+
for (; i < n; ++i) {
|
|
612
|
+
nk_f64_t a_f64 = a[i], b_f64 = b[i];
|
|
613
|
+
dot_f64 += a_f64 * b_f64;
|
|
614
|
+
a_norm_sq_f64 += a_f64 * a_f64;
|
|
615
|
+
b_norm_sq_f64 += b_f64 * b_f64;
|
|
616
|
+
}
|
|
617
|
+
*result = nk_angular_normalize_f64_haswell_(dot_f64, a_norm_sq_f64, b_norm_sq_f64);
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
NK_PUBLIC void nk_sqeuclidean_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
621
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
622
|
+
__m256d a_f64x4, b_f64x4;
|
|
623
|
+
|
|
624
|
+
nk_sqeuclidean_f64_haswell_cycle:
|
|
625
|
+
if (n < 4) {
|
|
626
|
+
nk_b256_vec_t a_tail, b_tail;
|
|
627
|
+
nk_partial_load_b64x4_serial_(a, &a_tail, n);
|
|
628
|
+
nk_partial_load_b64x4_serial_(b, &b_tail, n);
|
|
629
|
+
a_f64x4 = a_tail.ymm_pd;
|
|
630
|
+
b_f64x4 = b_tail.ymm_pd;
|
|
631
|
+
n = 0;
|
|
632
|
+
}
|
|
633
|
+
else {
|
|
634
|
+
a_f64x4 = _mm256_loadu_pd(a);
|
|
635
|
+
b_f64x4 = _mm256_loadu_pd(b);
|
|
636
|
+
a += 4, b += 4, n -= 4;
|
|
637
|
+
}
|
|
638
|
+
__m256d diff_f64x4 = _mm256_sub_pd(a_f64x4, b_f64x4);
|
|
639
|
+
sum_f64x4 = _mm256_fmadd_pd(diff_f64x4, diff_f64x4, sum_f64x4);
|
|
640
|
+
if (n) goto nk_sqeuclidean_f64_haswell_cycle;
|
|
641
|
+
|
|
642
|
+
*result = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
NK_PUBLIC void nk_euclidean_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
646
|
+
nk_sqeuclidean_f64_haswell(a, b, n, result);
|
|
647
|
+
*result = nk_f64_sqrt_haswell(*result);
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
NK_PUBLIC void nk_angular_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
651
|
+
// Dot2 (Ogita-Rump-Oishi 2005) for cross-product a × b only - it may have cancellation.
|
|
652
|
+
// Self-products ‖a‖² and ‖b‖² use simple FMA - all terms are non-negative, no cancellation.
|
|
653
|
+
// Note: For cross-product we use Knuth TwoSum (6 ops) rather than Neumaier with blends (10 ops)
|
|
654
|
+
// since products can be signed and Knuth handles any operand ordering efficiently.
|
|
655
|
+
__m256d dot_sum_f64x4 = _mm256_setzero_pd();
|
|
656
|
+
__m256d dot_compensation_f64x4 = _mm256_setzero_pd();
|
|
657
|
+
__m256d a_norm_sq_f64x4 = _mm256_setzero_pd();
|
|
658
|
+
__m256d b_norm_sq_f64x4 = _mm256_setzero_pd();
|
|
659
|
+
__m256d a_f64x4, b_f64x4;
|
|
660
|
+
|
|
661
|
+
nk_angular_f64_haswell_cycle:
|
|
662
|
+
if (n < 4) {
|
|
663
|
+
nk_b256_vec_t a_tail, b_tail;
|
|
664
|
+
nk_partial_load_b64x4_serial_(a, &a_tail, n);
|
|
665
|
+
nk_partial_load_b64x4_serial_(b, &b_tail, n);
|
|
666
|
+
a_f64x4 = a_tail.ymm_pd;
|
|
667
|
+
b_f64x4 = b_tail.ymm_pd;
|
|
668
|
+
n = 0;
|
|
669
|
+
}
|
|
670
|
+
else {
|
|
671
|
+
a_f64x4 = _mm256_loadu_pd(a);
|
|
672
|
+
b_f64x4 = _mm256_loadu_pd(b);
|
|
673
|
+
a += 4, b += 4, n -= 4;
|
|
674
|
+
}
|
|
675
|
+
// TwoProd: product = a × b, error = fma(a, b, -product)
|
|
676
|
+
__m256d x_f64x4 = _mm256_mul_pd(a_f64x4, b_f64x4);
|
|
677
|
+
__m256d product_error_f64x4 = _mm256_fmsub_pd(a_f64x4, b_f64x4, x_f64x4);
|
|
678
|
+
// Knuth TwoSum: error = (sum - (t - z)) + (x - z) where z = t - sum
|
|
679
|
+
__m256d tentative_sum_f64x4 = _mm256_add_pd(dot_sum_f64x4, x_f64x4);
|
|
680
|
+
__m256d virtual_addend_f64x4 = _mm256_sub_pd(tentative_sum_f64x4, dot_sum_f64x4);
|
|
681
|
+
__m256d sum_error_f64x4 = _mm256_add_pd(
|
|
682
|
+
_mm256_sub_pd(dot_sum_f64x4, _mm256_sub_pd(tentative_sum_f64x4, virtual_addend_f64x4)),
|
|
683
|
+
_mm256_sub_pd(x_f64x4, virtual_addend_f64x4));
|
|
684
|
+
dot_sum_f64x4 = tentative_sum_f64x4;
|
|
685
|
+
dot_compensation_f64x4 = _mm256_add_pd(dot_compensation_f64x4, _mm256_add_pd(sum_error_f64x4, product_error_f64x4));
|
|
686
|
+
// Simple FMA for self-products (no cancellation possible)
|
|
687
|
+
a_norm_sq_f64x4 = _mm256_fmadd_pd(a_f64x4, a_f64x4, a_norm_sq_f64x4);
|
|
688
|
+
b_norm_sq_f64x4 = _mm256_fmadd_pd(b_f64x4, b_f64x4, b_norm_sq_f64x4);
|
|
689
|
+
if (n) goto nk_angular_f64_haswell_cycle;
|
|
690
|
+
|
|
691
|
+
*result = nk_angular_normalize_f64_haswell_( //
|
|
692
|
+
nk_dot_stable_sum_f64x4_haswell_(dot_sum_f64x4, dot_compensation_f64x4),
|
|
693
|
+
nk_reduce_add_f64x4_haswell_(a_norm_sq_f64x4), nk_reduce_add_f64x4_haswell_(b_norm_sq_f64x4));
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
#pragma endregion - Traditional Floats
|
|
697
|
+
#pragma region - Smaller Floats
|
|
698
|
+
|
|
699
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_haswell(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
700
|
+
__m256 distance_sq_f32x8 = _mm256_setzero_ps();
|
|
701
|
+
|
|
702
|
+
nk_sqeuclidean_e2m3_haswell_cycle:
|
|
703
|
+
if (n < 8) {
|
|
704
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
705
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
706
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
707
|
+
__m256 a_f32x8 = nk_e2m3x8_to_f32x8_haswell_(a_vec.xmm);
|
|
708
|
+
__m256 b_f32x8 = nk_e2m3x8_to_f32x8_haswell_(b_vec.xmm);
|
|
709
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
710
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
711
|
+
}
|
|
712
|
+
else {
|
|
713
|
+
__m256 a_f32x8 = nk_e2m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
714
|
+
__m256 b_f32x8 = nk_e2m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
715
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
716
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
717
|
+
n -= 8, a += 8, b += 8;
|
|
718
|
+
goto nk_sqeuclidean_e2m3_haswell_cycle;
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
*result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
NK_PUBLIC void nk_euclidean_e2m3_haswell(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
725
|
+
nk_sqeuclidean_e2m3_haswell(a, b, n, result);
|
|
726
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
NK_PUBLIC void nk_angular_e2m3_haswell(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
730
|
+
__m256 dot_product_f32x8 = _mm256_setzero_ps();
|
|
731
|
+
__m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
732
|
+
__m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
733
|
+
|
|
734
|
+
nk_angular_e2m3_haswell_cycle:
|
|
735
|
+
if (n < 8) {
|
|
736
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
737
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
738
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
739
|
+
__m256 a_f32x8 = nk_e2m3x8_to_f32x8_haswell_(a_vec.xmm);
|
|
740
|
+
__m256 b_f32x8 = nk_e2m3x8_to_f32x8_haswell_(b_vec.xmm);
|
|
741
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
742
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
743
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
744
|
+
}
|
|
745
|
+
else {
|
|
746
|
+
__m256 a_f32x8 = nk_e2m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
747
|
+
__m256 b_f32x8 = nk_e2m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
748
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
749
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
750
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
751
|
+
n -= 8, a += 8, b += 8;
|
|
752
|
+
goto nk_angular_e2m3_haswell_cycle;
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
756
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
757
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(b_norm_sq_f32x8);
|
|
758
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
759
|
+
}
|
|
760
|
+
|
|
761
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_haswell(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
762
|
+
__m256 distance_sq_f32x8 = _mm256_setzero_ps();
|
|
763
|
+
|
|
764
|
+
nk_sqeuclidean_e3m2_haswell_cycle:
|
|
765
|
+
if (n < 8) {
|
|
766
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
767
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
768
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
769
|
+
__m256 a_f32x8 = nk_e3m2x8_to_f32x8_haswell_(a_vec.xmm);
|
|
770
|
+
__m256 b_f32x8 = nk_e3m2x8_to_f32x8_haswell_(b_vec.xmm);
|
|
771
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
772
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
773
|
+
}
|
|
774
|
+
else {
|
|
775
|
+
__m256 a_f32x8 = nk_e3m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
776
|
+
__m256 b_f32x8 = nk_e3m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
777
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
778
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
779
|
+
n -= 8, a += 8, b += 8;
|
|
780
|
+
goto nk_sqeuclidean_e3m2_haswell_cycle;
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
*result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
NK_PUBLIC void nk_euclidean_e3m2_haswell(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
787
|
+
nk_sqeuclidean_e3m2_haswell(a, b, n, result);
|
|
788
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
NK_PUBLIC void nk_angular_e3m2_haswell(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
792
|
+
__m256 dot_product_f32x8 = _mm256_setzero_ps();
|
|
793
|
+
__m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
794
|
+
__m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
795
|
+
|
|
796
|
+
nk_angular_e3m2_haswell_cycle:
|
|
797
|
+
if (n < 8) {
|
|
798
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
799
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
800
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
801
|
+
__m256 a_f32x8 = nk_e3m2x8_to_f32x8_haswell_(a_vec.xmm);
|
|
802
|
+
__m256 b_f32x8 = nk_e3m2x8_to_f32x8_haswell_(b_vec.xmm);
|
|
803
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
804
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
805
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
806
|
+
}
|
|
807
|
+
else {
|
|
808
|
+
__m256 a_f32x8 = nk_e3m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
809
|
+
__m256 b_f32x8 = nk_e3m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
810
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
811
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
812
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
813
|
+
n -= 8, a += 8, b += 8;
|
|
814
|
+
goto nk_angular_e3m2_haswell_cycle;
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
818
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
819
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(b_norm_sq_f32x8);
|
|
820
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
821
|
+
}
|
|
822
|
+
|
|
823
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
824
|
+
__m256 distance_sq_f32x8 = _mm256_setzero_ps();
|
|
825
|
+
|
|
826
|
+
nk_sqeuclidean_e4m3_haswell_cycle:
|
|
827
|
+
if (n < 8) {
|
|
828
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
829
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
830
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
831
|
+
__m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_vec.xmm);
|
|
832
|
+
__m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_vec.xmm);
|
|
833
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
834
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
835
|
+
}
|
|
836
|
+
else {
|
|
837
|
+
__m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
838
|
+
__m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
839
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
840
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
841
|
+
n -= 8, a += 8, b += 8;
|
|
842
|
+
goto nk_sqeuclidean_e4m3_haswell_cycle;
|
|
843
|
+
}
|
|
844
|
+
|
|
845
|
+
*result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
NK_PUBLIC void nk_euclidean_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
849
|
+
nk_sqeuclidean_e4m3_haswell(a, b, n, result);
|
|
850
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
NK_PUBLIC void nk_angular_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
854
|
+
__m256 dot_product_f32x8 = _mm256_setzero_ps();
|
|
855
|
+
__m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
856
|
+
__m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
857
|
+
|
|
858
|
+
nk_angular_e4m3_haswell_cycle:
|
|
859
|
+
if (n < 8) {
|
|
860
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
861
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
862
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
863
|
+
__m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_vec.xmm);
|
|
864
|
+
__m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_vec.xmm);
|
|
865
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
866
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
867
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
868
|
+
}
|
|
869
|
+
else {
|
|
870
|
+
__m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
871
|
+
__m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
872
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
873
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
874
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
875
|
+
n -= 8, a += 8, b += 8;
|
|
876
|
+
goto nk_angular_e4m3_haswell_cycle;
|
|
877
|
+
}
|
|
878
|
+
|
|
879
|
+
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
880
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
881
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(b_norm_sq_f32x8);
|
|
882
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
883
|
+
}
|
|
884
|
+
|
|
885
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
886
|
+
__m256 distance_sq_f32x8 = _mm256_setzero_ps();
|
|
887
|
+
|
|
888
|
+
nk_sqeuclidean_e5m2_haswell_cycle:
|
|
889
|
+
if (n < 8) {
|
|
890
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
891
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
892
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
893
|
+
__m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(a_vec.xmm);
|
|
894
|
+
__m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(b_vec.xmm);
|
|
895
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
896
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
897
|
+
}
|
|
898
|
+
else {
|
|
899
|
+
__m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
900
|
+
__m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
901
|
+
__m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
|
|
902
|
+
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
903
|
+
n -= 8, a += 8, b += 8;
|
|
904
|
+
goto nk_sqeuclidean_e5m2_haswell_cycle;
|
|
905
|
+
}
|
|
906
|
+
|
|
907
|
+
*result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
|
|
908
|
+
}
|
|
909
|
+
|
|
910
|
+
NK_PUBLIC void nk_euclidean_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
911
|
+
nk_sqeuclidean_e5m2_haswell(a, b, n, result);
|
|
912
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
913
|
+
}
|
|
914
|
+
|
|
915
|
+
NK_PUBLIC void nk_angular_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
916
|
+
__m256 dot_product_f32x8 = _mm256_setzero_ps();
|
|
917
|
+
__m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
918
|
+
__m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
919
|
+
|
|
920
|
+
nk_angular_e5m2_haswell_cycle:
|
|
921
|
+
if (n < 8) {
|
|
922
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
923
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
924
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
925
|
+
__m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(a_vec.xmm);
|
|
926
|
+
__m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(b_vec.xmm);
|
|
927
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
928
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
929
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
930
|
+
}
|
|
931
|
+
else {
|
|
932
|
+
__m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
|
|
933
|
+
__m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
|
|
934
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
|
|
935
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
936
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
937
|
+
n -= 8, a += 8, b += 8;
|
|
938
|
+
goto nk_angular_e5m2_haswell_cycle;
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
942
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
943
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(b_norm_sq_f32x8);
|
|
944
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
#if defined(__clang__)
|
|
948
|
+
#pragma clang attribute pop
|
|
949
|
+
#elif defined(__GNUC__)
|
|
950
|
+
#pragma GCC pop_options
|
|
951
|
+
#endif
|
|
952
|
+
|
|
953
|
+
#if defined(__cplusplus)
|
|
954
|
+
} // extern "C"
|
|
955
|
+
#endif
|
|
956
|
+
|
|
957
|
+
#pragma endregion - Smaller Floats
|
|
958
|
+
#endif // NK_TARGET_HASWELL
|
|
959
|
+
#endif // NK_TARGET_X86_
|
|
960
|
+
#endif // NK_SPATIAL_HASWELL_H
|