numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,606 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for Skylake.
|
|
3
|
+
* @file include/numkong/spatial/skylake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_skylake_instructions Key AVX-512 Spatial Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput Ports
|
|
12
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
|
|
13
|
+
* _mm512_sub_ps VSUBPS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
|
|
14
|
+
* _mm512_rsqrt14_ps VRSQRT14PS (ZMM, ZMM) 4cy 1/cy p0
|
|
15
|
+
* _mm512_sqrt_ps VSQRTPS (ZMM, ZMM) 12cy 3cy p0
|
|
16
|
+
* _mm512_reduce_add_ps (sequence) ~8-10cy - -
|
|
17
|
+
*
|
|
18
|
+
* Distance computations benefit from Skylake-X's dual FMA units achieving 0.5cy throughput for
|
|
19
|
+
* fused multiply-add operations. VRSQRT14PS provides ~14-bit precision reciprocal square root;
|
|
20
|
+
* with Newton-Raphson refinement, this exceeds f32's 23-bit mantissa requirements.
|
|
21
|
+
*/
|
|
22
|
+
#ifndef NK_SPATIAL_SKYLAKE_H
|
|
23
|
+
#define NK_SPATIAL_SKYLAKE_H
|
|
24
|
+
|
|
25
|
+
#if NK_TARGET_X86_
|
|
26
|
+
#if NK_TARGET_SKYLAKE
|
|
27
|
+
|
|
28
|
+
#include "numkong/types.h"
|
|
29
|
+
#include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
|
|
30
|
+
#include "numkong/dot/skylake.h" // `nk_dot_f64x8_state_skylake_t`
|
|
31
|
+
|
|
32
|
+
#if defined(__cplusplus)
|
|
33
|
+
extern "C" {
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
#if defined(__clang__)
|
|
37
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
|
|
38
|
+
apply_to = function)
|
|
39
|
+
#elif defined(__GNUC__)
|
|
40
|
+
#pragma GCC push_options
|
|
41
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
/** @brief Reciprocal square root of 16 floats with Newton-Raphson refinement (~28-bit precision). */
|
|
45
|
+
NK_INTERNAL __m512 nk_rsqrt_f32x16_skylake_(__m512 x) {
|
|
46
|
+
__m512 rsqrt = _mm512_rsqrt14_ps(x);
|
|
47
|
+
__m512 nr = _mm512_mul_ps(_mm512_mul_ps(x, rsqrt), rsqrt);
|
|
48
|
+
nr = _mm512_sub_ps(_mm512_set1_ps(3.0f), nr);
|
|
49
|
+
return _mm512_mul_ps(_mm512_mul_ps(_mm512_set1_ps(0.5f), rsqrt), nr);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
/** @brief Reciprocal square root of 8 doubles with Newton-Raphson refinement (~28-bit precision). */
|
|
53
|
+
NK_INTERNAL __m512d nk_rsqrt_f64x8_skylake_(__m512d x) {
|
|
54
|
+
__m512d rsqrt = _mm512_rsqrt14_pd(x);
|
|
55
|
+
__m512d nr = _mm512_mul_pd(_mm512_mul_pd(x, rsqrt), rsqrt);
|
|
56
|
+
nr = _mm512_sub_pd(_mm512_set1_pd(3.0), nr);
|
|
57
|
+
return _mm512_mul_pd(_mm512_mul_pd(_mm512_set1_pd(0.5), rsqrt), nr);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
#pragma region - Traditional Floats
|
|
61
|
+
|
|
62
|
+
NK_PUBLIC void nk_sqeuclidean_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
63
|
+
// Upcast to f64 for higher precision accumulation
|
|
64
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
65
|
+
__m256 a_f32x8, b_f32x8;
|
|
66
|
+
|
|
67
|
+
nk_sqeuclidean_f32_skylake_cycle:
|
|
68
|
+
if (n < 8) {
|
|
69
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
70
|
+
a_f32x8 = _mm256_maskz_loadu_ps(mask, a);
|
|
71
|
+
b_f32x8 = _mm256_maskz_loadu_ps(mask, b);
|
|
72
|
+
n = 0;
|
|
73
|
+
}
|
|
74
|
+
else {
|
|
75
|
+
a_f32x8 = _mm256_loadu_ps(a);
|
|
76
|
+
b_f32x8 = _mm256_loadu_ps(b);
|
|
77
|
+
a += 8, b += 8, n -= 8;
|
|
78
|
+
}
|
|
79
|
+
__m512d a_f64x8 = _mm512_cvtps_pd(a_f32x8);
|
|
80
|
+
__m512d b_f64x8 = _mm512_cvtps_pd(b_f32x8);
|
|
81
|
+
__m512d diff_f64x8 = _mm512_sub_pd(a_f64x8, b_f64x8);
|
|
82
|
+
sum_f64x8 = _mm512_fmadd_pd(diff_f64x8, diff_f64x8, sum_f64x8);
|
|
83
|
+
if (n) goto nk_sqeuclidean_f32_skylake_cycle;
|
|
84
|
+
|
|
85
|
+
*result = _mm512_reduce_add_pd(sum_f64x8);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
NK_PUBLIC void nk_euclidean_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
89
|
+
nk_sqeuclidean_f32_skylake(a, b, n, result);
|
|
90
|
+
*result = nk_f64_sqrt_haswell(*result);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
NK_INTERNAL nk_f64_t nk_angular_normalize_f64_skylake_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
|
|
94
|
+
|
|
95
|
+
// If both vectors have magnitude 0, the distance is 0.
|
|
96
|
+
if (a2 == 0 && b2 == 0) return 0;
|
|
97
|
+
// If any one of the vectors is 0, the square root of the product is 0,
|
|
98
|
+
// the division is illformed, and the result is 1.
|
|
99
|
+
else if (ab == 0) return 1;
|
|
100
|
+
|
|
101
|
+
// Design note: We use exact `_mm_sqrt_pd` instead of `_mm_rsqrt14_pd` approximation.
|
|
102
|
+
// The AVX-512 `_mm_rsqrt14_pd` has max relative error of 2⁻¹⁴ (~14 bits precision).
|
|
103
|
+
// Even with Newton-Raphson refinement (doubles precision to ~28 bits), this is
|
|
104
|
+
// insufficient for f64's 52-bit mantissa, causing ULP errors in the tens of millions.
|
|
105
|
+
// The `_mm_sqrt_pd` instruction provides full f64 precision.
|
|
106
|
+
//
|
|
107
|
+
// Precision comparison for 1536-dimensional vectors:
|
|
108
|
+
// DType rsqrt14+NR Error Exact sqrt Error
|
|
109
|
+
// float64 1.35e-11 ± 1.85e-11 ~0 (2 ULP max)
|
|
110
|
+
//
|
|
111
|
+
// https://web.archive.org/web/20210208132927/http://assemblyrequired.crashworks.org/timing-square-root/
|
|
112
|
+
__m128d squares_f64x2 = _mm_set_pd(a2, b2);
|
|
113
|
+
__m128d sqrts_f64x2 = _mm_sqrt_pd(squares_f64x2);
|
|
114
|
+
nk_f64_t a_sqrt = _mm_cvtsd_f64(_mm_unpackhi_pd(sqrts_f64x2, sqrts_f64x2));
|
|
115
|
+
nk_f64_t b_sqrt = _mm_cvtsd_f64(sqrts_f64x2);
|
|
116
|
+
nk_f64_t result = 1 - ab / (a_sqrt * b_sqrt);
|
|
117
|
+
return result > 0 ? result : 0;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
NK_PUBLIC void nk_angular_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
121
|
+
// Upcast to f64 for higher precision accumulation
|
|
122
|
+
__m512d dot_f64x8 = _mm512_setzero_pd();
|
|
123
|
+
__m512d a_norm_sq_f64x8 = _mm512_setzero_pd();
|
|
124
|
+
__m512d b_norm_sq_f64x8 = _mm512_setzero_pd();
|
|
125
|
+
__m256 a_f32x8, b_f32x8;
|
|
126
|
+
|
|
127
|
+
nk_angular_f32_skylake_cycle:
|
|
128
|
+
if (n < 8) {
|
|
129
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
130
|
+
a_f32x8 = _mm256_maskz_loadu_ps(mask, a);
|
|
131
|
+
b_f32x8 = _mm256_maskz_loadu_ps(mask, b);
|
|
132
|
+
n = 0;
|
|
133
|
+
}
|
|
134
|
+
else {
|
|
135
|
+
a_f32x8 = _mm256_loadu_ps(a);
|
|
136
|
+
b_f32x8 = _mm256_loadu_ps(b);
|
|
137
|
+
a += 8, b += 8, n -= 8;
|
|
138
|
+
}
|
|
139
|
+
__m512d a_f64x8 = _mm512_cvtps_pd(a_f32x8);
|
|
140
|
+
__m512d b_f64x8 = _mm512_cvtps_pd(b_f32x8);
|
|
141
|
+
dot_f64x8 = _mm512_fmadd_pd(a_f64x8, b_f64x8, dot_f64x8);
|
|
142
|
+
a_norm_sq_f64x8 = _mm512_fmadd_pd(a_f64x8, a_f64x8, a_norm_sq_f64x8);
|
|
143
|
+
b_norm_sq_f64x8 = _mm512_fmadd_pd(b_f64x8, b_f64x8, b_norm_sq_f64x8);
|
|
144
|
+
if (n) goto nk_angular_f32_skylake_cycle;
|
|
145
|
+
|
|
146
|
+
nk_f64_t dot_f64 = _mm512_reduce_add_pd(dot_f64x8);
|
|
147
|
+
nk_f64_t a_norm_sq_f64 = _mm512_reduce_add_pd(a_norm_sq_f64x8);
|
|
148
|
+
nk_f64_t b_norm_sq_f64 = _mm512_reduce_add_pd(b_norm_sq_f64x8);
|
|
149
|
+
*result = nk_angular_normalize_f64_skylake_(dot_f64, a_norm_sq_f64, b_norm_sq_f64);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
NK_PUBLIC void nk_sqeuclidean_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
153
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
154
|
+
__m512d a_f64x8, b_f64x8;
|
|
155
|
+
|
|
156
|
+
nk_sqeuclidean_f64_skylake_cycle:
|
|
157
|
+
if (n < 8) {
|
|
158
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
159
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a);
|
|
160
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, b);
|
|
161
|
+
n = 0;
|
|
162
|
+
}
|
|
163
|
+
else {
|
|
164
|
+
a_f64x8 = _mm512_loadu_pd(a);
|
|
165
|
+
b_f64x8 = _mm512_loadu_pd(b);
|
|
166
|
+
a += 8, b += 8, n -= 8;
|
|
167
|
+
}
|
|
168
|
+
__m512d diff_f64x8 = _mm512_sub_pd(a_f64x8, b_f64x8);
|
|
169
|
+
sum_f64x8 = _mm512_fmadd_pd(diff_f64x8, diff_f64x8, sum_f64x8);
|
|
170
|
+
if (n) goto nk_sqeuclidean_f64_skylake_cycle;
|
|
171
|
+
|
|
172
|
+
*result = _mm512_reduce_add_pd(sum_f64x8);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
NK_PUBLIC void nk_euclidean_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
176
|
+
nk_sqeuclidean_f64_skylake(a, b, n, result);
|
|
177
|
+
*result = nk_f64_sqrt_haswell(*result);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
NK_PUBLIC void nk_angular_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
181
|
+
// Dot2 (Ogita-Rump-Oishi 2005) for cross-product a × b only - it may have cancellation.
|
|
182
|
+
// Self-products ‖a‖² and ‖b‖² use simple FMA - all terms are non-negative, no cancellation.
|
|
183
|
+
__m512d dot_sum_f64x8 = _mm512_setzero_pd();
|
|
184
|
+
__m512d dot_compensation_f64x8 = _mm512_setzero_pd();
|
|
185
|
+
__m512d a_norm_sq_f64x8 = _mm512_setzero_pd();
|
|
186
|
+
__m512d b_norm_sq_f64x8 = _mm512_setzero_pd();
|
|
187
|
+
__m512d a_f64x8, b_f64x8;
|
|
188
|
+
|
|
189
|
+
nk_angular_f64_skylake_cycle:
|
|
190
|
+
if (n < 8) {
|
|
191
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n);
|
|
192
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a);
|
|
193
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, b);
|
|
194
|
+
n = 0;
|
|
195
|
+
}
|
|
196
|
+
else {
|
|
197
|
+
a_f64x8 = _mm512_loadu_pd(a);
|
|
198
|
+
b_f64x8 = _mm512_loadu_pd(b);
|
|
199
|
+
a += 8, b += 8, n -= 8;
|
|
200
|
+
}
|
|
201
|
+
// TwoProd for cross-product: product = a * b, error = fma(a, b, -product)
|
|
202
|
+
__m512d x_f64x8 = _mm512_mul_pd(a_f64x8, b_f64x8);
|
|
203
|
+
__m512d product_error_f64x8 = _mm512_fmsub_pd(a_f64x8, b_f64x8, x_f64x8);
|
|
204
|
+
// Neumaier TwoSum: t = sum + x, with masked error recovery
|
|
205
|
+
__m512d tentative_sum_f64x8 = _mm512_add_pd(dot_sum_f64x8, x_f64x8);
|
|
206
|
+
__m512d abs_sum_f64x8 = _mm512_abs_pd(dot_sum_f64x8);
|
|
207
|
+
__m512d abs_x_f64x8 = _mm512_abs_pd(x_f64x8);
|
|
208
|
+
__mmask8 sum_ge_x = _mm512_cmp_pd_mask(abs_sum_f64x8, abs_x_f64x8, _CMP_GE_OQ);
|
|
209
|
+
// z = t - larger, error = smaller - z
|
|
210
|
+
__m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, x_f64x8);
|
|
211
|
+
virtual_addend_f64x8 = _mm512_mask_sub_pd(virtual_addend_f64x8, sum_ge_x, tentative_sum_f64x8, dot_sum_f64x8);
|
|
212
|
+
__m512d sum_error_f64x8 = _mm512_sub_pd(dot_sum_f64x8, virtual_addend_f64x8);
|
|
213
|
+
sum_error_f64x8 = _mm512_mask_sub_pd(sum_error_f64x8, sum_ge_x, x_f64x8, virtual_addend_f64x8);
|
|
214
|
+
dot_sum_f64x8 = tentative_sum_f64x8;
|
|
215
|
+
dot_compensation_f64x8 = _mm512_add_pd(dot_compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
216
|
+
// Simple FMA for self-products (no cancellation possible)
|
|
217
|
+
a_norm_sq_f64x8 = _mm512_fmadd_pd(a_f64x8, a_f64x8, a_norm_sq_f64x8);
|
|
218
|
+
b_norm_sq_f64x8 = _mm512_fmadd_pd(b_f64x8, b_f64x8, b_norm_sq_f64x8);
|
|
219
|
+
if (n) goto nk_angular_f64_skylake_cycle;
|
|
220
|
+
|
|
221
|
+
nk_f64_t dot_product_f64 = nk_dot_stable_sum_f64x8_skylake_(dot_sum_f64x8, dot_compensation_f64x8);
|
|
222
|
+
nk_f64_t a_norm_sq_f64 = _mm512_reduce_add_pd(a_norm_sq_f64x8);
|
|
223
|
+
nk_f64_t b_norm_sq_f64 = _mm512_reduce_add_pd(b_norm_sq_f64x8);
|
|
224
|
+
*result = nk_angular_normalize_f64_skylake_(dot_product_f64, a_norm_sq_f64, b_norm_sq_f64);
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
/** @brief Angular from_dot for native f64: 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs. */
|
|
228
|
+
NK_INTERNAL void nk_angular_f64x4_from_dot_skylake_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
229
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
230
|
+
__m256d dots_f64x4 = dots.ymm_pd;
|
|
231
|
+
__m256d query_sumsq_f64x4 = _mm256_set1_pd(query_sumsq);
|
|
232
|
+
__m256d products_f64x4 = _mm256_mul_pd(query_sumsq_f64x4, target_sumsqs.ymm_pd);
|
|
233
|
+
__m256d sqrt_products_f64x4 = _mm256_sqrt_pd(products_f64x4);
|
|
234
|
+
__m256d normalized_f64x4 = _mm256_div_pd(dots_f64x4, sqrt_products_f64x4);
|
|
235
|
+
__m256d ones_f64x4 = _mm256_set1_pd(1.0);
|
|
236
|
+
__m256d angular_f64x4 = _mm256_sub_pd(ones_f64x4, normalized_f64x4);
|
|
237
|
+
results->ymm_pd = _mm256_max_pd(angular_f64x4, _mm256_setzero_pd());
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
/** @brief Euclidean from_dot for native f64: √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs. */
|
|
241
|
+
NK_INTERNAL void nk_euclidean_f64x4_from_dot_skylake_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
242
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
243
|
+
__m256d dots_f64x4 = dots.ymm_pd;
|
|
244
|
+
__m256d query_sumsq_f64x4 = _mm256_set1_pd(query_sumsq);
|
|
245
|
+
__m256d two_f64x4 = _mm256_set1_pd(2.0);
|
|
246
|
+
__m256d sum_sq_f64x4 = _mm256_add_pd(query_sumsq_f64x4, target_sumsqs.ymm_pd);
|
|
247
|
+
__m256d dist_sq_f64x4 = _mm256_fnmadd_pd(two_f64x4, dots_f64x4, sum_sq_f64x4);
|
|
248
|
+
__m256d zeros_f64x4 = _mm256_setzero_pd();
|
|
249
|
+
__m256d clamped_f64x4 = _mm256_max_pd(dist_sq_f64x4, zeros_f64x4);
|
|
250
|
+
results->ymm_pd = _mm256_sqrt_pd(clamped_f64x4);
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
/** @brief Angular from_dot: f32 dots upcast to f64 for precision. Output via nk_b128_vec_t (f32). */
|
|
254
|
+
NK_INTERNAL void nk_angular_through_f64_from_dot_skylake_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
255
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
256
|
+
__m128 dots_f32x4 = dots.xmm_ps;
|
|
257
|
+
__m256d dots_f64x4 = _mm256_cvtps_pd(dots_f32x4);
|
|
258
|
+
__m256d query_sumsq_f64x4 = _mm256_set1_pd((nk_f64_t)query_sumsq);
|
|
259
|
+
__m256d target_sumsqs_f64x4 = _mm256_cvtps_pd(target_sumsqs.xmm_ps);
|
|
260
|
+
__m256d products_f64x4 = _mm256_mul_pd(query_sumsq_f64x4, target_sumsqs_f64x4);
|
|
261
|
+
__m256d sqrt_products_f64x4 = _mm256_sqrt_pd(products_f64x4);
|
|
262
|
+
__m256d normalized_f64x4 = _mm256_div_pd(dots_f64x4, sqrt_products_f64x4);
|
|
263
|
+
__m256d ones_f64x4 = _mm256_set1_pd(1.0);
|
|
264
|
+
__m256d angular_f64x4 = _mm256_sub_pd(ones_f64x4, normalized_f64x4);
|
|
265
|
+
angular_f64x4 = _mm256_max_pd(angular_f64x4, _mm256_setzero_pd());
|
|
266
|
+
results->xmm_ps = _mm256_cvtpd_ps(angular_f64x4);
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
/** @brief Euclidean from_dot: f32 dots upcast to f64 for precision. Output via nk_b128_vec_t (f32). */
|
|
270
|
+
NK_INTERNAL void nk_euclidean_through_f64_from_dot_skylake_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
271
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
272
|
+
__m128 dots_f32x4 = dots.xmm_ps;
|
|
273
|
+
__m256d dots_f64x4 = _mm256_cvtps_pd(dots_f32x4);
|
|
274
|
+
__m256d query_sumsq_f64x4 = _mm256_set1_pd((nk_f64_t)query_sumsq);
|
|
275
|
+
__m256d target_sumsqs_f64x4 = _mm256_cvtps_pd(target_sumsqs.xmm_ps);
|
|
276
|
+
__m256d two_f64x4 = _mm256_set1_pd(2.0);
|
|
277
|
+
__m256d sum_sq_f64x4 = _mm256_add_pd(query_sumsq_f64x4, target_sumsqs_f64x4);
|
|
278
|
+
__m256d dist_sq_f64x4 = _mm256_fnmadd_pd(two_f64x4, dots_f64x4, sum_sq_f64x4);
|
|
279
|
+
__m256d zeros_f64x4 = _mm256_setzero_pd();
|
|
280
|
+
__m256d clamped_f64x4 = _mm256_max_pd(dist_sq_f64x4, zeros_f64x4);
|
|
281
|
+
__m256d dist_f64x4 = _mm256_sqrt_pd(clamped_f64x4);
|
|
282
|
+
results->xmm_ps = _mm256_cvtpd_ps(dist_f64x4);
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
#pragma endregion - Traditional Floats
|
|
286
|
+
#pragma region - Smaller Floats
|
|
287
|
+
|
|
288
|
+
NK_PUBLIC void nk_sqeuclidean_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
289
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
290
|
+
__m256i a_f16x16, b_f16x16;
|
|
291
|
+
|
|
292
|
+
nk_sqeuclidean_f16_skylake_cycle:
|
|
293
|
+
if (n < 16) {
|
|
294
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
295
|
+
a_f16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
296
|
+
b_f16x16 = _mm256_maskz_loadu_epi16(mask, b);
|
|
297
|
+
n = 0;
|
|
298
|
+
}
|
|
299
|
+
else {
|
|
300
|
+
a_f16x16 = _mm256_loadu_si256((__m256i const *)a);
|
|
301
|
+
b_f16x16 = _mm256_loadu_si256((__m256i const *)b);
|
|
302
|
+
a += 16, b += 16, n -= 16;
|
|
303
|
+
}
|
|
304
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
305
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
306
|
+
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
307
|
+
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
308
|
+
if (n) goto nk_sqeuclidean_f16_skylake_cycle;
|
|
309
|
+
|
|
310
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
NK_PUBLIC void nk_euclidean_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
314
|
+
nk_sqeuclidean_f16_skylake(a, b, n, result);
|
|
315
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
NK_PUBLIC void nk_angular_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
319
|
+
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
320
|
+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
321
|
+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
322
|
+
__m256i a_f16x16, b_f16x16;
|
|
323
|
+
|
|
324
|
+
nk_angular_f16_skylake_cycle:
|
|
325
|
+
if (n < 16) {
|
|
326
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
327
|
+
a_f16x16 = _mm256_maskz_loadu_epi16(mask, a);
|
|
328
|
+
b_f16x16 = _mm256_maskz_loadu_epi16(mask, b);
|
|
329
|
+
n = 0;
|
|
330
|
+
}
|
|
331
|
+
else {
|
|
332
|
+
a_f16x16 = _mm256_loadu_si256((__m256i const *)a);
|
|
333
|
+
b_f16x16 = _mm256_loadu_si256((__m256i const *)b);
|
|
334
|
+
a += 16, b += 16, n -= 16;
|
|
335
|
+
}
|
|
336
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
337
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
338
|
+
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
339
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
340
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
341
|
+
if (n) goto nk_angular_f16_skylake_cycle;
|
|
342
|
+
|
|
343
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
344
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
345
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
346
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
350
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
351
|
+
__m128i a_e4m3x16, b_e4m3x16;
|
|
352
|
+
|
|
353
|
+
nk_sqeuclidean_e4m3_skylake_cycle:
|
|
354
|
+
if (n < 16) {
|
|
355
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
356
|
+
a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
357
|
+
b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
358
|
+
n = 0;
|
|
359
|
+
}
|
|
360
|
+
else {
|
|
361
|
+
a_e4m3x16 = _mm_loadu_si128((__m128i const *)a);
|
|
362
|
+
b_e4m3x16 = _mm_loadu_si128((__m128i const *)b);
|
|
363
|
+
a += 16, b += 16, n -= 16;
|
|
364
|
+
}
|
|
365
|
+
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
|
|
366
|
+
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
|
|
367
|
+
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
368
|
+
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
369
|
+
if (n) goto nk_sqeuclidean_e4m3_skylake_cycle;
|
|
370
|
+
|
|
371
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
NK_PUBLIC void nk_euclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
375
|
+
nk_sqeuclidean_e4m3_skylake(a, b, n, result);
|
|
376
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
NK_PUBLIC void nk_angular_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
380
|
+
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
381
|
+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
382
|
+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
383
|
+
__m128i a_e4m3x16, b_e4m3x16;
|
|
384
|
+
|
|
385
|
+
nk_angular_e4m3_skylake_cycle:
|
|
386
|
+
if (n < 16) {
|
|
387
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
388
|
+
a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
389
|
+
b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
390
|
+
n = 0;
|
|
391
|
+
}
|
|
392
|
+
else {
|
|
393
|
+
a_e4m3x16 = _mm_loadu_si128((__m128i const *)a);
|
|
394
|
+
b_e4m3x16 = _mm_loadu_si128((__m128i const *)b);
|
|
395
|
+
a += 16, b += 16, n -= 16;
|
|
396
|
+
}
|
|
397
|
+
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3x16);
|
|
398
|
+
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3x16);
|
|
399
|
+
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
400
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
401
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
402
|
+
if (n) goto nk_angular_e4m3_skylake_cycle;
|
|
403
|
+
|
|
404
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
405
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
406
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
407
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
411
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
412
|
+
__m128i a_e5m2x16, b_e5m2x16;
|
|
413
|
+
|
|
414
|
+
nk_sqeuclidean_e5m2_skylake_cycle:
|
|
415
|
+
if (n < 16) {
|
|
416
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
417
|
+
a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
418
|
+
b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
419
|
+
n = 0;
|
|
420
|
+
}
|
|
421
|
+
else {
|
|
422
|
+
a_e5m2x16 = _mm_loadu_si128((__m128i const *)a);
|
|
423
|
+
b_e5m2x16 = _mm_loadu_si128((__m128i const *)b);
|
|
424
|
+
a += 16, b += 16, n -= 16;
|
|
425
|
+
}
|
|
426
|
+
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
|
|
427
|
+
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
|
|
428
|
+
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
429
|
+
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
430
|
+
if (n) goto nk_sqeuclidean_e5m2_skylake_cycle;
|
|
431
|
+
|
|
432
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
NK_PUBLIC void nk_euclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
436
|
+
nk_sqeuclidean_e5m2_skylake(a, b, n, result);
|
|
437
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
NK_PUBLIC void nk_angular_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
441
|
+
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
442
|
+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
443
|
+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
444
|
+
__m128i a_e5m2x16, b_e5m2x16;
|
|
445
|
+
|
|
446
|
+
nk_angular_e5m2_skylake_cycle:
|
|
447
|
+
if (n < 16) {
|
|
448
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
449
|
+
a_e5m2x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
450
|
+
b_e5m2x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
451
|
+
n = 0;
|
|
452
|
+
}
|
|
453
|
+
else {
|
|
454
|
+
a_e5m2x16 = _mm_loadu_si128((__m128i const *)a);
|
|
455
|
+
b_e5m2x16 = _mm_loadu_si128((__m128i const *)b);
|
|
456
|
+
a += 16, b += 16, n -= 16;
|
|
457
|
+
}
|
|
458
|
+
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2x16);
|
|
459
|
+
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2x16);
|
|
460
|
+
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
461
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
462
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
463
|
+
if (n) goto nk_angular_e5m2_skylake_cycle;
|
|
464
|
+
|
|
465
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
466
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
467
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
468
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
472
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
473
|
+
__m128i a_e2m3x16, b_e2m3x16;
|
|
474
|
+
|
|
475
|
+
nk_sqeuclidean_e2m3_skylake_cycle:
|
|
476
|
+
if (n < 16) {
|
|
477
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
478
|
+
a_e2m3x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
479
|
+
b_e2m3x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
480
|
+
n = 0;
|
|
481
|
+
}
|
|
482
|
+
else {
|
|
483
|
+
a_e2m3x16 = _mm_loadu_si128((__m128i const *)a);
|
|
484
|
+
b_e2m3x16 = _mm_loadu_si128((__m128i const *)b);
|
|
485
|
+
a += 16, b += 16, n -= 16;
|
|
486
|
+
}
|
|
487
|
+
__m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3x16);
|
|
488
|
+
__m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3x16);
|
|
489
|
+
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
490
|
+
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
491
|
+
if (n) goto nk_sqeuclidean_e2m3_skylake_cycle;
|
|
492
|
+
|
|
493
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
NK_PUBLIC void nk_euclidean_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
497
|
+
nk_sqeuclidean_e2m3_skylake(a, b, n, result);
|
|
498
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
NK_PUBLIC void nk_angular_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
502
|
+
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
503
|
+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
504
|
+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
505
|
+
__m128i a_e2m3x16, b_e2m3x16;
|
|
506
|
+
|
|
507
|
+
nk_angular_e2m3_skylake_cycle:
|
|
508
|
+
if (n < 16) {
|
|
509
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
510
|
+
a_e2m3x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
511
|
+
b_e2m3x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
512
|
+
n = 0;
|
|
513
|
+
}
|
|
514
|
+
else {
|
|
515
|
+
a_e2m3x16 = _mm_loadu_si128((__m128i const *)a);
|
|
516
|
+
b_e2m3x16 = _mm_loadu_si128((__m128i const *)b);
|
|
517
|
+
a += 16, b += 16, n -= 16;
|
|
518
|
+
}
|
|
519
|
+
__m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3x16);
|
|
520
|
+
__m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3x16);
|
|
521
|
+
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
522
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
523
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
524
|
+
if (n) goto nk_angular_e2m3_skylake_cycle;
|
|
525
|
+
|
|
526
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
527
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
528
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
529
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
533
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
534
|
+
__m128i a_e3m2x16, b_e3m2x16;
|
|
535
|
+
|
|
536
|
+
nk_sqeuclidean_e3m2_skylake_cycle:
|
|
537
|
+
if (n < 16) {
|
|
538
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
539
|
+
a_e3m2x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
540
|
+
b_e3m2x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
541
|
+
n = 0;
|
|
542
|
+
}
|
|
543
|
+
else {
|
|
544
|
+
a_e3m2x16 = _mm_loadu_si128((__m128i const *)a);
|
|
545
|
+
b_e3m2x16 = _mm_loadu_si128((__m128i const *)b);
|
|
546
|
+
a += 16, b += 16, n -= 16;
|
|
547
|
+
}
|
|
548
|
+
__m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2x16);
|
|
549
|
+
__m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2x16);
|
|
550
|
+
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
551
|
+
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
552
|
+
if (n) goto nk_sqeuclidean_e3m2_skylake_cycle;
|
|
553
|
+
|
|
554
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
NK_PUBLIC void nk_euclidean_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
558
|
+
nk_sqeuclidean_e3m2_skylake(a, b, n, result);
|
|
559
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
NK_PUBLIC void nk_angular_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
563
|
+
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
564
|
+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
565
|
+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
566
|
+
__m128i a_e3m2x16, b_e3m2x16;
|
|
567
|
+
|
|
568
|
+
nk_angular_e3m2_skylake_cycle:
|
|
569
|
+
if (n < 16) {
|
|
570
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
571
|
+
a_e3m2x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
572
|
+
b_e3m2x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
573
|
+
n = 0;
|
|
574
|
+
}
|
|
575
|
+
else {
|
|
576
|
+
a_e3m2x16 = _mm_loadu_si128((__m128i const *)a);
|
|
577
|
+
b_e3m2x16 = _mm_loadu_si128((__m128i const *)b);
|
|
578
|
+
a += 16, b += 16, n -= 16;
|
|
579
|
+
}
|
|
580
|
+
__m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2x16);
|
|
581
|
+
__m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2x16);
|
|
582
|
+
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
583
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
584
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
585
|
+
if (n) goto nk_angular_e3m2_skylake_cycle;
|
|
586
|
+
|
|
587
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
588
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
589
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
590
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
#if defined(__clang__)
|
|
594
|
+
#pragma clang attribute pop
|
|
595
|
+
#elif defined(__GNUC__)
|
|
596
|
+
#pragma GCC pop_options
|
|
597
|
+
#endif
|
|
598
|
+
|
|
599
|
+
#if defined(__cplusplus)
|
|
600
|
+
} // extern "C"
|
|
601
|
+
#endif
|
|
602
|
+
|
|
603
|
+
#pragma endregion - Smaller Floats
|
|
604
|
+
#endif // NK_TARGET_SKYLAKE
|
|
605
|
+
#endif // NK_TARGET_X86_
|
|
606
|
+
#endif // NK_SPATIAL_SKYLAKE_H
|