numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,507 @@
1
+ /**
2
+ * @brief SVE2-accelerated Sparse Vector Operations.
3
+ * @file include/numkong/sparse/sve2.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ *
7
+ * @sa include/numkong/sparse.h
8
+ */
9
+ #ifndef NK_SPARSE_SVE2_H
10
+ #define NK_SPARSE_SVE2_H
11
+
12
+ #if NK_TARGET_ARM_
13
+
14
+ #include "numkong/types.h"
15
+
16
+ #if defined(__cplusplus)
17
+ extern "C" {
18
+ #endif
19
+
20
+ /* SVE2 introduces many new integer-oriented instructions, extending some of the NEON functionality
21
+ * to variable-length SVE registers. Those include "compare multiple" intrinsics:
22
+ *
23
+ * - `svmatch[_u16]` that matches each scalar in first vector against all members of a 128-bit lane in the second.
24
+ * - `svhistcnt[_s32]_z` does something similar, performing an inclusive prefix scan.
25
+ * - `svtbx[_u16]` does extended table lookup
26
+ *
27
+ * Other notable instructions:
28
+ *
29
+ * - `DUP`: Broadcast indexed predicate element
30
+ * https://developer.arm.com/documentation/ddi0602/2021-06/SVE-Instructions/DUP--predicate---Broadcast-indexed-predicate-element-?lang=en
31
+ * - `SCLAMP` and `UCLAMP`: clamp values, i.e. combined min+max
32
+ * https://developer.arm.com/documentation/ddi0602/2021-06/SVE-Instructions/SCLAMP--Signed-clamp-to-minimum-maximum-vector-?lang=en
33
+ * https://developer.arm.com/documentation/ddi0602/2021-06/SVE-Instructions/UCLAMP--Unsigned-clamp-to-minimum-maximum-vector-?lang=en
34
+ * - `TBLQ`: Table lookup quadword
35
+ * https://developer.arm.com/documentation/ddi0602/2022-12/SVE-Instructions/TBLQ--Programmable-table-lookup-within-each-quadword-vector-segment--zeroing--?lang=en
36
+ *
37
+ * Great resources for SVE2 intrinsics:
38
+ *
39
+ * > ARM's Scalable Vector Extensions: A Critical Look at SVE2 For Integer Workloads
40
+ * https://gist.github.com/zingaburga/805669eb891c820bd220418ee3f0d6bd
41
+ */
42
+ #if NK_TARGET_SVE2
43
+ #if defined(__clang__)
44
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+sve2"))), apply_to = function)
45
+ #elif defined(__GNUC__)
46
+ #pragma GCC push_options
47
+ #pragma GCC target("arch=armv8.2-a+sve+sve2")
48
+ #endif
49
+
50
+ NK_PUBLIC void nk_sparse_intersect_u16_sve2( //
51
+ nk_u16_t const *a, nk_u16_t const *b, //
52
+ nk_size_t a_length, nk_size_t b_length, //
53
+ nk_u16_t *result, nk_size_t *count) {
54
+
55
+ // A single SVE lane is 128 bits wide, so one lane fits 8 values.
56
+ nk_size_t const register_size = svcnth();
57
+ nk_size_t const lanes_count = register_size / 8;
58
+ nk_size_t a_idx = 0, b_idx = 0;
59
+ nk_size_t c = 0;
60
+
61
+ while (a_idx < a_length && b_idx < b_length) {
62
+ // Load `a_member` and broadcast it, load `b_members_vec` from memory
63
+ svbool_t a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
64
+ svbool_t b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
65
+ svuint16_t a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
66
+ svuint16_t b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
67
+
68
+ // Intersecting registers with `svmatch_u16` involves a lot of shuffling
69
+ // and comparisons, so we want to avoid it if the slices don't overlap at all..
70
+ nk_u16_t a_min;
71
+ nk_u16_t a_max = svlastb(a_progress_u16x, a_u16x);
72
+ nk_u16_t b_min = svlasta(svpfalse_b(), b_u16x);
73
+ nk_u16_t b_max = svlastb(b_progress_u16x, b_u16x);
74
+
75
+ // If the slices don't overlap, advance the appropriate pointer
76
+ while (a_max < b_min && (a_idx + register_size) <= a_length) {
77
+ a_idx += register_size;
78
+ a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
79
+ a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
80
+ a_max = svlastb(a_progress_u16x, a_u16x);
81
+ }
82
+ a_min = svlasta(svpfalse_b(), a_u16x);
83
+ while (b_max < a_min && (b_idx + register_size) <= b_length) {
84
+ b_idx += register_size;
85
+ b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
86
+ b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
87
+ b_max = svlastb(b_progress_u16x, b_u16x);
88
+ }
89
+ b_min = svlasta(svpfalse_b(), b_u16x);
90
+
91
+ // Before we evaluate the intersection size, obfurscating the order in `b_u16x`,
92
+ // let's estimate how much we will need to advance the pointers afterwards.
93
+ // For that, we don't even need to broadcast the values in SVE, as the whole
94
+ // register can be compared against a scalar:
95
+ //
96
+ // svuint16_t a_last_broadcasted = svdup_n_u16(a_max);
97
+ // svuint16_t b_last_broadcasted = svdup_n_u16(b_max);
98
+ svbool_t a_mask_u16x = svcmple_n_u16(a_progress_u16x, a_u16x, b_max);
99
+ svbool_t b_mask_u16x = svcmple_n_u16(b_progress_u16x, b_u16x, a_max);
100
+ nk_u64_t a_step = svcntp_b16(a_progress_u16x, a_mask_u16x);
101
+ nk_u64_t b_step = svcntp_b16(b_progress_u16x, b_mask_u16x);
102
+
103
+ // Compare `a_u16x` with each lane of `b_u16x`
104
+ svbool_t equal_mask = svmatch_u16(a_progress_u16x, a_u16x, b_u16x);
105
+ for (nk_size_t i = 1; i < lanes_count; i++) {
106
+ b_u16x = svext_u16(b_u16x, b_u16x, 8);
107
+ equal_mask = svorr_z(svptrue_b16(), equal_mask, svmatch_u16(a_progress_u16x, a_u16x, b_u16x));
108
+ }
109
+ nk_size_t equal_count = svcntp_b16(svptrue_b16(), equal_mask);
110
+
111
+ // Manually compact and store matching elements (svcompact_u16 is not defined)
112
+ if (result) {
113
+ nk_u16_t a_data[16];
114
+ nk_u16_t mask_data[16];
115
+
116
+ svst1_u16(svptrue_b16(), a_data, a_u16x);
117
+ svst1_u16(svptrue_b16(), mask_data, svdup_n_u16_z(equal_mask, 1));
118
+
119
+ for (nk_size_t i = 0; i < svcnth(); i++)
120
+ if (mask_data[i]) result[c++] = a_data[i];
121
+ c -= equal_count;
122
+ }
123
+
124
+ // Advance
125
+ a_idx += a_step;
126
+ b_idx += b_step;
127
+ c += equal_count;
128
+ }
129
+ *count = c;
130
+ }
131
+
132
+ NK_PUBLIC void nk_sparse_intersect_u32_sve2( //
133
+ nk_u32_t const *a, nk_u32_t const *b, //
134
+ nk_size_t a_length, nk_size_t b_length, //
135
+ nk_u32_t *result, nk_size_t *count) {
136
+
137
+ // A single SVE lane is 128 bits wide, so one lane fits 4 values.
138
+ nk_size_t const register_size = svcntw();
139
+ nk_size_t const lanes_count = register_size / 4;
140
+ nk_size_t a_idx = 0, b_idx = 0;
141
+ nk_size_t c = 0;
142
+
143
+ while (a_idx < a_length && b_idx < b_length) {
144
+ // Load `a_member` and broadcast it, load `b_members_vec` from memory
145
+ svbool_t a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
146
+ svbool_t b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
147
+ svuint32_t a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
148
+ svuint32_t b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
149
+
150
+ // Intersecting registers with `svmatch_u16` involves a lot of shuffling
151
+ // and comparisons, so we want to avoid it if the slices don't overlap at all..
152
+ nk_u32_t a_min;
153
+ nk_u32_t a_max = svlastb(a_progress_u32x, a_u32x);
154
+ nk_u32_t b_min = svlasta(svpfalse_b(), b_u32x);
155
+ nk_u32_t b_max = svlastb(b_progress_u32x, b_u32x);
156
+
157
+ // If the slices don't overlap, advance the appropriate pointer
158
+ while (a_max < b_min && (a_idx + register_size) <= a_length) {
159
+ a_idx += register_size;
160
+ a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
161
+ a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
162
+ a_max = svlastb(a_progress_u32x, a_u32x);
163
+ }
164
+ a_min = svlasta(svpfalse_b(), a_u32x);
165
+ while (b_max < a_min && (b_idx + register_size) <= b_length) {
166
+ b_idx += register_size;
167
+ b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
168
+ b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
169
+ b_max = svlastb(b_progress_u32x, b_u32x);
170
+ }
171
+ b_min = svlasta(svpfalse_b(), b_u32x);
172
+
173
+ // Before we evaluate the intersection size, obfurscating the order in `b_u32x`,
174
+ // let's estimate how much we will need to advance the pointers afterwards.
175
+ // For that, we don't even need to broadcast the values in SVE, as the whole
176
+ // register can be compared against a scalar:
177
+ //
178
+ // svuint32_t a_last_broadcasted = svdup_n_u32(a_max);
179
+ // svuint32_t b_last_broadcasted = svdup_n_u32(b_max);
180
+ svbool_t a_mask_u32x = svcmple_n_u32(a_progress_u32x, a_u32x, b_max);
181
+ svbool_t b_mask_u32x = svcmple_n_u32(b_progress_u32x, b_u32x, a_max);
182
+ nk_u64_t a_step = svcntp_b32(a_progress_u32x, a_mask_u32x);
183
+ nk_u64_t b_step = svcntp_b32(b_progress_u32x, b_mask_u32x);
184
+
185
+ // Comparing `a_u32x` with each lane of `b_u32x` can't be done with `svmatch`,
186
+ // the same way as in `nk_sparse_intersect_u16_sve2`, as that instruction is only
187
+ // available for 8-bit and 16-bit integers.
188
+ //
189
+ // svbool_t equal_mask = svpfalse_b();
190
+ // for (nk_size_t i = 0; i < register_size; i++) {
191
+ // equal_mask = svorr_z(svptrue_b32(), equal_mask, svcmpeq_u32(a_progress, a_u32x, b_u32x));
192
+ // b_u32x = svext_u32(b_u32x, b_u32x, 1);
193
+ // }
194
+ // nk_size_t equal_count = svcntp_b32(a_progress, equal_mask);
195
+ //
196
+ // Alternatively, one can use histogram instructions, like `svhistcnt_u32_z`.
197
+ // They practically compute the prefix-matching count, which is equivalent to
198
+ // the lower triangle of the row-major intersection matrix.
199
+ // To compute the upper triangle, we can reverse (with `svrev_b32`) the order of
200
+ // elements and repeat the operation, accumulating the results for top and bottom.
201
+ // Let's look at 4x element registers as an example:
202
+ //
203
+ // ⊐ α = {A, B, C, D}, β = {X, Y, Z, W}:
204
+ //
205
+ // hist(α, β): hist(α_rev, β_rev):
206
+ //
207
+ // X Y Z W W Z Y X
208
+ // A 1 0 0 0 D 1 0 0 0
209
+ // B 1 1 0 0 C 1 1 0 0
210
+ // C 1 1 1 0 B 1 1 1 0
211
+ // D 1 1 1 1 A 1 1 1 1
212
+ //
213
+ svuint32_t hist_lower = svhistcnt_u32_z(a_progress_u32x, a_u32x, b_u32x);
214
+ svuint32_t a_rev_u32x = svrev_u32(a_u32x);
215
+ svuint32_t b_rev_u32x = svrev_u32(b_u32x);
216
+ svuint32_t hist_upper = svrev_u32(svhistcnt_u32_z(svptrue_b32(), a_rev_u32x, b_rev_u32x));
217
+ svuint32_t hist = svorr_u32_x(a_progress_u32x, hist_lower, hist_upper);
218
+ svbool_t equal_mask = svcmpne_n_u32(a_progress_u32x, hist, 0);
219
+ nk_size_t equal_count = svcntp_b32(a_progress_u32x, equal_mask);
220
+
221
+ // Use SVE2 svcompact to compress matching elements and store to result buffer
222
+ if (result) {
223
+ svuint32_t compacted = svcompact_u32(equal_mask, a_u32x);
224
+ svbool_t store_predicate = svwhilelt_b32_u64(0, equal_count);
225
+ svst1_u32(store_predicate, result + c, compacted);
226
+ }
227
+
228
+ // Advance
229
+ a_idx += a_step;
230
+ b_idx += b_step;
231
+ c += equal_count;
232
+ }
233
+ *count = c;
234
+ }
235
+
236
+ NK_PUBLIC void nk_sparse_intersect_u64_sve2( //
237
+ nk_u64_t const *a, nk_u64_t const *b, //
238
+ nk_size_t a_length, nk_size_t b_length, //
239
+ nk_u64_t *result, nk_size_t *count) {
240
+
241
+ // A single SVE lane is 128 bits wide, so one lane fits 2 values.
242
+ nk_size_t const register_size = svcntd();
243
+ nk_size_t const lanes_count = register_size / 2;
244
+ nk_size_t a_idx = 0, b_idx = 0;
245
+ nk_size_t c = 0;
246
+
247
+ while (a_idx < a_length && b_idx < b_length) {
248
+ // Load `a_member` and broadcast it, load `b_members_vec` from memory
249
+ svbool_t a_progress_u64x = svwhilelt_b64_u64(a_idx, a_length);
250
+ svbool_t b_progress_u64x = svwhilelt_b64_u64(b_idx, b_length);
251
+ svuint64_t a_u64x = svld1_u64(a_progress_u64x, a + a_idx);
252
+ svuint64_t b_u64x = svld1_u64(b_progress_u64x, b + b_idx);
253
+
254
+ // Intersecting registers involves comparisons,
255
+ // so we want to avoid it if the slices don't overlap at all.
256
+ nk_u64_t a_min;
257
+ nk_u64_t a_max = svlastb(a_progress_u64x, a_u64x);
258
+ nk_u64_t b_min = svlasta(svpfalse_b(), b_u64x);
259
+ nk_u64_t b_max = svlastb(b_progress_u64x, b_u64x);
260
+
261
+ // If the slices don't overlap, advance the appropriate pointer
262
+ while (a_max < b_min && (a_idx + register_size) <= a_length) {
263
+ a_idx += register_size;
264
+ a_progress_u64x = svwhilelt_b64_u64(a_idx, a_length);
265
+ a_u64x = svld1_u64(a_progress_u64x, a + a_idx);
266
+ a_max = svlastb(a_progress_u64x, a_u64x);
267
+ }
268
+ a_min = svlasta(svpfalse_b(), a_u64x);
269
+ while (b_max < a_min && (b_idx + register_size) <= b_length) {
270
+ b_idx += register_size;
271
+ b_progress_u64x = svwhilelt_b64_u64(b_idx, b_length);
272
+ b_u64x = svld1_u64(b_progress_u64x, b + b_idx);
273
+ b_max = svlastb(b_progress_u64x, b_u64x);
274
+ }
275
+ b_min = svlasta(svpfalse_b(), b_u64x);
276
+
277
+ // Estimate how much we will need to advance the pointers afterwards.
278
+ svbool_t a_mask_u64x = svcmple_n_u64(a_progress_u64x, a_u64x, b_max);
279
+ svbool_t b_mask_u64x = svcmple_n_u64(b_progress_u64x, b_u64x, a_max);
280
+ nk_u64_t a_step = svcntp_b64(a_progress_u64x, a_mask_u64x);
281
+ nk_u64_t b_step = svcntp_b64(b_progress_u64x, b_mask_u64x);
282
+
283
+ // Use histogram instructions like `svhistcnt_u64_z` to compute intersection.
284
+ // They compute the prefix-matching count, equivalent to the lower triangle
285
+ // of the row-major intersection matrix.
286
+ svuint64_t hist_lower = svhistcnt_u64_z(a_progress_u64x, a_u64x, b_u64x);
287
+ svuint64_t a_rev_u64x = svrev_u64(a_u64x);
288
+ svuint64_t b_rev_u64x = svrev_u64(b_u64x);
289
+ svuint64_t hist_upper = svrev_u64(svhistcnt_u64_z(svptrue_b64(), a_rev_u64x, b_rev_u64x));
290
+ svuint64_t hist = svorr_u64_x(a_progress_u64x, hist_lower, hist_upper);
291
+ svbool_t equal_mask = svcmpne_n_u64(a_progress_u64x, hist, 0);
292
+ nk_size_t equal_count = svcntp_b64(a_progress_u64x, equal_mask);
293
+
294
+ // Use SVE2 svcompact to compress matching elements and store to result buffer
295
+ if (result) {
296
+ svuint64_t compacted = svcompact_u64(equal_mask, a_u64x);
297
+ svbool_t store_predicate = svwhilelt_b64_u64(0, equal_count);
298
+ svst1_u64(store_predicate, result + c, compacted);
299
+ }
300
+
301
+ // Advance
302
+ a_idx += a_step;
303
+ b_idx += b_step;
304
+ c += equal_count;
305
+ }
306
+ *count = c;
307
+ }
308
+
309
+ NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
310
+ nk_u32_t const *a, nk_u32_t const *b, //
311
+ nk_f32_t const *a_weights, nk_f32_t const *b_weights, //
312
+ nk_size_t a_length, nk_size_t b_length, //
313
+ nk_f64_t *product) {
314
+
315
+ // A single SVE lane is 128 bits wide, so one lane fits 4 values.
316
+ nk_size_t const register_size = svcntw();
317
+ nk_size_t const vector_length_f64 = svcntd();
318
+ nk_size_t const lanes_count = register_size / 4;
319
+ nk_size_t a_idx = 0, b_idx = 0;
320
+ svbool_t const predicate_all_f32x = svptrue_b32();
321
+ svbool_t const predicate_all_f64x = svptrue_b64();
322
+ svfloat64_t product_f64x = svdup_f64(0.0);
323
+
324
+ while (a_idx < a_length && b_idx < b_length) {
325
+ // Load indices with progress predicates
326
+ svbool_t a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
327
+ svbool_t b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
328
+ svuint32_t a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
329
+ svuint32_t b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
330
+
331
+ // Avoid expensive intersection if slices don't overlap at all
332
+ nk_u32_t a_min;
333
+ nk_u32_t a_max = svlastb(a_progress_u32x, a_u32x);
334
+ nk_u32_t b_min = svlasta(svpfalse_b(), b_u32x);
335
+ nk_u32_t b_max = svlastb(b_progress_u32x, b_u32x);
336
+
337
+ // If the slices don't overlap, advance the appropriate pointer
338
+ while (a_max < b_min && (a_idx + register_size) <= a_length) {
339
+ a_idx += register_size;
340
+ a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
341
+ a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
342
+ a_max = svlastb(a_progress_u32x, a_u32x);
343
+ }
344
+ a_min = svlasta(svpfalse_b(), a_u32x);
345
+ while (b_max < a_min && (b_idx + register_size) <= b_length) {
346
+ b_idx += register_size;
347
+ b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
348
+ b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
349
+ b_max = svlastb(b_progress_u32x, b_u32x);
350
+ }
351
+ b_min = svlasta(svpfalse_b(), b_u32x);
352
+
353
+ // Calculate step sizes before modifying vectors
354
+ svbool_t a_mask_u32x = svcmple_n_u32(a_progress_u32x, a_u32x, b_max);
355
+ svbool_t b_mask_u32x = svcmple_n_u32(b_progress_u32x, b_u32x, a_max);
356
+ nk_u64_t a_step = svcntp_b32(a_progress_u32x, a_mask_u32x);
357
+ nk_u64_t b_step = svcntp_b32(b_progress_u32x, b_mask_u32x);
358
+
359
+ // Use histogram-based intersection (svmatch_u32 doesn't exist)
360
+ svuint32_t hist_lower_u32x = svhistcnt_u32_z(a_progress_u32x, a_u32x, b_u32x);
361
+ svuint32_t a_rev_u32x = svrev_u32(a_u32x);
362
+ svuint32_t b_rev_u32x = svrev_u32(b_u32x);
363
+ svuint32_t hist_upper_u32x = svrev_u32(svhistcnt_u32_z(predicate_all_f32x, a_rev_u32x, b_rev_u32x));
364
+ svuint32_t hist_u32x = svorr_u32_x(a_progress_u32x, hist_lower_u32x, hist_upper_u32x);
365
+ svbool_t a_equal_mask_u32x = svcmpne_n_u32(a_progress_u32x, hist_u32x, 0);
366
+ svbool_t a_overlap_mask_u32x = svand_b_z(predicate_all_f32x, a_progress_u32x, a_equal_mask_u32x);
367
+
368
+ if (!svptest_any(a_progress_u32x, a_overlap_mask_u32x)) {
369
+ a_idx += a_step;
370
+ b_idx += b_step;
371
+ continue;
372
+ }
373
+
374
+ // Load weights and mask by intersection
375
+ svfloat32_t a_weights_f32x = svsel_f32(a_overlap_mask_u32x, svld1_f32(a_progress_u32x, a_weights + a_idx),
376
+ svdup_f32(0.f));
377
+ svfloat32_t b_weights_f32x = svld1_f32(b_progress_u32x, b_weights + b_idx);
378
+ svbool_t predicate_low_f64x = svwhilelt_b64_u64(a_idx, a_length);
379
+ svbool_t predicate_high_f64x = svwhilelt_b64_u64(a_idx + vector_length_f64, a_length);
380
+ svfloat64_t a_low_f64x = svcvt_f64_f32_x(predicate_low_f64x, a_weights_f32x);
381
+ svfloat64_t a_high_f64x = svcvtlt_f64_f32_x(predicate_high_f64x, a_weights_f32x);
382
+
383
+ // For each position in a that matches something in b, we need the corresponding b weight.
384
+ // Use lane-by-lane matching for dot product.
385
+ for (nk_size_t i = 0; i < lanes_count; i++) {
386
+ // Check which elements of a match the current rotation of b
387
+ svbool_t equal_lane_u32x = svcmpeq_u32(a_progress_u32x, a_u32x, b_u32x);
388
+ svfloat32_t b_equal_weights_f32x = svsel_f32(equal_lane_u32x, b_weights_f32x, svdup_f32(0.f));
389
+ svfloat64_t b_low_f64x = svcvt_f64_f32_x(predicate_low_f64x, b_equal_weights_f32x);
390
+ svfloat64_t b_high_f64x = svcvtlt_f64_f32_x(predicate_high_f64x, b_equal_weights_f32x);
391
+ product_f64x = svmla_f64_x(predicate_low_f64x, product_f64x, a_low_f64x, b_low_f64x);
392
+ product_f64x = svmla_f64_x(predicate_high_f64x, product_f64x, a_high_f64x, b_high_f64x);
393
+ // Rotate b vectors
394
+ b_u32x = svext_u32(b_u32x, b_u32x, 4);
395
+ b_weights_f32x = svext_f32(b_weights_f32x, b_weights_f32x, 4);
396
+ }
397
+
398
+ // Advance
399
+ a_idx += a_step;
400
+ b_idx += b_step;
401
+ }
402
+ *product = svaddv_f64(predicate_all_f64x, product_f64x);
403
+ }
404
+
405
+ #if defined(__clang__)
406
+ #pragma clang attribute pop
407
+ #elif defined(__GNUC__)
408
+ #pragma GCC pop_options
409
+ #endif
410
+ #endif // NK_TARGET_SVE2
411
+
412
+ #if NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
413
+ #if defined(__clang__)
414
+ #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+sve+sve2+bf16"))), apply_to = function)
415
+ #elif defined(__GNUC__)
416
+ #pragma GCC push_options
417
+ #pragma GCC target("arch=armv8.6-a+sve+sve2+bf16")
418
+ #endif
419
+
420
+ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
421
+ nk_u16_t const *a, nk_u16_t const *b, //
422
+ nk_bf16_t const *a_weights, nk_bf16_t const *b_weights, //
423
+ nk_size_t a_length, nk_size_t b_length, //
424
+ nk_f32_t *product) {
425
+
426
+ // A single SVE lane is 128 bits wide, so one lane fits 8 values.
427
+ nk_size_t const register_size = svcnth();
428
+ nk_size_t const lanes_count = register_size / 8;
429
+ nk_size_t a_idx = 0, b_idx = 0;
430
+ svfloat32_t product_f32x = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
431
+
432
+ while (a_idx < a_length && b_idx < b_length) {
433
+ // Load `a_member` and broadcast it, load `b_members_vec` from memory
434
+ svbool_t a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
435
+ svbool_t b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
436
+ svuint16_t a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
437
+ svuint16_t b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
438
+
439
+ // Intersecting registers with `svmatch_u16` involves a lot of shuffling
440
+ // and comparisons, so we want to avoid it if the slices don't overlap at all..
441
+ nk_u16_t a_min;
442
+ nk_u16_t a_max = svlastb(a_progress_u16x, a_u16x);
443
+ nk_u16_t b_min = svlasta(svpfalse_b(), b_u16x);
444
+ nk_u16_t b_max = svlastb(b_progress_u16x, b_u16x);
445
+
446
+ // If the slices don't overlap, advance the appropriate pointer
447
+ while (a_max < b_min && (a_idx + register_size) <= a_length) {
448
+ a_idx += register_size;
449
+ a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
450
+ a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
451
+ a_max = svlastb(a_progress_u16x, a_u16x);
452
+ }
453
+ a_min = svlasta(svpfalse_b(), a_u16x);
454
+ while (b_max < a_min && (b_idx + register_size) <= b_length) {
455
+ b_idx += register_size;
456
+ b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
457
+ b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
458
+ b_max = svlastb(b_progress_u16x, b_u16x);
459
+ }
460
+ b_min = svlasta(svpfalse_b(), b_u16x);
461
+
462
+ // Before we evaluate the intersection size, obfurscating the order in `b_u16x`,
463
+ // let's estimate how much we will need to advance the pointers afterwards.
464
+ // For that, we don't even need to broadcast the values in SVE, as the whole
465
+ // register can be compared against a scalar:
466
+ //
467
+ // svuint16_t a_last_broadcasted = svdup_n_u16(a_max);
468
+ // svuint16_t b_last_broadcasted = svdup_n_u16(b_max);
469
+ svbool_t a_mask_u16x = svcmple_n_u16(a_progress_u16x, a_u16x, b_max);
470
+ svbool_t b_mask_u16x = svcmple_n_u16(b_progress_u16x, b_u16x, a_max);
471
+ nk_u64_t a_step = svcntp_b16(a_progress_u16x, a_mask_u16x);
472
+ nk_u64_t b_step = svcntp_b16(b_progress_u16x, b_mask_u16x);
473
+
474
+ // Compare `a_u16x` with each lane of `b_u16x`
475
+ svbfloat16_t a_weights_bf16x = svld1_bf16(a_progress_u16x, (__bf16 const *)a_weights + a_idx);
476
+ svbfloat16_t b_weights_bf16x = svld1_bf16(b_progress_u16x, (__bf16 const *)b_weights + b_idx);
477
+ for (nk_size_t i = 0; i < lanes_count; i++) {
478
+ svbool_t equal_mask_u16x = svmatch_u16(a_progress_u16x, a_u16x, b_u16x);
479
+ //! The `svsel_bf16` intrinsic is broken in many compilers, not returning the correct type.
480
+ //! So we reinterprete floats as integers and apply `svsel_s16`, but the `svreinterpret_s16_bs16`
481
+ //! and `svreinterpret_bf16_s16` are not always properly defined!
482
+ svint16_t b_equal_weights_s16x = svsel_s16(equal_mask_u16x, svreinterpret_s16_bf16(b_weights_bf16x),
483
+ svdup_n_s16(0));
484
+ product_f32x = svbfdot_f32(product_f32x, a_weights_bf16x, svreinterpret_bf16_s16(b_equal_weights_s16x));
485
+ b_u16x = svext_u16(b_u16x, b_u16x, 8);
486
+ }
487
+
488
+ // Advance
489
+ a_idx += a_step;
490
+ b_idx += b_step;
491
+ }
492
+ *product = svaddv_f32(svptrue_b32(), product_f32x);
493
+ }
494
+
495
+ #if defined(__clang__)
496
+ #pragma clang attribute pop
497
+ #elif defined(__GNUC__)
498
+ #pragma GCC pop_options
499
+ #endif
500
+ #endif // NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
501
+
502
+ #if defined(__cplusplus)
503
+ } // extern "C"
504
+ #endif
505
+
506
+ #endif // NK_TARGET_ARM_
507
+ #endif // NK_SPARSE_SVE2_H