numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief C++ wrappers for SIMD-accelerated Spatial Similarity Measures.
|
|
3
|
+
* @file include/numkong/spatial.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 5, 2026
|
|
6
|
+
*/
|
|
7
|
+
#ifndef NK_SPATIAL_HPP
|
|
8
|
+
#define NK_SPATIAL_HPP
|
|
9
|
+
|
|
10
|
+
#include <cstdint>
|
|
11
|
+
#include <type_traits>
|
|
12
|
+
|
|
13
|
+
#include "numkong/spatial.h"
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.hpp"
|
|
16
|
+
|
|
17
|
+
namespace ashvardanian::numkong {
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* @brief L₂ (Euclidean) distance: √Σ(aᵢ − bᵢ)²
|
|
21
|
+
* @param[in] a,b First and second vectors
|
|
22
|
+
* @param[in] d Number of dimensions in input vectors
|
|
23
|
+
* @param[out] r Pointer to output distance value
|
|
24
|
+
*
|
|
25
|
+
* @tparam in_type_ Input vector element type
|
|
26
|
+
* @tparam result_type_ Accumulator type, defaults to `in_type_::euclidean_result_t`
|
|
27
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
28
|
+
*/
|
|
29
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
|
|
30
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
31
|
+
void euclidean(in_type_ const *a, in_type_ const *b, std::size_t d, result_type_ *r) noexcept {
|
|
32
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
33
|
+
std::is_same_v<result_type_, typename in_type_::euclidean_result_t>;
|
|
34
|
+
|
|
35
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_euclidean_f64(&a->raw_, &b->raw_, d, &r->raw_);
|
|
36
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_euclidean_f32(&a->raw_, &b->raw_, d, &r->raw_);
|
|
37
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_euclidean_f16(&a->raw_, &b->raw_, d, &r->raw_);
|
|
38
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) nk_euclidean_bf16(&a->raw_, &b->raw_, d, &r->raw_);
|
|
39
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd) nk_euclidean_e4m3(&a->raw_, &b->raw_, d, &r->raw_);
|
|
40
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd) nk_euclidean_e5m2(&a->raw_, &b->raw_, d, &r->raw_);
|
|
41
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd) nk_euclidean_e2m3(&a->raw_, &b->raw_, d, &r->raw_);
|
|
42
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd) nk_euclidean_e3m2(&a->raw_, &b->raw_, d, &r->raw_);
|
|
43
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && simd) nk_euclidean_i8(&a->raw_, &b->raw_, d, &r->raw_);
|
|
44
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && simd) nk_euclidean_u8(&a->raw_, &b->raw_, d, &r->raw_);
|
|
45
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd) nk_euclidean_i4(&a->raw_, &b->raw_, d, &r->raw_);
|
|
46
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd) nk_euclidean_u4(&a->raw_, &b->raw_, d, &r->raw_);
|
|
47
|
+
// Scalar fallback
|
|
48
|
+
else {
|
|
49
|
+
result_type_ sum {};
|
|
50
|
+
for (std::size_t i = 0; i < divide_round_up(d, dimensions_per_value<in_type_>()); i++)
|
|
51
|
+
sum = fdsa(a[i], b[i], sum);
|
|
52
|
+
*r = sum.sqrt();
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
/**
|
|
57
|
+
* @brief Squared L₂ distance: Σ(aᵢ − bᵢ)²
|
|
58
|
+
* @param[in] a,b First and second vectors
|
|
59
|
+
* @param[in] d Number of dimensions in input vectors
|
|
60
|
+
* @param[out] r Pointer to output distance value
|
|
61
|
+
*
|
|
62
|
+
* @tparam in_type_ Input vector element type
|
|
63
|
+
* @tparam result_type_ Accumulator type, defaults to `in_type_::sqeuclidean_result_t`
|
|
64
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
65
|
+
*/
|
|
66
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::sqeuclidean_result_t,
|
|
67
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
68
|
+
void sqeuclidean(in_type_ const *a, in_type_ const *b, std::size_t d, result_type_ *r) noexcept {
|
|
69
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
70
|
+
std::is_same_v<result_type_, typename in_type_::sqeuclidean_result_t>;
|
|
71
|
+
|
|
72
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_sqeuclidean_f64(&a->raw_, &b->raw_, d, &r->raw_);
|
|
73
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_sqeuclidean_f32(&a->raw_, &b->raw_, d, &r->raw_);
|
|
74
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_sqeuclidean_f16(&a->raw_, &b->raw_, d, &r->raw_);
|
|
75
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) nk_sqeuclidean_bf16(&a->raw_, &b->raw_, d, &r->raw_);
|
|
76
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd) nk_sqeuclidean_e4m3(&a->raw_, &b->raw_, d, &r->raw_);
|
|
77
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd) nk_sqeuclidean_e5m2(&a->raw_, &b->raw_, d, &r->raw_);
|
|
78
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd) nk_sqeuclidean_e2m3(&a->raw_, &b->raw_, d, &r->raw_);
|
|
79
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd) nk_sqeuclidean_e3m2(&a->raw_, &b->raw_, d, &r->raw_);
|
|
80
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && simd) nk_sqeuclidean_i8(&a->raw_, &b->raw_, d, &r->raw_);
|
|
81
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && simd) nk_sqeuclidean_u8(&a->raw_, &b->raw_, d, &r->raw_);
|
|
82
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd) nk_sqeuclidean_i4(&a->raw_, &b->raw_, d, &r->raw_);
|
|
83
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd) nk_sqeuclidean_u4(&a->raw_, &b->raw_, d, &r->raw_);
|
|
84
|
+
// Scalar fallback
|
|
85
|
+
else {
|
|
86
|
+
result_type_ sum {};
|
|
87
|
+
for (std::size_t i = 0; i < divide_round_up(d, dimensions_per_value<in_type_>()); i++)
|
|
88
|
+
sum = fdsa(a[i], b[i], sum);
|
|
89
|
+
*r = sum;
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/**
|
|
94
|
+
* @brief Angular similarity (cosine): ⟨a,b⟩ / (‖a‖ × ‖b‖)
|
|
95
|
+
* @param[in] a,b First and second vectors
|
|
96
|
+
* @param[in] d Number of dimensions in input vectors
|
|
97
|
+
* @param[out] r Pointer to output distance value
|
|
98
|
+
*
|
|
99
|
+
* @tparam in_type_ Input vector element type
|
|
100
|
+
* @tparam result_type_ Accumulator type, defaults to `in_type_::angular_result_t`
|
|
101
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
102
|
+
*/
|
|
103
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
|
|
104
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
105
|
+
void angular(in_type_ const *a, in_type_ const *b, std::size_t d, result_type_ *r) noexcept {
|
|
106
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
107
|
+
std::is_same_v<result_type_, typename in_type_::angular_result_t>;
|
|
108
|
+
|
|
109
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_angular_f64(&a->raw_, &b->raw_, d, &r->raw_);
|
|
110
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_angular_f32(&a->raw_, &b->raw_, d, &r->raw_);
|
|
111
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_angular_f16(&a->raw_, &b->raw_, d, &r->raw_);
|
|
112
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) nk_angular_bf16(&a->raw_, &b->raw_, d, &r->raw_);
|
|
113
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd) nk_angular_e4m3(&a->raw_, &b->raw_, d, &r->raw_);
|
|
114
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd) nk_angular_e5m2(&a->raw_, &b->raw_, d, &r->raw_);
|
|
115
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd) nk_angular_e2m3(&a->raw_, &b->raw_, d, &r->raw_);
|
|
116
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd) nk_angular_e3m2(&a->raw_, &b->raw_, d, &r->raw_);
|
|
117
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && simd) nk_angular_i8(&a->raw_, &b->raw_, d, &r->raw_);
|
|
118
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && simd) nk_angular_u8(&a->raw_, &b->raw_, d, &r->raw_);
|
|
119
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd) nk_angular_i4(&a->raw_, &b->raw_, d, &r->raw_);
|
|
120
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd) nk_angular_u4(&a->raw_, &b->raw_, d, &r->raw_);
|
|
121
|
+
// Scalar fallback
|
|
122
|
+
else {
|
|
123
|
+
result_type_ ab {}, aa {}, bb {};
|
|
124
|
+
for (std::size_t i = 0; i < divide_round_up(d, dimensions_per_value<in_type_>()); i++) {
|
|
125
|
+
ab = fma(a[i], b[i], ab);
|
|
126
|
+
aa = fma(a[i], a[i], aa);
|
|
127
|
+
bb = fma(b[i], b[i], bb);
|
|
128
|
+
}
|
|
129
|
+
// Angular distance = 1 - cosine_similarity, clamped to [0, 2]
|
|
130
|
+
result_type_ cos_sim = ab / (aa.sqrt() * bb.sqrt());
|
|
131
|
+
result_type_ distance = result_type_(1) - cos_sim;
|
|
132
|
+
*r = distance > result_type_(0) ? distance : result_type_(0);
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
} // namespace ashvardanian::numkong
|
|
137
|
+
|
|
138
|
+
#include "numkong/tensor.hpp"
|
|
139
|
+
|
|
140
|
+
namespace ashvardanian::numkong {
|
|
141
|
+
|
|
142
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
|
|
143
|
+
allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_>
|
|
144
|
+
void euclidean(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b, std::size_t d,
|
|
145
|
+
result_type_ *r) noexcept {
|
|
146
|
+
euclidean<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
|
|
150
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
151
|
+
void euclidean(vector_view<in_type_> a, vector_view<in_type_> b, std::size_t d, result_type_ *r) noexcept {
|
|
152
|
+
euclidean<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::sqeuclidean_result_t,
|
|
156
|
+
allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_>
|
|
157
|
+
void sqeuclidean(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b, std::size_t d,
|
|
158
|
+
result_type_ *r) noexcept {
|
|
159
|
+
sqeuclidean<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::sqeuclidean_result_t,
|
|
163
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
164
|
+
void sqeuclidean(vector_view<in_type_> a, vector_view<in_type_> b, std::size_t d, result_type_ *r) noexcept {
|
|
165
|
+
sqeuclidean<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
|
|
169
|
+
allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_>
|
|
170
|
+
void angular(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b, std::size_t d,
|
|
171
|
+
result_type_ *r) noexcept {
|
|
172
|
+
angular<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
|
|
176
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
177
|
+
void angular(vector_view<in_type_> a, vector_view<in_type_> b, std::size_t d, result_type_ *r) noexcept {
|
|
178
|
+
angular<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
} // namespace ashvardanian::numkong
|
|
182
|
+
|
|
183
|
+
#endif // NK_SPATIAL_HPP
|