numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,183 @@
1
+ /**
2
+ * @brief C++ wrappers for SIMD-accelerated Spatial Similarity Measures.
3
+ * @file include/numkong/spatial.hpp
4
+ * @author Ash Vardanian
5
+ * @date February 5, 2026
6
+ */
7
+ #ifndef NK_SPATIAL_HPP
8
+ #define NK_SPATIAL_HPP
9
+
10
+ #include <cstdint>
11
+ #include <type_traits>
12
+
13
+ #include "numkong/spatial.h"
14
+
15
+ #include "numkong/types.hpp"
16
+
17
+ namespace ashvardanian::numkong {
18
+
19
+ /**
20
+ * @brief L₂ (Euclidean) distance: √Σ(aᵢ − bᵢ)²
21
+ * @param[in] a,b First and second vectors
22
+ * @param[in] d Number of dimensions in input vectors
23
+ * @param[out] r Pointer to output distance value
24
+ *
25
+ * @tparam in_type_ Input vector element type
26
+ * @tparam result_type_ Accumulator type, defaults to `in_type_::euclidean_result_t`
27
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
28
+ */
29
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
30
+ allow_simd_t allow_simd_ = prefer_simd_k>
31
+ void euclidean(in_type_ const *a, in_type_ const *b, std::size_t d, result_type_ *r) noexcept {
32
+ constexpr bool simd = allow_simd_ == prefer_simd_k &&
33
+ std::is_same_v<result_type_, typename in_type_::euclidean_result_t>;
34
+
35
+ if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_euclidean_f64(&a->raw_, &b->raw_, d, &r->raw_);
36
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_euclidean_f32(&a->raw_, &b->raw_, d, &r->raw_);
37
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_euclidean_f16(&a->raw_, &b->raw_, d, &r->raw_);
38
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) nk_euclidean_bf16(&a->raw_, &b->raw_, d, &r->raw_);
39
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd) nk_euclidean_e4m3(&a->raw_, &b->raw_, d, &r->raw_);
40
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd) nk_euclidean_e5m2(&a->raw_, &b->raw_, d, &r->raw_);
41
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd) nk_euclidean_e2m3(&a->raw_, &b->raw_, d, &r->raw_);
42
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd) nk_euclidean_e3m2(&a->raw_, &b->raw_, d, &r->raw_);
43
+ else if constexpr (std::is_same_v<in_type_, i8_t> && simd) nk_euclidean_i8(&a->raw_, &b->raw_, d, &r->raw_);
44
+ else if constexpr (std::is_same_v<in_type_, u8_t> && simd) nk_euclidean_u8(&a->raw_, &b->raw_, d, &r->raw_);
45
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd) nk_euclidean_i4(&a->raw_, &b->raw_, d, &r->raw_);
46
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd) nk_euclidean_u4(&a->raw_, &b->raw_, d, &r->raw_);
47
+ // Scalar fallback
48
+ else {
49
+ result_type_ sum {};
50
+ for (std::size_t i = 0; i < divide_round_up(d, dimensions_per_value<in_type_>()); i++)
51
+ sum = fdsa(a[i], b[i], sum);
52
+ *r = sum.sqrt();
53
+ }
54
+ }
55
+
56
+ /**
57
+ * @brief Squared L₂ distance: Σ(aᵢ − bᵢ)²
58
+ * @param[in] a,b First and second vectors
59
+ * @param[in] d Number of dimensions in input vectors
60
+ * @param[out] r Pointer to output distance value
61
+ *
62
+ * @tparam in_type_ Input vector element type
63
+ * @tparam result_type_ Accumulator type, defaults to `in_type_::sqeuclidean_result_t`
64
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
65
+ */
66
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::sqeuclidean_result_t,
67
+ allow_simd_t allow_simd_ = prefer_simd_k>
68
+ void sqeuclidean(in_type_ const *a, in_type_ const *b, std::size_t d, result_type_ *r) noexcept {
69
+ constexpr bool simd = allow_simd_ == prefer_simd_k &&
70
+ std::is_same_v<result_type_, typename in_type_::sqeuclidean_result_t>;
71
+
72
+ if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_sqeuclidean_f64(&a->raw_, &b->raw_, d, &r->raw_);
73
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_sqeuclidean_f32(&a->raw_, &b->raw_, d, &r->raw_);
74
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_sqeuclidean_f16(&a->raw_, &b->raw_, d, &r->raw_);
75
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) nk_sqeuclidean_bf16(&a->raw_, &b->raw_, d, &r->raw_);
76
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd) nk_sqeuclidean_e4m3(&a->raw_, &b->raw_, d, &r->raw_);
77
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd) nk_sqeuclidean_e5m2(&a->raw_, &b->raw_, d, &r->raw_);
78
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd) nk_sqeuclidean_e2m3(&a->raw_, &b->raw_, d, &r->raw_);
79
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd) nk_sqeuclidean_e3m2(&a->raw_, &b->raw_, d, &r->raw_);
80
+ else if constexpr (std::is_same_v<in_type_, i8_t> && simd) nk_sqeuclidean_i8(&a->raw_, &b->raw_, d, &r->raw_);
81
+ else if constexpr (std::is_same_v<in_type_, u8_t> && simd) nk_sqeuclidean_u8(&a->raw_, &b->raw_, d, &r->raw_);
82
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd) nk_sqeuclidean_i4(&a->raw_, &b->raw_, d, &r->raw_);
83
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd) nk_sqeuclidean_u4(&a->raw_, &b->raw_, d, &r->raw_);
84
+ // Scalar fallback
85
+ else {
86
+ result_type_ sum {};
87
+ for (std::size_t i = 0; i < divide_round_up(d, dimensions_per_value<in_type_>()); i++)
88
+ sum = fdsa(a[i], b[i], sum);
89
+ *r = sum;
90
+ }
91
+ }
92
+
93
+ /**
94
+ * @brief Angular similarity (cosine): ⟨a,b⟩ / (‖a‖ × ‖b‖)
95
+ * @param[in] a,b First and second vectors
96
+ * @param[in] d Number of dimensions in input vectors
97
+ * @param[out] r Pointer to output distance value
98
+ *
99
+ * @tparam in_type_ Input vector element type
100
+ * @tparam result_type_ Accumulator type, defaults to `in_type_::angular_result_t`
101
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
102
+ */
103
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
104
+ allow_simd_t allow_simd_ = prefer_simd_k>
105
+ void angular(in_type_ const *a, in_type_ const *b, std::size_t d, result_type_ *r) noexcept {
106
+ constexpr bool simd = allow_simd_ == prefer_simd_k &&
107
+ std::is_same_v<result_type_, typename in_type_::angular_result_t>;
108
+
109
+ if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_angular_f64(&a->raw_, &b->raw_, d, &r->raw_);
110
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_angular_f32(&a->raw_, &b->raw_, d, &r->raw_);
111
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_angular_f16(&a->raw_, &b->raw_, d, &r->raw_);
112
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) nk_angular_bf16(&a->raw_, &b->raw_, d, &r->raw_);
113
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd) nk_angular_e4m3(&a->raw_, &b->raw_, d, &r->raw_);
114
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd) nk_angular_e5m2(&a->raw_, &b->raw_, d, &r->raw_);
115
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd) nk_angular_e2m3(&a->raw_, &b->raw_, d, &r->raw_);
116
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd) nk_angular_e3m2(&a->raw_, &b->raw_, d, &r->raw_);
117
+ else if constexpr (std::is_same_v<in_type_, i8_t> && simd) nk_angular_i8(&a->raw_, &b->raw_, d, &r->raw_);
118
+ else if constexpr (std::is_same_v<in_type_, u8_t> && simd) nk_angular_u8(&a->raw_, &b->raw_, d, &r->raw_);
119
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd) nk_angular_i4(&a->raw_, &b->raw_, d, &r->raw_);
120
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd) nk_angular_u4(&a->raw_, &b->raw_, d, &r->raw_);
121
+ // Scalar fallback
122
+ else {
123
+ result_type_ ab {}, aa {}, bb {};
124
+ for (std::size_t i = 0; i < divide_round_up(d, dimensions_per_value<in_type_>()); i++) {
125
+ ab = fma(a[i], b[i], ab);
126
+ aa = fma(a[i], a[i], aa);
127
+ bb = fma(b[i], b[i], bb);
128
+ }
129
+ // Angular distance = 1 - cosine_similarity, clamped to [0, 2]
130
+ result_type_ cos_sim = ab / (aa.sqrt() * bb.sqrt());
131
+ result_type_ distance = result_type_(1) - cos_sim;
132
+ *r = distance > result_type_(0) ? distance : result_type_(0);
133
+ }
134
+ }
135
+
136
+ } // namespace ashvardanian::numkong
137
+
138
+ #include "numkong/tensor.hpp"
139
+
140
+ namespace ashvardanian::numkong {
141
+
142
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
143
+ allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_>
144
+ void euclidean(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b, std::size_t d,
145
+ result_type_ *r) noexcept {
146
+ euclidean<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
147
+ }
148
+
149
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
150
+ allow_simd_t allow_simd_ = prefer_simd_k>
151
+ void euclidean(vector_view<in_type_> a, vector_view<in_type_> b, std::size_t d, result_type_ *r) noexcept {
152
+ euclidean<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
153
+ }
154
+
155
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::sqeuclidean_result_t,
156
+ allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_>
157
+ void sqeuclidean(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b, std::size_t d,
158
+ result_type_ *r) noexcept {
159
+ sqeuclidean<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
160
+ }
161
+
162
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::sqeuclidean_result_t,
163
+ allow_simd_t allow_simd_ = prefer_simd_k>
164
+ void sqeuclidean(vector_view<in_type_> a, vector_view<in_type_> b, std::size_t d, result_type_ *r) noexcept {
165
+ sqeuclidean<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
166
+ }
167
+
168
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
169
+ allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_>
170
+ void angular(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b, std::size_t d,
171
+ result_type_ *r) noexcept {
172
+ angular<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
173
+ }
174
+
175
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
176
+ allow_simd_t allow_simd_ = prefer_simd_k>
177
+ void angular(vector_view<in_type_> a, vector_view<in_type_> b, std::size_t d, result_type_ *r) noexcept {
178
+ angular<in_type_, result_type_, allow_simd_>(a.data(), b.data(), d, r);
179
+ }
180
+
181
+ } // namespace ashvardanian::numkong
182
+
183
+ #endif // NK_SPATIAL_HPP