numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,639 @@
1
+ /**
2
+ * @brief C++ bindings for multi-target dot-product kernels.
3
+ * @file include/numkong/dots.hpp
4
+ * @author Ash Vardanian
5
+ * @date February 5, 2026
6
+ */
7
+ #ifndef NK_DOTS_HPP
8
+ #define NK_DOTS_HPP
9
+
10
+ #include <bit>
11
+ #include <cstdint>
12
+ #include <cstring>
13
+ #include <limits>
14
+ #include <type_traits>
15
+
16
+ #include "numkong/dot.h"
17
+ #include "numkong/dots.h"
18
+ #include "numkong/sets.h"
19
+
20
+ #include "numkong/types.hpp"
21
+
22
+ namespace ashvardanian::numkong {
23
+
24
+ /**
25
+ * @brief Reference unpacked GEMM: C = A × Bᵀ (row-major A and B, B transposed).
26
+ *
27
+ * This matches BLAS sgemm/dgemm with CblasNoTrans for A and CblasTrans for B.
28
+ * Useful as a reference implementation for validating BLAS/MKL/Accelerate.
29
+ *
30
+ * @param a Matrix A [m x k] row-major
31
+ * @param b Matrix B [n x k] row-major (accessed as Bᵀ)
32
+ * @param c Output matrix C [m x n] row-major
33
+ * @param row_count Rows of A and C (m)
34
+ * @param column_count Rows of B and columns of C (n)
35
+ * @param depth Columns of A and B (k)
36
+ * @param a_stride_in_bytes Stride between rows of A in bytes
37
+ * @param b_stride_in_bytes Stride between rows of B in bytes
38
+ * @param c_stride_in_bytes Stride between rows of C in bytes
39
+ * @tparam in_type_ Input element type (e.g., f32_t, bf16_t)
40
+ * @tparam result_type_ Accumulator/output type (e.g., f32_t, f118_t for high precision)
41
+ */
42
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t>
43
+ void dots_unpacked(in_type_ const *a, in_type_ const *b, result_type_ *c, size_t row_count, size_t column_count,
44
+ size_t depth, size_t a_stride_in_bytes, size_t b_stride_in_bytes,
45
+ size_t c_stride_in_bytes) noexcept {
46
+ char const *a_bytes = reinterpret_cast<char const *>(a);
47
+ char const *b_bytes = reinterpret_cast<char const *>(b);
48
+ char *c_bytes = reinterpret_cast<char *>(c);
49
+ std::size_t const depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
50
+
51
+ for (size_t i = 0; i < row_count; i++) {
52
+ in_type_ const *a_row = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
53
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
54
+ for (size_t j = 0; j < column_count; j++) {
55
+ in_type_ const *b_row = reinterpret_cast<in_type_ const *>(b_bytes + j * b_stride_in_bytes);
56
+ result_type_ sum {};
57
+ for (size_t l = 0; l < depth_values; l++) sum = fma(a_row[l], b_row[l], sum);
58
+ c_row[j] = sum;
59
+ }
60
+ }
61
+ }
62
+
63
+ /**
64
+ * @brief Conjugated unpacked dot products: C = A × Bᴴ (Hermitian inner product, row-major)
65
+ *
66
+ * Same as `dots_unpacked`, but conjugates elements of B before multiplication.
67
+ * For real types this is identical to `dots_unpacked`. For complex types this
68
+ * computes the standard Hermitian inner product matching `cblas_{c,z}gemm` with
69
+ * `CblasConjTrans`.
70
+ */
71
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t>
72
+ void dots_unpacked_conjugated(in_type_ const *a, in_type_ const *b, result_type_ *c, size_t row_count,
73
+ size_t column_count, size_t depth, size_t a_stride_in_bytes, size_t b_stride_in_bytes,
74
+ size_t c_stride_in_bytes) noexcept {
75
+ char const *a_bytes = reinterpret_cast<char const *>(a);
76
+ char const *b_bytes = reinterpret_cast<char const *>(b);
77
+ char *c_bytes = reinterpret_cast<char *>(c);
78
+ std::size_t const depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
79
+
80
+ for (size_t i = 0; i < row_count; i++) {
81
+ in_type_ const *a_row = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
82
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
83
+ for (size_t j = 0; j < column_count; j++) {
84
+ in_type_ const *b_row = reinterpret_cast<in_type_ const *>(b_bytes + j * b_stride_in_bytes);
85
+ result_type_ sum {};
86
+ for (size_t l = 0; l < depth_values; l++) sum = fcma(b_row[l], a_row[l], sum);
87
+ c_row[j] = sum;
88
+ }
89
+ }
90
+ }
91
+
92
+ /**
93
+ * @brief Packed dot products (batch matrix multiply): C = A × B (row-major)
94
+ * @param[in] a Matrix A [m x k]
95
+ * @param[in] b_packed Packed matrix B [k x n] with stride metadata appended
96
+ * @param[out] c Output matrix C [m x n]
97
+ * @param[in] row_count Rows of A and C (m)
98
+ * @param[in] column_count Columns of B and C (n)
99
+ * @param[in] depth Columns of A, Rows of B (k)
100
+ * @param[in] a_stride_in_bytes Stride between rows of A in bytes
101
+ * @param[in] c_stride_in_bytes Stride between rows of C in bytes
102
+ *
103
+ * @tparam in_type_ Input element type
104
+ * @tparam result_type_ Accumulator/output type, defaults to `in_type_::dot_result_t`
105
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
106
+ */
107
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t,
108
+ allow_simd_t allow_simd_ = prefer_simd_k>
109
+ void dots_packed(in_type_ const *a, void const *b_packed, result_type_ *c, size_t row_count, size_t column_count,
110
+ size_t depth, size_t a_stride_in_bytes, size_t c_stride_in_bytes) noexcept {
111
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
112
+ std::is_same_v<result_type_, typename in_type_::dot_result_t>;
113
+ if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
114
+ nk_dots_packed_f64(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
115
+ c_stride_in_bytes);
116
+ else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
117
+ nk_dots_packed_f32(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
118
+ c_stride_in_bytes);
119
+ else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
120
+ nk_dots_packed_f16(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
121
+ c_stride_in_bytes);
122
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
123
+ nk_dots_packed_bf16(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
124
+ c_stride_in_bytes);
125
+ else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
126
+ nk_dots_packed_i8(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
127
+ c_stride_in_bytes);
128
+ else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
129
+ nk_dots_packed_u8(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
130
+ c_stride_in_bytes);
131
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
132
+ nk_dots_packed_e4m3(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
133
+ c_stride_in_bytes);
134
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
135
+ nk_dots_packed_e5m2(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
136
+ c_stride_in_bytes);
137
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
138
+ nk_dots_packed_e2m3(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
139
+ c_stride_in_bytes);
140
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
141
+ nk_dots_packed_e3m2(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
142
+ c_stride_in_bytes);
143
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
144
+ nk_dots_packed_u4(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
145
+ c_stride_in_bytes);
146
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
147
+ nk_dots_packed_i4(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
148
+ c_stride_in_bytes);
149
+ else {
150
+ in_type_ const *b;
151
+ size_t b_stride_in_bytes;
152
+ char const *b_packed_bytes = reinterpret_cast<char const *>(b_packed);
153
+ std::memcpy(&b, b_packed_bytes, sizeof(void *));
154
+ std::memcpy(&b_stride_in_bytes, b_packed_bytes + sizeof(void *), sizeof(size_t));
155
+ dots_unpacked<in_type_, result_type_>(a, b, c, row_count, column_count, depth, a_stride_in_bytes,
156
+ b_stride_in_bytes, c_stride_in_bytes);
157
+ }
158
+ }
159
+
160
+ /**
161
+ * @brief Symmetric dot products: C = A × Aᵀ where C[i,j] = ⟨A[i], A[j]⟩
162
+ * @param[in] a Matrix A [n x k] (n vectors of dimension k)
163
+ * @param[in] n_vectors Number of vectors (n)
164
+ * @param[in] depth Dimension of each vector (k)
165
+ * @param[in] a_stride_in_bytes Stride between vectors in A
166
+ * @param[out] c Output matrix C [n x n]
167
+ * @param[in] c_stride_in_bytes Stride between rows of C in bytes
168
+ *
169
+ * @tparam in_type_ Input element type
170
+ * @tparam result_type_ Accumulator/output type, defaults to `in_type_::dot_result_t`
171
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
172
+ */
173
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t,
174
+ allow_simd_t allow_simd_ = prefer_simd_k>
175
+ void dots_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth, std::size_t a_stride_in_bytes,
176
+ result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
177
+ std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
178
+ if (row_count == std::numeric_limits<std::size_t>::max()) row_count = n_vectors;
179
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
180
+ std::is_same_v<result_type_, typename in_type_::dot_result_t>;
181
+
182
+ if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
183
+ nk_dots_symmetric_f64(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
184
+ row_count);
185
+ else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
186
+ nk_dots_symmetric_f32(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
187
+ row_count);
188
+ else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
189
+ nk_dots_symmetric_f16(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
190
+ row_count);
191
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
192
+ nk_dots_symmetric_bf16(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
193
+ row_count);
194
+ else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
195
+ nk_dots_symmetric_i8(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
196
+ row_count);
197
+ else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
198
+ nk_dots_symmetric_u8(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
199
+ row_count);
200
+ else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
201
+ nk_dots_symmetric_e4m3(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
202
+ row_count);
203
+ else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
204
+ nk_dots_symmetric_e5m2(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
205
+ row_count);
206
+ else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
207
+ nk_dots_symmetric_e2m3(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
208
+ row_count);
209
+ else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
210
+ nk_dots_symmetric_e3m2(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
211
+ row_count);
212
+ else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
213
+ nk_dots_symmetric_u4(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
214
+ row_count);
215
+ else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
216
+ nk_dots_symmetric_i4(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
217
+ row_count);
218
+ else {
219
+ std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
220
+ char const *a_bytes = reinterpret_cast<char const *>(a);
221
+ char *c_bytes = reinterpret_cast<char *>(c);
222
+ std::size_t row_end = row_start + row_count < n_vectors ? row_start + row_count : n_vectors;
223
+
224
+ for (std::size_t i = row_start; i < row_end; i++) {
225
+ in_type_ const *a_i = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
226
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
227
+ for (std::size_t j = 0; j < n_vectors; j++) {
228
+ in_type_ const *a_j = reinterpret_cast<in_type_ const *>(a_bytes + j * a_stride_in_bytes);
229
+ result_type_ sum {};
230
+ for (std::size_t l = 0; l < depth_values; l++) sum = fma(a_i[l], a_j[l], sum);
231
+ c_row[j] = sum;
232
+ }
233
+ }
234
+ }
235
+ }
236
+
237
+ /**
238
+ * @brief Symmetric Hamming distance matrix: C[i,j] = hamming(A[i], A[j])
239
+ * @param[in] a Input matrix (n_vectors x depth)
240
+ * @param[in] n_vectors Number of vectors
241
+ * @param[in] depth Number of dimensions per vector
242
+ * @param[in] a_stride_in_bytes Row stride in bytes
243
+ * @param[out] c Output matrix (n_vectors x n_vectors)
244
+ * @param[in] c_stride_in_bytes Output row stride in bytes
245
+ * @param[in] row_start Starting row index (default 0)
246
+ * @param[in] row_count Number of rows to compute (default all)
247
+ *
248
+ * Computes Hamming distances between all pairs of binary vectors.
249
+ * For u1x8_t inputs, distances are exact bit counts (u32_t outputs).
250
+ *
251
+ * @tparam in_type_ Input element type (u1x8_t)
252
+ * @tparam result_type_ Output type (u32_t for Hamming distances)
253
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
254
+ */
255
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::hamming_result_t,
256
+ allow_simd_t allow_simd_ = prefer_simd_k>
257
+ void hammings_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth, std::size_t a_stride_in_bytes,
258
+ result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
259
+ std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
260
+ if (row_count == std::numeric_limits<std::size_t>::max()) row_count = n_vectors;
261
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
262
+ std::is_same_v<result_type_, typename in_type_::hamming_result_t>;
263
+
264
+ if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch)
265
+ nk_hammings_symmetric_u1(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
266
+ row_count);
267
+ else {
268
+ using raw_t = typename in_type_::raw_t;
269
+ std::size_t depth_bytes = divide_round_up(depth, 8);
270
+ char const *a_bytes = reinterpret_cast<char const *>(a);
271
+ char *c_bytes = reinterpret_cast<char *>(c);
272
+ std::size_t row_end = row_start + row_count < n_vectors ? row_start + row_count : n_vectors;
273
+
274
+ for (std::size_t i = row_start; i < row_end; i++) {
275
+ raw_t const *a_i = reinterpret_cast<raw_t const *>(a_bytes + i * a_stride_in_bytes);
276
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
277
+
278
+ for (std::size_t j = 0; j < n_vectors; j++) {
279
+ raw_t const *a_j = reinterpret_cast<raw_t const *>(a_bytes + j * a_stride_in_bytes);
280
+ typename result_type_::raw_t distance = 0;
281
+ for (std::size_t b = 0; b < depth_bytes; b++) {
282
+ auto xor_val = a_i[b] ^ a_j[b];
283
+ distance += std::popcount(static_cast<unsigned>(xor_val));
284
+ }
285
+ c_row[j] = result_type_::from_raw(distance);
286
+ }
287
+ }
288
+ }
289
+ }
290
+
291
+ /**
292
+ * @brief Computes Hamming distances between rows of A and columns of packed B.
293
+ * @param[in] a Pointer to the first matrix (m x k).
294
+ * @param[in] b_packed Pointer to the packed second matrix (k x n).
295
+ * @param[out] c Pointer to the output matrix (m x n).
296
+ * @param[in] row_count Number of rows in A (m).
297
+ * @param[in] column_count Number of columns in B (n).
298
+ * @param[in] depth Depth dimension in bits (k).
299
+ * @param[in] a_stride_in_bytes Stride between consecutive rows of A in bytes.
300
+ * @param[in] c_stride_in_bytes Stride between consecutive rows of C in bytes.
301
+ *
302
+ * Computes Hamming distances between binary vectors using optimized packed format.
303
+ * For u1x8_t inputs, distances are exact bit counts (u32_t outputs).
304
+ *
305
+ * @tparam in_type_ Input element type (u1x8_t)
306
+ * @tparam result_type_ Output type (u32_t for Hamming distances)
307
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
308
+ */
309
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::hamming_result_t,
310
+ allow_simd_t allow_simd_ = prefer_simd_k>
311
+ void hammings_packed(in_type_ const *a, void const *b_packed, result_type_ *c, std::size_t row_count,
312
+ std::size_t column_count, std::size_t depth, std::size_t a_stride_in_bytes = 0,
313
+ std::size_t c_stride_in_bytes = 0) noexcept {
314
+ // Compute default strides
315
+ if (!a_stride_in_bytes) a_stride_in_bytes = divide_round_up(depth, 8) * sizeof(in_type_);
316
+ if (!c_stride_in_bytes) c_stride_in_bytes = column_count * sizeof(result_type_);
317
+
318
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
319
+ std::is_same_v<result_type_, typename in_type_::hamming_result_t>;
320
+
321
+ if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch) {
322
+ nk_hammings_packed_u1(reinterpret_cast<nk_u1x8_t const *>(a), b_packed, reinterpret_cast<nk_u32_t *>(c),
323
+ row_count, column_count, depth, a_stride_in_bytes, c_stride_in_bytes);
324
+ }
325
+ else {
326
+ // Scalar fallback: extract pointer and stride from b_packed, then compute directly
327
+ in_type_ const *b;
328
+ size_t b_stride_in_bytes;
329
+ char const *b_packed_bytes = reinterpret_cast<char const *>(b_packed);
330
+ std::memcpy(&b, b_packed_bytes, sizeof(void *));
331
+ std::memcpy(&b_stride_in_bytes, b_packed_bytes + sizeof(void *), sizeof(size_t));
332
+
333
+ // Compute Hamming distances using unpacked matrices
334
+ char const *a_bytes = reinterpret_cast<char const *>(a);
335
+ char const *b_bytes = reinterpret_cast<char const *>(b);
336
+ char *c_bytes = reinterpret_cast<char *>(c);
337
+ std::size_t depth_bytes = divide_round_up(depth, 8);
338
+
339
+ for (std::size_t i = 0; i < row_count; i++) {
340
+ typename in_type_::raw_t const *a_row = reinterpret_cast<typename in_type_::raw_t const *>(
341
+ a_bytes + i * a_stride_in_bytes);
342
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
343
+
344
+ for (std::size_t j = 0; j < column_count; j++) {
345
+ typename in_type_::raw_t const *b_row = reinterpret_cast<typename in_type_::raw_t const *>(
346
+ b_bytes + j * b_stride_in_bytes);
347
+
348
+ // Compute Hamming distance: XOR then popcount
349
+ typename result_type_::raw_t distance = 0;
350
+ for (std::size_t byte_idx = 0; byte_idx < depth_bytes; byte_idx++) {
351
+ auto xor_val = a_row[byte_idx] ^ b_row[byte_idx];
352
+ distance += std::popcount(static_cast<unsigned>(xor_val));
353
+ }
354
+ c_row[j] = result_type_::from_raw(distance);
355
+ }
356
+ }
357
+ }
358
+ }
359
+
360
+ /**
361
+ * @brief Symmetric Jaccard distance matrix: C[i,j] = jaccard(A[i], A[j])
362
+ */
363
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::jaccard_result_t,
364
+ allow_simd_t allow_simd_ = prefer_simd_k>
365
+ void jaccards_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth, std::size_t a_stride_in_bytes,
366
+ result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
367
+ std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
368
+ if (row_count == std::numeric_limits<std::size_t>::max()) row_count = n_vectors;
369
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
370
+ std::is_same_v<result_type_, typename in_type_::jaccard_result_t>;
371
+
372
+ if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch)
373
+ nk_jaccards_symmetric_u1(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
374
+ row_count);
375
+ else {
376
+ using raw_t = typename in_type_::raw_t;
377
+ std::size_t depth_bytes = divide_round_up(depth, 8);
378
+ char const *a_bytes = reinterpret_cast<char const *>(a);
379
+ char *c_bytes = reinterpret_cast<char *>(c);
380
+ std::size_t row_end = row_start + row_count < n_vectors ? row_start + row_count : n_vectors;
381
+
382
+ for (std::size_t i = row_start; i < row_end; i++) {
383
+ raw_t const *a_i = reinterpret_cast<raw_t const *>(a_bytes + i * a_stride_in_bytes);
384
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
385
+
386
+ for (std::size_t j = 0; j < n_vectors; j++) {
387
+ raw_t const *a_j = reinterpret_cast<raw_t const *>(a_bytes + j * a_stride_in_bytes);
388
+ unsigned intersection = 0, union_ = 0;
389
+ for (std::size_t b = 0; b < depth_bytes; b++) {
390
+ intersection += std::popcount(static_cast<unsigned>(a_i[b] & a_j[b]));
391
+ union_ += std::popcount(static_cast<unsigned>(a_i[b] | a_j[b]));
392
+ }
393
+ c_row[j] = result_type_::from_raw(union_ ? 1.0f - static_cast<float>(intersection) / union_ : 0.0f);
394
+ }
395
+ }
396
+ }
397
+ }
398
+
399
+ /**
400
+ * @brief Computes Jaccard distances between rows of A and columns of packed B.
401
+ */
402
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::jaccard_result_t,
403
+ allow_simd_t allow_simd_ = prefer_simd_k>
404
+ void jaccards_packed(in_type_ const *a, void const *b_packed, result_type_ *c, std::size_t row_count,
405
+ std::size_t column_count, std::size_t depth, std::size_t a_stride_in_bytes = 0,
406
+ std::size_t c_stride_in_bytes = 0) noexcept {
407
+ if (!a_stride_in_bytes) a_stride_in_bytes = divide_round_up(depth, 8) * sizeof(in_type_);
408
+ if (!c_stride_in_bytes) c_stride_in_bytes = column_count * sizeof(result_type_);
409
+
410
+ constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
411
+ std::is_same_v<result_type_, typename in_type_::jaccard_result_t>;
412
+
413
+ if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch) {
414
+ nk_jaccards_packed_u1(reinterpret_cast<nk_u1x8_t const *>(a), b_packed, reinterpret_cast<nk_f32_t *>(c),
415
+ row_count, column_count, depth, a_stride_in_bytes, c_stride_in_bytes);
416
+ }
417
+ else {
418
+ // Scalar fallback: extract pointer and stride from b_packed, then compute directly
419
+ in_type_ const *b;
420
+ size_t b_stride_in_bytes;
421
+ char const *b_packed_bytes = reinterpret_cast<char const *>(b_packed);
422
+ std::memcpy(&b, b_packed_bytes, sizeof(void *));
423
+ std::memcpy(&b_stride_in_bytes, b_packed_bytes + sizeof(void *), sizeof(size_t));
424
+
425
+ char const *a_bytes = reinterpret_cast<char const *>(a);
426
+ char const *b_bytes = reinterpret_cast<char const *>(b);
427
+ char *c_bytes = reinterpret_cast<char *>(c);
428
+ std::size_t depth_bytes = divide_round_up(depth, 8);
429
+
430
+ for (std::size_t i = 0; i < row_count; i++) {
431
+ typename in_type_::raw_t const *a_row = reinterpret_cast<typename in_type_::raw_t const *>(
432
+ a_bytes + i * a_stride_in_bytes);
433
+ result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
434
+
435
+ for (std::size_t j = 0; j < column_count; j++) {
436
+ typename in_type_::raw_t const *b_row = reinterpret_cast<typename in_type_::raw_t const *>(
437
+ b_bytes + j * b_stride_in_bytes);
438
+ unsigned intersection = 0, union_ = 0;
439
+ for (std::size_t byte_idx = 0; byte_idx < depth_bytes; byte_idx++) {
440
+ intersection += std::popcount(static_cast<unsigned>(a_row[byte_idx] & b_row[byte_idx]));
441
+ union_ += std::popcount(static_cast<unsigned>(a_row[byte_idx] | b_row[byte_idx]));
442
+ }
443
+ c_row[j] = result_type_::from_raw(union_ ? 1.0f - static_cast<float>(intersection) / union_ : 0.0f);
444
+ }
445
+ }
446
+ }
447
+ }
448
+
449
+ } // namespace ashvardanian::numkong
450
+
451
+ #include "numkong/tensor.hpp"
452
+
453
+ namespace ashvardanian::numkong {
454
+
455
+ #pragma region - Concept-Constrained Symmetric Dot Products
456
+
457
+ /** @brief C = A × Aᵀ where C[i,j] = ⟨A[i], A[j]⟩. */
458
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
459
+ mutable_matrix_of<typename value_type_::dot_result_t> output_matrix_>
460
+ bool dots_symmetric(input_matrix_ const &input, output_matrix_ &&output) noexcept {
461
+ std::size_t num_vectors = input.extent(0);
462
+ if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
463
+ numkong::dots_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
464
+ static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
465
+ static_cast<std::size_t>(output.stride_bytes(0)));
466
+ return true;
467
+ }
468
+
469
+ /** @brief Partitioned symmetric dot products for parallel row-range work. */
470
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
471
+ mutable_matrix_of<typename value_type_::dot_result_t> output_matrix_>
472
+ bool dots_symmetric(input_matrix_ const &input, output_matrix_ output, std::size_t row_start,
473
+ std::size_t row_count) noexcept {
474
+ std::size_t num_vectors = input.extent(0);
475
+ if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
476
+ numkong::dots_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
477
+ static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
478
+ static_cast<std::size_t>(output.stride_bytes(0)), row_start, row_count);
479
+ return true;
480
+ }
481
+
482
+ /** @brief Allocating symmetric dot products: C = A × Aᵀ. */
483
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
484
+ typename allocator_type_ = aligned_allocator<typename value_type_::dot_result_t>>
485
+ matrix<typename value_type_::dot_result_t, allocator_type_> try_dots_symmetric(input_matrix_ const &input) noexcept {
486
+ using result_t = typename value_type_::dot_result_t;
487
+ using out_tensor_t = matrix<result_t, allocator_type_>;
488
+ if (input.empty()) return out_tensor_t {};
489
+ std::size_t num_vectors = input.extent(0);
490
+ auto result = out_tensor_t::try_zeros({num_vectors, num_vectors});
491
+ if (result.empty()) return result;
492
+ if (!dots_symmetric<value_type_>(input, result.span())) return out_tensor_t {};
493
+ return result;
494
+ }
495
+
496
+ /** @brief Symmetric Hamming distances: C[i,j] = hamming(A[i], A[j]). */
497
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
498
+ mutable_matrix_of<typename value_type_::hamming_result_t> output_matrix_>
499
+ bool hammings_symmetric(input_matrix_ const &input, output_matrix_ &&output) noexcept {
500
+ std::size_t num_vectors = input.extent(0);
501
+ if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
502
+ numkong::hammings_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
503
+ static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
504
+ static_cast<std::size_t>(output.stride_bytes(0)));
505
+ return true;
506
+ }
507
+
508
+ /** @brief Allocating symmetric Hamming distances. */
509
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
510
+ typename allocator_type_ = aligned_allocator<typename value_type_::hamming_result_t>>
511
+ matrix<typename value_type_::hamming_result_t, allocator_type_> try_hammings_symmetric(
512
+ input_matrix_ const &input) noexcept {
513
+ using result_t = typename value_type_::hamming_result_t;
514
+ using out_tensor_t = matrix<result_t, allocator_type_>;
515
+ if (input.empty()) return out_tensor_t {};
516
+ std::size_t num_vectors = input.extent(0);
517
+ auto result = out_tensor_t::try_zeros({num_vectors, num_vectors});
518
+ if (result.empty()) return result;
519
+ if (!hammings_symmetric<value_type_>(input, result.span())) return out_tensor_t {};
520
+ return result;
521
+ }
522
+
523
+ /** @brief Symmetric Jaccard distances: C[i,j] = jaccard(A[i], A[j]). */
524
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
525
+ mutable_matrix_of<typename value_type_::jaccard_result_t> output_matrix_>
526
+ bool jaccards_symmetric(input_matrix_ const &input, output_matrix_ &&output) noexcept {
527
+ std::size_t num_vectors = input.extent(0);
528
+ if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
529
+ numkong::jaccards_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
530
+ static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
531
+ static_cast<std::size_t>(output.stride_bytes(0)));
532
+ return true;
533
+ }
534
+
535
+ /** @brief Allocating symmetric Jaccard distances. */
536
+ template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
537
+ typename allocator_type_ = aligned_allocator<typename value_type_::jaccard_result_t>>
538
+ matrix<typename value_type_::jaccard_result_t, allocator_type_> try_jaccards_symmetric(
539
+ input_matrix_ const &input) noexcept {
540
+ using result_t = typename value_type_::jaccard_result_t;
541
+ using out_tensor_t = matrix<result_t, allocator_type_>;
542
+ if (input.empty()) return out_tensor_t {};
543
+ std::size_t num_vectors = input.extent(0);
544
+ auto result = out_tensor_t::try_zeros({num_vectors, num_vectors});
545
+ if (result.empty()) return result;
546
+ if (!jaccards_symmetric<value_type_>(input, result.span())) return out_tensor_t {};
547
+ return result;
548
+ }
549
+
550
+ #pragma endregion - Concept - Constrained Symmetric Dot Products
551
+
552
+ #pragma region - Concept-Constrained Packed Dot Products
553
+
554
+ /** @brief Packed dot products: C = A × B_packedᵀ. */
555
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
556
+ mutable_matrix_of<typename value_type_::dot_result_t> output_matrix_>
557
+ bool dots_packed(input_matrix_ const &a, packed_type_ const &packed_b, output_matrix_ &&c) noexcept {
558
+ if (packed_b.empty() || a.rank() < 2 || c.rank() < 2) return false;
559
+ if (a.extent(1) != packed_b.depth()) return false;
560
+ if (c.extent(0) != a.extent(0) || c.extent(1) != packed_b.rows()) return false;
561
+ numkong::dots_packed<value_type_>(a.data(), packed_b.data(), c.data(), a.extent(0), packed_b.rows(),
562
+ packed_b.depth(), static_cast<std::size_t>(a.stride_bytes(0)),
563
+ static_cast<std::size_t>(c.stride_bytes(0)));
564
+ return true;
565
+ }
566
+
567
+ /** @brief Allocating packed dot products: C = A × B_packedᵀ. */
568
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
569
+ typename allocator_type_ = aligned_allocator<typename value_type_::dot_result_t>>
570
+ matrix<typename value_type_::dot_result_t, allocator_type_> try_dots_packed(input_matrix_ const &a,
571
+ packed_type_ const &packed_b) noexcept {
572
+ using result_t = typename value_type_::dot_result_t;
573
+ using out_t = matrix<result_t, allocator_type_>;
574
+ if (packed_b.empty() || a.rank() < 2) return out_t {};
575
+ auto c = out_t::try_empty({a.extent(0), packed_b.rows()});
576
+ if (c.empty()) return c;
577
+ if (!dots_packed<value_type_>(a, packed_b, c.as_matrix_span())) return out_t {};
578
+ return c;
579
+ }
580
+
581
+ /** @brief Packed Hamming distances: C = hamming(A, B_packed). */
582
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
583
+ mutable_matrix_of<typename value_type_::hamming_result_t> output_matrix_>
584
+ bool hammings_packed(input_matrix_ const &a, packed_type_ const &packed_b, output_matrix_ &&c) noexcept {
585
+ if (packed_b.empty() || a.rank() < 2 || c.rank() < 2) return false;
586
+ if (a.extent(1) != packed_b.depth()) return false;
587
+ if (c.extent(0) != a.extent(0) || c.extent(1) != packed_b.rows()) return false;
588
+ numkong::hammings_packed<value_type_>(a.data(), packed_b.data(), c.data(), a.extent(0), packed_b.rows(),
589
+ packed_b.depth(), static_cast<std::size_t>(a.stride_bytes(0)),
590
+ static_cast<std::size_t>(c.stride_bytes(0)));
591
+ return true;
592
+ }
593
+
594
+ /** @brief Allocating packed Hamming distances. */
595
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
596
+ typename allocator_type_ = aligned_allocator<typename value_type_::hamming_result_t>>
597
+ matrix<typename value_type_::hamming_result_t, allocator_type_> try_hammings_packed(
598
+ input_matrix_ const &a, packed_type_ const &packed_b) noexcept {
599
+ using result_t = typename value_type_::hamming_result_t;
600
+ using out_t = matrix<result_t, allocator_type_>;
601
+ if (packed_b.empty() || a.rank() < 2) return out_t {};
602
+ auto c = out_t::try_empty({a.extent(0), packed_b.rows()});
603
+ if (c.empty()) return c;
604
+ if (!hammings_packed<value_type_>(a, packed_b, c.as_matrix_span())) return out_t {};
605
+ return c;
606
+ }
607
+
608
+ /** @brief Packed Jaccard distances: C = jaccard(A, B_packed). */
609
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
610
+ mutable_matrix_of<typename value_type_::jaccard_result_t> output_matrix_>
611
+ bool jaccards_packed(input_matrix_ const &a, packed_type_ const &packed_b, output_matrix_ &&c) noexcept {
612
+ if (packed_b.empty() || a.rank() < 2 || c.rank() < 2) return false;
613
+ if (a.extent(1) != packed_b.depth()) return false;
614
+ if (c.extent(0) != a.extent(0) || c.extent(1) != packed_b.rows()) return false;
615
+ numkong::jaccards_packed<value_type_>(a.data(), packed_b.data(), c.data(), a.extent(0), packed_b.rows(),
616
+ packed_b.depth(), static_cast<std::size_t>(a.stride_bytes(0)),
617
+ static_cast<std::size_t>(c.stride_bytes(0)));
618
+ return true;
619
+ }
620
+
621
+ /** @brief Allocating packed Jaccard distances. */
622
+ template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
623
+ typename allocator_type_ = aligned_allocator<typename value_type_::jaccard_result_t>>
624
+ matrix<typename value_type_::jaccard_result_t, allocator_type_> try_jaccards_packed(
625
+ input_matrix_ const &a, packed_type_ const &packed_b) noexcept {
626
+ using result_t = typename value_type_::jaccard_result_t;
627
+ using out_t = matrix<result_t, allocator_type_>;
628
+ if (packed_b.empty() || a.rank() < 2) return out_t {};
629
+ auto c = out_t::try_empty({a.extent(0), packed_b.rows()});
630
+ if (c.empty()) return c;
631
+ if (!jaccards_packed<value_type_>(a, packed_b, c.as_matrix_span())) return out_t {};
632
+ return c;
633
+ }
634
+
635
+ #pragma endregion - Concept - Constrained Packed Dot Products
636
+
637
+ } // namespace ashvardanian::numkong
638
+
639
+ #endif // NK_DOTS_HPP