numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,288 @@
1
+ /**
2
+ * @brief NEON-accelerated Sparse Vector Operations.
3
+ * @file include/numkong/sparse/neon.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ *
7
+ * @sa include/numkong/sparse.h
8
+ */
9
+ #ifndef NK_SPARSE_NEON_H
10
+ #define NK_SPARSE_NEON_H
11
+
12
+ #if NK_TARGET_ARM_
13
+ #if NK_TARGET_NEON
14
+
15
+ #include "numkong/types.h"
16
+
17
+ #if defined(__cplusplus)
18
+ extern "C" {
19
+ #endif
20
+
21
+ #if defined(__clang__)
22
+ #pragma clang attribute push(__attribute__((target("arch=armv8-a"))), apply_to = function)
23
+ #elif defined(__GNUC__)
24
+ #pragma GCC push_options
25
+ #pragma GCC target("arch=armv8-a")
26
+ #endif
27
+
28
+ NK_INTERNAL uint32x4_t nk_intersect_u32x4_neon_(uint32x4_t a, uint32x4_t b) {
29
+ uint32x4_t b_rot1 = vextq_u32(b, b, 1);
30
+ uint32x4_t b_rot2 = vextq_u32(b, b, 2);
31
+ uint32x4_t b_rot3 = vextq_u32(b, b, 3);
32
+ uint32x4_t matches_rot0 = vceqq_u32(a, b);
33
+ uint32x4_t matches_rot1 = vceqq_u32(a, b_rot1);
34
+ uint32x4_t matches_rot2 = vceqq_u32(a, b_rot2);
35
+ uint32x4_t matches_rot3 = vceqq_u32(a, b_rot3);
36
+ uint32x4_t matches = vorrq_u32(vorrq_u32(matches_rot0, matches_rot1), vorrq_u32(matches_rot2, matches_rot3));
37
+ return matches;
38
+ }
39
+
40
+ NK_INTERNAL uint16x8_t nk_intersect_u16x8_neon_(uint16x8_t a, uint16x8_t b) {
41
+ uint16x8_t b_rot1 = vextq_u16(b, b, 1);
42
+ uint16x8_t b_rot2 = vextq_u16(b, b, 2);
43
+ uint16x8_t b_rot3 = vextq_u16(b, b, 3);
44
+ uint16x8_t b_rot4 = vextq_u16(b, b, 4);
45
+ uint16x8_t b_rot5 = vextq_u16(b, b, 5);
46
+ uint16x8_t b_rot6 = vextq_u16(b, b, 6);
47
+ uint16x8_t b_rot7 = vextq_u16(b, b, 7);
48
+ uint16x8_t matches_rot0 = vceqq_u16(a, b);
49
+ uint16x8_t matches_rot1 = vceqq_u16(a, b_rot1);
50
+ uint16x8_t matches_rot2 = vceqq_u16(a, b_rot2);
51
+ uint16x8_t matches_rot3 = vceqq_u16(a, b_rot3);
52
+ uint16x8_t matches_rot4 = vceqq_u16(a, b_rot4);
53
+ uint16x8_t matches_rot5 = vceqq_u16(a, b_rot5);
54
+ uint16x8_t matches_rot6 = vceqq_u16(a, b_rot6);
55
+ uint16x8_t matches_rot7 = vceqq_u16(a, b_rot7);
56
+ uint16x8_t matches = vorrq_u16(
57
+ vorrq_u16(vorrq_u16(matches_rot0, matches_rot1), vorrq_u16(matches_rot2, matches_rot3)),
58
+ vorrq_u16(vorrq_u16(matches_rot4, matches_rot5), vorrq_u16(matches_rot6, matches_rot7)));
59
+ return matches;
60
+ }
61
+
62
+ NK_PUBLIC void nk_sparse_intersect_u16_neon( //
63
+ nk_u16_t const *a, nk_u16_t const *b, //
64
+ nk_size_t a_length, nk_size_t b_length, //
65
+ nk_u16_t *result, nk_size_t *count) {
66
+
67
+ // NEON lacks compress-store, so fall back to serial for result output
68
+ if (result) {
69
+ nk_sparse_intersect_u16_serial(a, b, a_length, b_length, result, count);
70
+ return;
71
+ }
72
+
73
+ #if NK_ALLOW_ISA_REDIRECT
74
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
75
+ if (a_length < 32 && b_length < 32) {
76
+ nk_sparse_intersect_u16_serial(a, b, a_length, b_length, result, count);
77
+ return;
78
+ }
79
+ #endif
80
+
81
+ nk_u16_t const *const a_end = a + a_length;
82
+ nk_u16_t const *const b_end = b + b_length;
83
+ nk_b128_vec_t a_vec, b_vec;
84
+ uint16x8_t c_counts_u16x8 = vdupq_n_u16(0);
85
+
86
+ while (a + 8 <= a_end && b + 8 <= b_end) {
87
+ a_vec.u16x8 = vld1q_u16(a);
88
+ b_vec.u16x8 = vld1q_u16(b);
89
+
90
+ // Intersecting registers with `nk_intersect_u16x8_neon_` involves a lot of shuffling
91
+ // and comparisons, so we want to avoid it if the slices don't overlap at all..
92
+ nk_u16_t a_min;
93
+ nk_u16_t a_max = a_vec.u16s[7];
94
+ nk_u16_t b_min = b_vec.u16s[0];
95
+ nk_u16_t b_max = b_vec.u16s[7];
96
+
97
+ // If the slices don't overlap, advance the appropriate pointer
98
+ while (a_max < b_min && a + 16 <= a_end) {
99
+ a += 8;
100
+ a_vec.u16x8 = vld1q_u16(a);
101
+ a_max = a_vec.u16s[7];
102
+ }
103
+ a_min = a_vec.u16s[0];
104
+ while (b_max < a_min && b + 16 <= b_end) {
105
+ b += 8;
106
+ b_vec.u16x8 = vld1q_u16(b);
107
+ b_max = b_vec.u16s[7];
108
+ }
109
+ b_min = b_vec.u16s[0];
110
+
111
+ // Transform match-masks into "ones", accumulate them between the cycles,
112
+ // and merge all together in the end.
113
+ uint16x8_t a_matches = nk_intersect_u16x8_neon_(a_vec.u16x8, b_vec.u16x8);
114
+ c_counts_u16x8 = vaddq_u16(c_counts_u16x8, vandq_u16(a_matches, vdupq_n_u16(1)));
115
+
116
+ // Use `vclz_u32` to compute leading zeros for both `a_step` and `b_step` in parallel.
117
+ // Narrow comparison masks from 128→64→32 bits, pack both into a `uint32x2_t`.
118
+ uint16x8_t a_inrange_u16x8 = vcleq_u16(a_vec.u16x8, vdupq_n_u16(b_max));
119
+ uint16x8_t b_inrange_u16x8 = vcleq_u16(b_vec.u16x8, vdupq_n_u16(a_max));
120
+ uint8x8_t a_narrow_u8x8 = vmovn_u16(a_inrange_u16x8);
121
+ uint8x8_t b_narrow_u8x8 = vmovn_u16(b_inrange_u16x8);
122
+ uint8x8_t packed_u8x8 = vshrn_n_u16(vreinterpretq_u16_u8(vcombine_u8(a_narrow_u8x8, b_narrow_u8x8)), 4);
123
+ uint32x2_t clz_u32x2 = vclz_u32(vreinterpret_u32_u8(packed_u8x8));
124
+ a += (32 - vget_lane_u32(clz_u32x2, 0)) / 4;
125
+ b += (32 - vget_lane_u32(clz_u32x2, 1)) / 4;
126
+ }
127
+
128
+ nk_size_t tail_count = 0;
129
+ nk_sparse_intersect_u16_serial(a, b, a_end - a, b_end - b, 0, &tail_count);
130
+ *count = tail_count + (nk_size_t)vaddvq_u16(c_counts_u16x8);
131
+ }
132
+
133
+ NK_PUBLIC void nk_sparse_intersect_u32_neon( //
134
+ nk_u32_t const *a, nk_u32_t const *b, //
135
+ nk_size_t a_length, nk_size_t b_length, //
136
+ nk_u32_t *result, nk_size_t *count) {
137
+
138
+ // NEON lacks compress-store, so fall back to serial for result output
139
+ if (result) {
140
+ nk_sparse_intersect_u32_serial(a, b, a_length, b_length, result, count);
141
+ return;
142
+ }
143
+
144
+ #if NK_ALLOW_ISA_REDIRECT
145
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
146
+ if (a_length < 32 && b_length < 32) {
147
+ nk_sparse_intersect_u32_serial(a, b, a_length, b_length, result, count);
148
+ return;
149
+ }
150
+ #endif
151
+
152
+ nk_u32_t const *const a_end = a + a_length;
153
+ nk_u32_t const *const b_end = b + b_length;
154
+ nk_b128_vec_t a_vec, b_vec;
155
+ uint32x4_t c_counts_u32x4 = vdupq_n_u32(0);
156
+
157
+ while (a + 4 <= a_end && b + 4 <= b_end) {
158
+ a_vec.u32x4 = vld1q_u32(a);
159
+ b_vec.u32x4 = vld1q_u32(b);
160
+
161
+ // Intersecting registers with `nk_intersect_u32x4_neon_` involves a lot of shuffling
162
+ // and comparisons, so we want to avoid it if the slices don't overlap at all..
163
+ nk_u32_t a_min;
164
+ nk_u32_t a_max = a_vec.u32s[3];
165
+ nk_u32_t b_min = b_vec.u32s[0];
166
+ nk_u32_t b_max = b_vec.u32s[3];
167
+
168
+ // If the slices don't overlap, advance the appropriate pointer
169
+ while (a_max < b_min && a + 8 <= a_end) {
170
+ a += 4;
171
+ a_vec.u32x4 = vld1q_u32(a);
172
+ a_max = a_vec.u32s[3];
173
+ }
174
+ a_min = a_vec.u32s[0];
175
+ while (b_max < a_min && b + 8 <= b_end) {
176
+ b += 4;
177
+ b_vec.u32x4 = vld1q_u32(b);
178
+ b_max = b_vec.u32s[3];
179
+ }
180
+ b_min = b_vec.u32s[0];
181
+
182
+ // Transform match-masks into "ones", accumulate them between the cycles,
183
+ // and merge all together in the end.
184
+ uint32x4_t a_matches = nk_intersect_u32x4_neon_(a_vec.u32x4, b_vec.u32x4);
185
+ c_counts_u32x4 = vaddq_u32(c_counts_u32x4, vandq_u32(a_matches, vdupq_n_u32(1)));
186
+
187
+ uint32x4_t a_inrange_u32x4 = vcleq_u32(a_vec.u32x4, vdupq_n_u32(b_max));
188
+ uint32x4_t b_inrange_u32x4 = vcleq_u32(b_vec.u32x4, vdupq_n_u32(a_max));
189
+ uint8x8_t packed_u8x8 = vmovn_u16(vcombine_u16(vmovn_u32(a_inrange_u32x4), vmovn_u32(b_inrange_u32x4)));
190
+ uint32x2_t clz_u32x2 = vclz_u32(vreinterpret_u32_u8(packed_u8x8));
191
+ a += (32 - vget_lane_u32(clz_u32x2, 0)) / 8;
192
+ b += (32 - vget_lane_u32(clz_u32x2, 1)) / 8;
193
+ }
194
+
195
+ nk_size_t tail_count = 0;
196
+ nk_sparse_intersect_u32_serial(a, b, a_end - a, b_end - b, 0, &tail_count);
197
+ *count = tail_count + (nk_size_t)vaddvq_u32(c_counts_u32x4);
198
+ }
199
+
200
+ NK_INTERNAL uint64x2_t nk_intersect_u64x2_neon_(uint64x2_t a, uint64x2_t b) {
201
+ uint64x2_t b_rot1 = vextq_u64(b, b, 1);
202
+ uint64x2_t matches_rot0 = vceqq_u64(a, b);
203
+ uint64x2_t matches_rot1 = vceqq_u64(a, b_rot1);
204
+ uint64x2_t matches = vorrq_u64(matches_rot0, matches_rot1);
205
+ return matches;
206
+ }
207
+
208
+ NK_PUBLIC void nk_sparse_intersect_u64_neon( //
209
+ nk_u64_t const *a, nk_u64_t const *b, //
210
+ nk_size_t a_length, nk_size_t b_length, //
211
+ nk_u64_t *result, nk_size_t *count) {
212
+
213
+ // NEON lacks compress-store, so fall back to serial for result output
214
+ if (result) {
215
+ nk_sparse_intersect_u64_serial(a, b, a_length, b_length, result, count);
216
+ return;
217
+ }
218
+
219
+ #if NK_ALLOW_ISA_REDIRECT
220
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
221
+ if (a_length < 8 && b_length < 8) {
222
+ nk_sparse_intersect_u64_serial(a, b, a_length, b_length, result, count);
223
+ return;
224
+ }
225
+ #endif
226
+
227
+ nk_u64_t const *const a_end = a + a_length;
228
+ nk_u64_t const *const b_end = b + b_length;
229
+ nk_b128_vec_t a_vec, b_vec;
230
+ uint64x2_t c_counts_u64x2 = vdupq_n_u64(0);
231
+
232
+ while (a + 2 <= a_end && b + 2 <= b_end) {
233
+ a_vec.u64x2 = vld1q_u64(a);
234
+ b_vec.u64x2 = vld1q_u64(b);
235
+
236
+ // Intersecting registers with `nk_intersect_u64x2_neon_` involves comparisons,
237
+ // so we want to avoid it if the slices don't overlap at all.
238
+ nk_u64_t a_min;
239
+ nk_u64_t a_max = a_vec.u64s[1];
240
+ nk_u64_t b_min = b_vec.u64s[0];
241
+ nk_u64_t b_max = b_vec.u64s[1];
242
+
243
+ // If the slices don't overlap, advance the appropriate pointer
244
+ while (a_max < b_min && a + 4 <= a_end) {
245
+ a += 2;
246
+ a_vec.u64x2 = vld1q_u64(a);
247
+ a_max = a_vec.u64s[1];
248
+ }
249
+ a_min = a_vec.u64s[0];
250
+ while (b_max < a_min && b + 4 <= b_end) {
251
+ b += 2;
252
+ b_vec.u64x2 = vld1q_u64(b);
253
+ b_max = b_vec.u64s[1];
254
+ }
255
+ b_min = b_vec.u64s[0];
256
+
257
+ // Now we are likely to have some overlap, so we can intersect the registers
258
+ // Transform match-masks into "ones", accumulate them between the cycles,
259
+ // and merge all together in the end.
260
+ uint64x2_t a_matches = nk_intersect_u64x2_neon_(a_vec.u64x2, b_vec.u64x2);
261
+ c_counts_u64x2 = vaddq_u64(c_counts_u64x2, vandq_u64(a_matches, vdupq_n_u64(1)));
262
+
263
+ uint64x2_t a_inrange_u64x2 = vcleq_u64(a_vec.u64x2, vdupq_n_u64(b_max));
264
+ uint64x2_t b_inrange_u64x2 = vcleq_u64(b_vec.u64x2, vdupq_n_u64(a_max));
265
+ uint16x4_t packed_u16x4 = vmovn_u32(vcombine_u32(vmovn_u64(a_inrange_u64x2), vmovn_u64(b_inrange_u64x2)));
266
+ uint32x2_t clz_u32x2 = vclz_u32(vreinterpret_u32_u16(packed_u16x4));
267
+ a += (32 - vget_lane_u32(clz_u32x2, 0)) / 16;
268
+ b += (32 - vget_lane_u32(clz_u32x2, 1)) / 16;
269
+ }
270
+
271
+ nk_size_t tail_count = 0;
272
+ nk_sparse_intersect_u64_serial(a, b, a_end - a, b_end - b, 0, &tail_count);
273
+ *count = tail_count + (nk_size_t)vaddvq_u64(c_counts_u64x2);
274
+ }
275
+
276
+ #if defined(__clang__)
277
+ #pragma clang attribute pop
278
+ #elif defined(__GNUC__)
279
+ #pragma GCC pop_options
280
+ #endif
281
+
282
+ #if defined(__cplusplus)
283
+ } // extern "C"
284
+ #endif
285
+
286
+ #endif // NK_TARGET_NEON
287
+ #endif // NK_TARGET_ARM_
288
+ #endif // NK_SPARSE_NEON_H
@@ -0,0 +1,117 @@
1
+ /**
2
+ * @brief Serial Sparse Vector Operations.
3
+ * @file include/numkong/sparse/serial.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ *
7
+ * @sa include/numkong/sparse.h
8
+ */
9
+ #ifndef NK_SPARSE_SERIAL_H
10
+ #define NK_SPARSE_SERIAL_H
11
+
12
+ #include "numkong/types.h"
13
+ #include "numkong/cast/serial.h" // `nk_bf16_to_f32_serial`, `nk_assign_from_to_`
14
+
15
+ #if defined(__cplusplus)
16
+ extern "C" {
17
+ #endif
18
+
19
+ #define nk_define_sparse_intersect_(input_type) \
20
+ NK_PUBLIC nk_size_t nk_sparse_intersect_##input_type##_galloping_search_( \
21
+ nk_##input_type##_t const *array, nk_size_t start, nk_size_t length, nk_##input_type##_t val) { \
22
+ nk_size_t low = start; \
23
+ nk_size_t high = start + 1; \
24
+ while (high < length && array[high] < val) { \
25
+ low = high; \
26
+ high = (2 * high < length) ? 2 * high : length; \
27
+ } \
28
+ while (low < high) { \
29
+ nk_size_t mid = low + (high - low) / 2; \
30
+ if (array[mid] < val) { low = mid + 1; } \
31
+ else { high = mid; } \
32
+ } \
33
+ return low; \
34
+ } \
35
+ NK_PUBLIC nk_size_t nk_sparse_intersect_##input_type##_linear_scan_( \
36
+ nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_size_t a_length, nk_size_t b_length, \
37
+ nk_##input_type##_t *result) { \
38
+ nk_size_t intersection_size = 0; \
39
+ nk_size_t i = 0, j = 0; \
40
+ while (i != a_length && j != b_length) { \
41
+ nk_##input_type##_t ai = a[i]; \
42
+ nk_##input_type##_t bj = b[j]; \
43
+ if (ai == bj) { \
44
+ if (result) result[intersection_size] = ai; \
45
+ intersection_size++; \
46
+ } \
47
+ i += ai <= bj; \
48
+ j += ai >= bj; \
49
+ } \
50
+ return intersection_size; \
51
+ } \
52
+ NK_PUBLIC void nk_sparse_intersect_##input_type##_serial( \
53
+ nk_##input_type##_t const *shorter, nk_##input_type##_t const *longer, nk_size_t shorter_length, \
54
+ nk_size_t longer_length, nk_##input_type##_t *result, nk_size_t *count) { \
55
+ /* Swap arrays if necessary, as we want "longer" to be larger than "shorter" */ \
56
+ if (longer_length < shorter_length) { \
57
+ nk_##input_type##_t const *temp = shorter; \
58
+ shorter = longer; \
59
+ longer = temp; \
60
+ nk_size_t temp_length = shorter_length; \
61
+ shorter_length = longer_length; \
62
+ longer_length = temp_length; \
63
+ } \
64
+ \
65
+ /* Use the accurate implementation if galloping is not beneficial */ \
66
+ if (longer_length < 64 * shorter_length) { \
67
+ *count = nk_sparse_intersect_##input_type##_linear_scan_(shorter, longer, shorter_length, longer_length, \
68
+ result); \
69
+ return; \
70
+ } \
71
+ \
72
+ /* Perform galloping, shrinking the target range */ \
73
+ nk_size_t intersection_size = 0; \
74
+ nk_size_t j = 0; \
75
+ for (nk_size_t i = 0; i < shorter_length; ++i) { \
76
+ nk_##input_type##_t shorter_i = shorter[i]; \
77
+ j = nk_sparse_intersect_##input_type##_galloping_search_(longer, j, longer_length, shorter_i); \
78
+ if (j < longer_length && longer[j] == shorter_i) { \
79
+ if (result) result[intersection_size] = shorter_i; \
80
+ intersection_size++; \
81
+ } \
82
+ } \
83
+ *count = intersection_size; \
84
+ }
85
+
86
+ #define nk_define_sparse_dot_(input_type, weight_type, accumulator_type, load_and_convert) \
87
+ NK_PUBLIC void nk_sparse_dot_##input_type##weight_type##_serial( \
88
+ nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_##weight_type##_t const *a_weights, \
89
+ nk_##weight_type##_t const *b_weights, nk_size_t a_length, nk_size_t b_length, \
90
+ nk_##accumulator_type##_t *product) { \
91
+ nk_##accumulator_type##_t weights_product = 0, awi, bwi; \
92
+ nk_size_t i = 0, j = 0; \
93
+ while (i != a_length && j != b_length) { \
94
+ nk_##input_type##_t ai = a[i]; \
95
+ nk_##input_type##_t bj = b[j]; \
96
+ int matches = ai == bj; \
97
+ load_and_convert(a_weights + i, &awi); \
98
+ load_and_convert(b_weights + j, &bwi); \
99
+ weights_product += matches * awi * bwi; \
100
+ i += ai < bj; \
101
+ j += ai >= bj; \
102
+ } \
103
+ *product = weights_product; \
104
+ }
105
+
106
+ nk_define_sparse_intersect_(u16) // nk_sparse_intersect_u16_serial
107
+ nk_define_sparse_intersect_(u32) // nk_sparse_intersect_u32_serial
108
+ nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
109
+
110
+ nk_define_sparse_dot_(u16, bf16, f32, nk_bf16_to_f32_serial) // nk_sparse_dot_u16bf16_serial
111
+ nk_define_sparse_dot_(u32, f32, f64, nk_assign_from_to_) // nk_sparse_dot_u32f32_serial
112
+
113
+ #if defined(__cplusplus)
114
+ } // extern "C"
115
+ #endif
116
+
117
+ #endif // NK_SPARSE_SERIAL_H