numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,508 @@
1
+ /**
2
+ * @brief C++ wrappers for SIMD-accelerated Batched Spatial Distance Matrices.
3
+ * @file include/numkong/spatials.hpp
4
+ * @author Ash Vardanian
5
+ * @date March 2026
6
+ */
7
+ #ifndef NK_SPATIALS_HPP
8
+ #define NK_SPATIALS_HPP
9
+
10
+ #include <cstdint>
11
+ #include <cstring>
12
+ #include <type_traits>
13
+
14
+ #include "numkong/spatials.h"
15
+
16
+ #include "numkong/types.hpp"
17
+
18
+ namespace ashvardanian::numkong {
19
+
20
+ /**
21
+ * @brief Symmetric angular distance matrix: C[i,j] = angular(A[i], A[j])
22
+ * @param[in] a Matrix A [n_vectors x depth]
23
+ * @param[in] n_vectors Number of vectors (n)
24
+ * @param[in] depth Dimension of each vector (k)
25
+ * @param[in] a_stride_in_bytes Stride between vectors in A
26
+ * @param[out] c Output matrix C [n x n]
27
+ * @param[in] c_stride_in_bytes Stride between rows of C in bytes
28
+ * @param[in] row_start Starting row index (default 0)
29
+ * @param[in] row_count Number of rows to compute (default all)
30
+ *
31
+ * @tparam in_type_ Input element type
32
+ * @tparam result_type_ Output type, defaults to `in_type_::angular_result_t`
33
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
34
+ */
35
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
36
+ allow_simd_t allow_simd_ = prefer_simd_k>
37
+ void angulars_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth, std::size_t a_stride_in_bytes,
38
+ result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
39
+ std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
40
+ if (row_count == std::numeric_limits<std::size_t>::max()) row_count = n_vectors;
41
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
42
+ std::is_same_v<result_type_, typename in_type_::angular_result_t>;
43
+
44
+ if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
45
+ nk_angulars_symmetric_f64(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
46
+ row_count);
47
+ else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
48
+ nk_angulars_symmetric_f32(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
49
+ row_count);
50
+ else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
51
+ nk_angulars_symmetric_f16(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
52
+ row_count);
53
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
54
+ nk_angulars_symmetric_bf16(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
55
+ row_start, row_count);
56
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
57
+ nk_angulars_symmetric_e4m3(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
58
+ row_start, row_count);
59
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
60
+ nk_angulars_symmetric_e5m2(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
61
+ row_start, row_count);
62
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
63
+ nk_angulars_symmetric_e2m3(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
64
+ row_start, row_count);
65
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
66
+ nk_angulars_symmetric_e3m2(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
67
+ row_start, row_count);
68
+ else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
69
+ nk_angulars_symmetric_i8(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
70
+ row_count);
71
+ else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
72
+ nk_angulars_symmetric_u8(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
73
+ row_count);
74
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
75
+ nk_angulars_symmetric_i4(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
76
+ row_count);
77
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
78
+ nk_angulars_symmetric_u4(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
79
+ row_count);
80
+ else {
81
+ std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
82
+ char const *a_bytes = reinterpret_cast<char const *>(a);
83
+ char *c_bytes = reinterpret_cast<char *>(c);
84
+ std::size_t row_end = row_start + row_count < n_vectors ? row_start + row_count : n_vectors;
85
+
86
+ for (std::size_t i = row_start; i < row_end; i++) {
87
+ in_type_ const *a_i = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
88
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
89
+ for (std::size_t j = 0; j < n_vectors; j++) {
90
+ in_type_ const *a_j = reinterpret_cast<in_type_ const *>(a_bytes + j * a_stride_in_bytes);
91
+ result_type_ ab {}, aa {}, bb {};
92
+ for (std::size_t l = 0; l < depth_values; l++) {
93
+ ab = fma(a_i[l], a_j[l], ab);
94
+ aa = fma(a_i[l], a_i[l], aa);
95
+ bb = fma(a_j[l], a_j[l], bb);
96
+ }
97
+ result_type_ cos_sim = ab / (aa.sqrt() * bb.sqrt());
98
+ result_type_ distance = result_type_(1) - cos_sim;
99
+ c_row[j] = distance > result_type_(0) ? distance : result_type_(0);
100
+ }
101
+ }
102
+ }
103
+ }
104
+
105
+ /**
106
+ * @brief Symmetric Euclidean distance matrix: C[i,j] = euclidean(A[i], A[j])
107
+ * @param[in] a Matrix A [n_vectors x depth]
108
+ * @param[in] n_vectors Number of vectors (n)
109
+ * @param[in] depth Dimension of each vector (k)
110
+ * @param[in] a_stride_in_bytes Stride between vectors in A
111
+ * @param[out] c Output matrix C [n x n]
112
+ * @param[in] c_stride_in_bytes Stride between rows of C in bytes
113
+ * @param[in] row_start Starting row index (default 0)
114
+ * @param[in] row_count Number of rows to compute (default all)
115
+ *
116
+ * @tparam in_type_ Input element type
117
+ * @tparam result_type_ Output type, defaults to `in_type_::euclidean_result_t`
118
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
119
+ */
120
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
121
+ allow_simd_t allow_simd_ = prefer_simd_k>
122
+ void euclideans_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth, std::size_t a_stride_in_bytes,
123
+ result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
124
+ std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
125
+ if (row_count == std::numeric_limits<std::size_t>::max()) row_count = n_vectors;
126
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
127
+ std::is_same_v<result_type_, typename in_type_::euclidean_result_t>;
128
+
129
+ if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
130
+ nk_euclideans_symmetric_f64(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
131
+ row_start, row_count);
132
+ else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
133
+ nk_euclideans_symmetric_f32(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
134
+ row_start, row_count);
135
+ else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
136
+ nk_euclideans_symmetric_f16(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
137
+ row_start, row_count);
138
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
139
+ nk_euclideans_symmetric_bf16(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
140
+ row_start, row_count);
141
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
142
+ nk_euclideans_symmetric_e4m3(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
143
+ row_start, row_count);
144
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
145
+ nk_euclideans_symmetric_e5m2(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
146
+ row_start, row_count);
147
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
148
+ nk_euclideans_symmetric_e2m3(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
149
+ row_start, row_count);
150
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
151
+ nk_euclideans_symmetric_e3m2(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
152
+ row_start, row_count);
153
+ else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
154
+ nk_euclideans_symmetric_i8(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
155
+ row_start, row_count);
156
+ else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
157
+ nk_euclideans_symmetric_u8(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
158
+ row_start, row_count);
159
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
160
+ nk_euclideans_symmetric_i4(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
161
+ row_start, row_count);
162
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
163
+ nk_euclideans_symmetric_u4(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
164
+ row_start, row_count);
165
+ else {
166
+ std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
167
+ char const *a_bytes = reinterpret_cast<char const *>(a);
168
+ char *c_bytes = reinterpret_cast<char *>(c);
169
+ std::size_t row_end = row_start + row_count < n_vectors ? row_start + row_count : n_vectors;
170
+
171
+ for (std::size_t i = row_start; i < row_end; i++) {
172
+ in_type_ const *a_i = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
173
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
174
+ for (std::size_t j = 0; j < n_vectors; j++) {
175
+ in_type_ const *a_j = reinterpret_cast<in_type_ const *>(a_bytes + j * a_stride_in_bytes);
176
+ result_type_ sum {};
177
+ for (std::size_t l = 0; l < depth_values; l++) sum = fdsa(a_i[l], a_j[l], sum);
178
+ c_row[j] = sum.sqrt();
179
+ }
180
+ }
181
+ }
182
+ }
183
+
184
+ /**
185
+ * @brief Packed angular distances: C = angular(A, B_packed)
186
+ * @param[in] a Matrix A [row_count x depth]
187
+ * @param[in] b_packed Packed B matrix (produced by nk_dots_pack_*)
188
+ * @param[out] c Output matrix C [row_count x column_count]
189
+ * @param[in] row_count Rows of A and C (m)
190
+ * @param[in] column_count Columns of B and C (n)
191
+ * @param[in] depth Shared inner dimension (k)
192
+ * @param[in] a_stride_in_bytes Stride between rows of A in bytes
193
+ * @param[in] c_stride_in_bytes Stride between rows of C in bytes
194
+ *
195
+ * @tparam in_type_ Input element type
196
+ * @tparam result_type_ Output type, defaults to `in_type_::angular_result_t`
197
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
198
+ */
199
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
200
+ allow_simd_t allow_simd_ = prefer_simd_k>
201
+ void angulars_packed(in_type_ const *a, void const *b_packed, result_type_ *c, size_t row_count, size_t column_count,
202
+ size_t depth, size_t a_stride_in_bytes, size_t c_stride_in_bytes) noexcept {
203
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
204
+ std::is_same_v<result_type_, typename in_type_::angular_result_t>;
205
+
206
+ if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
207
+ nk_angulars_packed_f64(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
208
+ c_stride_in_bytes);
209
+ else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
210
+ nk_angulars_packed_f32(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
211
+ c_stride_in_bytes);
212
+ else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
213
+ nk_angulars_packed_f16(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
214
+ c_stride_in_bytes);
215
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
216
+ nk_angulars_packed_bf16(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
217
+ c_stride_in_bytes);
218
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
219
+ nk_angulars_packed_e4m3(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
220
+ c_stride_in_bytes);
221
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
222
+ nk_angulars_packed_e5m2(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
223
+ c_stride_in_bytes);
224
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
225
+ nk_angulars_packed_e2m3(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
226
+ c_stride_in_bytes);
227
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
228
+ nk_angulars_packed_e3m2(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
229
+ c_stride_in_bytes);
230
+ else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
231
+ nk_angulars_packed_i8(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
232
+ c_stride_in_bytes);
233
+ else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
234
+ nk_angulars_packed_u8(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
235
+ c_stride_in_bytes);
236
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
237
+ nk_angulars_packed_i4(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
238
+ c_stride_in_bytes);
239
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
240
+ nk_angulars_packed_u4(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
241
+ c_stride_in_bytes);
242
+ else {
243
+ // Scalar fallback: extract pointer and stride, compute pairwise angular distances
244
+ in_type_ const *b;
245
+ size_t b_stride_in_bytes;
246
+ char const *b_packed_bytes = reinterpret_cast<char const *>(b_packed);
247
+ std::memcpy(&b, b_packed_bytes, sizeof(void *));
248
+ std::memcpy(&b_stride_in_bytes, b_packed_bytes + sizeof(void *), sizeof(size_t));
249
+
250
+ char const *a_bytes = reinterpret_cast<char const *>(a);
251
+ char const *b_bytes = reinterpret_cast<char const *>(b);
252
+ char *c_bytes = reinterpret_cast<char *>(c);
253
+ std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
254
+
255
+ for (size_t i = 0; i < row_count; i++) {
256
+ in_type_ const *a_row = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
257
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
258
+ for (size_t j = 0; j < column_count; j++) {
259
+ in_type_ const *b_row = reinterpret_cast<in_type_ const *>(b_bytes + j * b_stride_in_bytes);
260
+ result_type_ ab {}, aa {}, bb {};
261
+ for (std::size_t l = 0; l < depth_values; l++) {
262
+ ab = fma(a_row[l], b_row[l], ab);
263
+ aa = fma(a_row[l], a_row[l], aa);
264
+ bb = fma(b_row[l], b_row[l], bb);
265
+ }
266
+ result_type_ cos_sim = ab / (aa.sqrt() * bb.sqrt());
267
+ result_type_ distance = result_type_(1) - cos_sim;
268
+ c_row[j] = distance > result_type_(0) ? distance : result_type_(0);
269
+ }
270
+ }
271
+ }
272
+ }
273
+
274
+ /**
275
+ * @brief Packed Euclidean distances: C = euclidean(A, B_packed)
276
+ * @param[in] a Matrix A [row_count x depth]
277
+ * @param[in] b_packed Packed B matrix (produced by nk_dots_pack_*)
278
+ * @param[out] c Output matrix C [row_count x column_count]
279
+ * @param[in] row_count Rows of A and C (m)
280
+ * @param[in] column_count Columns of B and C (n)
281
+ * @param[in] depth Shared inner dimension (k)
282
+ * @param[in] a_stride_in_bytes Stride between rows of A in bytes
283
+ * @param[in] c_stride_in_bytes Stride between rows of C in bytes
284
+ *
285
+ * @tparam in_type_ Input element type
286
+ * @tparam result_type_ Output type, defaults to `in_type_::euclidean_result_t`
287
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
288
+ */
289
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
290
+ allow_simd_t allow_simd_ = prefer_simd_k>
291
+ void euclideans_packed(in_type_ const *a, void const *b_packed, result_type_ *c, size_t row_count, size_t column_count,
292
+ size_t depth, size_t a_stride_in_bytes, size_t c_stride_in_bytes) noexcept {
293
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
294
+ std::is_same_v<result_type_, typename in_type_::euclidean_result_t>;
295
+
296
+ if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
297
+ nk_euclideans_packed_f64(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
298
+ c_stride_in_bytes);
299
+ else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
300
+ nk_euclideans_packed_f32(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
301
+ c_stride_in_bytes);
302
+ else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
303
+ nk_euclideans_packed_f16(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
304
+ c_stride_in_bytes);
305
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
306
+ nk_euclideans_packed_bf16(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
307
+ c_stride_in_bytes);
308
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
309
+ nk_euclideans_packed_e4m3(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
310
+ c_stride_in_bytes);
311
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
312
+ nk_euclideans_packed_e5m2(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
313
+ c_stride_in_bytes);
314
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
315
+ nk_euclideans_packed_e2m3(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
316
+ c_stride_in_bytes);
317
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
318
+ nk_euclideans_packed_e3m2(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
319
+ c_stride_in_bytes);
320
+ else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
321
+ nk_euclideans_packed_i8(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
322
+ c_stride_in_bytes);
323
+ else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
324
+ nk_euclideans_packed_u8(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
325
+ c_stride_in_bytes);
326
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
327
+ nk_euclideans_packed_i4(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
328
+ c_stride_in_bytes);
329
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
330
+ nk_euclideans_packed_u4(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
331
+ c_stride_in_bytes);
332
+ else {
333
+ // Scalar fallback: extract pointer and stride, compute pairwise euclidean distances
334
+ in_type_ const *b;
335
+ size_t b_stride_in_bytes;
336
+ char const *b_packed_bytes = reinterpret_cast<char const *>(b_packed);
337
+ std::memcpy(&b, b_packed_bytes, sizeof(void *));
338
+ std::memcpy(&b_stride_in_bytes, b_packed_bytes + sizeof(void *), sizeof(size_t));
339
+
340
+ char const *a_bytes = reinterpret_cast<char const *>(a);
341
+ char const *b_bytes = reinterpret_cast<char const *>(b);
342
+ char *c_bytes = reinterpret_cast<char *>(c);
343
+ std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
344
+
345
+ for (size_t i = 0; i < row_count; i++) {
346
+ in_type_ const *a_row = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
347
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
348
+ for (size_t j = 0; j < column_count; j++) {
349
+ in_type_ const *b_row = reinterpret_cast<in_type_ const *>(b_bytes + j * b_stride_in_bytes);
350
+ result_type_ sum {};
351
+ for (std::size_t l = 0; l < depth_values; l++) sum = fdsa(a_row[l], b_row[l], sum);
352
+ c_row[j] = sum.sqrt();
353
+ }
354
+ }
355
+ }
356
+ }
357
+
358
+ } // namespace ashvardanian::numkong
359
+
360
+ #include "numkong/tensor.hpp"
361
+
362
+ namespace ashvardanian::numkong {
363
+
364
+ #pragma region - Concept-Constrained Symmetric Spatial Distances
365
+
366
+ /** @brief Symmetric angular distances: C[i,j] = angular(A[i], A[j]). */
367
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
368
+ mutable_matrix_of<typename value_type_::angular_result_t> output_matrix_>
369
+ bool angulars_symmetric(input_matrix_ const &input, output_matrix_ &&output) noexcept {
370
+ std::size_t num_vectors = input.extent(0);
371
+ if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
372
+ numkong::angulars_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
373
+ static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
374
+ static_cast<std::size_t>(output.stride_bytes(0)));
375
+ return true;
376
+ }
377
+
378
+ /** @brief Allocating symmetric angular distances. */
379
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
380
+ typename allocator_type_ = aligned_allocator<typename value_type_::angular_result_t>>
381
+ matrix<typename value_type_::angular_result_t, allocator_type_> try_angulars_symmetric(
382
+ input_matrix_ const &input) noexcept {
383
+ using result_t = typename value_type_::angular_result_t;
384
+ using out_tensor_t = matrix<result_t, allocator_type_>;
385
+ if (input.empty()) return out_tensor_t {};
386
+ std::size_t num_vectors = input.extent(0);
387
+ auto result = out_tensor_t::try_zeros({num_vectors, num_vectors});
388
+ if (result.empty()) return result;
389
+ if (!angulars_symmetric<value_type_>(input, result.span())) return out_tensor_t {};
390
+ return result;
391
+ }
392
+
393
+ /** @brief Symmetric Euclidean distances: C[i,j] = euclidean(A[i], A[j]). */
394
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
395
+ mutable_matrix_of<typename value_type_::euclidean_result_t> output_matrix_>
396
+ bool euclideans_symmetric(input_matrix_ const &input, output_matrix_ &&output) noexcept {
397
+ std::size_t num_vectors = input.extent(0);
398
+ if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
399
+ numkong::euclideans_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
400
+ static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
401
+ static_cast<std::size_t>(output.stride_bytes(0)));
402
+ return true;
403
+ }
404
+
405
+ /** @brief Allocating symmetric Euclidean distances. */
406
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
407
+ typename allocator_type_ = aligned_allocator<typename value_type_::euclidean_result_t>>
408
+ matrix<typename value_type_::euclidean_result_t, allocator_type_> try_euclideans_symmetric(
409
+ input_matrix_ const &input) noexcept {
410
+ using result_t = typename value_type_::euclidean_result_t;
411
+ using out_tensor_t = matrix<result_t, allocator_type_>;
412
+ if (input.empty()) return out_tensor_t {};
413
+ std::size_t num_vectors = input.extent(0);
414
+ auto result = out_tensor_t::try_zeros({num_vectors, num_vectors});
415
+ if (result.empty()) return result;
416
+ if (!euclideans_symmetric<value_type_>(input, result.span())) return out_tensor_t {};
417
+ return result;
418
+ }
419
+
420
+ /** @brief Partitioned symmetric angular distances for parallel row-range work. */
421
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
422
+ mutable_matrix_of<typename value_type_::angular_result_t> output_matrix_>
423
+ bool angulars_symmetric(input_matrix_ const &input, output_matrix_ &&output, std::size_t row_start,
424
+ std::size_t row_count) noexcept {
425
+ std::size_t num_vectors = input.extent(0);
426
+ if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
427
+ numkong::angulars_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
428
+ static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
429
+ static_cast<std::size_t>(output.stride_bytes(0)), row_start, row_count);
430
+ return true;
431
+ }
432
+
433
+ /** @brief Partitioned symmetric Euclidean distances for parallel row-range work. */
434
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
435
+ mutable_matrix_of<typename value_type_::euclidean_result_t> output_matrix_>
436
+ bool euclideans_symmetric(input_matrix_ const &input, output_matrix_ &&output, std::size_t row_start,
437
+ std::size_t row_count) noexcept {
438
+ std::size_t num_vectors = input.extent(0);
439
+ if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
440
+ numkong::euclideans_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
441
+ static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
442
+ static_cast<std::size_t>(output.stride_bytes(0)), row_start, row_count);
443
+ return true;
444
+ }
445
+
446
+ #pragma endregion - Concept - Constrained Symmetric Spatial Distances
447
+
448
+ #pragma region - Concept-Constrained Packed Spatial Distances
449
+
450
+ /** @brief Packed angular distances: C = angular(A, B_packed). */
451
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
452
+ mutable_matrix_of<typename value_type_::angular_result_t> output_matrix_>
453
+ bool angulars_packed(input_matrix_ const &a, packed_type_ const &packed_b, output_matrix_ &&c) noexcept {
454
+ if (packed_b.empty() || a.rank() < 2 || c.rank() < 2) return false;
455
+ if (a.extent(1) != packed_b.depth()) return false;
456
+ if (c.extent(0) != a.extent(0) || c.extent(1) != packed_b.rows()) return false;
457
+ numkong::angulars_packed<value_type_>(a.data(), packed_b.data(), c.data(), a.extent(0), packed_b.rows(),
458
+ packed_b.depth(), static_cast<std::size_t>(a.stride_bytes(0)),
459
+ static_cast<std::size_t>(c.stride_bytes(0)));
460
+ return true;
461
+ }
462
+
463
+ /** @brief Allocating packed angular distances. */
464
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
465
+ typename allocator_type_ = aligned_allocator<typename value_type_::angular_result_t>>
466
+ matrix<typename value_type_::angular_result_t, allocator_type_> try_angulars_packed(
467
+ input_matrix_ const &a, packed_type_ const &packed_b) noexcept {
468
+ using result_t = typename value_type_::angular_result_t;
469
+ using out_t = matrix<result_t, allocator_type_>;
470
+ if (packed_b.empty() || a.rank() < 2) return out_t {};
471
+ auto c = out_t::try_empty({a.extent(0), packed_b.rows()});
472
+ if (c.empty()) return c;
473
+ if (!angulars_packed<value_type_>(a, packed_b, c.as_matrix_span())) return out_t {};
474
+ return c;
475
+ }
476
+
477
+ /** @brief Packed Euclidean distances: C = euclidean(A, B_packed). */
478
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
479
+ mutable_matrix_of<typename value_type_::euclidean_result_t> output_matrix_>
480
+ bool euclideans_packed(input_matrix_ const &a, packed_type_ const &packed_b, output_matrix_ &&c) noexcept {
481
+ if (packed_b.empty() || a.rank() < 2 || c.rank() < 2) return false;
482
+ if (a.extent(1) != packed_b.depth()) return false;
483
+ if (c.extent(0) != a.extent(0) || c.extent(1) != packed_b.rows()) return false;
484
+ numkong::euclideans_packed<value_type_>(a.data(), packed_b.data(), c.data(), a.extent(0), packed_b.rows(),
485
+ packed_b.depth(), static_cast<std::size_t>(a.stride_bytes(0)),
486
+ static_cast<std::size_t>(c.stride_bytes(0)));
487
+ return true;
488
+ }
489
+
490
+ /** @brief Allocating packed Euclidean distances. */
491
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
492
+ typename allocator_type_ = aligned_allocator<typename value_type_::euclidean_result_t>>
493
+ matrix<typename value_type_::euclidean_result_t, allocator_type_> try_euclideans_packed(
494
+ input_matrix_ const &a, packed_type_ const &packed_b) noexcept {
495
+ using result_t = typename value_type_::euclidean_result_t;
496
+ using out_t = matrix<result_t, allocator_type_>;
497
+ if (packed_b.empty() || a.rank() < 2) return out_t {};
498
+ auto c = out_t::try_empty({a.extent(0), packed_b.rows()});
499
+ if (c.empty()) return c;
500
+ if (!euclideans_packed<value_type_>(a, packed_b, c.as_matrix_span())) return out_t {};
501
+ return c;
502
+ }
503
+
504
+ #pragma endregion - Concept - Constrained Packed Spatial Distances
505
+
506
+ } // namespace ashvardanian::numkong
507
+
508
+ #endif // NK_SPATIALS_HPP