numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -60,31 +60,31 @@ NK_PUBLIC void nk_sparse_intersect_u16_sve2( //
60
60
 
61
61
  while (a_idx < a_length && b_idx < b_length) {
62
62
  // Load `a_member` and broadcast it, load `b_members_vec` from memory
63
- svbool_t a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
64
- svbool_t b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
65
- svuint16_t a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
66
- svuint16_t b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
63
+ svbool_t a_progress_b16x = svwhilelt_b16_u64(a_idx, a_length);
64
+ svbool_t b_progress_b16x = svwhilelt_b16_u64(b_idx, b_length);
65
+ svuint16_t a_u16x = svld1_u16(a_progress_b16x, a + a_idx);
66
+ svuint16_t b_u16x = svld1_u16(b_progress_b16x, b + b_idx);
67
67
 
68
68
  // Intersecting registers with `svmatch_u16` involves a lot of shuffling
69
69
  // and comparisons, so we want to avoid it if the slices don't overlap at all..
70
70
  nk_u16_t a_min;
71
- nk_u16_t a_max = svlastb(a_progress_u16x, a_u16x);
71
+ nk_u16_t a_max = svlastb(a_progress_b16x, a_u16x);
72
72
  nk_u16_t b_min = svlasta(svpfalse_b(), b_u16x);
73
- nk_u16_t b_max = svlastb(b_progress_u16x, b_u16x);
73
+ nk_u16_t b_max = svlastb(b_progress_b16x, b_u16x);
74
74
 
75
75
  // If the slices don't overlap, advance the appropriate pointer
76
76
  while (a_max < b_min && (a_idx + register_size) <= a_length) {
77
77
  a_idx += register_size;
78
- a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
79
- a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
80
- a_max = svlastb(a_progress_u16x, a_u16x);
78
+ a_progress_b16x = svwhilelt_b16_u64(a_idx, a_length);
79
+ a_u16x = svld1_u16(a_progress_b16x, a + a_idx);
80
+ a_max = svlastb(a_progress_b16x, a_u16x);
81
81
  }
82
82
  a_min = svlasta(svpfalse_b(), a_u16x);
83
83
  while (b_max < a_min && (b_idx + register_size) <= b_length) {
84
84
  b_idx += register_size;
85
- b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
86
- b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
87
- b_max = svlastb(b_progress_u16x, b_u16x);
85
+ b_progress_b16x = svwhilelt_b16_u64(b_idx, b_length);
86
+ b_u16x = svld1_u16(b_progress_b16x, b + b_idx);
87
+ b_max = svlastb(b_progress_b16x, b_u16x);
88
88
  }
89
89
  b_min = svlasta(svpfalse_b(), b_u16x);
90
90
 
@@ -95,18 +95,18 @@ NK_PUBLIC void nk_sparse_intersect_u16_sve2( //
95
95
  //
96
96
  // svuint16_t a_last_broadcasted = svdup_n_u16(a_max);
97
97
  // svuint16_t b_last_broadcasted = svdup_n_u16(b_max);
98
- svbool_t a_mask_u16x = svcmple_n_u16(a_progress_u16x, a_u16x, b_max);
99
- svbool_t b_mask_u16x = svcmple_n_u16(b_progress_u16x, b_u16x, a_max);
100
- nk_u64_t a_step = svcntp_b16(a_progress_u16x, a_mask_u16x);
101
- nk_u64_t b_step = svcntp_b16(b_progress_u16x, b_mask_u16x);
98
+ svbool_t a_mask_b16x = svcmple_n_u16(a_progress_b16x, a_u16x, b_max);
99
+ svbool_t b_mask_b16x = svcmple_n_u16(b_progress_b16x, b_u16x, a_max);
100
+ nk_u64_t a_step = svcntp_b16(a_progress_b16x, a_mask_b16x);
101
+ nk_u64_t b_step = svcntp_b16(b_progress_b16x, b_mask_b16x);
102
102
 
103
103
  // Compare `a_u16x` with each lane of `b_u16x`
104
- svbool_t equal_mask = svmatch_u16(a_progress_u16x, a_u16x, b_u16x);
104
+ svbool_t equal_mask_b16x = svmatch_u16(a_progress_b16x, a_u16x, b_u16x);
105
105
  for (nk_size_t i = 1; i < lanes_count; i++) {
106
106
  b_u16x = svext_u16(b_u16x, b_u16x, 8);
107
- equal_mask = svorr_z(svptrue_b16(), equal_mask, svmatch_u16(a_progress_u16x, a_u16x, b_u16x));
107
+ equal_mask_b16x = svorr_z(svptrue_b16(), equal_mask_b16x, svmatch_u16(a_progress_b16x, a_u16x, b_u16x));
108
108
  }
109
- nk_size_t equal_count = svcntp_b16(svptrue_b16(), equal_mask);
109
+ nk_size_t equal_count = svcntp_b16(svptrue_b16(), equal_mask_b16x);
110
110
 
111
111
  // Manually compact and store matching elements (svcompact_u16 is not defined)
112
112
  if (result) {
@@ -114,7 +114,7 @@ NK_PUBLIC void nk_sparse_intersect_u16_sve2( //
114
114
  nk_u16_t mask_data[16];
115
115
 
116
116
  svst1_u16(svptrue_b16(), a_data, a_u16x);
117
- svst1_u16(svptrue_b16(), mask_data, svdup_n_u16_z(equal_mask, 1));
117
+ svst1_u16(svptrue_b16(), mask_data, svdup_n_u16_z(equal_mask_b16x, 1));
118
118
 
119
119
  for (nk_size_t i = 0; i < svcnth(); i++)
120
120
  if (mask_data[i]) result[c++] = a_data[i];
@@ -142,31 +142,31 @@ NK_PUBLIC void nk_sparse_intersect_u32_sve2( //
142
142
 
143
143
  while (a_idx < a_length && b_idx < b_length) {
144
144
  // Load `a_member` and broadcast it, load `b_members_vec` from memory
145
- svbool_t a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
146
- svbool_t b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
147
- svuint32_t a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
148
- svuint32_t b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
145
+ svbool_t a_progress_b32x = svwhilelt_b32_u64(a_idx, a_length);
146
+ svbool_t b_progress_b32x = svwhilelt_b32_u64(b_idx, b_length);
147
+ svuint32_t a_u32x = svld1_u32(a_progress_b32x, a + a_idx);
148
+ svuint32_t b_u32x = svld1_u32(b_progress_b32x, b + b_idx);
149
149
 
150
150
  // Intersecting registers with `svmatch_u16` involves a lot of shuffling
151
151
  // and comparisons, so we want to avoid it if the slices don't overlap at all..
152
152
  nk_u32_t a_min;
153
- nk_u32_t a_max = svlastb(a_progress_u32x, a_u32x);
153
+ nk_u32_t a_max = svlastb(a_progress_b32x, a_u32x);
154
154
  nk_u32_t b_min = svlasta(svpfalse_b(), b_u32x);
155
- nk_u32_t b_max = svlastb(b_progress_u32x, b_u32x);
155
+ nk_u32_t b_max = svlastb(b_progress_b32x, b_u32x);
156
156
 
157
157
  // If the slices don't overlap, advance the appropriate pointer
158
158
  while (a_max < b_min && (a_idx + register_size) <= a_length) {
159
159
  a_idx += register_size;
160
- a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
161
- a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
162
- a_max = svlastb(a_progress_u32x, a_u32x);
160
+ a_progress_b32x = svwhilelt_b32_u64(a_idx, a_length);
161
+ a_u32x = svld1_u32(a_progress_b32x, a + a_idx);
162
+ a_max = svlastb(a_progress_b32x, a_u32x);
163
163
  }
164
164
  a_min = svlasta(svpfalse_b(), a_u32x);
165
165
  while (b_max < a_min && (b_idx + register_size) <= b_length) {
166
166
  b_idx += register_size;
167
- b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
168
- b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
169
- b_max = svlastb(b_progress_u32x, b_u32x);
167
+ b_progress_b32x = svwhilelt_b32_u64(b_idx, b_length);
168
+ b_u32x = svld1_u32(b_progress_b32x, b + b_idx);
169
+ b_max = svlastb(b_progress_b32x, b_u32x);
170
170
  }
171
171
  b_min = svlasta(svpfalse_b(), b_u32x);
172
172
 
@@ -177,21 +177,21 @@ NK_PUBLIC void nk_sparse_intersect_u32_sve2( //
177
177
  //
178
178
  // svuint32_t a_last_broadcasted = svdup_n_u32(a_max);
179
179
  // svuint32_t b_last_broadcasted = svdup_n_u32(b_max);
180
- svbool_t a_mask_u32x = svcmple_n_u32(a_progress_u32x, a_u32x, b_max);
181
- svbool_t b_mask_u32x = svcmple_n_u32(b_progress_u32x, b_u32x, a_max);
182
- nk_u64_t a_step = svcntp_b32(a_progress_u32x, a_mask_u32x);
183
- nk_u64_t b_step = svcntp_b32(b_progress_u32x, b_mask_u32x);
180
+ svbool_t a_mask_b32x = svcmple_n_u32(a_progress_b32x, a_u32x, b_max);
181
+ svbool_t b_mask_b32x = svcmple_n_u32(b_progress_b32x, b_u32x, a_max);
182
+ nk_u64_t a_step = svcntp_b32(a_progress_b32x, a_mask_b32x);
183
+ nk_u64_t b_step = svcntp_b32(b_progress_b32x, b_mask_b32x);
184
184
 
185
185
  // Comparing `a_u32x` with each lane of `b_u32x` can't be done with `svmatch`,
186
186
  // the same way as in `nk_sparse_intersect_u16_sve2`, as that instruction is only
187
187
  // available for 8-bit and 16-bit integers.
188
188
  //
189
- // svbool_t equal_mask = svpfalse_b();
189
+ // svbool_t equal_mask_b32x = svpfalse_b();
190
190
  // for (nk_size_t i = 0; i < register_size; i++) {
191
- // equal_mask = svorr_z(svptrue_b32(), equal_mask, svcmpeq_u32(a_progress, a_u32x, b_u32x));
191
+ // equal_mask_b32x = svorr_z(svptrue_b32(), equal_mask_b32x, svcmpeq_u32(a_progress, a_u32x, b_u32x));
192
192
  // b_u32x = svext_u32(b_u32x, b_u32x, 1);
193
193
  // }
194
- // nk_size_t equal_count = svcntp_b32(a_progress, equal_mask);
194
+ // nk_size_t equal_count = svcntp_b32(a_progress, equal_mask_b32x);
195
195
  //
196
196
  // Alternatively, one can use histogram instructions, like `svhistcnt_u32_z`.
197
197
  // They practically compute the prefix-matching count, which is equivalent to
@@ -210,19 +210,19 @@ NK_PUBLIC void nk_sparse_intersect_u32_sve2( //
210
210
  // C 1 1 1 0 B 1 1 1 0
211
211
  // D 1 1 1 1 A 1 1 1 1
212
212
  //
213
- svuint32_t hist_lower = svhistcnt_u32_z(a_progress_u32x, a_u32x, b_u32x);
213
+ svuint32_t hist_low_u32x = svhistcnt_u32_z(a_progress_b32x, a_u32x, b_u32x);
214
214
  svuint32_t a_rev_u32x = svrev_u32(a_u32x);
215
215
  svuint32_t b_rev_u32x = svrev_u32(b_u32x);
216
- svuint32_t hist_upper = svrev_u32(svhistcnt_u32_z(svptrue_b32(), a_rev_u32x, b_rev_u32x));
217
- svuint32_t hist = svorr_u32_x(a_progress_u32x, hist_lower, hist_upper);
218
- svbool_t equal_mask = svcmpne_n_u32(a_progress_u32x, hist, 0);
219
- nk_size_t equal_count = svcntp_b32(a_progress_u32x, equal_mask);
216
+ svuint32_t hist_high_u32x = svrev_u32(svhistcnt_u32_z(svptrue_b32(), a_rev_u32x, b_rev_u32x));
217
+ svuint32_t hist_u32x = svorr_u32_x(a_progress_b32x, hist_low_u32x, hist_high_u32x);
218
+ svbool_t equal_mask_b32x = svcmpne_n_u32(a_progress_b32x, hist_u32x, 0);
219
+ nk_size_t equal_count = svcntp_b32(a_progress_b32x, equal_mask_b32x);
220
220
 
221
221
  // Use SVE2 svcompact to compress matching elements and store to result buffer
222
222
  if (result) {
223
- svuint32_t compacted = svcompact_u32(equal_mask, a_u32x);
224
- svbool_t store_predicate = svwhilelt_b32_u64(0, equal_count);
225
- svst1_u32(store_predicate, result + c, compacted);
223
+ svuint32_t compacted_u32x = svcompact_u32(equal_mask_b32x, a_u32x);
224
+ svbool_t store_predicate_b32x = svwhilelt_b32_u64(0u, equal_count);
225
+ svst1_u32(store_predicate_b32x, result + c, compacted_u32x);
226
226
  }
227
227
 
228
228
  // Advance
@@ -246,56 +246,56 @@ NK_PUBLIC void nk_sparse_intersect_u64_sve2( //
246
246
 
247
247
  while (a_idx < a_length && b_idx < b_length) {
248
248
  // Load `a_member` and broadcast it, load `b_members_vec` from memory
249
- svbool_t a_progress_u64x = svwhilelt_b64_u64(a_idx, a_length);
250
- svbool_t b_progress_u64x = svwhilelt_b64_u64(b_idx, b_length);
251
- svuint64_t a_u64x = svld1_u64(a_progress_u64x, a + a_idx);
252
- svuint64_t b_u64x = svld1_u64(b_progress_u64x, b + b_idx);
249
+ svbool_t a_progress_b64x = svwhilelt_b64_u64(a_idx, a_length);
250
+ svbool_t b_progress_b64x = svwhilelt_b64_u64(b_idx, b_length);
251
+ svuint64_t a_u64x = svld1_u64(a_progress_b64x, a + a_idx);
252
+ svuint64_t b_u64x = svld1_u64(b_progress_b64x, b + b_idx);
253
253
 
254
254
  // Intersecting registers involves comparisons,
255
255
  // so we want to avoid it if the slices don't overlap at all.
256
256
  nk_u64_t a_min;
257
- nk_u64_t a_max = svlastb(a_progress_u64x, a_u64x);
257
+ nk_u64_t a_max = svlastb(a_progress_b64x, a_u64x);
258
258
  nk_u64_t b_min = svlasta(svpfalse_b(), b_u64x);
259
- nk_u64_t b_max = svlastb(b_progress_u64x, b_u64x);
259
+ nk_u64_t b_max = svlastb(b_progress_b64x, b_u64x);
260
260
 
261
261
  // If the slices don't overlap, advance the appropriate pointer
262
262
  while (a_max < b_min && (a_idx + register_size) <= a_length) {
263
263
  a_idx += register_size;
264
- a_progress_u64x = svwhilelt_b64_u64(a_idx, a_length);
265
- a_u64x = svld1_u64(a_progress_u64x, a + a_idx);
266
- a_max = svlastb(a_progress_u64x, a_u64x);
264
+ a_progress_b64x = svwhilelt_b64_u64(a_idx, a_length);
265
+ a_u64x = svld1_u64(a_progress_b64x, a + a_idx);
266
+ a_max = svlastb(a_progress_b64x, a_u64x);
267
267
  }
268
268
  a_min = svlasta(svpfalse_b(), a_u64x);
269
269
  while (b_max < a_min && (b_idx + register_size) <= b_length) {
270
270
  b_idx += register_size;
271
- b_progress_u64x = svwhilelt_b64_u64(b_idx, b_length);
272
- b_u64x = svld1_u64(b_progress_u64x, b + b_idx);
273
- b_max = svlastb(b_progress_u64x, b_u64x);
271
+ b_progress_b64x = svwhilelt_b64_u64(b_idx, b_length);
272
+ b_u64x = svld1_u64(b_progress_b64x, b + b_idx);
273
+ b_max = svlastb(b_progress_b64x, b_u64x);
274
274
  }
275
275
  b_min = svlasta(svpfalse_b(), b_u64x);
276
276
 
277
277
  // Estimate how much we will need to advance the pointers afterwards.
278
- svbool_t a_mask_u64x = svcmple_n_u64(a_progress_u64x, a_u64x, b_max);
279
- svbool_t b_mask_u64x = svcmple_n_u64(b_progress_u64x, b_u64x, a_max);
280
- nk_u64_t a_step = svcntp_b64(a_progress_u64x, a_mask_u64x);
281
- nk_u64_t b_step = svcntp_b64(b_progress_u64x, b_mask_u64x);
278
+ svbool_t a_mask_b64x = svcmple_n_u64(a_progress_b64x, a_u64x, b_max);
279
+ svbool_t b_mask_b64x = svcmple_n_u64(b_progress_b64x, b_u64x, a_max);
280
+ nk_u64_t a_step = svcntp_b64(a_progress_b64x, a_mask_b64x);
281
+ nk_u64_t b_step = svcntp_b64(b_progress_b64x, b_mask_b64x);
282
282
 
283
283
  // Use histogram instructions like `svhistcnt_u64_z` to compute intersection.
284
284
  // They compute the prefix-matching count, equivalent to the lower triangle
285
285
  // of the row-major intersection matrix.
286
- svuint64_t hist_lower = svhistcnt_u64_z(a_progress_u64x, a_u64x, b_u64x);
286
+ svuint64_t hist_low_u64x = svhistcnt_u64_z(a_progress_b64x, a_u64x, b_u64x);
287
287
  svuint64_t a_rev_u64x = svrev_u64(a_u64x);
288
288
  svuint64_t b_rev_u64x = svrev_u64(b_u64x);
289
- svuint64_t hist_upper = svrev_u64(svhistcnt_u64_z(svptrue_b64(), a_rev_u64x, b_rev_u64x));
290
- svuint64_t hist = svorr_u64_x(a_progress_u64x, hist_lower, hist_upper);
291
- svbool_t equal_mask = svcmpne_n_u64(a_progress_u64x, hist, 0);
292
- nk_size_t equal_count = svcntp_b64(a_progress_u64x, equal_mask);
289
+ svuint64_t hist_high_u64x = svrev_u64(svhistcnt_u64_z(svptrue_b64(), a_rev_u64x, b_rev_u64x));
290
+ svuint64_t hist_u64x = svorr_u64_x(a_progress_b64x, hist_low_u64x, hist_high_u64x);
291
+ svbool_t equal_mask_b64x = svcmpne_n_u64(a_progress_b64x, hist_u64x, 0);
292
+ nk_size_t equal_count = svcntp_b64(a_progress_b64x, equal_mask_b64x);
293
293
 
294
294
  // Use SVE2 svcompact to compress matching elements and store to result buffer
295
295
  if (result) {
296
- svuint64_t compacted = svcompact_u64(equal_mask, a_u64x);
297
- svbool_t store_predicate = svwhilelt_b64_u64(0, equal_count);
298
- svst1_u64(store_predicate, result + c, compacted);
296
+ svuint64_t compacted_u64x = svcompact_u64(equal_mask_b64x, a_u64x);
297
+ svbool_t store_predicate_b64x = svwhilelt_b64_u64(0u, equal_count);
298
+ svst1_u64(store_predicate_b64x, result + c, compacted_u64x);
299
299
  }
300
300
 
301
301
  // Advance
@@ -312,94 +312,90 @@ NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
312
312
  nk_size_t a_length, nk_size_t b_length, //
313
313
  nk_f64_t *product) {
314
314
 
315
- // A single SVE lane is 128 bits wide, so one lane fits 4 values.
316
315
  nk_size_t const register_size = svcntw();
317
316
  nk_size_t const vector_length_f64 = svcntd();
318
- nk_size_t const lanes_count = register_size / 4;
319
317
  nk_size_t a_idx = 0, b_idx = 0;
320
- svbool_t const predicate_all_f32x = svptrue_b32();
321
- svbool_t const predicate_all_f64x = svptrue_b64();
318
+ svbool_t const predicate_all_b32x = svptrue_b32();
319
+ svbool_t const predicate_all_b64x = svptrue_b64();
322
320
  svfloat64_t product_f64x = svdup_f64(0.0);
323
321
 
324
322
  while (a_idx < a_length && b_idx < b_length) {
325
323
  // Load indices with progress predicates
326
- svbool_t a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
327
- svbool_t b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
328
- svuint32_t a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
329
- svuint32_t b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
324
+ svbool_t a_progress_b32x = svwhilelt_b32_u64(a_idx, a_length);
325
+ svbool_t b_progress_b32x = svwhilelt_b32_u64(b_idx, b_length);
326
+ svuint32_t a_u32x = svld1_u32(a_progress_b32x, a + a_idx);
327
+ svuint32_t b_u32x = svld1_u32(b_progress_b32x, b + b_idx);
330
328
 
331
329
  // Avoid expensive intersection if slices don't overlap at all
332
330
  nk_u32_t a_min;
333
- nk_u32_t a_max = svlastb(a_progress_u32x, a_u32x);
331
+ nk_u32_t a_max = svlastb(a_progress_b32x, a_u32x);
334
332
  nk_u32_t b_min = svlasta(svpfalse_b(), b_u32x);
335
- nk_u32_t b_max = svlastb(b_progress_u32x, b_u32x);
333
+ nk_u32_t b_max = svlastb(b_progress_b32x, b_u32x);
336
334
 
337
335
  // If the slices don't overlap, advance the appropriate pointer
338
336
  while (a_max < b_min && (a_idx + register_size) <= a_length) {
339
337
  a_idx += register_size;
340
- a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
341
- a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
342
- a_max = svlastb(a_progress_u32x, a_u32x);
338
+ a_progress_b32x = svwhilelt_b32_u64(a_idx, a_length);
339
+ a_u32x = svld1_u32(a_progress_b32x, a + a_idx);
340
+ a_max = svlastb(a_progress_b32x, a_u32x);
343
341
  }
344
342
  a_min = svlasta(svpfalse_b(), a_u32x);
345
343
  while (b_max < a_min && (b_idx + register_size) <= b_length) {
346
344
  b_idx += register_size;
347
- b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
348
- b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
349
- b_max = svlastb(b_progress_u32x, b_u32x);
345
+ b_progress_b32x = svwhilelt_b32_u64(b_idx, b_length);
346
+ b_u32x = svld1_u32(b_progress_b32x, b + b_idx);
347
+ b_max = svlastb(b_progress_b32x, b_u32x);
350
348
  }
351
349
  b_min = svlasta(svpfalse_b(), b_u32x);
352
350
 
353
351
  // Calculate step sizes before modifying vectors
354
- svbool_t a_mask_u32x = svcmple_n_u32(a_progress_u32x, a_u32x, b_max);
355
- svbool_t b_mask_u32x = svcmple_n_u32(b_progress_u32x, b_u32x, a_max);
356
- nk_u64_t a_step = svcntp_b32(a_progress_u32x, a_mask_u32x);
357
- nk_u64_t b_step = svcntp_b32(b_progress_u32x, b_mask_u32x);
352
+ svbool_t a_mask_b32x = svcmple_n_u32(a_progress_b32x, a_u32x, b_max);
353
+ svbool_t b_mask_b32x = svcmple_n_u32(b_progress_b32x, b_u32x, a_max);
354
+ nk_u64_t a_step = svcntp_b32(a_progress_b32x, a_mask_b32x);
355
+ nk_u64_t b_step = svcntp_b32(b_progress_b32x, b_mask_b32x);
358
356
 
359
357
  // Use histogram-based intersection (svmatch_u32 doesn't exist)
360
- svuint32_t hist_lower_u32x = svhistcnt_u32_z(a_progress_u32x, a_u32x, b_u32x);
358
+ svuint32_t hist_low_u32x = svhistcnt_u32_z(a_progress_b32x, a_u32x, b_u32x);
361
359
  svuint32_t a_rev_u32x = svrev_u32(a_u32x);
362
360
  svuint32_t b_rev_u32x = svrev_u32(b_u32x);
363
- svuint32_t hist_upper_u32x = svrev_u32(svhistcnt_u32_z(predicate_all_f32x, a_rev_u32x, b_rev_u32x));
364
- svuint32_t hist_u32x = svorr_u32_x(a_progress_u32x, hist_lower_u32x, hist_upper_u32x);
365
- svbool_t a_equal_mask_u32x = svcmpne_n_u32(a_progress_u32x, hist_u32x, 0);
366
- svbool_t a_overlap_mask_u32x = svand_b_z(predicate_all_f32x, a_progress_u32x, a_equal_mask_u32x);
361
+ svuint32_t hist_high_u32x = svrev_u32(svhistcnt_u32_z(predicate_all_b32x, a_rev_u32x, b_rev_u32x));
362
+ svuint32_t hist_u32x = svorr_u32_x(a_progress_b32x, hist_low_u32x, hist_high_u32x);
363
+ svbool_t a_equal_mask_b32x = svcmpne_n_u32(a_progress_b32x, hist_u32x, 0);
364
+ svbool_t a_overlap_mask_b32x = svand_b_z(predicate_all_b32x, a_progress_b32x, a_equal_mask_b32x);
367
365
 
368
- if (!svptest_any(a_progress_u32x, a_overlap_mask_u32x)) {
366
+ if (!svptest_any(a_progress_b32x, a_overlap_mask_b32x)) {
369
367
  a_idx += a_step;
370
368
  b_idx += b_step;
371
369
  continue;
372
370
  }
373
371
 
374
- // Load weights and mask by intersection
375
- svfloat32_t a_weights_f32x = svsel_f32(a_overlap_mask_u32x, svld1_f32(a_progress_u32x, a_weights + a_idx),
376
- svdup_f32(0.f));
377
- svfloat32_t b_weights_f32x = svld1_f32(b_progress_u32x, b_weights + b_idx);
378
- svbool_t predicate_low_f64x = svwhilelt_b64_u64(a_idx, a_length);
379
- svbool_t predicate_high_f64x = svwhilelt_b64_u64(a_idx + vector_length_f64, a_length);
380
- svfloat64_t a_low_f64x = svcvt_f64_f32_x(predicate_low_f64x, a_weights_f32x);
381
- svfloat64_t a_high_f64x = svcvtlt_f64_f32_x(predicate_high_f64x, a_weights_f32x);
382
-
383
- // For each position in a that matches something in b, we need the corresponding b weight.
384
- // Use lane-by-lane matching for dot product.
385
- for (nk_size_t i = 0; i < lanes_count; i++) {
386
- // Check which elements of a match the current rotation of b
387
- svbool_t equal_lane_u32x = svcmpeq_u32(a_progress_u32x, a_u32x, b_u32x);
388
- svfloat32_t b_equal_weights_f32x = svsel_f32(equal_lane_u32x, b_weights_f32x, svdup_f32(0.f));
389
- svfloat64_t b_low_f64x = svcvt_f64_f32_x(predicate_low_f64x, b_equal_weights_f32x);
390
- svfloat64_t b_high_f64x = svcvtlt_f64_f32_x(predicate_high_f64x, b_equal_weights_f32x);
391
- product_f64x = svmla_f64_x(predicate_low_f64x, product_f64x, a_low_f64x, b_low_f64x);
392
- product_f64x = svmla_f64_x(predicate_high_f64x, product_f64x, a_high_f64x, b_high_f64x);
393
- // Rotate b vectors
394
- b_u32x = svext_u32(b_u32x, b_u32x, 4);
395
- b_weights_f32x = svext_f32(b_weights_f32x, b_weights_f32x, 4);
396
- }
372
+ // Compute b overlap mask (symmetric histogram: which b elements match something in a)
373
+ svuint32_t b_hist_low_u32x = svhistcnt_u32_z(b_progress_b32x, b_u32x, a_u32x);
374
+ svuint32_t b_hist_high_u32x = svrev_u32(svhistcnt_u32_z(predicate_all_b32x, b_rev_u32x, a_rev_u32x));
375
+ svuint32_t b_hist_u32x = svorr_u32_x(b_progress_b32x, b_hist_low_u32x, b_hist_high_u32x);
376
+ svbool_t b_overlap_mask_b32x = svand_b_z(predicate_all_b32x, b_progress_b32x,
377
+ svcmpne_n_u32(b_progress_b32x, b_hist_u32x, 0));
378
+
379
+ // Compact matching weights — both arrays are sorted, so svcompact
380
+ // preserves relative order and aligns corresponding intersection pairs.
381
+ svfloat32_t a_matched_f32x = svcompact_f32(a_overlap_mask_b32x, svld1_f32(a_progress_b32x, a_weights + a_idx));
382
+ svfloat32_t b_matched_f32x = svcompact_f32(b_overlap_mask_b32x, svld1_f32(b_progress_b32x, b_weights + b_idx));
383
+
384
+ // Widen to f64 and accumulate. svcvt_f64_f32 converts even-indexed f32
385
+ // elements; svcvtlt_f64_f32 converts odd-indexed f32 elements.
386
+ nk_size_t match_count = svcntp_b32(a_progress_b32x, a_overlap_mask_b32x);
387
+ svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (match_count + 1) / 2);
388
+ svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, match_count / 2);
389
+ product_f64x = svmla_f64_x(pred_even_b64x, product_f64x, svcvt_f64_f32_x(pred_even_b64x, a_matched_f32x),
390
+ svcvt_f64_f32_x(pred_even_b64x, b_matched_f32x));
391
+ product_f64x = svmla_f64_x(pred_odd_b64x, product_f64x, svcvtlt_f64_f32_x(pred_odd_b64x, a_matched_f32x),
392
+ svcvtlt_f64_f32_x(pred_odd_b64x, b_matched_f32x));
397
393
 
398
394
  // Advance
399
395
  a_idx += a_step;
400
396
  b_idx += b_step;
401
397
  }
402
- *product = svaddv_f64(predicate_all_f64x, product_f64x);
398
+ *product = svaddv_f64(predicate_all_b64x, product_f64x);
403
399
  }
404
400
 
405
401
  #if defined(__clang__)
@@ -431,31 +427,31 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
431
427
 
432
428
  while (a_idx < a_length && b_idx < b_length) {
433
429
  // Load `a_member` and broadcast it, load `b_members_vec` from memory
434
- svbool_t a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
435
- svbool_t b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
436
- svuint16_t a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
437
- svuint16_t b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
430
+ svbool_t a_progress_b16x = svwhilelt_b16_u64(a_idx, a_length);
431
+ svbool_t b_progress_b16x = svwhilelt_b16_u64(b_idx, b_length);
432
+ svuint16_t a_u16x = svld1_u16(a_progress_b16x, a + a_idx);
433
+ svuint16_t b_u16x = svld1_u16(b_progress_b16x, b + b_idx);
438
434
 
439
435
  // Intersecting registers with `svmatch_u16` involves a lot of shuffling
440
436
  // and comparisons, so we want to avoid it if the slices don't overlap at all..
441
437
  nk_u16_t a_min;
442
- nk_u16_t a_max = svlastb(a_progress_u16x, a_u16x);
438
+ nk_u16_t a_max = svlastb(a_progress_b16x, a_u16x);
443
439
  nk_u16_t b_min = svlasta(svpfalse_b(), b_u16x);
444
- nk_u16_t b_max = svlastb(b_progress_u16x, b_u16x);
440
+ nk_u16_t b_max = svlastb(b_progress_b16x, b_u16x);
445
441
 
446
442
  // If the slices don't overlap, advance the appropriate pointer
447
443
  while (a_max < b_min && (a_idx + register_size) <= a_length) {
448
444
  a_idx += register_size;
449
- a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
450
- a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
451
- a_max = svlastb(a_progress_u16x, a_u16x);
445
+ a_progress_b16x = svwhilelt_b16_u64(a_idx, a_length);
446
+ a_u16x = svld1_u16(a_progress_b16x, a + a_idx);
447
+ a_max = svlastb(a_progress_b16x, a_u16x);
452
448
  }
453
449
  a_min = svlasta(svpfalse_b(), a_u16x);
454
450
  while (b_max < a_min && (b_idx + register_size) <= b_length) {
455
451
  b_idx += register_size;
456
- b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
457
- b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
458
- b_max = svlastb(b_progress_u16x, b_u16x);
452
+ b_progress_b16x = svwhilelt_b16_u64(b_idx, b_length);
453
+ b_u16x = svld1_u16(b_progress_b16x, b + b_idx);
454
+ b_max = svlastb(b_progress_b16x, b_u16x);
459
455
  }
460
456
  b_min = svlasta(svpfalse_b(), b_u16x);
461
457
 
@@ -466,20 +462,20 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
466
462
  //
467
463
  // svuint16_t a_last_broadcasted = svdup_n_u16(a_max);
468
464
  // svuint16_t b_last_broadcasted = svdup_n_u16(b_max);
469
- svbool_t a_mask_u16x = svcmple_n_u16(a_progress_u16x, a_u16x, b_max);
470
- svbool_t b_mask_u16x = svcmple_n_u16(b_progress_u16x, b_u16x, a_max);
471
- nk_u64_t a_step = svcntp_b16(a_progress_u16x, a_mask_u16x);
472
- nk_u64_t b_step = svcntp_b16(b_progress_u16x, b_mask_u16x);
465
+ svbool_t a_mask_b16x = svcmple_n_u16(a_progress_b16x, a_u16x, b_max);
466
+ svbool_t b_mask_b16x = svcmple_n_u16(b_progress_b16x, b_u16x, a_max);
467
+ nk_u64_t a_step = svcntp_b16(a_progress_b16x, a_mask_b16x);
468
+ nk_u64_t b_step = svcntp_b16(b_progress_b16x, b_mask_b16x);
473
469
 
474
470
  // Compare `a_u16x` with each lane of `b_u16x`
475
- svbfloat16_t a_weights_bf16x = svld1_bf16(a_progress_u16x, (__bf16 const *)a_weights + a_idx);
476
- svbfloat16_t b_weights_bf16x = svld1_bf16(b_progress_u16x, (__bf16 const *)b_weights + b_idx);
471
+ svbfloat16_t a_weights_bf16x = svld1_bf16(a_progress_b16x, (__bf16 const *)a_weights + a_idx);
472
+ svbfloat16_t b_weights_bf16x = svld1_bf16(b_progress_b16x, (__bf16 const *)b_weights + b_idx);
477
473
  for (nk_size_t i = 0; i < lanes_count; i++) {
478
- svbool_t equal_mask_u16x = svmatch_u16(a_progress_u16x, a_u16x, b_u16x);
474
+ svbool_t equal_mask_b16x = svmatch_u16(a_progress_b16x, a_u16x, b_u16x);
479
475
  //! The `svsel_bf16` intrinsic is broken in many compilers, not returning the correct type.
480
476
  //! So we reinterprete floats as integers and apply `svsel_s16`, but the `svreinterpret_s16_bs16`
481
477
  //! and `svreinterpret_bf16_s16` are not always properly defined!
482
- svint16_t b_equal_weights_s16x = svsel_s16(equal_mask_u16x, svreinterpret_s16_bf16(b_weights_bf16x),
478
+ svint16_t b_equal_weights_s16x = svsel_s16(equal_mask_b16x, svreinterpret_s16_bf16(b_weights_bf16x),
483
479
  svdup_n_s16(0));
484
480
  product_f32x = svbfdot_f32(product_f32x, a_weights_bf16x, svreinterpret_bf16_s16(b_equal_weights_s16x));
485
481
  b_u16x = svext_u16(b_u16x, b_u16x, 8);
@@ -243,8 +243,8 @@ NK_PUBLIC void nk_sparse_dot_u32f32_turin( //
243
243
  // Native VP2INTERSECTD works directly on u32 - no conversion needed!
244
244
  nk_u32_t const *const a_end = a + a_length;
245
245
  nk_u32_t const *const b_end = b + b_length;
246
- __m512d product_lower_f64x8 = _mm512_setzero_pd();
247
- __m512d product_upper_f64x8 = _mm512_setzero_pd();
246
+ __m512d product_low_f64x8 = _mm512_setzero_pd();
247
+ __m512d product_high_f64x8 = _mm512_setzero_pd();
248
248
  nk_b512_vec_t a_vec, b_vec;
249
249
 
250
250
  while (a + 16 <= a_end && b + 16 <= b_end) {
@@ -281,15 +281,15 @@ NK_PUBLIC void nk_sparse_dot_u32f32_turin( //
281
281
  __m512 b_weights_f32x16 = _mm512_loadu_ps(b_weights);
282
282
  __m512 a_matched_f32x16 = _mm512_maskz_compress_ps(a_matches, a_weights_f32x16);
283
283
  __m512 b_matched_f32x16 = _mm512_maskz_compress_ps(b_matches, b_weights_f32x16);
284
- __m256 a_matched_lower_f32x8 = _mm512_castps512_ps256(a_matched_f32x16);
285
- __m256 a_matched_upper_f32x8 = _mm512_extractf32x8_ps(a_matched_f32x16, 1);
286
- __m256 b_matched_lower_f32x8 = _mm512_castps512_ps256(b_matched_f32x16);
287
- __m256 b_matched_upper_f32x8 = _mm512_extractf32x8_ps(b_matched_f32x16, 1);
288
-
289
- product_lower_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_lower_f32x8),
290
- _mm512_cvtps_pd(b_matched_lower_f32x8), product_lower_f64x8);
291
- product_upper_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_upper_f32x8),
292
- _mm512_cvtps_pd(b_matched_upper_f32x8), product_upper_f64x8);
284
+ __m256 a_matched_low_f32x8 = _mm512_castps512_ps256(a_matched_f32x16);
285
+ __m256 a_matched_high_f32x8 = _mm512_extractf32x8_ps(a_matched_f32x16, 1);
286
+ __m256 b_matched_low_f32x8 = _mm512_castps512_ps256(b_matched_f32x16);
287
+ __m256 b_matched_high_f32x8 = _mm512_extractf32x8_ps(b_matched_f32x16, 1);
288
+
289
+ product_low_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_low_f32x8),
290
+ _mm512_cvtps_pd(b_matched_low_f32x8), product_low_f64x8);
291
+ product_high_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_high_f32x8),
292
+ _mm512_cvtps_pd(b_matched_high_f32x8), product_high_f64x8);
293
293
  }
294
294
 
295
295
  __m512i a_max_u32x16 = _mm512_set1_epi32(*(int const *)&a_max);
@@ -304,7 +304,7 @@ NK_PUBLIC void nk_sparse_dot_u32f32_turin( //
304
304
 
305
305
  nk_f64_t tail_product = 0;
306
306
  nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, &tail_product);
307
- *product = _mm512_reduce_add_pd(product_lower_f64x8) + _mm512_reduce_add_pd(product_upper_f64x8) + tail_product;
307
+ *product = _mm512_reduce_add_pd(product_low_f64x8) + _mm512_reduce_add_pd(product_high_f64x8) + tail_product;
308
308
  }
309
309
 
310
310
  #if defined(__clang__)
@@ -57,22 +57,22 @@
57
57
  * The Ice Lake kernels are shuffle/compare heavy; their throughput is often gated by port 5.
58
58
  * On Genoa, many integer ops dual-issue on FP ports, often improving throughput despite higher latency.
59
59
  *
60
- * Intrinsic Instruction Ice Genoa
61
- * _mm512_shuffle_epi32 VPSHUFD (ZMM, ZMM, I8) 1c @ p5 1c @ p123
62
- * _mm512_mask_cmpneq_epi32_mask VPCMPD (K, ZMM, ZMM, I8) 3c @ p5 5c @ p01
63
- * _mm512_alignr_epi32 VALIGND (ZMM, ZMM, ZMM, I8) 3c @ p5 6c @ p12
64
- * _mm512_conflict_epi32 VPCONFLICTD (ZMM, ZMM) 26c @ p0/5 7c @ p01/12
65
- * _mm256_maskz_compress_epi16 VPCOMPRESSW (YMM, K, YMM) 3-6c @ p5 4-8c @ p01/12
66
- * _mm256_dpwssds_epi32 VPDPWSSDS (YMM, K, YMM, YMM) 4-5c @ p01 4c @ p01
67
- * _mm256_dpbf16_ps VDPBF16PS (YMM, YMM, YMM) n/a 6c @ p01
60
+ * Intrinsic Instruction Icelake Genoa
61
+ * _mm512_shuffle_epi32 VPSHUFD (ZMM, ZMM, I8) 1cy @ p5 1cy @ p123
62
+ * _mm512_mask_cmpneq_epi32_mask VPCMPD (K, ZMM, ZMM, I8) 3cy @ p5 5cy @ p01
63
+ * _mm512_alignr_epi32 VALIGND (ZMM, ZMM, ZMM, I8) 3cy @ p5 6cy @ p12
64
+ * _mm512_conflict_epi32 VPCONFLICTD (ZMM, ZMM) 26cy @ p0+p05+p5 7cy @ p01+p12
65
+ * _mm256_maskz_compress_epi16 VPCOMPRESSW (YMM, K, YMM) 3-6cy @ p5+p5 4-8cy @ p01+p12
66
+ * _mm256_dpwssds_epi32 VPDPWSSDS (YMM, K, YMM, YMM) 4-5cy @ p01 4cy @ p01
67
+ * _mm256_dpbf16_ps VDPBF16PS (YMM, YMM, YMM) n/a 6cy @ p01
68
68
  *
69
69
  * VP2INTERSECTD is unsupported on Ice Lake and not yet covered by uops.info for Zen5/Turin.
70
- * Tiger Lake measures ~36-41c @ p5 for ZMM variants, which is why we always avoid it on Intel.
70
+ * Tiger Lake measures ~36-41cy @ p5 for ZMM variants, which is why we always avoid it on Intel.
71
71
  *
72
72
  * @section references References
73
73
  *
74
74
  * - uops.info: https://uops.info/
75
- * - Intel Intrinsics Guide: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
75
+ * - Intel Intrinsics Guide: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
76
76
  * - Arm Intrinsics Reference: https://developer.arm.com/architectures/instruction-sets/intrinsics/
77
77
  * - vp2intersect experiments: https://github.com/mozonaut/vp2intersect
78
78
  * - Diez-Canas "Faster-Than-Native Alternatives for x86 VP2INTERSECT Instructions":