numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -73,17 +73,17 @@ NK_DYNAMIC void nk_hammings_packed_u1(nk_u1x8_t const *v, void const *q_packed,
73
73
  /**
74
74
  * @brief Computes C = A × Aᵀ symmetric Gram matrix of Hamming distances.
75
75
  * @param[in] vectors Input matrix of row vectors in row-major order.
76
- * @param[in] n_vectors Number of vectors (rows) in the input matrix.
76
+ * @param[in] vectors_count Number of vectors (rows) in the input matrix.
77
77
  * @param[in] d Dimension of each vector (columns).
78
78
  * @param[in] stride Row stride in bytes for the input matrix.
79
- * @param[out] result Output symmetric matrix (n_vectors × n_vectors).
79
+ * @param[out] result Output symmetric matrix (vectors_count × vectors_count).
80
80
  * @param[in] result_stride Row stride in bytes for the result matrix.
81
81
  * @param[in] row_start Starting row offset of results to compute (needed for parallelism).
82
82
  * @param[in] row_count Number of rows of results to compute (needed for parallelism).
83
83
  */
84
- NK_DYNAMIC void nk_hammings_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d, nk_size_t stride,
85
- nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
86
- nk_size_t row_count);
84
+ NK_DYNAMIC void nk_hammings_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
85
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
86
+ nk_size_t row_start, nk_size_t row_count);
87
87
 
88
88
  /**
89
89
  * @brief Compute Jaccard distances between V rows and packed Q rows.
@@ -103,24 +103,24 @@ NK_DYNAMIC void nk_jaccards_packed_u1(nk_u1x8_t const *v, void const *q_packed,
103
103
  /**
104
104
  * @brief Computes C = f(A, Aᵀ) symmetric Gram matrix of Jaccard distances.
105
105
  * @param[in] vectors Input matrix of row vectors in row-major order.
106
- * @param[in] n_vectors Number of vectors (rows).
106
+ * @param[in] vectors_count Number of vectors (rows).
107
107
  * @param[in] d Dimension of each vector (columns).
108
108
  * @param[in] stride Row stride in bytes.
109
- * @param[out] result Output symmetric f32 matrix (n_vectors × n_vectors).
109
+ * @param[out] result Output symmetric f32 matrix (vectors_count × vectors_count).
110
110
  * @param[in] result_stride Row stride in bytes for the result matrix.
111
111
  * @param[in] row_start Starting row offset (for parallelism).
112
112
  * @param[in] row_count Number of rows to compute (for parallelism).
113
113
  */
114
- NK_DYNAMIC void nk_jaccards_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d, nk_size_t stride,
115
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
116
- nk_size_t row_count);
114
+ NK_DYNAMIC void nk_jaccards_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
115
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
116
+ nk_size_t row_start, nk_size_t row_count);
117
117
 
118
118
  /** @copydoc nk_hammings_packed_u1 */
119
119
  NK_PUBLIC void nk_hammings_packed_u1_serial(nk_u1x8_t const *v, void const *q_packed, nk_u32_t *result, nk_size_t rows,
120
120
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
121
121
  nk_size_t r_stride_in_bytes);
122
122
  /** @copydoc nk_hammings_symmetric_u1 */
123
- NK_PUBLIC void nk_hammings_symmetric_u1_serial(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
123
+ NK_PUBLIC void nk_hammings_symmetric_u1_serial(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
124
124
  nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
125
125
  nk_size_t row_start, nk_size_t row_count);
126
126
  /** @copydoc nk_jaccards_packed_u1 */
@@ -128,7 +128,7 @@ NK_PUBLIC void nk_jaccards_packed_u1_serial(nk_u1x8_t const *v, void const *q_pa
128
128
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
129
129
  nk_size_t r_stride_in_bytes);
130
130
  /** @copydoc nk_jaccards_symmetric_u1 */
131
- NK_PUBLIC void nk_jaccards_symmetric_u1_serial(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
131
+ NK_PUBLIC void nk_jaccards_symmetric_u1_serial(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
132
132
  nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
133
133
  nk_size_t row_start, nk_size_t row_count);
134
134
 
@@ -141,7 +141,7 @@ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *v, void const *q_p
141
141
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
142
142
  nk_size_t r_stride_in_bytes);
143
143
  /** @copydoc nk_hammings_symmetric_u1 */
144
- NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
144
+ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
145
145
  nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
146
146
  nk_size_t row_start, nk_size_t row_count);
147
147
  /** @copydoc nk_jaccards_packed_u1 */
@@ -149,7 +149,7 @@ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *v, void const *q_p
149
149
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
150
150
  nk_size_t r_stride_in_bytes);
151
151
  /** @copydoc nk_jaccards_symmetric_u1 */
152
- NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
152
+ NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
153
153
  nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
154
154
  nk_size_t row_start, nk_size_t row_count);
155
155
  #endif // NK_TARGET_SMEBI32
@@ -163,7 +163,7 @@ NK_PUBLIC void nk_hammings_packed_u1_haswell(nk_u1x8_t const *v, void const *q_p
163
163
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
164
164
  nk_size_t r_stride_in_bytes);
165
165
  /** @copydoc nk_hammings_symmetric_u1 */
166
- NK_PUBLIC void nk_hammings_symmetric_u1_haswell(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
166
+ NK_PUBLIC void nk_hammings_symmetric_u1_haswell(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
167
167
  nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
168
168
  nk_size_t row_start, nk_size_t row_count);
169
169
  /** @copydoc nk_jaccards_packed_u1 */
@@ -171,7 +171,7 @@ NK_PUBLIC void nk_jaccards_packed_u1_haswell(nk_u1x8_t const *v, void const *q_p
171
171
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
172
172
  nk_size_t r_stride_in_bytes);
173
173
  /** @copydoc nk_jaccards_symmetric_u1 */
174
- NK_PUBLIC void nk_jaccards_symmetric_u1_haswell(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
174
+ NK_PUBLIC void nk_jaccards_symmetric_u1_haswell(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
175
175
  nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
176
176
  nk_size_t row_start, nk_size_t row_count);
177
177
  #endif // NK_TARGET_HASWELL
@@ -185,7 +185,7 @@ NK_PUBLIC void nk_hammings_packed_u1_icelake(nk_u1x8_t const *v, void const *q_p
185
185
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
186
186
  nk_size_t r_stride_in_bytes);
187
187
  /** @copydoc nk_hammings_symmetric_u1 */
188
- NK_PUBLIC void nk_hammings_symmetric_u1_icelake(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
188
+ NK_PUBLIC void nk_hammings_symmetric_u1_icelake(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
189
189
  nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
190
190
  nk_size_t row_start, nk_size_t row_count);
191
191
  /** @copydoc nk_jaccards_packed_u1 */
@@ -193,7 +193,7 @@ NK_PUBLIC void nk_jaccards_packed_u1_icelake(nk_u1x8_t const *v, void const *q_p
193
193
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
194
194
  nk_size_t r_stride_in_bytes);
195
195
  /** @copydoc nk_jaccards_symmetric_u1 */
196
- NK_PUBLIC void nk_jaccards_symmetric_u1_icelake(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
196
+ NK_PUBLIC void nk_jaccards_symmetric_u1_icelake(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
197
197
  nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
198
198
  nk_size_t row_start, nk_size_t row_count);
199
199
  #endif // NK_TARGET_ICELAKE
@@ -207,7 +207,7 @@ NK_PUBLIC void nk_hammings_packed_u1_neon(nk_u1x8_t const *v, void const *q_pack
207
207
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
208
208
  nk_size_t r_stride_in_bytes);
209
209
  /** @copydoc nk_hammings_symmetric_u1 */
210
- NK_PUBLIC void nk_hammings_symmetric_u1_neon(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
210
+ NK_PUBLIC void nk_hammings_symmetric_u1_neon(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
211
211
  nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
212
212
  nk_size_t row_start, nk_size_t row_count);
213
213
  /** @copydoc nk_jaccards_packed_u1 */
@@ -215,7 +215,7 @@ NK_PUBLIC void nk_jaccards_packed_u1_neon(nk_u1x8_t const *v, void const *q_pack
215
215
  nk_size_t cols, nk_size_t d, nk_size_t v_stride_in_bytes,
216
216
  nk_size_t r_stride_in_bytes);
217
217
  /** @copydoc nk_jaccards_symmetric_u1 */
218
- NK_PUBLIC void nk_jaccards_symmetric_u1_neon(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
218
+ NK_PUBLIC void nk_jaccards_symmetric_u1_neon(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
219
219
  nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
220
220
  nk_size_t row_start, nk_size_t row_count);
221
221
  #endif // NK_TARGET_NEON
@@ -228,7 +228,7 @@ NK_PUBLIC void nk_hammings_packed_u1_v128relaxed(nk_u1x8_t const *v, void const
228
228
  nk_size_t rows, nk_size_t cols, nk_size_t d,
229
229
  nk_size_t v_stride_in_bytes, nk_size_t r_stride_in_bytes);
230
230
  /** @copydoc nk_hammings_symmetric_u1 */
231
- NK_PUBLIC void nk_hammings_symmetric_u1_v128relaxed(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
231
+ NK_PUBLIC void nk_hammings_symmetric_u1_v128relaxed(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
232
232
  nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
233
233
  nk_size_t row_start, nk_size_t row_count);
234
234
  /** @copydoc nk_jaccards_packed_u1 */
@@ -236,11 +236,32 @@ NK_PUBLIC void nk_jaccards_packed_u1_v128relaxed(nk_u1x8_t const *v, void const
236
236
  nk_size_t rows, nk_size_t cols, nk_size_t d,
237
237
  nk_size_t v_stride_in_bytes, nk_size_t r_stride_in_bytes);
238
238
  /** @copydoc nk_jaccards_symmetric_u1 */
239
- NK_PUBLIC void nk_jaccards_symmetric_u1_v128relaxed(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d,
239
+ NK_PUBLIC void nk_jaccards_symmetric_u1_v128relaxed(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
240
240
  nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
241
241
  nk_size_t row_start, nk_size_t row_count);
242
242
  #endif // NK_TARGET_V128RELAXED
243
243
 
244
+ /* Loongson LASX backends using 256-bit SIMD with XVPCNT.W for popcount-based set distances.
245
+ */
246
+ #if NK_TARGET_LOONGSONASX
247
+ /** @copydoc nk_hammings_packed_u1 */
248
+ NK_PUBLIC void nk_hammings_packed_u1_loongsonasx(nk_u1x8_t const *v, void const *q_packed, nk_u32_t *result,
249
+ nk_size_t rows, nk_size_t cols, nk_size_t d,
250
+ nk_size_t v_stride_in_bytes, nk_size_t r_stride_in_bytes);
251
+ /** @copydoc nk_hammings_symmetric_u1 */
252
+ NK_PUBLIC void nk_hammings_symmetric_u1_loongsonasx(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
253
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
254
+ nk_size_t row_start, nk_size_t row_count);
255
+ /** @copydoc nk_jaccards_packed_u1 */
256
+ NK_PUBLIC void nk_jaccards_packed_u1_loongsonasx(nk_u1x8_t const *v, void const *q_packed, nk_f32_t *result,
257
+ nk_size_t rows, nk_size_t cols, nk_size_t d,
258
+ nk_size_t v_stride_in_bytes, nk_size_t r_stride_in_bytes);
259
+ /** @copydoc nk_jaccards_symmetric_u1 */
260
+ NK_PUBLIC void nk_jaccards_symmetric_u1_loongsonasx(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
261
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
262
+ nk_size_t row_start, nk_size_t row_count);
263
+ #endif // NK_TARGET_LOONGSONASX
264
+
244
265
  #if defined(__cplusplus)
245
266
  } // extern "C"
246
267
  #endif
@@ -251,6 +272,8 @@ NK_PUBLIC void nk_jaccards_symmetric_u1_v128relaxed(nk_u1x8_t const *vectors, nk
251
272
  #include "numkong/sets/haswell.h"
252
273
  #include "numkong/sets/smebi32.h"
253
274
  #include "numkong/sets/v128relaxed.h"
275
+ #include "numkong/sets/powervsx.h"
276
+ #include "numkong/sets/loongsonasx.h"
254
277
 
255
278
  #if defined(__cplusplus)
256
279
  extern "C" {
@@ -269,6 +292,10 @@ NK_PUBLIC void nk_hammings_packed_u1(nk_u1x8_t const *v, void const *q_packed, n
269
292
  nk_hammings_packed_u1_icelake(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
270
293
  #elif NK_TARGET_HASWELL
271
294
  nk_hammings_packed_u1_haswell(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
295
+ #elif NK_TARGET_POWERVSX
296
+ nk_hammings_packed_u1_powervsx(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
297
+ #elif NK_TARGET_LOONGSONASX
298
+ nk_hammings_packed_u1_loongsonasx(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
272
299
  #elif NK_TARGET_V128RELAXED
273
300
  nk_hammings_packed_u1_v128relaxed(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
274
301
  #else
@@ -276,21 +303,27 @@ NK_PUBLIC void nk_hammings_packed_u1(nk_u1x8_t const *v, void const *q_packed, n
276
303
  #endif
277
304
  }
278
305
 
279
- NK_PUBLIC void nk_hammings_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d, nk_size_t stride,
280
- nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
281
- nk_size_t row_count) {
306
+ NK_PUBLIC void nk_hammings_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
307
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
308
+ nk_size_t row_start, nk_size_t row_count) {
282
309
  #if NK_TARGET_SMEBI32
283
- nk_hammings_symmetric_u1_smebi32(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
310
+ nk_hammings_symmetric_u1_smebi32(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
284
311
  #elif NK_TARGET_NEON
285
- nk_hammings_symmetric_u1_neon(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
312
+ nk_hammings_symmetric_u1_neon(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
286
313
  #elif NK_TARGET_ICELAKE
287
- nk_hammings_symmetric_u1_icelake(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
314
+ nk_hammings_symmetric_u1_icelake(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
288
315
  #elif NK_TARGET_HASWELL
289
- nk_hammings_symmetric_u1_haswell(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
316
+ nk_hammings_symmetric_u1_haswell(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
317
+ #elif NK_TARGET_POWERVSX
318
+ nk_hammings_symmetric_u1_powervsx(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
319
+ #elif NK_TARGET_LOONGSONASX
320
+ nk_hammings_symmetric_u1_loongsonasx(vectors, vectors_count, d, stride, result, result_stride, row_start,
321
+ row_count);
290
322
  #elif NK_TARGET_V128RELAXED
291
- nk_hammings_symmetric_u1_v128relaxed(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
323
+ nk_hammings_symmetric_u1_v128relaxed(vectors, vectors_count, d, stride, result, result_stride, row_start,
324
+ row_count);
292
325
  #else
293
- nk_hammings_symmetric_u1_serial(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
326
+ nk_hammings_symmetric_u1_serial(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
294
327
  #endif
295
328
  }
296
329
 
@@ -305,6 +338,10 @@ NK_PUBLIC void nk_jaccards_packed_u1(nk_u1x8_t const *v, void const *q_packed, n
305
338
  nk_jaccards_packed_u1_icelake(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
306
339
  #elif NK_TARGET_HASWELL
307
340
  nk_jaccards_packed_u1_haswell(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
341
+ #elif NK_TARGET_POWERVSX
342
+ nk_jaccards_packed_u1_powervsx(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
343
+ #elif NK_TARGET_LOONGSONASX
344
+ nk_jaccards_packed_u1_loongsonasx(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
308
345
  #elif NK_TARGET_V128RELAXED
309
346
  nk_jaccards_packed_u1_v128relaxed(v, q_packed, result, rows, cols, d, v_stride_in_bytes, r_stride_in_bytes);
310
347
  #else
@@ -312,21 +349,27 @@ NK_PUBLIC void nk_jaccards_packed_u1(nk_u1x8_t const *v, void const *q_packed, n
312
349
  #endif
313
350
  }
314
351
 
315
- NK_PUBLIC void nk_jaccards_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t d, nk_size_t stride,
316
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
317
- nk_size_t row_count) {
352
+ NK_PUBLIC void nk_jaccards_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t d,
353
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
354
+ nk_size_t row_start, nk_size_t row_count) {
318
355
  #if NK_TARGET_SMEBI32
319
- nk_jaccards_symmetric_u1_smebi32(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
356
+ nk_jaccards_symmetric_u1_smebi32(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
320
357
  #elif NK_TARGET_NEON
321
- nk_jaccards_symmetric_u1_neon(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
358
+ nk_jaccards_symmetric_u1_neon(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
322
359
  #elif NK_TARGET_ICELAKE
323
- nk_jaccards_symmetric_u1_icelake(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
360
+ nk_jaccards_symmetric_u1_icelake(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
324
361
  #elif NK_TARGET_HASWELL
325
- nk_jaccards_symmetric_u1_haswell(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
362
+ nk_jaccards_symmetric_u1_haswell(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
363
+ #elif NK_TARGET_POWERVSX
364
+ nk_jaccards_symmetric_u1_powervsx(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
365
+ #elif NK_TARGET_LOONGSONASX
366
+ nk_jaccards_symmetric_u1_loongsonasx(vectors, vectors_count, d, stride, result, result_stride, row_start,
367
+ row_count);
326
368
  #elif NK_TARGET_V128RELAXED
327
- nk_jaccards_symmetric_u1_v128relaxed(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
369
+ nk_jaccards_symmetric_u1_v128relaxed(vectors, vectors_count, d, stride, result, result_stride, row_start,
370
+ row_count);
328
371
  #else
329
- nk_jaccards_symmetric_u1_serial(vectors, n_vectors, d, stride, result, result_stride, row_start, row_count);
372
+ nk_jaccards_symmetric_u1_serial(vectors, vectors_count, d, stride, result, result_stride, row_start, row_count);
330
373
  #endif
331
374
  }
332
375
 
@@ -7,15 +7,15 @@ The separate index/weight stream design makes these primitives composable into b
7
7
 
8
8
  Set intersection:
9
9
 
10
- ```math
10
+ $$
11
11
  |A \cap B| = |\{i : i \in A \land i \in B\}|
12
- ```
12
+ $$
13
13
 
14
14
  Sparse dot product:
15
15
 
16
- ```math
16
+ $$
17
17
  \text{dot}(a, b) = \sum_{i \in A \cap B} w_a(i) \cdot w_b(i)
18
- ```
18
+ $$
19
19
 
20
20
  Reformulating as Python pseudocode:
21
21
 
@@ -45,58 +45,58 @@ extern "C" {
45
45
  * slightly faster than the native Tiger Lake implementation, but returns only one mask.
46
46
  */
47
47
  NK_INTERNAL nk_u32_t nk_intersect_u16x32_icelake_(__m512i a, __m512i b) {
48
- __m512i a1 = _mm512_alignr_epi32(a, a, 4);
49
- __m512i a2 = _mm512_alignr_epi32(a, a, 8);
50
- __m512i a3 = _mm512_alignr_epi32(a, a, 12);
48
+ __m512i a1_u16x32 = _mm512_alignr_epi32(a, a, 4);
49
+ __m512i a2_u16x32 = _mm512_alignr_epi32(a, a, 8);
50
+ __m512i a3_u16x32 = _mm512_alignr_epi32(a, a, 12);
51
51
 
52
- __m512i b1 = _mm512_shuffle_epi32(b, _MM_PERM_ADCB);
53
- __m512i b2 = _mm512_shuffle_epi32(b, _MM_PERM_BADC);
54
- __m512i b3 = _mm512_shuffle_epi32(b, _MM_PERM_CBAD);
52
+ __m512i b1_u16x32 = _mm512_shuffle_epi32(b, _MM_PERM_ADCB);
53
+ __m512i b2_u16x32 = _mm512_shuffle_epi32(b, _MM_PERM_BADC);
54
+ __m512i b3_u16x32 = _mm512_shuffle_epi32(b, _MM_PERM_CBAD);
55
55
 
56
- __m512i b01 = _mm512_shrdi_epi32(b, b, 16);
57
- __m512i b11 = _mm512_shrdi_epi32(b1, b1, 16);
58
- __m512i b21 = _mm512_shrdi_epi32(b2, b2, 16);
59
- __m512i b31 = _mm512_shrdi_epi32(b3, b3, 16);
56
+ __m512i b01_u16x32 = _mm512_shrdi_epi32(b, b, 16);
57
+ __m512i b11_u16x32 = _mm512_shrdi_epi32(b1_u16x32, b1_u16x32, 16);
58
+ __m512i b21_u16x32 = _mm512_shrdi_epi32(b2_u16x32, b2_u16x32, 16);
59
+ __m512i b31_u16x32 = _mm512_shrdi_epi32(b3_u16x32, b3_u16x32, 16);
60
60
 
61
61
  __mmask32 nm00 = _mm512_cmpneq_epi16_mask(a, b);
62
- __mmask32 nm01 = _mm512_cmpneq_epi16_mask(a1, b);
63
- __mmask32 nm02 = _mm512_cmpneq_epi16_mask(a2, b);
64
- __mmask32 nm03 = _mm512_cmpneq_epi16_mask(a3, b);
65
-
66
- __mmask32 nm10 = _mm512_mask_cmpneq_epi16_mask(nm00, a, b01);
67
- __mmask32 nm11 = _mm512_mask_cmpneq_epi16_mask(nm01, a1, b01);
68
- __mmask32 nm12 = _mm512_mask_cmpneq_epi16_mask(nm02, a2, b01);
69
- __mmask32 nm13 = _mm512_mask_cmpneq_epi16_mask(nm03, a3, b01);
70
-
71
- __mmask32 nm20 = _mm512_mask_cmpneq_epi16_mask(nm10, a, b1);
72
- __mmask32 nm21 = _mm512_mask_cmpneq_epi16_mask(nm11, a1, b1);
73
- __mmask32 nm22 = _mm512_mask_cmpneq_epi16_mask(nm12, a2, b1);
74
- __mmask32 nm23 = _mm512_mask_cmpneq_epi16_mask(nm13, a3, b1);
75
-
76
- __mmask32 nm30 = _mm512_mask_cmpneq_epi16_mask(nm20, a, b11);
77
- __mmask32 nm31 = _mm512_mask_cmpneq_epi16_mask(nm21, a1, b11);
78
- __mmask32 nm32 = _mm512_mask_cmpneq_epi16_mask(nm22, a2, b11);
79
- __mmask32 nm33 = _mm512_mask_cmpneq_epi16_mask(nm23, a3, b11);
80
-
81
- __mmask32 nm40 = _mm512_mask_cmpneq_epi16_mask(nm30, a, b2);
82
- __mmask32 nm41 = _mm512_mask_cmpneq_epi16_mask(nm31, a1, b2);
83
- __mmask32 nm42 = _mm512_mask_cmpneq_epi16_mask(nm32, a2, b2);
84
- __mmask32 nm43 = _mm512_mask_cmpneq_epi16_mask(nm33, a3, b2);
85
-
86
- __mmask32 nm50 = _mm512_mask_cmpneq_epi16_mask(nm40, a, b21);
87
- __mmask32 nm51 = _mm512_mask_cmpneq_epi16_mask(nm41, a1, b21);
88
- __mmask32 nm52 = _mm512_mask_cmpneq_epi16_mask(nm42, a2, b21);
89
- __mmask32 nm53 = _mm512_mask_cmpneq_epi16_mask(nm43, a3, b21);
90
-
91
- __mmask32 nm60 = _mm512_mask_cmpneq_epi16_mask(nm50, a, b3);
92
- __mmask32 nm61 = _mm512_mask_cmpneq_epi16_mask(nm51, a1, b3);
93
- __mmask32 nm62 = _mm512_mask_cmpneq_epi16_mask(nm52, a2, b3);
94
- __mmask32 nm63 = _mm512_mask_cmpneq_epi16_mask(nm53, a3, b3);
95
-
96
- __mmask32 nm70 = _mm512_mask_cmpneq_epi16_mask(nm60, a, b31);
97
- __mmask32 nm71 = _mm512_mask_cmpneq_epi16_mask(nm61, a1, b31);
98
- __mmask32 nm72 = _mm512_mask_cmpneq_epi16_mask(nm62, a2, b31);
99
- __mmask32 nm73 = _mm512_mask_cmpneq_epi16_mask(nm63, a3, b31);
62
+ __mmask32 nm01 = _mm512_cmpneq_epi16_mask(a1_u16x32, b);
63
+ __mmask32 nm02 = _mm512_cmpneq_epi16_mask(a2_u16x32, b);
64
+ __mmask32 nm03 = _mm512_cmpneq_epi16_mask(a3_u16x32, b);
65
+
66
+ __mmask32 nm10 = _mm512_mask_cmpneq_epi16_mask(nm00, a, b01_u16x32);
67
+ __mmask32 nm11 = _mm512_mask_cmpneq_epi16_mask(nm01, a1_u16x32, b01_u16x32);
68
+ __mmask32 nm12 = _mm512_mask_cmpneq_epi16_mask(nm02, a2_u16x32, b01_u16x32);
69
+ __mmask32 nm13 = _mm512_mask_cmpneq_epi16_mask(nm03, a3_u16x32, b01_u16x32);
70
+
71
+ __mmask32 nm20 = _mm512_mask_cmpneq_epi16_mask(nm10, a, b1_u16x32);
72
+ __mmask32 nm21 = _mm512_mask_cmpneq_epi16_mask(nm11, a1_u16x32, b1_u16x32);
73
+ __mmask32 nm22 = _mm512_mask_cmpneq_epi16_mask(nm12, a2_u16x32, b1_u16x32);
74
+ __mmask32 nm23 = _mm512_mask_cmpneq_epi16_mask(nm13, a3_u16x32, b1_u16x32);
75
+
76
+ __mmask32 nm30 = _mm512_mask_cmpneq_epi16_mask(nm20, a, b11_u16x32);
77
+ __mmask32 nm31 = _mm512_mask_cmpneq_epi16_mask(nm21, a1_u16x32, b11_u16x32);
78
+ __mmask32 nm32 = _mm512_mask_cmpneq_epi16_mask(nm22, a2_u16x32, b11_u16x32);
79
+ __mmask32 nm33 = _mm512_mask_cmpneq_epi16_mask(nm23, a3_u16x32, b11_u16x32);
80
+
81
+ __mmask32 nm40 = _mm512_mask_cmpneq_epi16_mask(nm30, a, b2_u16x32);
82
+ __mmask32 nm41 = _mm512_mask_cmpneq_epi16_mask(nm31, a1_u16x32, b2_u16x32);
83
+ __mmask32 nm42 = _mm512_mask_cmpneq_epi16_mask(nm32, a2_u16x32, b2_u16x32);
84
+ __mmask32 nm43 = _mm512_mask_cmpneq_epi16_mask(nm33, a3_u16x32, b2_u16x32);
85
+
86
+ __mmask32 nm50 = _mm512_mask_cmpneq_epi16_mask(nm40, a, b21_u16x32);
87
+ __mmask32 nm51 = _mm512_mask_cmpneq_epi16_mask(nm41, a1_u16x32, b21_u16x32);
88
+ __mmask32 nm52 = _mm512_mask_cmpneq_epi16_mask(nm42, a2_u16x32, b21_u16x32);
89
+ __mmask32 nm53 = _mm512_mask_cmpneq_epi16_mask(nm43, a3_u16x32, b21_u16x32);
90
+
91
+ __mmask32 nm60 = _mm512_mask_cmpneq_epi16_mask(nm50, a, b3_u16x32);
92
+ __mmask32 nm61 = _mm512_mask_cmpneq_epi16_mask(nm51, a1_u16x32, b3_u16x32);
93
+ __mmask32 nm62 = _mm512_mask_cmpneq_epi16_mask(nm52, a2_u16x32, b3_u16x32);
94
+ __mmask32 nm63 = _mm512_mask_cmpneq_epi16_mask(nm53, a3_u16x32, b3_u16x32);
95
+
96
+ __mmask32 nm70 = _mm512_mask_cmpneq_epi16_mask(nm60, a, b31_u16x32);
97
+ __mmask32 nm71 = _mm512_mask_cmpneq_epi16_mask(nm61, a1_u16x32, b31_u16x32);
98
+ __mmask32 nm72 = _mm512_mask_cmpneq_epi16_mask(nm62, a2_u16x32, b31_u16x32);
99
+ __mmask32 nm73 = _mm512_mask_cmpneq_epi16_mask(nm63, a3_u16x32, b31_u16x32);
100
100
 
101
101
  return ~(nk_u32_t)(nm70 & nk_u32_rol(nm71, 8) & nk_u32_rol(nm72, 16) & nk_u32_ror(nm73, 8));
102
102
  }
@@ -106,33 +106,33 @@ NK_INTERNAL nk_u32_t nk_intersect_u16x32_icelake_(__m512i a, __m512i b) {
106
106
  * slightly faster than the native Tiger Lake implementation, but returns only one mask.
107
107
  */
108
108
  NK_INTERNAL nk_u16_t nk_intersect_u32x16_icelake_(__m512i a, __m512i b) {
109
- __m512i a1 = _mm512_alignr_epi32(a, a, 4);
110
- __m512i b1 = _mm512_shuffle_epi32(b, _MM_PERM_ADCB);
109
+ __m512i a1_u32x16 = _mm512_alignr_epi32(a, a, 4);
110
+ __m512i b1_u32x16 = _mm512_shuffle_epi32(b, _MM_PERM_ADCB);
111
111
  __mmask16 nm00 = _mm512_cmpneq_epi32_mask(a, b);
112
112
 
113
- __m512i a2 = _mm512_alignr_epi32(a, a, 8);
114
- __m512i a3 = _mm512_alignr_epi32(a, a, 12);
115
- __mmask16 nm01 = _mm512_cmpneq_epi32_mask(a1, b);
116
- __mmask16 nm02 = _mm512_cmpneq_epi32_mask(a2, b);
113
+ __m512i a2_u32x16 = _mm512_alignr_epi32(a, a, 8);
114
+ __m512i a3_u32x16 = _mm512_alignr_epi32(a, a, 12);
115
+ __mmask16 nm01 = _mm512_cmpneq_epi32_mask(a1_u32x16, b);
116
+ __mmask16 nm02 = _mm512_cmpneq_epi32_mask(a2_u32x16, b);
117
117
 
118
- __mmask16 nm03 = _mm512_cmpneq_epi32_mask(a3, b);
119
- __mmask16 nm10 = _mm512_mask_cmpneq_epi32_mask(nm00, a, b1);
120
- __mmask16 nm11 = _mm512_mask_cmpneq_epi32_mask(nm01, a1, b1);
118
+ __mmask16 nm03 = _mm512_cmpneq_epi32_mask(a3_u32x16, b);
119
+ __mmask16 nm10 = _mm512_mask_cmpneq_epi32_mask(nm00, a, b1_u32x16);
120
+ __mmask16 nm11 = _mm512_mask_cmpneq_epi32_mask(nm01, a1_u32x16, b1_u32x16);
121
121
 
122
- __m512i b2 = _mm512_shuffle_epi32(b, _MM_PERM_BADC);
123
- __mmask16 nm12 = _mm512_mask_cmpneq_epi32_mask(nm02, a2, b1);
124
- __mmask16 nm13 = _mm512_mask_cmpneq_epi32_mask(nm03, a3, b1);
125
- __mmask16 nm20 = _mm512_mask_cmpneq_epi32_mask(nm10, a, b2);
122
+ __m512i b2_u32x16 = _mm512_shuffle_epi32(b, _MM_PERM_BADC);
123
+ __mmask16 nm12 = _mm512_mask_cmpneq_epi32_mask(nm02, a2_u32x16, b1_u32x16);
124
+ __mmask16 nm13 = _mm512_mask_cmpneq_epi32_mask(nm03, a3_u32x16, b1_u32x16);
125
+ __mmask16 nm20 = _mm512_mask_cmpneq_epi32_mask(nm10, a, b2_u32x16);
126
126
 
127
- __m512i b3 = _mm512_shuffle_epi32(b, _MM_PERM_CBAD);
128
- __mmask16 nm21 = _mm512_mask_cmpneq_epi32_mask(nm11, a1, b2);
129
- __mmask16 nm22 = _mm512_mask_cmpneq_epi32_mask(nm12, a2, b2);
130
- __mmask16 nm23 = _mm512_mask_cmpneq_epi32_mask(nm13, a3, b2);
127
+ __m512i b3_u32x16 = _mm512_shuffle_epi32(b, _MM_PERM_CBAD);
128
+ __mmask16 nm21 = _mm512_mask_cmpneq_epi32_mask(nm11, a1_u32x16, b2_u32x16);
129
+ __mmask16 nm22 = _mm512_mask_cmpneq_epi32_mask(nm12, a2_u32x16, b2_u32x16);
130
+ __mmask16 nm23 = _mm512_mask_cmpneq_epi32_mask(nm13, a3_u32x16, b2_u32x16);
131
131
 
132
- __mmask16 nm0 = _mm512_mask_cmpneq_epi32_mask(nm20, a, b3);
133
- __mmask16 nm1 = _mm512_mask_cmpneq_epi32_mask(nm21, a1, b3);
134
- __mmask16 nm2 = _mm512_mask_cmpneq_epi32_mask(nm22, a2, b3);
135
- __mmask16 nm3 = _mm512_mask_cmpneq_epi32_mask(nm23, a3, b3);
132
+ __mmask16 nm0 = _mm512_mask_cmpneq_epi32_mask(nm20, a, b3_u32x16);
133
+ __mmask16 nm1 = _mm512_mask_cmpneq_epi32_mask(nm21, a1_u32x16, b3_u32x16);
134
+ __mmask16 nm2 = _mm512_mask_cmpneq_epi32_mask(nm22, a2_u32x16, b3_u32x16);
135
+ __mmask16 nm3 = _mm512_mask_cmpneq_epi32_mask(nm23, a3_u32x16, b3_u32x16);
136
136
 
137
137
  return ~(nk_u16_t)(nm0 & nk_u16_rol(nm1, 4) & nk_u16_rol(nm2, 8) & nk_u16_ror(nm3, 4));
138
138
  }
@@ -268,33 +268,33 @@ NK_PUBLIC void nk_sparse_intersect_u32_icelake( //
268
268
  * returns only one mask indicating which elements in `a` have a match in `b`.
269
269
  */
270
270
  NK_INTERNAL nk_u8_t nk_intersect_u64x8_icelake_(__m512i a, __m512i b) {
271
- __m512i a1 = _mm512_alignr_epi64(a, a, 2);
272
- __m512i b1 = _mm512_permutex_epi64(b, _MM_PERM_ADCB);
271
+ __m512i a1_u64x8 = _mm512_alignr_epi64(a, a, 2);
272
+ __m512i b1_u64x8 = _mm512_permutex_epi64(b, _MM_PERM_ADCB);
273
273
  __mmask8 nm00 = _mm512_cmpneq_epi64_mask(a, b);
274
274
 
275
- __m512i a2 = _mm512_alignr_epi64(a, a, 4);
276
- __m512i a3 = _mm512_alignr_epi64(a, a, 6);
277
- __mmask8 nm01 = _mm512_cmpneq_epi64_mask(a1, b);
278
- __mmask8 nm02 = _mm512_cmpneq_epi64_mask(a2, b);
275
+ __m512i a2_u64x8 = _mm512_alignr_epi64(a, a, 4);
276
+ __m512i a3_u64x8 = _mm512_alignr_epi64(a, a, 6);
277
+ __mmask8 nm01 = _mm512_cmpneq_epi64_mask(a1_u64x8, b);
278
+ __mmask8 nm02 = _mm512_cmpneq_epi64_mask(a2_u64x8, b);
279
279
 
280
- __m512i b2 = _mm512_permutex_epi64(b, _MM_PERM_BADC);
281
- __mmask8 nm03 = _mm512_cmpneq_epi64_mask(a3, b);
282
- __mmask8 nm10 = _mm512_mask_cmpneq_epi64_mask(nm00, a, b1);
283
- __mmask8 nm11 = _mm512_mask_cmpneq_epi64_mask(nm01, a1, b1);
280
+ __m512i b2_u64x8 = _mm512_permutex_epi64(b, _MM_PERM_BADC);
281
+ __mmask8 nm03 = _mm512_cmpneq_epi64_mask(a3_u64x8, b);
282
+ __mmask8 nm10 = _mm512_mask_cmpneq_epi64_mask(nm00, a, b1_u64x8);
283
+ __mmask8 nm11 = _mm512_mask_cmpneq_epi64_mask(nm01, a1_u64x8, b1_u64x8);
284
284
 
285
- __m512i b3 = _mm512_permutex_epi64(b, _MM_PERM_CBAD);
286
- __mmask8 nm12 = _mm512_mask_cmpneq_epi64_mask(nm02, a2, b1);
287
- __mmask8 nm13 = _mm512_mask_cmpneq_epi64_mask(nm03, a3, b1);
288
- __mmask8 nm20 = _mm512_mask_cmpneq_epi64_mask(nm10, a, b2);
285
+ __m512i b3_u64x8 = _mm512_permutex_epi64(b, _MM_PERM_CBAD);
286
+ __mmask8 nm12 = _mm512_mask_cmpneq_epi64_mask(nm02, a2_u64x8, b1_u64x8);
287
+ __mmask8 nm13 = _mm512_mask_cmpneq_epi64_mask(nm03, a3_u64x8, b1_u64x8);
288
+ __mmask8 nm20 = _mm512_mask_cmpneq_epi64_mask(nm10, a, b2_u64x8);
289
289
 
290
- __mmask8 nm21 = _mm512_mask_cmpneq_epi64_mask(nm11, a1, b2);
291
- __mmask8 nm22 = _mm512_mask_cmpneq_epi64_mask(nm12, a2, b2);
292
- __mmask8 nm23 = _mm512_mask_cmpneq_epi64_mask(nm13, a3, b2);
290
+ __mmask8 nm21 = _mm512_mask_cmpneq_epi64_mask(nm11, a1_u64x8, b2_u64x8);
291
+ __mmask8 nm22 = _mm512_mask_cmpneq_epi64_mask(nm12, a2_u64x8, b2_u64x8);
292
+ __mmask8 nm23 = _mm512_mask_cmpneq_epi64_mask(nm13, a3_u64x8, b2_u64x8);
293
293
 
294
- __mmask8 nm0 = _mm512_mask_cmpneq_epi64_mask(nm20, a, b3);
295
- __mmask8 nm1 = _mm512_mask_cmpneq_epi64_mask(nm21, a1, b3);
296
- __mmask8 nm2 = _mm512_mask_cmpneq_epi64_mask(nm22, a2, b3);
297
- __mmask8 nm3 = _mm512_mask_cmpneq_epi64_mask(nm23, a3, b3);
294
+ __mmask8 nm0 = _mm512_mask_cmpneq_epi64_mask(nm20, a, b3_u64x8);
295
+ __mmask8 nm1 = _mm512_mask_cmpneq_epi64_mask(nm21, a1_u64x8, b3_u64x8);
296
+ __mmask8 nm2 = _mm512_mask_cmpneq_epi64_mask(nm22, a2_u64x8, b3_u64x8);
297
+ __mmask8 nm3 = _mm512_mask_cmpneq_epi64_mask(nm23, a3_u64x8, b3_u64x8);
298
298
 
299
299
  return ~(nk_u8_t)(nm0 & nk_u8_rol(nm1, 2) & nk_u8_rol(nm2, 4) & nk_u8_ror(nm3, 2));
300
300
  }
@@ -377,8 +377,8 @@ NK_PUBLIC void nk_sparse_dot_u32f32_icelake( //
377
377
 
378
378
  nk_u32_t const *const a_end = a + a_length;
379
379
  nk_u32_t const *const b_end = b + b_length;
380
- __m512d product_lower_f64x8 = _mm512_setzero_pd();
381
- __m512d product_upper_f64x8 = _mm512_setzero_pd();
380
+ __m512d product_low_f64x8 = _mm512_setzero_pd();
381
+ __m512d product_high_f64x8 = _mm512_setzero_pd();
382
382
  nk_b512_vec_t a_vec, b_vec;
383
383
 
384
384
  while (a + 16 <= a_end && b + 16 <= b_end) {
@@ -425,15 +425,15 @@ NK_PUBLIC void nk_sparse_dot_u32f32_icelake( //
425
425
  __m512 a_matched_f32x16 = _mm512_maskz_compress_ps(a_matches, a_weights_f32x16);
426
426
  __m512 b_matched_f32x16 = _mm512_maskz_compress_ps(b_matches, b_weights_f32x16);
427
427
 
428
- __m256 a_matched_lower_f32x8 = _mm512_castps512_ps256(a_matched_f32x16);
429
- __m256 a_matched_upper_f32x8 = _mm512_extractf32x8_ps(a_matched_f32x16, 1);
430
- __m256 b_matched_lower_f32x8 = _mm512_castps512_ps256(b_matched_f32x16);
431
- __m256 b_matched_upper_f32x8 = _mm512_extractf32x8_ps(b_matched_f32x16, 1);
428
+ __m256 a_matched_low_f32x8 = _mm512_castps512_ps256(a_matched_f32x16);
429
+ __m256 a_matched_high_f32x8 = _mm512_extractf32x8_ps(a_matched_f32x16, 1);
430
+ __m256 b_matched_low_f32x8 = _mm512_castps512_ps256(b_matched_f32x16);
431
+ __m256 b_matched_high_f32x8 = _mm512_extractf32x8_ps(b_matched_f32x16, 1);
432
432
 
433
- product_lower_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_lower_f32x8),
434
- _mm512_cvtps_pd(b_matched_lower_f32x8), product_lower_f64x8);
435
- product_upper_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_upper_f32x8),
436
- _mm512_cvtps_pd(b_matched_upper_f32x8), product_upper_f64x8);
433
+ product_low_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_low_f32x8),
434
+ _mm512_cvtps_pd(b_matched_low_f32x8), product_low_f64x8);
435
+ product_high_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_high_f32x8),
436
+ _mm512_cvtps_pd(b_matched_high_f32x8), product_high_f64x8);
437
437
  }
438
438
 
439
439
  // Advance pointers after processing
@@ -445,7 +445,7 @@ NK_PUBLIC void nk_sparse_dot_u32f32_icelake( //
445
445
 
446
446
  nk_f64_t tail_product = 0;
447
447
  nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, &tail_product);
448
- *product = _mm512_reduce_add_pd(product_lower_f64x8) + _mm512_reduce_add_pd(product_upper_f64x8) + tail_product;
448
+ *product = _mm512_reduce_add_pd(product_low_f64x8) + _mm512_reduce_add_pd(product_high_f64x8) + tail_product;
449
449
  }
450
450
 
451
451
  #if defined(__clang__)
@@ -96,7 +96,7 @@ extern "C" {
96
96
  int matches = ai == bj; \
97
97
  load_and_convert(a_weights + i, &awi); \
98
98
  load_and_convert(b_weights + j, &bwi); \
99
- weights_product += matches * awi * bwi; \
99
+ weights_product += (nk_##accumulator_type##_t)matches * awi * bwi; \
100
100
  i += ai < bj; \
101
101
  j += ai >= bj; \
102
102
  } \