numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,463 @@
1
+ /**
2
+ * @brief Ice Lake-accelerated Sparse Vector Operations.
3
+ * @file include/numkong/sparse/icelake.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ *
7
+ * @sa include/numkong/sparse.h
8
+ *
9
+ * The AVX-512 implementations are inspired by the "Faster-Than-Native Alternatives
10
+ * for x86 VP2INTERSECT Instructions" paper by Guille Diez-Canas, 2022.
11
+ *
12
+ * https://github.com/mozonaut/vp2intersect
13
+ * https://arxiv.org/pdf/2112.06342.pdf
14
+ *
15
+ * For R&D purposes, it's important to keep the following latencies in mind:
16
+ *
17
+ * - `_mm512_permutex_epi64` (VPERMQ) - needs F - 3 cy latency, 1 cy throughput @ p5
18
+ * - `_mm512_shuffle_epi8` (VPSHUFB) - needs BW - 1 cy latency, 1 cy throughput @ p5
19
+ * - `_mm512_permutexvar_epi16` (VPERMW) - needs BW - 4-6 cy latency, 1 cy throughput @ p5
20
+ * - `_mm512_permutexvar_epi8` (VPERMB) - needs VBMI - 3 cy latency, 1 cy throughput @ p5
21
+ */
22
+ #ifndef NK_SPARSE_ICELAKE_H
23
+ #define NK_SPARSE_ICELAKE_H
24
+
25
+ #if NK_TARGET_X86_
26
+ #if NK_TARGET_ICELAKE
27
+
28
+ #include "numkong/types.h"
29
+
30
+ #if defined(__cplusplus)
31
+ extern "C" {
32
+ #endif
33
+
34
+ #if defined(__clang__)
35
+ #pragma clang attribute push( \
36
+ __attribute__((target("avx2,avx512f,avx512vl,avx512dq,bmi2,lzcnt,popcnt,avx512bw,avx512vbmi2"))), \
37
+ apply_to = function)
38
+ #elif defined(__GNUC__)
39
+ #pragma GCC push_options
40
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512dq", "bmi2", "lzcnt", "popcnt", "avx512bw", "avx512vbmi2")
41
+ #endif
42
+
43
+ /**
44
+ * @brief Analogous to `_mm512_2intersect_epi16_mask`, but compatible with Ice Lake CPUs,
45
+ * slightly faster than the native Tiger Lake implementation, but returns only one mask.
46
+ */
47
+ NK_INTERNAL nk_u32_t nk_intersect_u16x32_icelake_(__m512i a, __m512i b) {
48
+ __m512i a1 = _mm512_alignr_epi32(a, a, 4);
49
+ __m512i a2 = _mm512_alignr_epi32(a, a, 8);
50
+ __m512i a3 = _mm512_alignr_epi32(a, a, 12);
51
+
52
+ __m512i b1 = _mm512_shuffle_epi32(b, _MM_PERM_ADCB);
53
+ __m512i b2 = _mm512_shuffle_epi32(b, _MM_PERM_BADC);
54
+ __m512i b3 = _mm512_shuffle_epi32(b, _MM_PERM_CBAD);
55
+
56
+ __m512i b01 = _mm512_shrdi_epi32(b, b, 16);
57
+ __m512i b11 = _mm512_shrdi_epi32(b1, b1, 16);
58
+ __m512i b21 = _mm512_shrdi_epi32(b2, b2, 16);
59
+ __m512i b31 = _mm512_shrdi_epi32(b3, b3, 16);
60
+
61
+ __mmask32 nm00 = _mm512_cmpneq_epi16_mask(a, b);
62
+ __mmask32 nm01 = _mm512_cmpneq_epi16_mask(a1, b);
63
+ __mmask32 nm02 = _mm512_cmpneq_epi16_mask(a2, b);
64
+ __mmask32 nm03 = _mm512_cmpneq_epi16_mask(a3, b);
65
+
66
+ __mmask32 nm10 = _mm512_mask_cmpneq_epi16_mask(nm00, a, b01);
67
+ __mmask32 nm11 = _mm512_mask_cmpneq_epi16_mask(nm01, a1, b01);
68
+ __mmask32 nm12 = _mm512_mask_cmpneq_epi16_mask(nm02, a2, b01);
69
+ __mmask32 nm13 = _mm512_mask_cmpneq_epi16_mask(nm03, a3, b01);
70
+
71
+ __mmask32 nm20 = _mm512_mask_cmpneq_epi16_mask(nm10, a, b1);
72
+ __mmask32 nm21 = _mm512_mask_cmpneq_epi16_mask(nm11, a1, b1);
73
+ __mmask32 nm22 = _mm512_mask_cmpneq_epi16_mask(nm12, a2, b1);
74
+ __mmask32 nm23 = _mm512_mask_cmpneq_epi16_mask(nm13, a3, b1);
75
+
76
+ __mmask32 nm30 = _mm512_mask_cmpneq_epi16_mask(nm20, a, b11);
77
+ __mmask32 nm31 = _mm512_mask_cmpneq_epi16_mask(nm21, a1, b11);
78
+ __mmask32 nm32 = _mm512_mask_cmpneq_epi16_mask(nm22, a2, b11);
79
+ __mmask32 nm33 = _mm512_mask_cmpneq_epi16_mask(nm23, a3, b11);
80
+
81
+ __mmask32 nm40 = _mm512_mask_cmpneq_epi16_mask(nm30, a, b2);
82
+ __mmask32 nm41 = _mm512_mask_cmpneq_epi16_mask(nm31, a1, b2);
83
+ __mmask32 nm42 = _mm512_mask_cmpneq_epi16_mask(nm32, a2, b2);
84
+ __mmask32 nm43 = _mm512_mask_cmpneq_epi16_mask(nm33, a3, b2);
85
+
86
+ __mmask32 nm50 = _mm512_mask_cmpneq_epi16_mask(nm40, a, b21);
87
+ __mmask32 nm51 = _mm512_mask_cmpneq_epi16_mask(nm41, a1, b21);
88
+ __mmask32 nm52 = _mm512_mask_cmpneq_epi16_mask(nm42, a2, b21);
89
+ __mmask32 nm53 = _mm512_mask_cmpneq_epi16_mask(nm43, a3, b21);
90
+
91
+ __mmask32 nm60 = _mm512_mask_cmpneq_epi16_mask(nm50, a, b3);
92
+ __mmask32 nm61 = _mm512_mask_cmpneq_epi16_mask(nm51, a1, b3);
93
+ __mmask32 nm62 = _mm512_mask_cmpneq_epi16_mask(nm52, a2, b3);
94
+ __mmask32 nm63 = _mm512_mask_cmpneq_epi16_mask(nm53, a3, b3);
95
+
96
+ __mmask32 nm70 = _mm512_mask_cmpneq_epi16_mask(nm60, a, b31);
97
+ __mmask32 nm71 = _mm512_mask_cmpneq_epi16_mask(nm61, a1, b31);
98
+ __mmask32 nm72 = _mm512_mask_cmpneq_epi16_mask(nm62, a2, b31);
99
+ __mmask32 nm73 = _mm512_mask_cmpneq_epi16_mask(nm63, a3, b31);
100
+
101
+ return ~(nk_u32_t)(nm70 & nk_u32_rol(nm71, 8) & nk_u32_rol(nm72, 16) & nk_u32_ror(nm73, 8));
102
+ }
103
+
104
+ /**
105
+ * @brief Analogous to `_mm512_2intersect_epi32`, but compatible with Ice Lake CPUs,
106
+ * slightly faster than the native Tiger Lake implementation, but returns only one mask.
107
+ */
108
+ NK_INTERNAL nk_u16_t nk_intersect_u32x16_icelake_(__m512i a, __m512i b) {
109
+ __m512i a1 = _mm512_alignr_epi32(a, a, 4);
110
+ __m512i b1 = _mm512_shuffle_epi32(b, _MM_PERM_ADCB);
111
+ __mmask16 nm00 = _mm512_cmpneq_epi32_mask(a, b);
112
+
113
+ __m512i a2 = _mm512_alignr_epi32(a, a, 8);
114
+ __m512i a3 = _mm512_alignr_epi32(a, a, 12);
115
+ __mmask16 nm01 = _mm512_cmpneq_epi32_mask(a1, b);
116
+ __mmask16 nm02 = _mm512_cmpneq_epi32_mask(a2, b);
117
+
118
+ __mmask16 nm03 = _mm512_cmpneq_epi32_mask(a3, b);
119
+ __mmask16 nm10 = _mm512_mask_cmpneq_epi32_mask(nm00, a, b1);
120
+ __mmask16 nm11 = _mm512_mask_cmpneq_epi32_mask(nm01, a1, b1);
121
+
122
+ __m512i b2 = _mm512_shuffle_epi32(b, _MM_PERM_BADC);
123
+ __mmask16 nm12 = _mm512_mask_cmpneq_epi32_mask(nm02, a2, b1);
124
+ __mmask16 nm13 = _mm512_mask_cmpneq_epi32_mask(nm03, a3, b1);
125
+ __mmask16 nm20 = _mm512_mask_cmpneq_epi32_mask(nm10, a, b2);
126
+
127
+ __m512i b3 = _mm512_shuffle_epi32(b, _MM_PERM_CBAD);
128
+ __mmask16 nm21 = _mm512_mask_cmpneq_epi32_mask(nm11, a1, b2);
129
+ __mmask16 nm22 = _mm512_mask_cmpneq_epi32_mask(nm12, a2, b2);
130
+ __mmask16 nm23 = _mm512_mask_cmpneq_epi32_mask(nm13, a3, b2);
131
+
132
+ __mmask16 nm0 = _mm512_mask_cmpneq_epi32_mask(nm20, a, b3);
133
+ __mmask16 nm1 = _mm512_mask_cmpneq_epi32_mask(nm21, a1, b3);
134
+ __mmask16 nm2 = _mm512_mask_cmpneq_epi32_mask(nm22, a2, b3);
135
+ __mmask16 nm3 = _mm512_mask_cmpneq_epi32_mask(nm23, a3, b3);
136
+
137
+ return ~(nk_u16_t)(nm0 & nk_u16_rol(nm1, 4) & nk_u16_rol(nm2, 8) & nk_u16_ror(nm3, 4));
138
+ }
139
+
140
+ NK_PUBLIC void nk_sparse_intersect_u16_icelake( //
141
+ nk_u16_t const *a, nk_u16_t const *b, //
142
+ nk_size_t a_length, nk_size_t b_length, //
143
+ nk_u16_t *result, nk_size_t *count) {
144
+
145
+ #if NK_ALLOW_ISA_REDIRECT
146
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
147
+ if (a_length < 64 && b_length < 64) {
148
+ nk_sparse_intersect_u16_serial(a, b, a_length, b_length, result, count);
149
+ return;
150
+ }
151
+ #endif
152
+
153
+ nk_u16_t const *const a_end = a + a_length;
154
+ nk_u16_t const *const b_end = b + b_length;
155
+ nk_size_t c = 0;
156
+ nk_b512_vec_t a_vec, b_vec;
157
+
158
+ while (a + 32 <= a_end && b + 32 <= b_end) {
159
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
160
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
161
+
162
+ // Intersecting registers with `nk_intersect_u16x32_icelake_` involves a lot of shuffling
163
+ // and comparisons, so we want to avoid it if the slices don't overlap at all..
164
+ nk_u16_t a_min;
165
+ nk_u16_t a_max = a_vec.u16s[31];
166
+ nk_u16_t b_min = b_vec.u16s[0];
167
+ nk_u16_t b_max = b_vec.u16s[31];
168
+
169
+ // If the slices don't overlap, advance the appropriate pointer
170
+ while (a_max < b_min && a + 64 <= a_end) {
171
+ a += 32;
172
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
173
+ a_max = a_vec.u16s[31];
174
+ }
175
+ a_min = a_vec.u16s[0];
176
+ while (b_max < a_min && b + 64 <= b_end) {
177
+ b += 32;
178
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
179
+ b_max = b_vec.u16s[31];
180
+ }
181
+ b_min = b_vec.u16s[0];
182
+
183
+ __m512i a_max_u16x32 = _mm512_set1_epi16(*(short const *)&a_max);
184
+ __m512i b_max_u16x32 = _mm512_set1_epi16(*(short const *)&b_max);
185
+ __mmask32 a_step_mask = _mm512_cmple_epu16_mask(a_vec.zmm, b_max_u16x32);
186
+ __mmask32 b_step_mask = _mm512_cmple_epu16_mask(b_vec.zmm, a_max_u16x32);
187
+ a += 32 - _lzcnt_u32((nk_u32_t)a_step_mask);
188
+ b += 32 - _lzcnt_u32((nk_u32_t)b_step_mask);
189
+
190
+ // Now we are likely to have some overlap, so we can intersect the registers
191
+ __mmask32 a_matches = nk_intersect_u16x32_icelake_(a_vec.zmm, b_vec.zmm);
192
+
193
+ // Export matches if result buffer is provided
194
+ if (result) { _mm512_mask_compressstoreu_epi16(result + c, a_matches, a_vec.zmm); }
195
+ c += _mm_popcnt_u32(a_matches); // MSVC has no `_popcnt32`
196
+ }
197
+
198
+ nk_size_t tail_count = 0;
199
+ nk_sparse_intersect_u16_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
200
+ *count = c + tail_count;
201
+ }
202
+
203
+ NK_PUBLIC void nk_sparse_intersect_u32_icelake( //
204
+ nk_u32_t const *a, nk_u32_t const *b, //
205
+ nk_size_t a_length, nk_size_t b_length, //
206
+ nk_u32_t *result, nk_size_t *count) {
207
+
208
+ #if NK_ALLOW_ISA_REDIRECT
209
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
210
+ if (a_length < 32 && b_length < 32) {
211
+ nk_sparse_intersect_u32_serial(a, b, a_length, b_length, result, count);
212
+ return;
213
+ }
214
+ #endif
215
+
216
+ nk_u32_t const *const a_end = a + a_length;
217
+ nk_u32_t const *const b_end = b + b_length;
218
+ nk_size_t c = 0;
219
+ nk_b512_vec_t a_vec, b_vec;
220
+
221
+ while (a + 16 <= a_end && b + 16 <= b_end) {
222
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
223
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
224
+
225
+ // Intersecting registers with `nk_intersect_u32x16_icelake_` involves a lot of shuffling
226
+ // and comparisons, so we want to avoid it if the slices don't overlap at all..
227
+ nk_u32_t a_min;
228
+ nk_u32_t a_max = a_vec.u32s[15];
229
+ nk_u32_t b_min = b_vec.u32s[0];
230
+ nk_u32_t b_max = b_vec.u32s[15];
231
+
232
+ // If the slices don't overlap, advance the appropriate pointer
233
+ while (a_max < b_min && a + 32 <= a_end) {
234
+ a += 16;
235
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
236
+ a_max = a_vec.u32s[15];
237
+ }
238
+ a_min = a_vec.u32s[0];
239
+ while (b_max < a_min && b + 32 <= b_end) {
240
+ b += 16;
241
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
242
+ b_max = b_vec.u32s[15];
243
+ }
244
+ b_min = b_vec.u32s[0];
245
+
246
+ __m512i a_max_u32x16 = _mm512_set1_epi32(*(int const *)&a_max);
247
+ __m512i b_max_u32x16 = _mm512_set1_epi32(*(int const *)&b_max);
248
+ __mmask16 a_step_mask = _mm512_cmple_epu32_mask(a_vec.zmm, b_max_u32x16);
249
+ __mmask16 b_step_mask = _mm512_cmple_epu32_mask(b_vec.zmm, a_max_u32x16);
250
+ a += 32 - _lzcnt_u32((nk_u32_t)a_step_mask);
251
+ b += 32 - _lzcnt_u32((nk_u32_t)b_step_mask);
252
+
253
+ // Now we are likely to have some overlap, so we can intersect the registers
254
+ __mmask16 a_matches = nk_intersect_u32x16_icelake_(a_vec.zmm, b_vec.zmm);
255
+
256
+ // Export matches if result buffer is provided
257
+ if (result) { _mm512_mask_compressstoreu_epi32(result + c, a_matches, a_vec.zmm); }
258
+ c += _mm_popcnt_u32(a_matches); // MSVC has no `_popcnt32`
259
+ }
260
+
261
+ nk_size_t tail_count = 0;
262
+ nk_sparse_intersect_u32_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
263
+ *count = c + tail_count;
264
+ }
265
+
266
+ /**
267
+ * @brief Analogous to `_mm512_2intersect_epi64`, but compatible with Ice Lake CPUs,
268
+ * returns only one mask indicating which elements in `a` have a match in `b`.
269
+ */
270
+ NK_INTERNAL nk_u8_t nk_intersect_u64x8_icelake_(__m512i a, __m512i b) {
271
+ __m512i a1 = _mm512_alignr_epi64(a, a, 2);
272
+ __m512i b1 = _mm512_permutex_epi64(b, _MM_PERM_ADCB);
273
+ __mmask8 nm00 = _mm512_cmpneq_epi64_mask(a, b);
274
+
275
+ __m512i a2 = _mm512_alignr_epi64(a, a, 4);
276
+ __m512i a3 = _mm512_alignr_epi64(a, a, 6);
277
+ __mmask8 nm01 = _mm512_cmpneq_epi64_mask(a1, b);
278
+ __mmask8 nm02 = _mm512_cmpneq_epi64_mask(a2, b);
279
+
280
+ __m512i b2 = _mm512_permutex_epi64(b, _MM_PERM_BADC);
281
+ __mmask8 nm03 = _mm512_cmpneq_epi64_mask(a3, b);
282
+ __mmask8 nm10 = _mm512_mask_cmpneq_epi64_mask(nm00, a, b1);
283
+ __mmask8 nm11 = _mm512_mask_cmpneq_epi64_mask(nm01, a1, b1);
284
+
285
+ __m512i b3 = _mm512_permutex_epi64(b, _MM_PERM_CBAD);
286
+ __mmask8 nm12 = _mm512_mask_cmpneq_epi64_mask(nm02, a2, b1);
287
+ __mmask8 nm13 = _mm512_mask_cmpneq_epi64_mask(nm03, a3, b1);
288
+ __mmask8 nm20 = _mm512_mask_cmpneq_epi64_mask(nm10, a, b2);
289
+
290
+ __mmask8 nm21 = _mm512_mask_cmpneq_epi64_mask(nm11, a1, b2);
291
+ __mmask8 nm22 = _mm512_mask_cmpneq_epi64_mask(nm12, a2, b2);
292
+ __mmask8 nm23 = _mm512_mask_cmpneq_epi64_mask(nm13, a3, b2);
293
+
294
+ __mmask8 nm0 = _mm512_mask_cmpneq_epi64_mask(nm20, a, b3);
295
+ __mmask8 nm1 = _mm512_mask_cmpneq_epi64_mask(nm21, a1, b3);
296
+ __mmask8 nm2 = _mm512_mask_cmpneq_epi64_mask(nm22, a2, b3);
297
+ __mmask8 nm3 = _mm512_mask_cmpneq_epi64_mask(nm23, a3, b3);
298
+
299
+ return ~(nk_u8_t)(nm0 & nk_u8_rol(nm1, 2) & nk_u8_rol(nm2, 4) & nk_u8_ror(nm3, 2));
300
+ }
301
+
302
+ NK_PUBLIC void nk_sparse_intersect_u64_icelake( //
303
+ nk_u64_t const *a, nk_u64_t const *b, //
304
+ nk_size_t a_length, nk_size_t b_length, //
305
+ nk_u64_t *result, nk_size_t *count) {
306
+
307
+ #if NK_ALLOW_ISA_REDIRECT
308
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
309
+ if (a_length < 16 && b_length < 16) {
310
+ nk_sparse_intersect_u64_serial(a, b, a_length, b_length, result, count);
311
+ return;
312
+ }
313
+ #endif
314
+
315
+ nk_u64_t const *const a_end = a + a_length;
316
+ nk_u64_t const *const b_end = b + b_length;
317
+ nk_size_t c = 0;
318
+ nk_b512_vec_t a_vec, b_vec;
319
+
320
+ while (a + 8 <= a_end && b + 8 <= b_end) {
321
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
322
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
323
+
324
+ // Intersecting registers with `nk_intersect_u64x8_icelake_` involves a lot of shuffling
325
+ // and comparisons, so we want to avoid it if the slices don't overlap at all.
326
+ nk_u64_t a_min;
327
+ nk_u64_t a_max = a_vec.u64s[7];
328
+ nk_u64_t b_min = b_vec.u64s[0];
329
+ nk_u64_t b_max = b_vec.u64s[7];
330
+
331
+ // If the slices don't overlap, advance the appropriate pointer
332
+ while (a_max < b_min && a + 16 <= a_end) {
333
+ a += 8;
334
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
335
+ a_max = a_vec.u64s[7];
336
+ }
337
+ a_min = a_vec.u64s[0];
338
+ while (b_max < a_min && b + 16 <= b_end) {
339
+ b += 8;
340
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
341
+ b_max = b_vec.u64s[7];
342
+ }
343
+ b_min = b_vec.u64s[0];
344
+
345
+ __m512i a_max_u64x8 = _mm512_set1_epi64(*(long long const *)&a_max);
346
+ __m512i b_max_u64x8 = _mm512_set1_epi64(*(long long const *)&b_max);
347
+ __mmask8 a_step_mask = _mm512_cmple_epu64_mask(a_vec.zmm, b_max_u64x8);
348
+ __mmask8 b_step_mask = _mm512_cmple_epu64_mask(b_vec.zmm, a_max_u64x8);
349
+ a += 32 - _lzcnt_u32((nk_u32_t)a_step_mask);
350
+ b += 32 - _lzcnt_u32((nk_u32_t)b_step_mask);
351
+
352
+ // Now we are likely to have some overlap, so we can intersect the registers
353
+ __mmask8 a_matches = nk_intersect_u64x8_icelake_(a_vec.zmm, b_vec.zmm);
354
+
355
+ // Export matches if result buffer is provided
356
+ if (result) { _mm512_mask_compressstoreu_epi64(result + c, a_matches, a_vec.zmm); }
357
+ c += _mm_popcnt_u32(a_matches); // MSVC has no `_popcnt32`
358
+ }
359
+
360
+ nk_size_t tail_count = 0;
361
+ nk_sparse_intersect_u64_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
362
+ *count = c + tail_count;
363
+ }
364
+
365
+ NK_PUBLIC void nk_sparse_dot_u32f32_icelake( //
366
+ nk_u32_t const *a, nk_u32_t const *b, //
367
+ nk_f32_t const *a_weights, nk_f32_t const *b_weights, //
368
+ nk_size_t a_length, nk_size_t b_length, nk_f64_t *product) {
369
+
370
+ #if NK_ALLOW_ISA_REDIRECT
371
+ // The baseline implementation for very small arrays (2 registers or less) can be quite simple:
372
+ if (a_length < 32 && b_length < 32) {
373
+ nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_length, b_length, product);
374
+ return;
375
+ }
376
+ #endif
377
+
378
+ nk_u32_t const *const a_end = a + a_length;
379
+ nk_u32_t const *const b_end = b + b_length;
380
+ __m512d product_lower_f64x8 = _mm512_setzero_pd();
381
+ __m512d product_upper_f64x8 = _mm512_setzero_pd();
382
+ nk_b512_vec_t a_vec, b_vec;
383
+
384
+ while (a + 16 <= a_end && b + 16 <= b_end) {
385
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
386
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
387
+
388
+ // Intersecting registers with `nk_intersect_u32x16_icelake_` involves a lot of shuffling
389
+ // and comparisons, so we want to avoid it if the slices don't overlap at all.
390
+ nk_u32_t a_min;
391
+ nk_u32_t a_max = a_vec.u32s[15];
392
+ nk_u32_t b_min = b_vec.u32s[0];
393
+ nk_u32_t b_max = b_vec.u32s[15];
394
+
395
+ // If the slices don't overlap, advance the appropriate pointer
396
+ while (a_max < b_min && a + 32 <= a_end) {
397
+ a += 16;
398
+ a_weights += 16;
399
+ a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
400
+ a_max = a_vec.u32s[15];
401
+ }
402
+ a_min = a_vec.u32s[0];
403
+ while (b_max < a_min && b + 32 <= b_end) {
404
+ b += 16;
405
+ b_weights += 16;
406
+ b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
407
+ b_max = b_vec.u32s[15];
408
+ }
409
+ b_min = b_vec.u32s[0];
410
+
411
+ __m512i a_max_u32x16 = _mm512_set1_epi32(*(int const *)&a_max);
412
+ __m512i b_max_u32x16 = _mm512_set1_epi32(*(int const *)&b_max);
413
+ __mmask16 a_step_mask = _mm512_cmple_epu32_mask(a_vec.zmm, b_max_u32x16);
414
+ __mmask16 b_step_mask = _mm512_cmple_epu32_mask(b_vec.zmm, a_max_u32x16);
415
+ nk_u32_t a_advance = 32 - _lzcnt_u32((nk_u32_t)a_step_mask);
416
+ nk_u32_t b_advance = 32 - _lzcnt_u32((nk_u32_t)b_step_mask);
417
+
418
+ // Now we are likely to have some overlap, so we can intersect the registers
419
+ __mmask16 a_matches = nk_intersect_u32x16_icelake_(a_vec.zmm, b_vec.zmm);
420
+ __mmask16 b_matches = nk_intersect_u32x16_icelake_(b_vec.zmm, a_vec.zmm);
421
+ if (a_matches) {
422
+ // Load and compress matching weights at current position
423
+ __m512 a_weights_f32x16 = _mm512_loadu_ps(a_weights);
424
+ __m512 b_weights_f32x16 = _mm512_loadu_ps(b_weights);
425
+ __m512 a_matched_f32x16 = _mm512_maskz_compress_ps(a_matches, a_weights_f32x16);
426
+ __m512 b_matched_f32x16 = _mm512_maskz_compress_ps(b_matches, b_weights_f32x16);
427
+
428
+ __m256 a_matched_lower_f32x8 = _mm512_castps512_ps256(a_matched_f32x16);
429
+ __m256 a_matched_upper_f32x8 = _mm512_extractf32x8_ps(a_matched_f32x16, 1);
430
+ __m256 b_matched_lower_f32x8 = _mm512_castps512_ps256(b_matched_f32x16);
431
+ __m256 b_matched_upper_f32x8 = _mm512_extractf32x8_ps(b_matched_f32x16, 1);
432
+
433
+ product_lower_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_lower_f32x8),
434
+ _mm512_cvtps_pd(b_matched_lower_f32x8), product_lower_f64x8);
435
+ product_upper_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_upper_f32x8),
436
+ _mm512_cvtps_pd(b_matched_upper_f32x8), product_upper_f64x8);
437
+ }
438
+
439
+ // Advance pointers after processing
440
+ a += a_advance;
441
+ a_weights += a_advance;
442
+ b += b_advance;
443
+ b_weights += b_advance;
444
+ }
445
+
446
+ nk_f64_t tail_product = 0;
447
+ nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, &tail_product);
448
+ *product = _mm512_reduce_add_pd(product_lower_f64x8) + _mm512_reduce_add_pd(product_upper_f64x8) + tail_product;
449
+ }
450
+
451
+ #if defined(__clang__)
452
+ #pragma clang attribute pop
453
+ #elif defined(__GNUC__)
454
+ #pragma GCC pop_options
455
+ #endif
456
+
457
+ #if defined(__cplusplus)
458
+ } // extern "C"
459
+ #endif
460
+
461
+ #endif // NK_TARGET_ICELAKE
462
+ #endif // NK_TARGET_X86_
463
+ #endif // NK_SPARSE_ICELAKE_H