numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,322 @@
1
+ /**
2
+ * @brief Turin-accelerated Sparse Vector Operations.
3
+ * @file include/numkong/sparse/turin.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ *
7
+ * @sa include/numkong/sparse.h
8
+ */
9
+ #ifndef NK_SPARSE_TURIN_H
10
+ #define NK_SPARSE_TURIN_H
11
+
12
+ #if NK_TARGET_X86_
13
+ #if NK_TARGET_TURIN
14
+
15
+ #include "numkong/types.h"
16
+
17
+ #if defined(__cplusplus)
18
+ extern "C" {
19
+ #endif
20
+
21
+ #if defined(__clang__)
22
+ #pragma clang attribute push( \
23
+ __attribute__((target( \
24
+ "avx2,avx512f,avx512vl,bmi,bmi2,lzcnt,popcnt,avx512bw,avx512vbmi2,avx512bf16,avx512vnni,avx512vp2intersect,avx512dq"))), \
25
+ apply_to = function)
26
+ #elif defined(__GNUC__)
27
+ #pragma GCC push_options
28
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "bmi", "bmi2", "lzcnt", "popcnt", "avx512bw", "avx512vbmi2", \
29
+ "avx512bf16", "avx512vnni", "avx512vp2intersect", "avx512dq")
30
+ #endif
31
+
32
+ NK_PUBLIC void nk_sparse_intersect_u16_turin( //
33
+ nk_u16_t const *a, nk_u16_t const *b, //
34
+ nk_size_t a_length, nk_size_t b_length, //
35
+ nk_u16_t *result, nk_size_t *count) {
36
+
37
+ //! There is no such thing as `_mm512_2intersect_epi16`, only the 32-bit variant!
38
+ //! So instead of jumping through 32 entries at a time, like on Ice Lake, we will
39
+ //! step through 16 entries at a time.
40
+ nk_u16_t const *const a_end = a + a_length;
41
+ nk_u16_t const *const b_end = b + b_length;
42
+ nk_size_t c = 0;
43
+ nk_b256_vec_t a_vec, b_vec;
44
+
45
+ // Broadcast index for last element (hoisted outside loop)
46
+ __m256i const last_idx = _mm256_set1_epi16(15);
47
+ while (a + 16 <= a_end && b + 16 <= b_end) {
48
+ a_vec.ymm = _mm256_loadu_si256((__m256i const *)a);
49
+ b_vec.ymm = _mm256_loadu_si256((__m256i const *)b);
50
+
51
+ // Intersect the registers
52
+ __m512i a_i32x16 = _mm512_cvtepu16_epi32(a_vec.ymm);
53
+ __m512i b_i32x16 = _mm512_cvtepu16_epi32(b_vec.ymm);
54
+ __mmask16 a_matches_any_in_b, b_matches_any_in_a;
55
+ _mm512_2intersect_epi32(a_i32x16, b_i32x16, &a_matches_any_in_b, &b_matches_any_in_a);
56
+
57
+ // Export matches if result buffer is provided
58
+ if (result) { _mm256_mask_compressstoreu_epi16(result + c, a_matches_any_in_b, a_vec.ymm); }
59
+ c += _mm_popcnt_u32(a_matches_any_in_b); // MSVC has no `_popcnt32`
60
+
61
+ __m256i a_max_u16x16 = _mm256_permutexvar_epi16(last_idx, a_vec.ymm);
62
+ __m256i b_max_u16x16 = _mm256_permutexvar_epi16(last_idx, b_vec.ymm);
63
+ __mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_max_u16x16);
64
+ __mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_max_u16x16);
65
+ a += _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x10000);
66
+ b += _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x10000);
67
+ }
68
+
69
+ nk_size_t tail_count = 0;
70
+ nk_sparse_intersect_u16_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
71
+ *count = c + tail_count;
72
+ }
73
+
74
+ NK_PUBLIC void nk_sparse_intersect_u32_turin( //
75
+ nk_u32_t const *a, nk_u32_t const *b, //
76
+ nk_size_t a_length, nk_size_t b_length, //
77
+ nk_u32_t *result, nk_size_t *count) {
78
+
79
+ nk_u32_t const *const a_end = a + a_length;
80
+ nk_u32_t const *const b_end = b + b_length;
81
+ nk_size_t c = 0;
82
+ nk_b512_vec_t a_vec, b_vec;
83
+
84
+ // Broadcast index for last element (hoisted outside loop)
85
+ __m512i const last_idx = _mm512_set1_epi32(15);
86
+ while (a + 16 <= a_end && b + 16 <= b_end) {
87
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
88
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
89
+
90
+ // Intersect the registers
91
+ __mmask16 a_matches_any_in_b, b_matches_any_in_a;
92
+ _mm512_2intersect_epi32(a_vec.zmm, b_vec.zmm, &a_matches_any_in_b, &b_matches_any_in_a);
93
+
94
+ // Export matches if result buffer is provided
95
+ if (result) { _mm512_mask_compressstoreu_epi32(result + c, a_matches_any_in_b, a_vec.zmm); }
96
+ c += _mm_popcnt_u32(a_matches_any_in_b); // MSVC has no `_popcnt32`
97
+
98
+ // Pure SIMD broadcasts - no scalar extraction needed
99
+ __m512i a_max_u32x16 = _mm512_permutexvar_epi32(last_idx, a_vec.zmm);
100
+ __m512i b_max_u32x16 = _mm512_permutexvar_epi32(last_idx, b_vec.zmm);
101
+ __mmask16 a_step_mask = _mm512_cmple_epu32_mask(a_vec.zmm, b_max_u32x16);
102
+ __mmask16 b_step_mask = _mm512_cmple_epu32_mask(b_vec.zmm, a_max_u32x16);
103
+ a += _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x10000);
104
+ b += _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x10000);
105
+ }
106
+
107
+ nk_size_t tail_count = 0;
108
+ nk_sparse_intersect_u32_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
109
+ *count = c + tail_count;
110
+ }
111
+
112
+ NK_PUBLIC void nk_sparse_intersect_u64_turin( //
113
+ nk_u64_t const *a, nk_u64_t const *b, //
114
+ nk_size_t a_length, nk_size_t b_length, //
115
+ nk_u64_t *result, nk_size_t *count) {
116
+
117
+ nk_u64_t const *const a_end = a + a_length;
118
+ nk_u64_t const *const b_end = b + b_length;
119
+ nk_size_t c = 0;
120
+ nk_b512_vec_t a_vec, b_vec;
121
+
122
+ // Broadcast index for last element (hoisted outside loop)
123
+ __m512i const last_idx = _mm512_set1_epi64(7);
124
+ while (a + 8 <= a_end && b + 8 <= b_end) {
125
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
126
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
127
+
128
+ // Intersect the registers
129
+ __mmask8 a_matches_any_in_b, b_matches_any_in_a;
130
+ _mm512_2intersect_epi64(a_vec.zmm, b_vec.zmm, &a_matches_any_in_b, &b_matches_any_in_a);
131
+
132
+ // Export matches if result buffer is provided
133
+ if (result) { _mm512_mask_compressstoreu_epi64(result + c, a_matches_any_in_b, a_vec.zmm); }
134
+ c += _mm_popcnt_u32(a_matches_any_in_b); // MSVC has no `_popcnt32`
135
+
136
+ // Pure SIMD broadcasts - no scalar extraction needed
137
+ __m512i a_max_u64x8 = _mm512_permutexvar_epi64(last_idx, a_vec.zmm);
138
+ __m512i b_max_u64x8 = _mm512_permutexvar_epi64(last_idx, b_vec.zmm);
139
+ __mmask8 a_step_mask = _mm512_cmple_epu64_mask(a_vec.zmm, b_max_u64x8);
140
+ __mmask8 b_step_mask = _mm512_cmple_epu64_mask(b_vec.zmm, a_max_u64x8);
141
+ a += _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x100);
142
+ b += _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x100);
143
+ }
144
+
145
+ nk_size_t tail_count = 0;
146
+ nk_sparse_intersect_u64_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
147
+ *count = c + tail_count;
148
+ }
149
+
150
+ NK_PUBLIC void nk_sparse_dot_u16bf16_turin( //
151
+ nk_u16_t const *a, nk_u16_t const *b, //
152
+ nk_bf16_t const *a_weights, nk_bf16_t const *b_weights, //
153
+ nk_size_t a_length, nk_size_t b_length, //
154
+ nk_f32_t *product) {
155
+
156
+ #if NK_ALLOW_ISA_REDIRECT
157
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
158
+ if (a_length < 64 && b_length < 64) {
159
+ nk_sparse_dot_u16bf16_serial(a, b, a_weights, b_weights, a_length, b_length, product);
160
+ return;
161
+ }
162
+ #endif
163
+
164
+ //! There is no such thing as `_mm512_2intersect_epi16`, only the 32-bit variant!
165
+ //! So instead of jumping through 32 entries at a time, like on Ice Lake, we will
166
+ //! step through 16 entries at a time.
167
+ nk_u16_t const *const a_end = a + a_length;
168
+ nk_u16_t const *const b_end = b + b_length;
169
+ nk_b256_vec_t a_vec, b_vec;
170
+ __m256 product_f32x8 = _mm256_setzero_ps();
171
+
172
+ // Broadcast index for last element (hoisted outside loop)
173
+ __m256i const last_idx = _mm256_set1_epi16(15);
174
+ while (a + 16 <= a_end && b + 16 <= b_end) {
175
+ a_vec.ymm = _mm256_loadu_si256((__m256i const *)a);
176
+ b_vec.ymm = _mm256_loadu_si256((__m256i const *)b);
177
+
178
+ // Intersecting registers with `_mm512_2intersect_epi16_mask` involves a lot of shuffling
179
+ // and comparisons, so we want to avoid it if the slices don't overlap at all..
180
+ nk_u16_t a_min;
181
+ nk_u16_t a_max = a_vec.u16s[15];
182
+ nk_u16_t b_min = b_vec.u16s[0];
183
+ nk_u16_t b_max = b_vec.u16s[15];
184
+
185
+ // If the slices don't overlap, advance the appropriate pointer
186
+ while (a_max < b_min && a + 32 <= a_end) {
187
+ a += 16, a_weights += 16;
188
+ a_vec.ymm = _mm256_loadu_si256((__m256i const *)a);
189
+ a_max = a_vec.u16s[15];
190
+ }
191
+ a_min = a_vec.u16s[0];
192
+ while (b_max < a_min && b + 32 <= b_end) {
193
+ b += 16, b_weights += 16;
194
+ b_vec.ymm = _mm256_loadu_si256((__m256i const *)b);
195
+ b_max = b_vec.u16s[15];
196
+ }
197
+ b_min = b_vec.u16s[0];
198
+
199
+ // Now we are likely to have some overlap, so we can intersect the registers
200
+ __m512i a_i32x16 = _mm512_cvtepu16_epi32(a_vec.ymm);
201
+ __m512i b_i32x16 = _mm512_cvtepu16_epi32(b_vec.ymm);
202
+ __mmask16 a_matches_any_in_b, b_matches_any_in_a;
203
+ _mm512_2intersect_epi32(a_i32x16, b_i32x16, &a_matches_any_in_b, &b_matches_any_in_a);
204
+
205
+ // Load and shift all the relevant weights to the start of the vector before doing the dot product
206
+ if (a_matches_any_in_b) {
207
+ __m256i a_weights_bf16x16 = _mm256_loadu_si256((__m256i const *)a_weights);
208
+ a_weights_bf16x16 = _mm256_maskz_compress_epi16(a_matches_any_in_b, a_weights_bf16x16);
209
+ __m256i b_weights_bf16x16 = _mm256_loadu_si256((__m256i const *)b_weights);
210
+ b_weights_bf16x16 = _mm256_maskz_compress_epi16(b_matches_any_in_a, b_weights_bf16x16);
211
+ product_f32x8 = _mm256_dpbf16_ps(product_f32x8, nk_m256bh_from_m256i_(a_weights_bf16x16),
212
+ nk_m256bh_from_m256i_(b_weights_bf16x16));
213
+ }
214
+
215
+ __m256i a_max_u16x16 = _mm256_permutexvar_epi16(last_idx, a_vec.ymm);
216
+ __m256i b_max_u16x16 = _mm256_permutexvar_epi16(last_idx, b_vec.ymm);
217
+ __mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_max_u16x16);
218
+ __mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_max_u16x16);
219
+ nk_size_t a_step = _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x10000);
220
+ nk_size_t b_step = _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x10000);
221
+ a += a_step, a_weights += a_step;
222
+ b += b_step, b_weights += b_step;
223
+ }
224
+ nk_f32_t tail_product = 0;
225
+ nk_sparse_dot_u16bf16_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, &tail_product);
226
+ *product = tail_product + _mm512_reduce_add_ps(_mm512_insertf32x8(_mm512_setzero_ps(), product_f32x8, 0));
227
+ }
228
+
229
+ NK_PUBLIC void nk_sparse_dot_u32f32_turin( //
230
+ nk_u32_t const *a, nk_u32_t const *b, //
231
+ nk_f32_t const *a_weights, nk_f32_t const *b_weights, //
232
+ nk_size_t a_length, nk_size_t b_length, //
233
+ nk_f64_t *product) {
234
+
235
+ #if NK_ALLOW_ISA_REDIRECT
236
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
237
+ if (a_length < 32 && b_length < 32) {
238
+ nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_length, b_length, product);
239
+ return;
240
+ }
241
+ #endif
242
+
243
+ // Native VP2INTERSECTD works directly on u32 - no conversion needed!
244
+ nk_u32_t const *const a_end = a + a_length;
245
+ nk_u32_t const *const b_end = b + b_length;
246
+ __m512d product_lower_f64x8 = _mm512_setzero_pd();
247
+ __m512d product_upper_f64x8 = _mm512_setzero_pd();
248
+ nk_b512_vec_t a_vec, b_vec;
249
+
250
+ while (a + 16 <= a_end && b + 16 <= b_end) {
251
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
252
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
253
+
254
+ // Avoid expensive intersection if slices don't overlap at all
255
+ nk_u32_t a_min;
256
+ nk_u32_t a_max = a_vec.u32s[15];
257
+ nk_u32_t b_min = b_vec.u32s[0];
258
+ nk_u32_t b_max = b_vec.u32s[15];
259
+
260
+ // If the slices don't overlap, advance the appropriate pointer
261
+ while (a_max < b_min && a + 32 <= a_end) {
262
+ a += 16, a_weights += 16;
263
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
264
+ a_max = a_vec.u32s[15];
265
+ }
266
+ a_min = a_vec.u32s[0];
267
+ while (b_max < a_min && b + 32 <= b_end) {
268
+ b += 16, b_weights += 16;
269
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
270
+ b_max = b_vec.u32s[15];
271
+ }
272
+ b_min = b_vec.u32s[0];
273
+
274
+ // Native u32 intersection - no conversion needed!
275
+ __mmask16 a_matches, b_matches;
276
+ _mm512_2intersect_epi32(a_vec.zmm, b_vec.zmm, &a_matches, &b_matches);
277
+
278
+ // Load and compress matching weights, then FMA
279
+ if (a_matches) {
280
+ __m512 a_weights_f32x16 = _mm512_loadu_ps(a_weights);
281
+ __m512 b_weights_f32x16 = _mm512_loadu_ps(b_weights);
282
+ __m512 a_matched_f32x16 = _mm512_maskz_compress_ps(a_matches, a_weights_f32x16);
283
+ __m512 b_matched_f32x16 = _mm512_maskz_compress_ps(b_matches, b_weights_f32x16);
284
+ __m256 a_matched_lower_f32x8 = _mm512_castps512_ps256(a_matched_f32x16);
285
+ __m256 a_matched_upper_f32x8 = _mm512_extractf32x8_ps(a_matched_f32x16, 1);
286
+ __m256 b_matched_lower_f32x8 = _mm512_castps512_ps256(b_matched_f32x16);
287
+ __m256 b_matched_upper_f32x8 = _mm512_extractf32x8_ps(b_matched_f32x16, 1);
288
+
289
+ product_lower_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_lower_f32x8),
290
+ _mm512_cvtps_pd(b_matched_lower_f32x8), product_lower_f64x8);
291
+ product_upper_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_upper_f32x8),
292
+ _mm512_cvtps_pd(b_matched_upper_f32x8), product_upper_f64x8);
293
+ }
294
+
295
+ __m512i a_max_u32x16 = _mm512_set1_epi32(*(int const *)&a_max);
296
+ __m512i b_max_u32x16 = _mm512_set1_epi32(*(int const *)&b_max);
297
+ __mmask16 a_step_mask = _mm512_cmple_epu32_mask(a_vec.zmm, b_max_u32x16);
298
+ __mmask16 b_step_mask = _mm512_cmple_epu32_mask(b_vec.zmm, a_max_u32x16);
299
+ nk_size_t a_step = _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x10000);
300
+ nk_size_t b_step = _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x10000);
301
+ a += a_step, a_weights += a_step;
302
+ b += b_step, b_weights += b_step;
303
+ }
304
+
305
+ nk_f64_t tail_product = 0;
306
+ nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, &tail_product);
307
+ *product = _mm512_reduce_add_pd(product_lower_f64x8) + _mm512_reduce_add_pd(product_upper_f64x8) + tail_product;
308
+ }
309
+
310
+ #if defined(__clang__)
311
+ #pragma clang attribute pop
312
+ #elif defined(__GNUC__)
313
+ #pragma GCC pop_options
314
+ #endif
315
+
316
+ #if defined(__cplusplus)
317
+ } // extern "C"
318
+ #endif
319
+
320
+ #endif // NK_TARGET_TURIN
321
+ #endif // NK_TARGET_X86_
322
+ #endif // NK_SPARSE_TURIN_H
@@ -0,0 +1,363 @@
1
+ /**
2
+ * @brief SIMD-accelerated Sparse Vector Dot Products.
3
+ * @file include/numkong/sparse.h
4
+ * @author Ash Vardanian
5
+ * @date March 21, 2024
6
+ *
7
+ * Contains:
8
+ *
9
+ * - Set intersection for sorted unique arrays → `u32` count
10
+ * - Sparse dot products for weighted sparse vectors
11
+ *
12
+ * For dtypes:
13
+ *
14
+ * - `u16`: indices for vocabularies under 64 thousand tokens
15
+ * - `u32`: indices for vocabularies under 4 billion tokens
16
+ * - `u64`: indices for trillion-scale combinatorics and graphs
17
+ * - `u16` indices + `bf16` weights → `f32` product
18
+ * - `u32` indices + `f32` weights → `f64` product
19
+ *
20
+ * For hardware architectures:
21
+ *
22
+ * - Arm: NEON, SVE2
23
+ * - x86: Ice Lake, Turin
24
+ *
25
+ * @section intersection_algorithm Intersection by Merge
26
+ *
27
+ * The core primitive is analogous to `std::set_intersection`, taking two sorted arrays
28
+ * of unique values and producing the intersection size:
29
+ *
30
+ * std::size_t intersection_size = 0;
31
+ * while (i != a_length && j != b_length) {
32
+ * scalar_t ai = a[i], bj = b[j];
33
+ * intersection_size += ai == bj;
34
+ * i += ai < bj;
35
+ * j += ai ≥ bj;
36
+ * }
37
+ *
38
+ * Weighted sparse dot-products follow the same merge loop, but accumulate a product
39
+ * for matching indices. For the `u32+f32` family the matched products are widened before
40
+ * accumulation, matching the widened `f64` public result.
41
+ *
42
+ * double product = 0;
43
+ * while (i != a_length && j != b_length) {
44
+ * scalar_t ai = a[i], bj = b[j];
45
+ * product += ai == bj ? a_weights[i] * b_weights[j] : 0;
46
+ * i += ai < bj;
47
+ * j += ai ≥ bj;
48
+ * }
49
+ *
50
+ * @section galloping_search Galloping vs Linear
51
+ *
52
+ * When the arrays are highly imbalanced, linear merge wastes cycles skipping elements.
53
+ * The serial implementation switches to a galloping search to jump over large gaps.
54
+ *
55
+ * @section x86_instructions Relevant x86 Instructions
56
+ *
57
+ * The Ice Lake kernels are shuffle/compare heavy; their throughput is often gated by port 5.
58
+ * On Genoa, many integer ops dual-issue on FP ports, often improving throughput despite higher latency.
59
+ *
60
+ * Intrinsic Instruction Ice Genoa
61
+ * _mm512_shuffle_epi32 VPSHUFD (ZMM, ZMM, I8) 1c @ p5 1c @ p123
62
+ * _mm512_mask_cmpneq_epi32_mask VPCMPD (K, ZMM, ZMM, I8) 3c @ p5 5c @ p01
63
+ * _mm512_alignr_epi32 VALIGND (ZMM, ZMM, ZMM, I8) 3c @ p5 6c @ p12
64
+ * _mm512_conflict_epi32 VPCONFLICTD (ZMM, ZMM) 26c @ p0/5 7c @ p01/12
65
+ * _mm256_maskz_compress_epi16 VPCOMPRESSW (YMM, K, YMM) 3-6c @ p5 4-8c @ p01/12
66
+ * _mm256_dpwssds_epi32 VPDPWSSDS (YMM, K, YMM, YMM) 4-5c @ p01 4c @ p01
67
+ * _mm256_dpbf16_ps VDPBF16PS (YMM, YMM, YMM) n/a 6c @ p01
68
+ *
69
+ * VP2INTERSECTD is unsupported on Ice Lake and not yet covered by uops.info for Zen5/Turin.
70
+ * Tiger Lake measures ~36-41c @ p5 for ZMM variants, which is why we always avoid it on Intel.
71
+ *
72
+ * @section references References
73
+ *
74
+ * - uops.info: https://uops.info/
75
+ * - Intel Intrinsics Guide: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
76
+ * - Arm Intrinsics Reference: https://developer.arm.com/architectures/instruction-sets/intrinsics/
77
+ * - vp2intersect experiments: https://github.com/mozonaut/vp2intersect
78
+ * - Diez-Canas "Faster-Than-Native Alternatives for x86 VP2INTERSECT Instructions":
79
+ * https://arxiv.org/pdf/2112.06342.pdf
80
+ *
81
+ */
82
+ #ifndef NK_SPARSE_H
83
+ #define NK_SPARSE_H
84
+
85
+ #include "numkong/types.h"
86
+
87
+ #if defined(__cplusplus)
88
+ extern "C" {
89
+ #endif
90
+
91
+ /**
92
+ * @brief Set intersection between two sorted u16 arrays.
93
+ *
94
+ * @param[in] a The first sorted array of indices.
95
+ * @param[in] b The second sorted array of indices.
96
+ * @param[in] a_length The number of elements in the first array.
97
+ * @param[in] b_length The number of elements in the second array.
98
+ * @param[out] result Output buffer for intersection elements, or NULL to count only.
99
+ * @param[out] count The output intersection count.
100
+ *
101
+ * @note Inputs must be sorted in ascending order and contain unique elements.
102
+ */
103
+ NK_DYNAMIC void nk_sparse_intersect_u16( //
104
+ nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length, nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
105
+
106
+ /**
107
+ * @brief Set intersection between two sorted u32 arrays.
108
+ *
109
+ * @param[in] a The first sorted array of indices.
110
+ * @param[in] b The second sorted array of indices.
111
+ * @param[in] a_length The number of elements in the first array.
112
+ * @param[in] b_length The number of elements in the second array.
113
+ * @param[out] result Output buffer for intersection elements, or NULL to count only.
114
+ * @param[out] count The output intersection count.
115
+ *
116
+ * @note Inputs must be sorted in ascending order and contain unique elements.
117
+ */
118
+ NK_DYNAMIC void nk_sparse_intersect_u32( //
119
+ nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length, nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
120
+
121
+ /**
122
+ * @brief Set intersection between two sorted u64 arrays.
123
+ *
124
+ * @param[in] a The first sorted array of indices.
125
+ * @param[in] b The second sorted array of indices.
126
+ * @param[in] a_length The number of elements in the first array.
127
+ * @param[in] b_length The number of elements in the second array.
128
+ * @param[out] result Output buffer for intersection elements, or NULL to count only.
129
+ * @param[out] count The output intersection count.
130
+ *
131
+ * @note Inputs must be sorted in ascending order and contain unique elements.
132
+ */
133
+ NK_DYNAMIC void nk_sparse_intersect_u64( //
134
+ nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length, nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
135
+
136
+ /**
137
+ * @brief Sparse dot-product over u16 indices with bf16 weights.
138
+ *
139
+ * @param[in] a The first sorted array of indices.
140
+ * @param[in] b The second sorted array of indices.
141
+ * @param[in] a_weights The bf16 weights for the first array.
142
+ * @param[in] b_weights The bf16 weights for the second array.
143
+ * @param[in] a_length The number of elements in the first array.
144
+ * @param[in] b_length The number of elements in the second array.
145
+ * @param[out] product The output dot product.
146
+ *
147
+ * @note Inputs must be sorted in ascending order and contain unique elements.
148
+ */
149
+ NK_DYNAMIC void nk_sparse_dot_u16bf16( //
150
+ nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights, nk_bf16_t const *b_weights, nk_size_t a_length,
151
+ nk_size_t b_length, nk_f32_t *product);
152
+
153
+ /**
154
+ * @brief Sparse dot-product over u32 indices with f32 weights.
155
+ *
156
+ * @param[in] a The first sorted array of indices.
157
+ * @param[in] b The second sorted array of indices.
158
+ * @param[in] a_weights The f32 weights for the first array.
159
+ * @param[in] b_weights The f32 weights for the second array.
160
+ * @param[in] a_length The number of elements in the first array.
161
+ * @param[in] b_length The number of elements in the second array.
162
+ * @param[out] product The output dot product.
163
+ *
164
+ * @note Inputs must be sorted in ascending order and contain unique elements.
165
+ */
166
+ NK_DYNAMIC void nk_sparse_dot_u32f32( //
167
+ nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights, nk_f32_t const *b_weights, nk_size_t a_length,
168
+ nk_size_t b_length, nk_f64_t *product);
169
+
170
+ /** @copydoc nk_sparse_intersect_u16 */
171
+ NK_PUBLIC void nk_sparse_intersect_u16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
172
+ nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
173
+ /** @copydoc nk_sparse_intersect_u32 */
174
+ NK_PUBLIC void nk_sparse_intersect_u32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
175
+ nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
176
+ /** @copydoc nk_sparse_intersect_u64 */
177
+ NK_PUBLIC void nk_sparse_intersect_u64_serial(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
178
+ nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
179
+ /** @copydoc nk_sparse_dot_u16bf16 */
180
+ NK_PUBLIC void nk_sparse_dot_u16bf16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights,
181
+ nk_bf16_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
182
+ nk_f32_t *product);
183
+ /** @copydoc nk_sparse_dot_u32f32 */
184
+ NK_PUBLIC void nk_sparse_dot_u32f32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
185
+ nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
186
+ nk_f64_t *product);
187
+
188
+ #if NK_TARGET_NEON
189
+ /** @copydoc nk_sparse_intersect_u16 */
190
+ NK_PUBLIC void nk_sparse_intersect_u16_neon(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
191
+ nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
192
+ /** @copydoc nk_sparse_intersect_u32 */
193
+ NK_PUBLIC void nk_sparse_intersect_u32_neon(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
194
+ nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
195
+ /** @copydoc nk_sparse_intersect_u64 */
196
+ NK_PUBLIC void nk_sparse_intersect_u64_neon(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
197
+ nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
198
+ #endif // NK_TARGET_NEON
199
+
200
+ #if NK_TARGET_SVE2
201
+ /** @copydoc nk_sparse_intersect_u16 */
202
+ NK_PUBLIC void nk_sparse_intersect_u16_sve2(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
203
+ nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
204
+ /** @copydoc nk_sparse_intersect_u32 */
205
+ NK_PUBLIC void nk_sparse_intersect_u32_sve2(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
206
+ nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
207
+ /** @copydoc nk_sparse_intersect_u64 */
208
+ NK_PUBLIC void nk_sparse_intersect_u64_sve2(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
209
+ nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
210
+ /** @copydoc nk_sparse_dot_u32f32 */
211
+ NK_PUBLIC void nk_sparse_dot_u32f32_sve2(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
212
+ nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
213
+ nk_f64_t *product);
214
+ #endif // NK_TARGET_SVE2
215
+
216
+ #if NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
217
+ /** @copydoc nk_sparse_dot_u16bf16 */
218
+ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2(nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights,
219
+ nk_bf16_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
220
+ nk_f32_t *product);
221
+ #endif // NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
222
+
223
+ #if NK_TARGET_ICELAKE
224
+ /** @copydoc nk_sparse_intersect_u16 */
225
+ NK_PUBLIC void nk_sparse_intersect_u16_icelake(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
226
+ nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
227
+ /** @copydoc nk_sparse_intersect_u32 */
228
+ NK_PUBLIC void nk_sparse_intersect_u32_icelake(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
229
+ nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
230
+ /** @copydoc nk_sparse_intersect_u64 */
231
+ NK_PUBLIC void nk_sparse_intersect_u64_icelake(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
232
+ nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
233
+ /** @copydoc nk_sparse_dot_u32f32 */
234
+ NK_PUBLIC void nk_sparse_dot_u32f32_icelake(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
235
+ nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
236
+ nk_f64_t *product);
237
+ #endif // NK_TARGET_ICELAKE
238
+
239
+ #if NK_TARGET_TURIN
240
+ /** @copydoc nk_sparse_intersect_u16 */
241
+ NK_PUBLIC void nk_sparse_intersect_u16_turin(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
242
+ nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
243
+ /** @copydoc nk_sparse_intersect_u32 */
244
+ NK_PUBLIC void nk_sparse_intersect_u32_turin(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
245
+ nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
246
+ /** @copydoc nk_sparse_intersect_u64 */
247
+ NK_PUBLIC void nk_sparse_intersect_u64_turin(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
248
+ nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
249
+ /** @copydoc nk_sparse_dot_u16bf16 */
250
+ NK_PUBLIC void nk_sparse_dot_u16bf16_turin(nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights,
251
+ nk_bf16_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
252
+ nk_f32_t *product);
253
+ /** @copydoc nk_sparse_dot_u32f32 */
254
+ NK_PUBLIC void nk_sparse_dot_u32f32_turin(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
255
+ nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
256
+ nk_f64_t *product);
257
+ #endif // NK_TARGET_TURIN
258
+
259
+ /**
260
+ * @brief Returns the output dtype for sparse dot products.
261
+ */
262
+ NK_INTERNAL nk_dtype_t nk_sparse_dot_output_dtype(nk_dtype_t dtype) {
263
+ switch (dtype) {
264
+ case nk_f32_k: return nk_f64_k;
265
+ case nk_bf16_k: return nk_f32_k;
266
+ default: return nk_dtype_unknown_k;
267
+ }
268
+ }
269
+
270
+ #if defined(__cplusplus)
271
+ } // extern "C"
272
+ #endif
273
+
274
+ #include "numkong/sparse/serial.h"
275
+ #include "numkong/sparse/neon.h"
276
+ #include "numkong/sparse/sve2.h"
277
+ #include "numkong/sparse/icelake.h"
278
+ #include "numkong/sparse/turin.h"
279
+
280
+ #if defined(__cplusplus)
281
+ extern "C" {
282
+ #endif
283
+
284
+ #if !NK_DYNAMIC_DISPATCH
285
+
286
+ NK_PUBLIC void nk_sparse_intersect_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length, nk_size_t b_length,
287
+ nk_u16_t *result, nk_size_t *count) {
288
+ #if NK_TARGET_SVE2
289
+ nk_sparse_intersect_u16_sve2(a, b, a_length, b_length, result, count);
290
+ #elif NK_TARGET_NEON
291
+ nk_sparse_intersect_u16_neon(a, b, a_length, b_length, result, count);
292
+ #elif NK_TARGET_TURIN
293
+ nk_sparse_intersect_u16_turin(a, b, a_length, b_length, result, count);
294
+ #elif NK_TARGET_ICELAKE
295
+ nk_sparse_intersect_u16_icelake(a, b, a_length, b_length, result, count);
296
+ #else
297
+ nk_sparse_intersect_u16_serial(a, b, a_length, b_length, result, count);
298
+ #endif
299
+ }
300
+
301
+ NK_PUBLIC void nk_sparse_intersect_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length, nk_size_t b_length,
302
+ nk_u32_t *result, nk_size_t *count) {
303
+ #if NK_TARGET_SVE2
304
+ nk_sparse_intersect_u32_sve2(a, b, a_length, b_length, result, count);
305
+ #elif NK_TARGET_NEON
306
+ nk_sparse_intersect_u32_neon(a, b, a_length, b_length, result, count);
307
+ #elif NK_TARGET_TURIN
308
+ nk_sparse_intersect_u32_turin(a, b, a_length, b_length, result, count);
309
+ #elif NK_TARGET_ICELAKE
310
+ nk_sparse_intersect_u32_icelake(a, b, a_length, b_length, result, count);
311
+ #else
312
+ nk_sparse_intersect_u32_serial(a, b, a_length, b_length, result, count);
313
+ #endif
314
+ }
315
+
316
+ NK_PUBLIC void nk_sparse_intersect_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length, nk_size_t b_length,
317
+ nk_u64_t *result, nk_size_t *count) {
318
+ #if NK_TARGET_SVE2
319
+ nk_sparse_intersect_u64_sve2(a, b, a_length, b_length, result, count);
320
+ #elif NK_TARGET_NEON
321
+ nk_sparse_intersect_u64_neon(a, b, a_length, b_length, result, count);
322
+ #elif NK_TARGET_TURIN
323
+ nk_sparse_intersect_u64_turin(a, b, a_length, b_length, result, count);
324
+ #elif NK_TARGET_ICELAKE
325
+ nk_sparse_intersect_u64_icelake(a, b, a_length, b_length, result, count);
326
+ #else
327
+ nk_sparse_intersect_u64_serial(a, b, a_length, b_length, result, count);
328
+ #endif
329
+ }
330
+
331
+ NK_PUBLIC void nk_sparse_dot_u16bf16(nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights,
332
+ nk_bf16_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
333
+ nk_f32_t *product) {
334
+ #if NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
335
+ nk_sparse_dot_u16bf16_sve2(a, b, a_weights, b_weights, a_length, b_length, product);
336
+ #elif NK_TARGET_TURIN
337
+ nk_sparse_dot_u16bf16_turin(a, b, a_weights, b_weights, a_length, b_length, product);
338
+ #else
339
+ nk_sparse_dot_u16bf16_serial(a, b, a_weights, b_weights, a_length, b_length, product);
340
+ #endif
341
+ }
342
+
343
+ NK_PUBLIC void nk_sparse_dot_u32f32(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
344
+ nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
345
+ nk_f64_t *product) {
346
+ #if NK_TARGET_SVE2
347
+ nk_sparse_dot_u32f32_sve2(a, b, a_weights, b_weights, a_length, b_length, product);
348
+ #elif NK_TARGET_TURIN
349
+ nk_sparse_dot_u32f32_turin(a, b, a_weights, b_weights, a_length, b_length, product);
350
+ #elif NK_TARGET_ICELAKE
351
+ nk_sparse_dot_u32f32_icelake(a, b, a_weights, b_weights, a_length, b_length, product);
352
+ #else
353
+ nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_length, b_length, product);
354
+ #endif
355
+ }
356
+
357
+ #endif // !NK_DYNAMIC_DISPATCH
358
+
359
+ #if defined(__cplusplus)
360
+ } // extern "C"
361
+ #endif
362
+
363
+ #endif