numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,133 @@
1
+ /**
2
+ * @brief C++ bindings for multi-target MaxSim (ColBERT late-interaction) kernels.
3
+ * @file include/numkong/maxsim.hpp
4
+ * @author Ash Vardanian
5
+ * @date February 28, 2026
6
+ */
7
+ #ifndef NK_MAXSIM_HPP
8
+ #define NK_MAXSIM_HPP
9
+
10
+ #include <cstddef>
11
+ #include <cstring>
12
+ #include <limits>
13
+ #include <type_traits>
14
+
15
+ #include "numkong/maxsim.h"
16
+ #include "numkong/types.hpp"
17
+ #include "numkong/spatial.hpp" // angular<>
18
+
19
+ namespace ashvardanian::numkong {
20
+
21
+ /**
22
+ * @brief Computes angular distance late-interaction on pre-packed vectors.
23
+ * Returns Σᵢ minⱼ angular(qᵢ, dⱼ).
24
+ * @param[in] query_packed Packed query vectors.
25
+ * @param[in] document_packed Packed document vectors.
26
+ * @param[in] query_count Number of query vectors.
27
+ * @param[in] document_count Number of document vectors.
28
+ * @param[in] depth Number of dimensions per vector.
29
+ * @return Sum of per-query minimum angular distances.
30
+ *
31
+ * @tparam in_type_ Input element type (bf16_t, f32_t, f16_t).
32
+ * @tparam result_type_ Result type, defaults to `in_type_::maxsim_result_t`.
33
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`.
34
+ */
35
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::maxsim_result_t,
36
+ allow_simd_t allow_simd_ = prefer_simd_k>
37
+ NK_PUBLIC void maxsim_packed(void const *query_packed, void const *document_packed, std::size_t query_count,
38
+ std::size_t document_count, std::size_t depth, result_type_ *result) {
39
+ constexpr bool simd = allow_simd_ == prefer_simd_k &&
40
+ std::is_same_v<result_type_, typename in_type_::maxsim_result_t>;
41
+
42
+ if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
43
+ nk_maxsim_packed_bf16(query_packed, document_packed, query_count, document_count, depth, &result->raw_);
44
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
45
+ nk_maxsim_packed_f32(query_packed, document_packed, query_count, document_count, depth, &result->raw_);
46
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
47
+ nk_maxsim_packed_f16(query_packed, document_packed, query_count, document_count, depth, &result->raw_);
48
+ else {
49
+ typename in_type_::raw_t const *q_ptr;
50
+ std::size_t q_stride;
51
+ char const *q_bytes = reinterpret_cast<char const *>(query_packed);
52
+ std::memcpy(&q_ptr, q_bytes, sizeof(void *));
53
+ std::memcpy(&q_stride, q_bytes + sizeof(void *), sizeof(std::size_t));
54
+
55
+ typename in_type_::raw_t const *d_ptr;
56
+ std::size_t d_stride;
57
+ char const *d_bytes = reinterpret_cast<char const *>(document_packed);
58
+ std::memcpy(&d_ptr, d_bytes, sizeof(void *));
59
+ std::memcpy(&d_stride, d_bytes + sizeof(void *), sizeof(std::size_t));
60
+
61
+ maxsim_reference<in_type_, result_type_>(q_ptr, query_count, q_stride, d_ptr, document_count, d_stride, depth,
62
+ result);
63
+ }
64
+ }
65
+
66
+ /**
67
+ * @brief Exhaustive angular reference for testing: Σᵢ minⱼ angular(qᵢ, dⱼ).
68
+ * Computes all pairwise angular distances and picks the minimum per query.
69
+ * Uses f64 accumulator for precision.
70
+ * @param[in] queries Query vectors in row-major order.
71
+ * @param[in] query_count Number of query vectors.
72
+ * @param[in] query_stride Row stride in bytes for query vectors.
73
+ * @param[in] documents Document vectors in row-major order.
74
+ * @param[in] document_count Number of document vectors.
75
+ * @param[in] document_stride Row stride in bytes for document vectors.
76
+ * @param[in] depth Number of dimensions per vector.
77
+ * @param[out] result Pointer to store the sum of per-query minimum angular distances.
78
+ *
79
+ * @tparam in_type_ Input element type (bf16_t, f32_t, f16_t).
80
+ * @tparam result_type_ Result type, defaults to `in_type_::angular_result_t`.
81
+ */
82
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
83
+ allow_simd_t allow_simd_ = prefer_simd_k>
84
+ NK_PUBLIC void maxsim_reference(typename in_type_::raw_t const *queries, std::size_t query_count,
85
+ std::size_t query_stride, typename in_type_::raw_t const *documents,
86
+ std::size_t document_count, std::size_t document_stride, std::size_t depth,
87
+ result_type_ *result) {
88
+ result_type_ total_angular_distance {};
89
+
90
+ for (std::size_t query_index = 0; query_index < query_count; query_index++) {
91
+ in_type_ const *query_row = reinterpret_cast<in_type_ const *>(reinterpret_cast<char const *>(queries) +
92
+ query_index * query_stride);
93
+
94
+ result_type_ min_angular = result_type_::finite_max();
95
+
96
+ for (std::size_t document_index = 0; document_index < document_count; document_index++) {
97
+ in_type_ const *document_row = reinterpret_cast<in_type_ const *>(
98
+ reinterpret_cast<char const *>(documents) + document_index * document_stride);
99
+
100
+ result_type_ angular_distance {};
101
+ angular<in_type_, result_type_, allow_simd_>(query_row, document_row, depth, &angular_distance);
102
+
103
+ if (angular_distance < min_angular) min_angular = angular_distance;
104
+ }
105
+
106
+ total_angular_distance = total_angular_distance + min_angular;
107
+ }
108
+
109
+ *result = total_angular_distance;
110
+ }
111
+
112
+ } // namespace ashvardanian::numkong
113
+
114
+ #include "numkong/matrix.hpp"
115
+
116
+ namespace ashvardanian::numkong {
117
+
118
+ /** @brief MaxSim: Σᵢ minⱼ angular(qᵢ, dⱼ) on pre-packed vectors. */
119
+ template <numeric_dtype value_type_>
120
+ typename value_type_::maxsim_result_t maxsim(packed_maxsim<value_type_> const &queries,
121
+ packed_maxsim<value_type_> const &documents) noexcept {
122
+ using result_t = typename value_type_::maxsim_result_t;
123
+ result_t result {};
124
+ if (queries.empty() || documents.empty()) return result;
125
+ if (queries.depth() != documents.depth()) return result;
126
+ maxsim_packed<value_type_>(queries.data(), documents.data(), queries.vector_count(), documents.vector_count(),
127
+ queries.depth(), &result);
128
+ return result;
129
+ }
130
+
131
+ } // namespace ashvardanian::numkong
132
+
133
+ #endif // NK_MAXSIM_HPP
@@ -0,0 +1,227 @@
1
+ # Point Cloud Alignment in NumKong
2
+
3
+ NumKong implements RMSD, Kabsch, and Umeyama algorithms for rigid-body superposition of 3D point clouds.
4
+ RMSD measures alignment quality, Kabsch finds the optimal rotation minimizing RMSD, and Umeyama extends Kabsch with uniform scaling.
5
+ Used in structural biology (protein alignment), robotics (point cloud registration), and computer graphics (mesh registration).
6
+
7
+ Centroid:
8
+
9
+ ```math
10
+ \bar{a} = \frac{1}{n}\sum a_i
11
+ ```
12
+
13
+ Cross-covariance matrix:
14
+
15
+ ```math
16
+ H = \sum (a_i - \bar{a})(b_i - \bar{b})^T
17
+ ```
18
+
19
+ SVD-based rotation:
20
+
21
+ ```math
22
+ H = U \Sigma V^T, \quad R = V U^T
23
+ ```
24
+
25
+ Umeyama scale factor:
26
+
27
+ ```math
28
+ s = \frac{\text{tr}(\Sigma)}{n \cdot \sigma_a^2}
29
+ ```
30
+
31
+ RMSD after alignment:
32
+
33
+ ```math
34
+ \text{RMSD} = \sqrt{\frac{1}{n}\sum \|s \cdot R(a_i - \bar{a}) - (b_i - \bar{b})\|^2}
35
+ ```
36
+
37
+ Reformulating as Python pseudocode:
38
+
39
+ ```python
40
+ import numpy as np
41
+
42
+ def kabsch(a: np.ndarray, b: np.ndarray) -> np.ndarray:
43
+ a_c, b_c = a - a.mean(0), b - b.mean(0)
44
+ H = a_c.T @ b_c
45
+ U, S, Vt = np.linalg.svd(H)
46
+ d = np.sign(np.linalg.det(Vt.T @ U.T))
47
+ R = Vt.T @ np.diag([1, 1, d]) @ U.T
48
+ return R
49
+
50
+ def umeyama(a: np.ndarray, b: np.ndarray) -> tuple:
51
+ a_c, b_c = a - a.mean(0), b - b.mean(0)
52
+ H = a_c.T @ b_c
53
+ U, S, Vt = np.linalg.svd(H)
54
+ d = np.sign(np.linalg.det(Vt.T @ U.T))
55
+ R = Vt.T @ np.diag([1, 1, d]) @ U.T
56
+ scale = S.sum() / (len(a) * np.var(a_c))
57
+ return R, scale
58
+
59
+ def rmsd(a: np.ndarray, b: np.ndarray) -> float:
60
+ return np.sqrt(np.mean(np.sum((a - b) ** 2, axis=1)))
61
+ ```
62
+
63
+ ## Input & Output Types
64
+
65
+ | Input Type | Output Type | Description |
66
+ | ---------- | ----------- | ---------------------------------------------- |
67
+ | `f64` | `f64` | 64-bit IEEE 754 double precision |
68
+ | `f32` | `f32` | 32-bit IEEE 754 single precision |
69
+ | `f16` | `f32` | 16-bit IEEE 754 half precision, widened output |
70
+ | `bf16` | `f32` | 16-bit brain float, widened output |
71
+
72
+ ## Optimizations
73
+
74
+ ### McAdams Branching-Free 3×3 SVD
75
+
76
+ `nk_kabsch_f32_serial`, `nk_kabsch_f64_haswell`, `nk_umeyama_f32_neon` use a Jacobi eigenanalysis with fixed 16 iterations (no convergence check) for deterministic behavior.
77
+ Quaternion-accumulated rotations: each Jacobi sweep updates a 4-element quaternion instead of recomputing eigenvectors.
78
+ Approximate Givens angles via `nk_approximate_givens_quaternion_` — a γ-threshold test selects between computed angles and precomputed cos(π/8), sin(π/8) constants.
79
+ Cyclic permutation of matrix elements avoids explicit sorting of eigenvalues.
80
+
81
+ ### Stride-3 Deinterleaving
82
+
83
+ Point clouds are stored interleaved as [x₀,y₀,z₀, x₁,y₁,z₁, ...].
84
+ NEON uses `vld3q_f32` to hardware-deinterleave 4 XYZ triplets in one instruction — no gather needed.
85
+ Haswell uses `_mm256_i32gather_ps` with indices [0,3,6,9,12,15,18,21] to load 8 x-coordinates from 8 points.
86
+ RVV uses indexed loads with dynamic stride to adapt to variable vector length.
87
+
88
+ ### Reflection Correction
89
+
90
+ `nk_kabsch_f32_haswell`, `nk_kabsch_f64_skylake` check for improper rotations (det(R) = -1, reflections) after computing R = V·Uᵀ.
91
+ If det(R) is negative, the last column of V is flipped.
92
+ This ensures the output is always a proper rotation matrix (det = +1).
93
+
94
+ ### Pre-Scaled Rotation for Umeyama
95
+
96
+ `nk_umeyama_f32_haswell`, `nk_umeyama_f64_skylake` fold the computed scale factor into the rotation matrix before applying to points.
97
+ `sr[i] = scale * r[i]` is computed once and broadcast — avoiding a per-point scalar multiply.
98
+
99
+ ### Why SME and SVE Were Removed
100
+
101
+ Historical note: experimental SME variants of RMSD, Kabsch, and Umeyama were implemented in 1,052 lines across `sme.h` and `smef64.h` (commit `0e0bc30c`) and removed 4 days later (commit `f55e9a71`).
102
+ The fundamental mismatch: the algorithm computes a 3×3 cross-covariance matrix $H = \sum (a_i - \bar{a})(b_i - \bar{b})^T$ — a sum of outer products of 3D vectors.
103
+ SME's `FMOPA` operates on SVL-wide vectors (16+ elements at SVL=512), but the outer products here are 3×3 — the tile is 99.6% wasted (9 useful cells out of 256).
104
+ Three approaches were explored in a design document (`sme_design.h`, 398 lines):
105
+ (1) batched outer products — reformulates as 9 independent dot products but loses SME's outer-product strength, falling back to what NEON already does;
106
+ (2) streaming SVE with `svld3` — hardware stride-3 deinterleaving processes 16 points per iteration vs NEON's 4, but `SMSTART`/`SMSTOP` mode transitions cost ~100 cycles and the 3×3 SVD step cannot use streaming mode at all;
107
+ (3) SME for SVD — the 3×3 matrix cannot fill even one 16×16 tile.
108
+ Performance estimates from the design document: NEON baseline ~2.25N cycles for N points; streaming SVE ~1.2N cycles but with ~100-cycle mode transition overhead — for typical protein alignment workloads (N = 100–500 atoms), the overhead dominates.
109
+ Experimental SVE mesh kernels (`sve.h`, `svehalf.h`, 112 lines total) were removed in the same commit — variable vector length added complexity without clear benefit over fixed-width NEON for the 3D point cloud problem.
110
+
111
+ ## Performance
112
+
113
+ The following performance tables are produced by manually re-running `nk_test` and `nk_bench` included internal tools to measure both accuracy and throughput at different input shapes.
114
+ The input size is controlled by the `NK_MESH_POINTS` environment variable and set to 256, 1024, and 4096 points.
115
+ Each alignment computes centroids, covariance, and a 3×3 SVD over $N$ point pairs, so cost is $O(N)$ per alignment with a large constant.
116
+ The throughput is measured in mp/s as millions of 3D points aligned per second.
117
+ Accuracy is reported as mean ULP (units in last place) unless noted otherwise — the average number of representable floating-point values between the result and the exact answer.
118
+ Each kernel runs for at least 20 seconds per configuration.
119
+ Benchmark threads are pinned to specific cores; on machines with heterogeneous core types (e.g., Apple P/E cores), only the fastest cores are used.
120
+ Workloads that significantly degrade CPU frequencies (Intel AMX, Apple SME) run in separate passes to avoid affecting throughput measurements of other kernels.
121
+
122
+ ### Intel Sapphire Rapids
123
+
124
+ #### Native
125
+
126
+ | Kernel | 256 | 1024 | 4096 |
127
+ | :------------------------ | -----------------------: | -----------------------: | -----------------------: |
128
+ | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
129
+ | `nk_rmsd_f64_serial` | 354 mp/s, 1.4 ulp | 176 mp/s, 2.7 ulp | 159 mp/s, 5.0 ulp |
130
+ | `nk_kabsch_f64_serial` | 71.1 mp/s, 1.4 ulp | 70.8 mp/s, 2.7 ulp | 80.3 mp/s, 5.2 ulp |
131
+ | `nk_umeyama_f64_serial` | 70.1 mp/s, 1.0 ulp | 75.1 mp/s, 1.8 ulp | 79.1 mp/s, 3.9 ulp |
132
+ | `nk_rmsd_f64_haswell` | 405 mp/s, 0.3 ulp | 260 mp/s, 0.4 ulp | 192 mp/s, 0.8 ulp |
133
+ | `nk_kabsch_f64_haswell` | 82.1 mp/s, 0.9 ulp | 105 mp/s, 1.3 ulp | 133 mp/s, 2.3 ulp |
134
+ | `nk_umeyama_f64_haswell` | 82.6 mp/s, 0.4 ulp | 119 mp/s, 0.8 ulp | 134 mp/s, 1.5 ulp |
135
+ | `nk_rmsd_f64_skylake` | 540 mp/s, 0.3 ulp | 219 mp/s, 0.3 ulp | 213 mp/s, 0.5 ulp |
136
+ | `nk_kabsch_f64_skylake` | 96.8 mp/s, 0.7 ulp | 115 mp/s, 0.9 ulp | 159 mp/s, 1.1 ulp |
137
+ | `nk_umeyama_f64_skylake` | 101 mp/s, 0.2 ulp | 119 mp/s, 0.4 ulp | 157 mp/s, 0.8 ulp |
138
+ | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
139
+ | `nk_rmsd_f32_serial` | 480 mp/s, 1.4 ulp | 314 mp/s, 2.7 ulp | 270 mp/s, 5.4 ulp |
140
+ | `nk_kabsch_f32_serial` | 83.2 mp/s, 1.5 ulp | 91.6 mp/s, 2.6 ulp | 110 mp/s, 5.3 ulp |
141
+ | `nk_umeyama_f32_serial` | 80.4 mp/s, 1.0 ulp | 104 mp/s, 1.9 ulp | 106 mp/s, 3.7 ulp |
142
+ | `nk_rmsd_f32_haswell` | 447 mp/s, 0.3 ulp | 484 mp/s, 0.3 ulp | 350 mp/s, 0.4 ulp |
143
+ | `nk_kabsch_f32_haswell` | 101 mp/s, 0.7 ulp | 192 mp/s, 0.9 ulp | 213 mp/s, 1.3 ulp |
144
+ | `nk_umeyama_f32_haswell` | 97.4 mp/s, 0.3 ulp | 155 mp/s, 0.4 ulp | 207 mp/s, 0.8 ulp |
145
+ | `nk_rmsd_f32_skylake` | 936 mp/s, 0.3 ulp | 970 mp/s, 0.3 ulp | 426 mp/s, 0.3 ulp |
146
+ | `nk_kabsch_f32_skylake` | 122 mp/s, 0.7 ulp | 258 mp/s, 0.7 ulp | 290 mp/s, 0.9 ulp |
147
+ | `nk_umeyama_f32_skylake` | 133 mp/s, 0.2 ulp | 231 mp/s, 0.3 ulp | 285 mp/s, 0.5 ulp |
148
+ | __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
149
+ | `nk_rmsd_bf16_haswell` | 511 mp/s, 0.3 ulp | 481 mp/s, 3.5 ulp | 497 mp/s, 12.8 ulp |
150
+ | `nk_kabsch_bf16_haswell` | 52.4 mp/s, 0.7 ulp | 65.3 mp/s, 0.9 ulp | 74.8 mp/s, 1.3 ulp |
151
+ | `nk_umeyama_bf16_haswell` | 51.5 mp/s, 0.2 ulp | 69.2 mp/s, 0.4 ulp | 74.6 mp/s, 0.8 ulp |
152
+ | __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
153
+ | `nk_rmsd_f16_haswell` | 415 mp/s, 0.3 ulp | 497 mp/s, 0.7 ulp | 458 mp/s, 2.5 ulp |
154
+ | `nk_kabsch_f16_haswell` | 151 mp/s, 0.7 ulp | 222 mp/s, 0.9 ulp | 221 mp/s, 1.4 ulp |
155
+ | `nk_umeyama_f16_haswell` | 186 mp/s, 0.2 ulp | 232 mp/s, 0.5 ulp | 222 mp/s, 0.9 ulp |
156
+
157
+ #### WASM
158
+
159
+ Measured with Wasmtime v42 (Cranelift backend).
160
+
161
+ | Kernel | 256 | 1024 | 4096 |
162
+ | :--------------------------- | -----------------------: | -----------------------: | -----------------------: |
163
+ | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
164
+ | `nk_rmsd_f64_serial` | 178 mp/s, 1.4 ulp | 158 mp/s, 2.6 ulp | ? mp/s, 5.3 ulp |
165
+ | `nk_rmsd_f64_v128relaxed` | 273 mp/s, 0.4 ulp | 307 mp/s, 0.7 ulp | ? mp/s, 1.3 ulp |
166
+ | `nk_kabsch_f64_serial` | 37.7 mp/s, 1.4 ulp | 51.7 mp/s, 2.5 ulp | ? mp/s, 5.2 ulp |
167
+ | `nk_kabsch_f64_v128relaxed` | 31.7 mp/s, 1.2 ulp | 56.9 mp/s, 2.3 ulp | ? mp/s, 4.5 ulp |
168
+ | `nk_umeyama_f64_serial` | 36.5 mp/s, 0.9 ulp | 49.6 mp/s, 1.9 ulp | ? mp/s, 3.6 ulp |
169
+ | `nk_umeyama_f64_v128relaxed` | 32.6 mp/s, 0.8 ulp | 55.5 mp/s, 1.5 ulp | ? mp/s, 3.2 ulp |
170
+ | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
171
+ | `nk_rmsd_f32_serial` | 105 mp/s, 1.4 ulp | 122 mp/s, 2.7 ulp | ? mp/s, 5.2 ulp |
172
+ | `nk_rmsd_f32_v128relaxed` | 213 mp/s, 0.3 ulp | 258 mp/s, 0.4 ulp | ? mp/s, 0.8 ulp |
173
+ | `nk_kabsch_f32_serial` | 15.5 mp/s, 1.4 ulp | 32.8 mp/s, 2.6 ulp | ? mp/s, 5.1 ulp |
174
+ | `nk_kabsch_f32_v128relaxed` | 13.5 mp/s, 0.9 ulp | 46.2 mp/s, 1.3 ulp | ? mp/s, 2.5 ulp |
175
+ | `nk_umeyama_f32_serial` | 15.2 mp/s, 1.0 ulp | 37.4 mp/s, 1.8 ulp | ? mp/s, 3.7 ulp |
176
+ | `nk_umeyama_f32_v128relaxed` | 18.3 mp/s, 0.4 ulp | 38.9 mp/s, 0.8 ulp | ? mp/s, 1.5 ulp |
177
+
178
+
179
+ ### Apple M4
180
+
181
+ #### Native
182
+
183
+ | Kernel | 256 | 1024 | 4096 |
184
+ | :----------------------- | -----------------------: | -----------------------: | -----------------------: |
185
+ | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
186
+ | `nk_rmsd_f64_serial` | 447 mp/s, 1.4 ulp | 410 mp/s, 2.6 ulp | 412 mp/s, 5.3 ulp |
187
+ | `nk_kabsch_f64_serial` | 95.2 mp/s, 1.4 ulp | 169 mp/s, 2.6 ulp | 214 mp/s, 5.4 ulp |
188
+ | `nk_umeyama_f64_serial` | 89.2 mp/s, 1.0 ulp | 157 mp/s, 1.9 ulp | 195 mp/s, 3.7 ulp |
189
+ | `nk_rmsd_f64_neon` | 823 mp/s, 0.4 ulp | 761 mp/s, 0.7 ulp | 702 mp/s, 1.3 ulp |
190
+ | `nk_kabsch_f64_neon` | 105 mp/s, 0.8 ulp | 213 mp/s, 1.3 ulp | 287 mp/s, 2.2 ulp |
191
+ | `nk_umeyama_f64_neon` | 106 mp/s, 0.4 ulp | 214 mp/s, 0.8 ulp | 297 mp/s, 1.6 ulp |
192
+ | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
193
+ | `nk_rmsd_f32_serial` | 554 mp/s, 1.4 ulp | 566 mp/s, 2.6 ulp | 532 mp/s, 5.2 ulp |
194
+ | `nk_kabsch_f32_serial` | 110 mp/s, 1.4 ulp | 214 mp/s, 2.7 ulp | 264 mp/s, 5.0 ulp |
195
+ | `nk_umeyama_f32_serial` | 104 mp/s, 0.9 ulp | 197 mp/s, 1.8 ulp | 240 mp/s, 3.5 ulp |
196
+ | `nk_rmsd_f32_neon` | 1,580 mp/s, 0.3 ulp | 1,560 mp/s, 0.4 ulp | 1,200 mp/s, 0.8 ulp |
197
+ | `nk_kabsch_f32_neon` | 139 mp/s, 0.7 ulp | 336 mp/s, 0.9 ulp | 485 mp/s, 1.4 ulp |
198
+ | `nk_umeyama_f32_neon` | 137 mp/s, 0.3 ulp | 325 mp/s, 0.4 ulp | 470 mp/s, 0.8 ulp |
199
+ | __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
200
+ | `nk_rmsd_bf16_serial` | 1,740 mp/s, 0.5 ulp | 1,880 mp/s, 6.0 ulp | 1,860 mp/s, 10.0 ulp |
201
+ | `nk_kabsch_bf16_serial` | 137 mp/s, 0.7 ulp | 335 mp/s, 0.9 ulp | 527 mp/s, 1.3 ulp |
202
+ | `nk_umeyama_bf16_serial` | 135 mp/s, 0.2 ulp | 329 mp/s, 0.4 ulp | 510 mp/s, 0.8 ulp |
203
+ | __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
204
+ | `nk_rmsd_f16_serial` | 1,840 mp/s, 0.4 ulp | 1,900 mp/s, 1.7 ulp | 1,860 mp/s, 4.6 ulp |
205
+ | `nk_kabsch_f16_serial` | 140 mp/s, 0.9 ulp | 349 mp/s, 1.3 ulp | 547 mp/s, 2.4 ulp |
206
+ | `nk_umeyama_f16_serial` | 135 mp/s, 0.4 ulp | 316 mp/s, 0.8 ulp | 474 mp/s, 1.5 ulp |
207
+
208
+ #### WASM
209
+
210
+ Measured with Wasmtime v42 (Cranelift backend).
211
+
212
+ | Kernel | 256 | 1024 | 4096 |
213
+ | :--------------------------- | -----------------------: | -----------------------: | -----------------------: |
214
+ | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
215
+ | `nk_rmsd_f64_serial` | 219 mp/s, 2.6 ulp | 202 mp/s, 2.6 ulp | 255 mp/s, 2.6 ulp |
216
+ | `nk_rmsd_f64_v128relaxed` | 434 mp/s, 0.8 ulp | 363 mp/s, 0.8 ulp | 586 mp/s, 0.8 ulp |
217
+ | `nk_kabsch_f64_serial` | 42.8 mp/s, 2.7 ulp | 76.0 mp/s, 2.7 ulp | 110 mp/s, 2.7 ulp |
218
+ | `nk_kabsch_f64_v128relaxed` | 55.2 mp/s, 2.2 ulp | 110 mp/s, 2.2 ulp | 202 mp/s, 2.2 ulp |
219
+ | `nk_umeyama_f64_serial` | 36.1 mp/s, 1.8 ulp | 58.9 mp/s, 1.8 ulp | 98.7 mp/s, 1.8 ulp |
220
+ | `nk_umeyama_f64_v128relaxed` | 52.4 mp/s, 1.5 ulp | 103 mp/s, 1.5 ulp | 183 mp/s, 1.5 ulp |
221
+ | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
222
+ | `nk_rmsd_f32_serial` | 218 mp/s, 2.7 ulp | 223 mp/s, 2.7 ulp | 271 mp/s, 2.7 ulp |
223
+ | `nk_rmsd_f32_v128relaxed` | 626 mp/s, 0.5 ulp | 626 mp/s, 0.5 ulp | 687 mp/s, 0.5 ulp |
224
+ | `nk_kabsch_f32_serial` | 45.5 mp/s, 2.6 ulp | 77.0 mp/s, 2.6 ulp | 112 mp/s, 2.6 ulp |
225
+ | `nk_kabsch_f32_v128relaxed` | 68.6 mp/s, 1.3 ulp | 160 mp/s, 1.3 ulp | 273 mp/s, 1.3 ulp |
226
+ | `nk_umeyama_f32_serial` | 38.7 mp/s, 1.8 ulp | 60.0 mp/s, 1.8 ulp | 80.5 mp/s, 1.8 ulp |
227
+ | `nk_umeyama_f32_v128relaxed` | 66.9 mp/s, 0.8 ulp | 157 mp/s, 0.8 ulp | 291 mp/s, 0.8 ulp |