numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SWAR-accelerated Spatial Similarity Measures for SIMD-free CPUs.
|
|
3
|
+
* @file include/numkong/spatial/serial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPATIAL_SERIAL_H
|
|
10
|
+
#define NK_SPATIAL_SERIAL_H
|
|
11
|
+
|
|
12
|
+
#include "numkong/types.h"
|
|
13
|
+
#include "numkong/scalar/serial.h" // `nk_f32_rsqrt_serial`
|
|
14
|
+
#include "numkong/cast/serial.h"
|
|
15
|
+
#include "numkong/dot/serial.h" // `nk_dot_f64x2_state_serial_t`
|
|
16
|
+
|
|
17
|
+
#if defined(__cplusplus)
|
|
18
|
+
extern "C" {
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
/**
|
|
22
|
+
* @brief Macro for L2 squared distance with Neumaier compensated summation.
|
|
23
|
+
*
|
|
24
|
+
* Implements Neumaier's Kahan-Babuška variant to minimize floating-point rounding errors.
|
|
25
|
+
* Unlike Kahan, Neumaier handles the case where the term being added is larger than the
|
|
26
|
+
* running sum. Achieves O(1) error growth regardless of vector dimension.
|
|
27
|
+
*
|
|
28
|
+
* Performance vs Accuracy Tradeoff:
|
|
29
|
+
* - Adds ~30% overhead (3 extra FP operations per iteration) compared to naive summation
|
|
30
|
+
* - Reduces relative error from ~10⁻⁵ to ~10⁻⁷ at n=100K for f32
|
|
31
|
+
* - Benefits all floating-point types: f64, f32, f16, bf16
|
|
32
|
+
* - Integer types (i8) maintain perfect accuracy regardless
|
|
33
|
+
*
|
|
34
|
+
* Algorithm: For each term, compute t = sum + term, then:
|
|
35
|
+
* - If |sum| ≥ |term|: c += (sum − t) + term (lost low-order bits of term)
|
|
36
|
+
* - Else: c += (term − t) + sum (lost low-order bits of sum)
|
|
37
|
+
*
|
|
38
|
+
* @see Neumaier, A. (1974). "Rundungsfehleranalyse einiger Verfahren zur Summation endlicher Summen"
|
|
39
|
+
*/
|
|
40
|
+
#define nk_define_sqeuclidean_(input_type, accumulator_type, output_type, load_and_convert) \
|
|
41
|
+
NK_PUBLIC void nk_sqeuclidean_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
42
|
+
nk_size_t n, nk_##output_type##_t *result) { \
|
|
43
|
+
nk_##accumulator_type##_t sum = 0, compensation = 0, a_element, b_element; \
|
|
44
|
+
for (nk_size_t i = 0; i != n; ++i) { \
|
|
45
|
+
load_and_convert(a + i, &a_element); \
|
|
46
|
+
load_and_convert(b + i, &b_element); \
|
|
47
|
+
nk_##accumulator_type##_t diff = a_element - b_element; \
|
|
48
|
+
nk_##accumulator_type##_t term = diff * diff, t = sum + term; \
|
|
49
|
+
compensation += (nk_##accumulator_type##_abs_(sum) >= nk_##accumulator_type##_abs_(term)) \
|
|
50
|
+
? ((sum - t) + term) \
|
|
51
|
+
: ((term - t) + sum); \
|
|
52
|
+
sum = t; \
|
|
53
|
+
} \
|
|
54
|
+
*result = (nk_##output_type##_t)(sum + compensation); \
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
#define nk_define_euclidean_(input_type, accumulator_type, l2sq_output_type, output_type, load_and_convert, \
|
|
58
|
+
compute_sqrt) \
|
|
59
|
+
NK_PUBLIC void nk_euclidean_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
60
|
+
nk_size_t n, nk_##output_type##_t *result) { \
|
|
61
|
+
nk_##l2sq_output_type##_t distance_sq; \
|
|
62
|
+
nk_sqeuclidean_##input_type##_serial(a, b, n, &distance_sq); \
|
|
63
|
+
*result = compute_sqrt((nk_##output_type##_t)distance_sq); \
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
/**
|
|
67
|
+
* @brief Macro for cosine/angular distance with Neumaier compensated summation.
|
|
68
|
+
*
|
|
69
|
+
* Uses Neumaier summation for all three accumulators (dot_product, a_norm_sq, b_norm_sq).
|
|
70
|
+
* Achieves O(1) error growth regardless of vector dimension.
|
|
71
|
+
*
|
|
72
|
+
* @see nk_define_sqeuclidean_ for detailed documentation on Neumaier summation.
|
|
73
|
+
*/
|
|
74
|
+
#define nk_define_angular_(input_type, accumulator_type, output_type, load_and_convert, compute_rsqrt) \
|
|
75
|
+
NK_PUBLIC void nk_angular_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
76
|
+
nk_size_t n, nk_##output_type##_t *result) { \
|
|
77
|
+
nk_##accumulator_type##_t dot_sum = 0, a_sum = 0, b_sum = 0, a_element, b_element; \
|
|
78
|
+
nk_##accumulator_type##_t compensation_dot = 0, compensation_a = 0, compensation_b = 0; \
|
|
79
|
+
for (nk_size_t i = 0; i != n; ++i) { \
|
|
80
|
+
load_and_convert(a + i, &a_element); \
|
|
81
|
+
load_and_convert(b + i, &b_element); \
|
|
82
|
+
nk_##accumulator_type##_t term_dot = a_element * b_element, t_dot = dot_sum + term_dot; \
|
|
83
|
+
nk_##accumulator_type##_t term_a = a_element * a_element, t_a = a_sum + term_a; \
|
|
84
|
+
nk_##accumulator_type##_t term_b = b_element * b_element, t_b = b_sum + term_b; \
|
|
85
|
+
compensation_dot += (nk_##accumulator_type##_abs_(dot_sum) >= nk_##accumulator_type##_abs_(term_dot)) \
|
|
86
|
+
? ((dot_sum - t_dot) + term_dot) \
|
|
87
|
+
: ((term_dot - t_dot) + dot_sum); \
|
|
88
|
+
compensation_a += (nk_##accumulator_type##_abs_(a_sum) >= nk_##accumulator_type##_abs_(term_a)) \
|
|
89
|
+
? ((a_sum - t_a) + term_a) \
|
|
90
|
+
: ((term_a - t_a) + a_sum); \
|
|
91
|
+
compensation_b += (nk_##accumulator_type##_abs_(b_sum) >= nk_##accumulator_type##_abs_(term_b)) \
|
|
92
|
+
? ((b_sum - t_b) + term_b) \
|
|
93
|
+
: ((term_b - t_b) + b_sum); \
|
|
94
|
+
dot_sum = t_dot; \
|
|
95
|
+
a_sum = t_a; \
|
|
96
|
+
b_sum = t_b; \
|
|
97
|
+
} \
|
|
98
|
+
nk_##accumulator_type##_t dot_product = dot_sum + compensation_dot; \
|
|
99
|
+
nk_##accumulator_type##_t a_norm_sq = a_sum + compensation_a; \
|
|
100
|
+
nk_##accumulator_type##_t b_norm_sq = b_sum + compensation_b; \
|
|
101
|
+
if (a_norm_sq == 0 && b_norm_sq == 0) { *result = 0; } \
|
|
102
|
+
else if (dot_product == 0) { *result = 1; } \
|
|
103
|
+
else { \
|
|
104
|
+
nk_##output_type##_t unclipped_distance = 1 - dot_product * compute_rsqrt(a_norm_sq) * \
|
|
105
|
+
compute_rsqrt(b_norm_sq); \
|
|
106
|
+
*result = unclipped_distance > 0 ? unclipped_distance : 0; \
|
|
107
|
+
} \
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
nk_define_angular_(f64, f64, f64, nk_assign_from_to_, nk_f64_rsqrt_serial) // nk_angular_f64_serial
|
|
111
|
+
nk_define_sqeuclidean_(f64, f64, f64, nk_assign_from_to_) // nk_sqeuclidean_f64_serial
|
|
112
|
+
nk_define_euclidean_(f64, f64, f64, f64, nk_assign_from_to_, nk_f64_sqrt_serial) // nk_euclidean_f64_serial
|
|
113
|
+
|
|
114
|
+
nk_define_angular_(f32, f64, f64, nk_assign_from_to_, nk_f64_rsqrt_serial) // nk_angular_f32_serial
|
|
115
|
+
nk_define_sqeuclidean_(f32, f64, f64, nk_assign_from_to_) // nk_sqeuclidean_f32_serial
|
|
116
|
+
nk_define_euclidean_(f32, f64, f64, f64, nk_assign_from_to_, nk_f64_sqrt_serial) // nk_euclidean_f32_serial
|
|
117
|
+
|
|
118
|
+
nk_define_angular_(f16, f32, f32, nk_f16_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_f16_serial
|
|
119
|
+
nk_define_sqeuclidean_(f16, f32, f32, nk_f16_to_f32_serial) // nk_sqeuclidean_f16_serial
|
|
120
|
+
nk_define_euclidean_(f16, f32, f32, f32, nk_f16_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_f16_serial
|
|
121
|
+
|
|
122
|
+
nk_define_angular_(bf16, f32, f32, nk_bf16_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_bf16_serial
|
|
123
|
+
nk_define_sqeuclidean_(bf16, f32, f32, nk_bf16_to_f32_serial) // nk_sqeuclidean_bf16_serial
|
|
124
|
+
nk_define_euclidean_(bf16, f32, f32, f32, nk_bf16_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_bf16_serial
|
|
125
|
+
|
|
126
|
+
nk_define_angular_(e4m3, f32, f32, nk_e4m3_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_e4m3_serial
|
|
127
|
+
nk_define_sqeuclidean_(e4m3, f32, f32, nk_e4m3_to_f32_serial) // nk_sqeuclidean_e4m3_serial
|
|
128
|
+
nk_define_euclidean_(e4m3, f32, f32, f32, nk_e4m3_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_e4m3_serial
|
|
129
|
+
|
|
130
|
+
nk_define_angular_(e5m2, f32, f32, nk_e5m2_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_e5m2_serial
|
|
131
|
+
nk_define_sqeuclidean_(e5m2, f32, f32, nk_e5m2_to_f32_serial) // nk_sqeuclidean_e5m2_serial
|
|
132
|
+
nk_define_euclidean_(e5m2, f32, f32, f32, nk_e5m2_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_e5m2_serial
|
|
133
|
+
|
|
134
|
+
nk_define_angular_(e2m3, f32, f32, nk_e2m3_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_e2m3_serial
|
|
135
|
+
nk_define_sqeuclidean_(e2m3, f32, f32, nk_e2m3_to_f32_serial) // nk_sqeuclidean_e2m3_serial
|
|
136
|
+
nk_define_euclidean_(e2m3, f32, f32, f32, nk_e2m3_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_e2m3_serial
|
|
137
|
+
|
|
138
|
+
nk_define_angular_(e3m2, f32, f32, nk_e3m2_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_e3m2_serial
|
|
139
|
+
nk_define_sqeuclidean_(e3m2, f32, f32, nk_e3m2_to_f32_serial) // nk_sqeuclidean_e3m2_serial
|
|
140
|
+
nk_define_euclidean_(e3m2, f32, f32, f32, nk_e3m2_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_e3m2_serial
|
|
141
|
+
|
|
142
|
+
nk_define_angular_(i8, i32, f32, nk_assign_from_to_, nk_f32_rsqrt_serial) // nk_angular_i8_serial
|
|
143
|
+
nk_define_sqeuclidean_(i8, i32, u32, nk_assign_from_to_) // nk_sqeuclidean_i8_serial
|
|
144
|
+
nk_define_euclidean_(i8, i32, u32, f32, nk_assign_from_to_, nk_f32_sqrt_serial) // nk_euclidean_i8_serial
|
|
145
|
+
|
|
146
|
+
nk_define_angular_(u8, u32, f32, nk_assign_from_to_, nk_f32_rsqrt_serial) // nk_angular_u8_serial
|
|
147
|
+
nk_define_sqeuclidean_(u8, u32, u32, nk_assign_from_to_) // nk_sqeuclidean_u8_serial
|
|
148
|
+
nk_define_euclidean_(u8, u32, u32, f32, nk_assign_from_to_, nk_f32_sqrt_serial) // nk_euclidean_u8_serial
|
|
149
|
+
|
|
150
|
+
#undef nk_define_sqeuclidean_
|
|
151
|
+
#undef nk_define_euclidean_
|
|
152
|
+
#undef nk_define_angular_
|
|
153
|
+
|
|
154
|
+
NK_PUBLIC void nk_sqeuclidean_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
155
|
+
// i4 values are packed as nibbles: two 4-bit signed values per byte.
|
|
156
|
+
// Parameter `n` is the number of 4-bit values (dimensions), not bytes.
|
|
157
|
+
// Sign extension: (nibble ^ 8) - 8 maps [0,15] to [-8,7]
|
|
158
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
159
|
+
nk_size_t n_bytes = n / 2;
|
|
160
|
+
nk_i32_t sum = 0;
|
|
161
|
+
for (nk_size_t i = 0; i < n_bytes; ++i) {
|
|
162
|
+
nk_i32_t a_low = (nk_i32_t)nk_i4x2_low_(a[i]);
|
|
163
|
+
nk_i32_t b_low = (nk_i32_t)nk_i4x2_low_(b[i]);
|
|
164
|
+
nk_i32_t a_high = (nk_i32_t)nk_i4x2_high_(a[i]);
|
|
165
|
+
nk_i32_t b_high = (nk_i32_t)nk_i4x2_high_(b[i]);
|
|
166
|
+
nk_i32_t diff_low = a_low - b_low, diff_high = a_high - b_high;
|
|
167
|
+
sum += diff_low * diff_low + diff_high * diff_high;
|
|
168
|
+
}
|
|
169
|
+
*result = (nk_u32_t)sum;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
NK_PUBLIC void nk_euclidean_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
173
|
+
nk_u32_t distance_sq;
|
|
174
|
+
nk_sqeuclidean_i4_serial(a, b, n, &distance_sq);
|
|
175
|
+
*result = nk_f32_sqrt_serial((nk_f32_t)distance_sq);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
NK_PUBLIC void nk_angular_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
179
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
180
|
+
nk_size_t n_bytes = n / 2;
|
|
181
|
+
nk_i32_t dot_sum = 0, a_norm_sq = 0, b_norm_sq = 0;
|
|
182
|
+
for (nk_size_t i = 0; i < n_bytes; ++i) {
|
|
183
|
+
nk_i32_t a_low = (nk_i32_t)nk_i4x2_low_(a[i]);
|
|
184
|
+
nk_i32_t b_low = (nk_i32_t)nk_i4x2_low_(b[i]);
|
|
185
|
+
nk_i32_t a_high = (nk_i32_t)nk_i4x2_high_(a[i]);
|
|
186
|
+
nk_i32_t b_high = (nk_i32_t)nk_i4x2_high_(b[i]);
|
|
187
|
+
dot_sum += a_low * b_low + a_high * b_high;
|
|
188
|
+
a_norm_sq += a_low * a_low + a_high * a_high;
|
|
189
|
+
b_norm_sq += b_low * b_low + b_high * b_high;
|
|
190
|
+
}
|
|
191
|
+
if (a_norm_sq == 0 && b_norm_sq == 0) { *result = 0; }
|
|
192
|
+
else if (dot_sum == 0) { *result = 1; }
|
|
193
|
+
else {
|
|
194
|
+
nk_f32_t unclipped = 1.0f - (nk_f32_t)dot_sum * nk_f32_rsqrt_serial((nk_f32_t)a_norm_sq) *
|
|
195
|
+
nk_f32_rsqrt_serial((nk_f32_t)b_norm_sq);
|
|
196
|
+
*result = unclipped > 0 ? unclipped : 0;
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
NK_PUBLIC void nk_sqeuclidean_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
201
|
+
// u4 values are packed as nibbles: two 4-bit unsigned values per byte.
|
|
202
|
+
// Parameter `n` is the number of 4-bit values (dimensions), not bytes.
|
|
203
|
+
// No sign extension needed - values are in [0,15].
|
|
204
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
205
|
+
nk_size_t n_bytes = n / 2;
|
|
206
|
+
nk_u32_t sum = 0;
|
|
207
|
+
for (nk_size_t i = 0; i < n_bytes; ++i) {
|
|
208
|
+
nk_i32_t a_low = (nk_i32_t)nk_u4x2_low_(a[i]);
|
|
209
|
+
nk_i32_t b_low = (nk_i32_t)nk_u4x2_low_(b[i]);
|
|
210
|
+
nk_i32_t a_high = (nk_i32_t)nk_u4x2_high_(a[i]);
|
|
211
|
+
nk_i32_t b_high = (nk_i32_t)nk_u4x2_high_(b[i]);
|
|
212
|
+
nk_i32_t diff_low = a_low - b_low, diff_high = a_high - b_high;
|
|
213
|
+
sum += (nk_u32_t)(diff_low * diff_low + diff_high * diff_high);
|
|
214
|
+
}
|
|
215
|
+
*result = sum;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
NK_PUBLIC void nk_euclidean_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
219
|
+
nk_u32_t distance_sq;
|
|
220
|
+
nk_sqeuclidean_u4_serial(a, b, n, &distance_sq);
|
|
221
|
+
*result = nk_f32_sqrt_serial((nk_f32_t)distance_sq);
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
NK_PUBLIC void nk_angular_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
225
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
226
|
+
nk_size_t n_bytes = n / 2;
|
|
227
|
+
nk_u32_t dot_sum = 0, a_norm_sq = 0, b_norm_sq = 0;
|
|
228
|
+
for (nk_size_t i = 0; i < n_bytes; ++i) {
|
|
229
|
+
nk_u32_t a_low = (nk_u32_t)nk_u4x2_low_(a[i]);
|
|
230
|
+
nk_u32_t b_low = (nk_u32_t)nk_u4x2_low_(b[i]);
|
|
231
|
+
nk_u32_t a_high = (nk_u32_t)nk_u4x2_high_(a[i]);
|
|
232
|
+
nk_u32_t b_high = (nk_u32_t)nk_u4x2_high_(b[i]);
|
|
233
|
+
dot_sum += a_low * b_low + a_high * b_high;
|
|
234
|
+
a_norm_sq += a_low * a_low + a_high * a_high;
|
|
235
|
+
b_norm_sq += b_low * b_low + b_high * b_high;
|
|
236
|
+
}
|
|
237
|
+
if (a_norm_sq == 0 && b_norm_sq == 0) { *result = 0; }
|
|
238
|
+
else if (dot_sum == 0) { *result = 1; }
|
|
239
|
+
else {
|
|
240
|
+
nk_f32_t unclipped = 1.0f - (nk_f32_t)dot_sum * nk_f32_rsqrt_serial((nk_f32_t)a_norm_sq) *
|
|
241
|
+
nk_f32_rsqrt_serial((nk_f32_t)b_norm_sq);
|
|
242
|
+
*result = unclipped > 0 ? unclipped : 0;
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
/** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs (serial). */
|
|
247
|
+
NK_INTERNAL void nk_angular_through_f32_from_dot_serial_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
248
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
249
|
+
for (int i = 0; i < 4; ++i) {
|
|
250
|
+
nk_f32_t product = query_sumsq * target_sumsqs.f32s[i];
|
|
251
|
+
if (product > 0) {
|
|
252
|
+
nk_f32_t rsqrt_val = nk_f32_rsqrt_serial(product);
|
|
253
|
+
nk_f32_t normalized = dots.f32s[i] * rsqrt_val;
|
|
254
|
+
nk_f32_t result = 1.0f - normalized;
|
|
255
|
+
results->f32s[i] = result > 0 ? result : 0;
|
|
256
|
+
}
|
|
257
|
+
else { results->f32s[i] = (dots.f32s[i] == 0) ? 0.0f : 1.0f; }
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
/** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs (serial). */
|
|
262
|
+
NK_INTERNAL void nk_euclidean_through_f32_from_dot_serial_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
263
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
264
|
+
for (int i = 0; i < 4; ++i) {
|
|
265
|
+
nk_f32_t dist_sq = query_sumsq + target_sumsqs.f32s[i] - 2.0f * dots.f32s[i];
|
|
266
|
+
results->f32s[i] = dist_sq > 0 ? nk_f32_sqrt_serial(dist_sq) : 0.0f;
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
/** @brief Angular from_dot for f64 precision. */
|
|
271
|
+
NK_INTERNAL void nk_angular_through_f64_from_dot_serial_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
272
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
273
|
+
for (int i = 0; i < 4; ++i) {
|
|
274
|
+
nk_f64_t product = query_sumsq * target_sumsqs.f64s[i];
|
|
275
|
+
if (product > 0) {
|
|
276
|
+
nk_f64_t rsqrt_val = nk_f64_rsqrt_serial(product);
|
|
277
|
+
nk_f64_t normalized = dots.f64s[i] * rsqrt_val;
|
|
278
|
+
nk_f64_t result = 1.0 - normalized;
|
|
279
|
+
results->f64s[i] = result > 0 ? result : 0;
|
|
280
|
+
}
|
|
281
|
+
else { results->f64s[i] = (dots.f64s[i] == 0) ? 0.0 : 1.0; }
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
/** @brief Euclidean from_dot for f64 precision. */
|
|
286
|
+
NK_INTERNAL void nk_euclidean_through_f64_from_dot_serial_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
287
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
288
|
+
for (int i = 0; i < 4; ++i) {
|
|
289
|
+
nk_f64_t dist_sq = query_sumsq + target_sumsqs.f64s[i] - 2.0 * dots.f64s[i];
|
|
290
|
+
results->f64s[i] = dist_sq > 0 ? nk_f64_sqrt_serial(dist_sq) : 0.0;
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
/** @brief Angular from_dot for i32 accumulators: cast to f32, then same math as f32 variant. */
|
|
295
|
+
NK_INTERNAL void nk_angular_through_i32_from_dot_serial_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
296
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
297
|
+
for (int i = 0; i < 4; ++i) {
|
|
298
|
+
nk_f32_t product = (nk_f32_t)query_sumsq * (nk_f32_t)target_sumsqs.i32s[i];
|
|
299
|
+
if (product > 0) {
|
|
300
|
+
nk_f32_t rsqrt_val = nk_f32_rsqrt_serial(product);
|
|
301
|
+
nk_f32_t normalized = (nk_f32_t)dots.i32s[i] * rsqrt_val;
|
|
302
|
+
nk_f32_t result = 1.0f - normalized;
|
|
303
|
+
results->f32s[i] = result > 0 ? result : 0;
|
|
304
|
+
}
|
|
305
|
+
else { results->f32s[i] = (dots.i32s[i] == 0) ? 0.0f : 1.0f; }
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
/** @brief Euclidean from_dot for i32 accumulators: cast to f32, then same math as f32 variant. */
|
|
310
|
+
NK_INTERNAL void nk_euclidean_through_i32_from_dot_serial_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
311
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
312
|
+
for (int i = 0; i < 4; ++i) {
|
|
313
|
+
nk_f32_t dist_sq = (nk_f32_t)query_sumsq + (nk_f32_t)target_sumsqs.i32s[i] - 2.0f * (nk_f32_t)dots.i32s[i];
|
|
314
|
+
results->f32s[i] = dist_sq > 0 ? nk_f32_sqrt_serial(dist_sq) : 0.0f;
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
/** @brief Angular from_dot for u32 accumulators: cast to f32, then same math as f32 variant. */
|
|
319
|
+
NK_INTERNAL void nk_angular_through_u32_from_dot_serial_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
320
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
321
|
+
for (int i = 0; i < 4; ++i) {
|
|
322
|
+
nk_f32_t product = (nk_f32_t)query_sumsq * (nk_f32_t)target_sumsqs.u32s[i];
|
|
323
|
+
if (product > 0) {
|
|
324
|
+
nk_f32_t rsqrt_val = nk_f32_rsqrt_serial(product);
|
|
325
|
+
nk_f32_t normalized = (nk_f32_t)dots.u32s[i] * rsqrt_val;
|
|
326
|
+
nk_f32_t result = 1.0f - normalized;
|
|
327
|
+
results->f32s[i] = result > 0 ? result : 0;
|
|
328
|
+
}
|
|
329
|
+
else { results->f32s[i] = (dots.u32s[i] == 0) ? 0.0f : 1.0f; }
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
/** @brief Euclidean from_dot for u32 accumulators: cast to f32, then same math as f32 variant. */
|
|
334
|
+
NK_INTERNAL void nk_euclidean_through_u32_from_dot_serial_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
335
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
336
|
+
for (int i = 0; i < 4; ++i) {
|
|
337
|
+
nk_f32_t dist_sq = (nk_f32_t)query_sumsq + (nk_f32_t)target_sumsqs.u32s[i] - 2.0f * (nk_f32_t)dots.u32s[i];
|
|
338
|
+
results->f32s[i] = dist_sq > 0 ? nk_f32_sqrt_serial(dist_sq) : 0.0f;
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
#if defined(__cplusplus)
|
|
343
|
+
} // extern "C"
|
|
344
|
+
#endif
|
|
345
|
+
|
|
346
|
+
#endif // NK_SPATIAL_SERIAL_H
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for Sierra Forest.
|
|
3
|
+
* @file include/numkong/spatial/sierra.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_sierra_instructions AVXVNNIINT8 Instructions Performance
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Sierra Forest
|
|
12
|
+
* _mm256_dpbssds_epi32 VPDPBSSDS (YMM, YMM, YMM) 4cy @ p05
|
|
13
|
+
* _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) 4cy @ p05
|
|
14
|
+
* _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) 4cy @ p05
|
|
15
|
+
* _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0
|
|
16
|
+
* _mm_sqrt_ss VSQRTSS (XMM, XMM, XMM) 12cy @ p0
|
|
17
|
+
*
|
|
18
|
+
* Sierra Forest (AVXVNNIINT8) provides native signed x signed and unsigned x unsigned
|
|
19
|
+
* dot products, eliminating the need for algebraic corrections required on Alder Lake.
|
|
20
|
+
* This gives ~2.6x throughput over Haswell and ~1.3x over Alder for spatial kernels.
|
|
21
|
+
*/
|
|
22
|
+
#ifndef NK_SPATIAL_SIERRA_H
|
|
23
|
+
#define NK_SPATIAL_SIERRA_H
|
|
24
|
+
|
|
25
|
+
#if NK_TARGET_X86_
|
|
26
|
+
#if NK_TARGET_SIERRA
|
|
27
|
+
|
|
28
|
+
#include "numkong/types.h"
|
|
29
|
+
#include "numkong/scalar/haswell.h" // `nk_f32_sqrt_haswell`
|
|
30
|
+
#include "numkong/reduce/haswell.h" // `nk_reduce_add_i32x8_haswell_`
|
|
31
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b8x32_serial_`
|
|
32
|
+
|
|
33
|
+
#if defined(__cplusplus)
|
|
34
|
+
extern "C" {
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
#if defined(__clang__)
|
|
38
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni,avxvnniint8"))), apply_to = function)
|
|
39
|
+
#elif defined(__GNUC__)
|
|
40
|
+
#pragma GCC push_options
|
|
41
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni", "avxvnniint8")
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
NK_PUBLIC void nk_angular_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
45
|
+
|
|
46
|
+
__m256i dot_product_i32x8 = _mm256_setzero_si256();
|
|
47
|
+
__m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
48
|
+
__m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
49
|
+
|
|
50
|
+
nk_size_t i = 0;
|
|
51
|
+
for (; i + 32 <= n; i += 32) {
|
|
52
|
+
__m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
53
|
+
__m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
54
|
+
dot_product_i32x8 = _mm256_dpbssds_epi32(dot_product_i32x8, a_i8x32, b_i8x32);
|
|
55
|
+
a_norm_sq_i32x8 = _mm256_dpbssds_epi32(a_norm_sq_i32x8, a_i8x32, a_i8x32);
|
|
56
|
+
b_norm_sq_i32x8 = _mm256_dpbssds_epi32(b_norm_sq_i32x8, b_i8x32, b_i8x32);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8);
|
|
60
|
+
nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8);
|
|
61
|
+
nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8);
|
|
62
|
+
|
|
63
|
+
for (; i < n; ++i) {
|
|
64
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
65
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
66
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
67
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
NK_PUBLIC void nk_sqeuclidean_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
74
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b) using dpbssds (signed x signed)
|
|
75
|
+
|
|
76
|
+
__m256i dot_product_i32x8 = _mm256_setzero_si256();
|
|
77
|
+
__m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
78
|
+
__m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
|
|
79
|
+
|
|
80
|
+
nk_size_t i = 0;
|
|
81
|
+
for (; i + 32 <= n; i += 32) {
|
|
82
|
+
__m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
83
|
+
__m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
84
|
+
dot_product_i32x8 = _mm256_dpbssds_epi32(dot_product_i32x8, a_i8x32, b_i8x32);
|
|
85
|
+
a_norm_sq_i32x8 = _mm256_dpbssds_epi32(a_norm_sq_i32x8, a_i8x32, a_i8x32);
|
|
86
|
+
b_norm_sq_i32x8 = _mm256_dpbssds_epi32(b_norm_sq_i32x8, b_i8x32, b_i8x32);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8);
|
|
90
|
+
nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8);
|
|
91
|
+
nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8);
|
|
92
|
+
|
|
93
|
+
for (; i < n; ++i) {
|
|
94
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
95
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
96
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
97
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
*result = (nk_u32_t)(a_norm_sq_i32 + b_norm_sq_i32 - 2 * dot_product_i32);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
NK_PUBLIC void nk_euclidean_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
104
|
+
nk_u32_t distance_sq_u32;
|
|
105
|
+
nk_sqeuclidean_i8_sierra(a, b, n, &distance_sq_u32);
|
|
106
|
+
*result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
NK_PUBLIC void nk_angular_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
110
|
+
|
|
111
|
+
__m256i dot_product_u32x8 = _mm256_setzero_si256();
|
|
112
|
+
__m256i a_norm_sq_u32x8 = _mm256_setzero_si256();
|
|
113
|
+
__m256i b_norm_sq_u32x8 = _mm256_setzero_si256();
|
|
114
|
+
|
|
115
|
+
nk_size_t i = 0;
|
|
116
|
+
for (; i + 32 <= n; i += 32) {
|
|
117
|
+
__m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
118
|
+
__m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
119
|
+
dot_product_u32x8 = _mm256_dpbuud_epi32(dot_product_u32x8, a_u8x32, b_u8x32);
|
|
120
|
+
a_norm_sq_u32x8 = _mm256_dpbuud_epi32(a_norm_sq_u32x8, a_u8x32, a_u8x32);
|
|
121
|
+
b_norm_sq_u32x8 = _mm256_dpbuud_epi32(b_norm_sq_u32x8, b_u8x32, b_u8x32);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_u32x8);
|
|
125
|
+
nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_u32x8);
|
|
126
|
+
nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_u32x8);
|
|
127
|
+
|
|
128
|
+
for (; i < n; ++i) {
|
|
129
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
130
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
131
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
132
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
NK_PUBLIC void nk_sqeuclidean_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
139
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b) using dpbuud (unsigned x unsigned)
|
|
140
|
+
|
|
141
|
+
__m256i dot_product_u32x8 = _mm256_setzero_si256();
|
|
142
|
+
__m256i a_norm_sq_u32x8 = _mm256_setzero_si256();
|
|
143
|
+
__m256i b_norm_sq_u32x8 = _mm256_setzero_si256();
|
|
144
|
+
|
|
145
|
+
nk_size_t i = 0;
|
|
146
|
+
for (; i + 32 <= n; i += 32) {
|
|
147
|
+
__m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
|
|
148
|
+
__m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
|
|
149
|
+
dot_product_u32x8 = _mm256_dpbuud_epi32(dot_product_u32x8, a_u8x32, b_u8x32);
|
|
150
|
+
a_norm_sq_u32x8 = _mm256_dpbuud_epi32(a_norm_sq_u32x8, a_u8x32, a_u8x32);
|
|
151
|
+
b_norm_sq_u32x8 = _mm256_dpbuud_epi32(b_norm_sq_u32x8, b_u8x32, b_u8x32);
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_u32x8);
|
|
155
|
+
nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_u32x8);
|
|
156
|
+
nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_u32x8);
|
|
157
|
+
|
|
158
|
+
for (; i < n; ++i) {
|
|
159
|
+
nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
|
|
160
|
+
dot_product_i32 += a_element_i32 * b_element_i32;
|
|
161
|
+
a_norm_sq_i32 += a_element_i32 * a_element_i32;
|
|
162
|
+
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
*result = (nk_u32_t)(a_norm_sq_i32 + b_norm_sq_i32 - 2 * dot_product_i32);
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
NK_PUBLIC void nk_euclidean_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
169
|
+
nk_u32_t distance_sq_u32;
|
|
170
|
+
nk_sqeuclidean_u8_sierra(a, b, n, &distance_sq_u32);
|
|
171
|
+
*result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
NK_PUBLIC void nk_angular_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
175
|
+
nk_f32_t *result) {
|
|
176
|
+
// Angular distance for e2m3 using dual-VPSHUFB LUT + VPDPBSSD norm decomposition.
|
|
177
|
+
// Every e2m3 value × 16 is an exact integer in [-120, +120].
|
|
178
|
+
// DPBSSD(signed, signed) eliminates the need for unsigned conversion tricks.
|
|
179
|
+
//
|
|
180
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
|
|
181
|
+
26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
182
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
183
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
184
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
185
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
186
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
187
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
188
|
+
__m256i dot_i32x8 = _mm256_setzero_si256();
|
|
189
|
+
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
190
|
+
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
191
|
+
__m256i a_e2m3_u8x32, b_e2m3_u8x32;
|
|
192
|
+
|
|
193
|
+
nk_angular_e2m3_sierra_cycle:
|
|
194
|
+
if (count_scalars < 32) {
|
|
195
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
196
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
197
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
198
|
+
a_e2m3_u8x32 = a_vec.ymm;
|
|
199
|
+
b_e2m3_u8x32 = b_vec.ymm;
|
|
200
|
+
count_scalars = 0;
|
|
201
|
+
}
|
|
202
|
+
else {
|
|
203
|
+
a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
204
|
+
b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
205
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
// Decode a: extract magnitude, dual-VPSHUFB LUT, apply sign
|
|
209
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
210
|
+
__m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
211
|
+
__m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
212
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
|
|
213
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
|
|
214
|
+
__m256i a_negate = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
215
|
+
__m256i a_signed_i8x32 = _mm256_blendv_epi8(a_unsigned_u8x32,
|
|
216
|
+
_mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate);
|
|
217
|
+
|
|
218
|
+
// Decode b: same LUT decode + sign
|
|
219
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
220
|
+
__m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
221
|
+
__m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
222
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
|
|
223
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
|
|
224
|
+
__m256i b_negate = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
225
|
+
__m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32,
|
|
226
|
+
_mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate);
|
|
227
|
+
|
|
228
|
+
// VPDPBSSD: signed × signed → i32
|
|
229
|
+
dot_i32x8 = _mm256_dpbssd_epi32(dot_i32x8, a_signed_i8x32, b_signed_i8x32);
|
|
230
|
+
a_norm_i32x8 = _mm256_dpbssd_epi32(a_norm_i32x8, a_signed_i8x32, a_signed_i8x32);
|
|
231
|
+
b_norm_i32x8 = _mm256_dpbssd_epi32(b_norm_i32x8, b_signed_i8x32, b_signed_i8x32);
|
|
232
|
+
|
|
233
|
+
if (count_scalars) goto nk_angular_e2m3_sierra_cycle;
|
|
234
|
+
|
|
235
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
|
|
236
|
+
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
237
|
+
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
238
|
+
*result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
|
|
242
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
243
|
+
// Squared Euclidean distance for e2m3 using norm decomposition + VPDPBSSD.
|
|
244
|
+
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
245
|
+
//
|
|
246
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
|
|
247
|
+
26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
248
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
249
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
250
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
251
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
252
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
253
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
254
|
+
__m256i dot_i32x8 = _mm256_setzero_si256();
|
|
255
|
+
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
256
|
+
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
257
|
+
__m256i a_e2m3_u8x32, b_e2m3_u8x32;
|
|
258
|
+
|
|
259
|
+
nk_sqeuclidean_e2m3_sierra_cycle:
|
|
260
|
+
if (count_scalars < 32) {
|
|
261
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
262
|
+
nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
|
|
263
|
+
nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
|
|
264
|
+
a_e2m3_u8x32 = a_vec.ymm;
|
|
265
|
+
b_e2m3_u8x32 = b_vec.ymm;
|
|
266
|
+
count_scalars = 0;
|
|
267
|
+
}
|
|
268
|
+
else {
|
|
269
|
+
a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
270
|
+
b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
271
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
// Decode a
|
|
275
|
+
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
276
|
+
__m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
277
|
+
__m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
278
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
|
|
279
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
|
|
280
|
+
__m256i a_negate = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
281
|
+
__m256i a_signed_i8x32 = _mm256_blendv_epi8(a_unsigned_u8x32,
|
|
282
|
+
_mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate);
|
|
283
|
+
|
|
284
|
+
// Decode b
|
|
285
|
+
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
286
|
+
__m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
287
|
+
__m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
288
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
|
|
289
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
|
|
290
|
+
__m256i b_negate = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
291
|
+
__m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32,
|
|
292
|
+
_mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate);
|
|
293
|
+
|
|
294
|
+
dot_i32x8 = _mm256_dpbssd_epi32(dot_i32x8, a_signed_i8x32, b_signed_i8x32);
|
|
295
|
+
a_norm_i32x8 = _mm256_dpbssd_epi32(a_norm_i32x8, a_signed_i8x32, a_signed_i8x32);
|
|
296
|
+
b_norm_i32x8 = _mm256_dpbssd_epi32(b_norm_i32x8, b_signed_i8x32, b_signed_i8x32);
|
|
297
|
+
|
|
298
|
+
if (count_scalars) goto nk_sqeuclidean_e2m3_sierra_cycle;
|
|
299
|
+
|
|
300
|
+
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
|
|
301
|
+
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
302
|
+
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
303
|
+
*result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
NK_PUBLIC void nk_euclidean_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
307
|
+
nk_sqeuclidean_e2m3_sierra(a, b, n, result);
|
|
308
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
#if defined(__clang__)
|
|
312
|
+
#pragma clang attribute pop
|
|
313
|
+
#elif defined(__GNUC__)
|
|
314
|
+
#pragma GCC pop_options
|
|
315
|
+
#endif
|
|
316
|
+
|
|
317
|
+
#if defined(__cplusplus)
|
|
318
|
+
} // extern "C"
|
|
319
|
+
#endif
|
|
320
|
+
|
|
321
|
+
#endif // NK_TARGET_SIERRA
|
|
322
|
+
#endif // NK_TARGET_X86_
|
|
323
|
+
#endif // NK_SPATIAL_SIERRA_H
|