numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Turin-accelerated Sparse Vector Operations.
|
|
3
|
+
* @file include/numkong/sparse/turin.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/sparse.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPARSE_TURIN_H
|
|
10
|
+
#define NK_SPARSE_TURIN_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_X86_
|
|
13
|
+
#if NK_TARGET_TURIN
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.h"
|
|
16
|
+
|
|
17
|
+
#if defined(__cplusplus)
|
|
18
|
+
extern "C" {
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#if defined(__clang__)
|
|
22
|
+
#pragma clang attribute push( \
|
|
23
|
+
__attribute__((target( \
|
|
24
|
+
"avx2,avx512f,avx512vl,bmi,bmi2,lzcnt,popcnt,avx512bw,avx512vbmi2,avx512bf16,avx512vnni,avx512vp2intersect,avx512dq"))), \
|
|
25
|
+
apply_to = function)
|
|
26
|
+
#elif defined(__GNUC__)
|
|
27
|
+
#pragma GCC push_options
|
|
28
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "bmi", "bmi2", "lzcnt", "popcnt", "avx512bw", "avx512vbmi2", \
|
|
29
|
+
"avx512bf16", "avx512vnni", "avx512vp2intersect", "avx512dq")
|
|
30
|
+
#endif
|
|
31
|
+
|
|
32
|
+
NK_PUBLIC void nk_sparse_intersect_u16_turin( //
|
|
33
|
+
nk_u16_t const *a, nk_u16_t const *b, //
|
|
34
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
35
|
+
nk_u16_t *result, nk_size_t *count) {
|
|
36
|
+
|
|
37
|
+
//! There is no such thing as `_mm512_2intersect_epi16`, only the 32-bit variant!
|
|
38
|
+
//! So instead of jumping through 32 entries at a time, like on Ice Lake, we will
|
|
39
|
+
//! step through 16 entries at a time.
|
|
40
|
+
nk_u16_t const *const a_end = a + a_length;
|
|
41
|
+
nk_u16_t const *const b_end = b + b_length;
|
|
42
|
+
nk_size_t c = 0;
|
|
43
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
44
|
+
|
|
45
|
+
// Broadcast index for last element (hoisted outside loop)
|
|
46
|
+
__m256i const last_idx = _mm256_set1_epi16(15);
|
|
47
|
+
while (a + 16 <= a_end && b + 16 <= b_end) {
|
|
48
|
+
a_vec.ymm = _mm256_loadu_si256((__m256i const *)a);
|
|
49
|
+
b_vec.ymm = _mm256_loadu_si256((__m256i const *)b);
|
|
50
|
+
|
|
51
|
+
// Intersect the registers
|
|
52
|
+
__m512i a_i32x16 = _mm512_cvtepu16_epi32(a_vec.ymm);
|
|
53
|
+
__m512i b_i32x16 = _mm512_cvtepu16_epi32(b_vec.ymm);
|
|
54
|
+
__mmask16 a_matches_any_in_b, b_matches_any_in_a;
|
|
55
|
+
_mm512_2intersect_epi32(a_i32x16, b_i32x16, &a_matches_any_in_b, &b_matches_any_in_a);
|
|
56
|
+
|
|
57
|
+
// Export matches if result buffer is provided
|
|
58
|
+
if (result) { _mm256_mask_compressstoreu_epi16(result + c, a_matches_any_in_b, a_vec.ymm); }
|
|
59
|
+
c += _mm_popcnt_u32(a_matches_any_in_b); // MSVC has no `_popcnt32`
|
|
60
|
+
|
|
61
|
+
__m256i a_max_u16x16 = _mm256_permutexvar_epi16(last_idx, a_vec.ymm);
|
|
62
|
+
__m256i b_max_u16x16 = _mm256_permutexvar_epi16(last_idx, b_vec.ymm);
|
|
63
|
+
__mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_max_u16x16);
|
|
64
|
+
__mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_max_u16x16);
|
|
65
|
+
a += _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x10000);
|
|
66
|
+
b += _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x10000);
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
nk_size_t tail_count = 0;
|
|
70
|
+
nk_sparse_intersect_u16_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
|
|
71
|
+
*count = c + tail_count;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
NK_PUBLIC void nk_sparse_intersect_u32_turin( //
|
|
75
|
+
nk_u32_t const *a, nk_u32_t const *b, //
|
|
76
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
77
|
+
nk_u32_t *result, nk_size_t *count) {
|
|
78
|
+
|
|
79
|
+
nk_u32_t const *const a_end = a + a_length;
|
|
80
|
+
nk_u32_t const *const b_end = b + b_length;
|
|
81
|
+
nk_size_t c = 0;
|
|
82
|
+
nk_b512_vec_t a_vec, b_vec;
|
|
83
|
+
|
|
84
|
+
// Broadcast index for last element (hoisted outside loop)
|
|
85
|
+
__m512i const last_idx = _mm512_set1_epi32(15);
|
|
86
|
+
while (a + 16 <= a_end && b + 16 <= b_end) {
|
|
87
|
+
a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
|
|
88
|
+
b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
|
|
89
|
+
|
|
90
|
+
// Intersect the registers
|
|
91
|
+
__mmask16 a_matches_any_in_b, b_matches_any_in_a;
|
|
92
|
+
_mm512_2intersect_epi32(a_vec.zmm, b_vec.zmm, &a_matches_any_in_b, &b_matches_any_in_a);
|
|
93
|
+
|
|
94
|
+
// Export matches if result buffer is provided
|
|
95
|
+
if (result) { _mm512_mask_compressstoreu_epi32(result + c, a_matches_any_in_b, a_vec.zmm); }
|
|
96
|
+
c += _mm_popcnt_u32(a_matches_any_in_b); // MSVC has no `_popcnt32`
|
|
97
|
+
|
|
98
|
+
// Pure SIMD broadcasts - no scalar extraction needed
|
|
99
|
+
__m512i a_max_u32x16 = _mm512_permutexvar_epi32(last_idx, a_vec.zmm);
|
|
100
|
+
__m512i b_max_u32x16 = _mm512_permutexvar_epi32(last_idx, b_vec.zmm);
|
|
101
|
+
__mmask16 a_step_mask = _mm512_cmple_epu32_mask(a_vec.zmm, b_max_u32x16);
|
|
102
|
+
__mmask16 b_step_mask = _mm512_cmple_epu32_mask(b_vec.zmm, a_max_u32x16);
|
|
103
|
+
a += _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x10000);
|
|
104
|
+
b += _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x10000);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
nk_size_t tail_count = 0;
|
|
108
|
+
nk_sparse_intersect_u32_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
|
|
109
|
+
*count = c + tail_count;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
NK_PUBLIC void nk_sparse_intersect_u64_turin( //
|
|
113
|
+
nk_u64_t const *a, nk_u64_t const *b, //
|
|
114
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
115
|
+
nk_u64_t *result, nk_size_t *count) {
|
|
116
|
+
|
|
117
|
+
nk_u64_t const *const a_end = a + a_length;
|
|
118
|
+
nk_u64_t const *const b_end = b + b_length;
|
|
119
|
+
nk_size_t c = 0;
|
|
120
|
+
nk_b512_vec_t a_vec, b_vec;
|
|
121
|
+
|
|
122
|
+
// Broadcast index for last element (hoisted outside loop)
|
|
123
|
+
__m512i const last_idx = _mm512_set1_epi64(7);
|
|
124
|
+
while (a + 8 <= a_end && b + 8 <= b_end) {
|
|
125
|
+
a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
|
|
126
|
+
b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
|
|
127
|
+
|
|
128
|
+
// Intersect the registers
|
|
129
|
+
__mmask8 a_matches_any_in_b, b_matches_any_in_a;
|
|
130
|
+
_mm512_2intersect_epi64(a_vec.zmm, b_vec.zmm, &a_matches_any_in_b, &b_matches_any_in_a);
|
|
131
|
+
|
|
132
|
+
// Export matches if result buffer is provided
|
|
133
|
+
if (result) { _mm512_mask_compressstoreu_epi64(result + c, a_matches_any_in_b, a_vec.zmm); }
|
|
134
|
+
c += _mm_popcnt_u32(a_matches_any_in_b); // MSVC has no `_popcnt32`
|
|
135
|
+
|
|
136
|
+
// Pure SIMD broadcasts - no scalar extraction needed
|
|
137
|
+
__m512i a_max_u64x8 = _mm512_permutexvar_epi64(last_idx, a_vec.zmm);
|
|
138
|
+
__m512i b_max_u64x8 = _mm512_permutexvar_epi64(last_idx, b_vec.zmm);
|
|
139
|
+
__mmask8 a_step_mask = _mm512_cmple_epu64_mask(a_vec.zmm, b_max_u64x8);
|
|
140
|
+
__mmask8 b_step_mask = _mm512_cmple_epu64_mask(b_vec.zmm, a_max_u64x8);
|
|
141
|
+
a += _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x100);
|
|
142
|
+
b += _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x100);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
nk_size_t tail_count = 0;
|
|
146
|
+
nk_sparse_intersect_u64_serial(a, b, a_end - a, b_end - b, result ? result + c : 0, &tail_count);
|
|
147
|
+
*count = c + tail_count;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
NK_PUBLIC void nk_sparse_dot_u16bf16_turin( //
|
|
151
|
+
nk_u16_t const *a, nk_u16_t const *b, //
|
|
152
|
+
nk_bf16_t const *a_weights, nk_bf16_t const *b_weights, //
|
|
153
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
154
|
+
nk_f32_t *product) {
|
|
155
|
+
|
|
156
|
+
#if NK_ALLOW_ISA_REDIRECT
|
|
157
|
+
// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
|
|
158
|
+
if (a_length < 64 && b_length < 64) {
|
|
159
|
+
nk_sparse_dot_u16bf16_serial(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
160
|
+
return;
|
|
161
|
+
}
|
|
162
|
+
#endif
|
|
163
|
+
|
|
164
|
+
//! There is no such thing as `_mm512_2intersect_epi16`, only the 32-bit variant!
|
|
165
|
+
//! So instead of jumping through 32 entries at a time, like on Ice Lake, we will
|
|
166
|
+
//! step through 16 entries at a time.
|
|
167
|
+
nk_u16_t const *const a_end = a + a_length;
|
|
168
|
+
nk_u16_t const *const b_end = b + b_length;
|
|
169
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
170
|
+
__m256 product_f32x8 = _mm256_setzero_ps();
|
|
171
|
+
|
|
172
|
+
// Broadcast index for last element (hoisted outside loop)
|
|
173
|
+
__m256i const last_idx = _mm256_set1_epi16(15);
|
|
174
|
+
while (a + 16 <= a_end && b + 16 <= b_end) {
|
|
175
|
+
a_vec.ymm = _mm256_loadu_si256((__m256i const *)a);
|
|
176
|
+
b_vec.ymm = _mm256_loadu_si256((__m256i const *)b);
|
|
177
|
+
|
|
178
|
+
// Intersecting registers with `_mm512_2intersect_epi16_mask` involves a lot of shuffling
|
|
179
|
+
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
180
|
+
nk_u16_t a_min;
|
|
181
|
+
nk_u16_t a_max = a_vec.u16s[15];
|
|
182
|
+
nk_u16_t b_min = b_vec.u16s[0];
|
|
183
|
+
nk_u16_t b_max = b_vec.u16s[15];
|
|
184
|
+
|
|
185
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
186
|
+
while (a_max < b_min && a + 32 <= a_end) {
|
|
187
|
+
a += 16, a_weights += 16;
|
|
188
|
+
a_vec.ymm = _mm256_loadu_si256((__m256i const *)a);
|
|
189
|
+
a_max = a_vec.u16s[15];
|
|
190
|
+
}
|
|
191
|
+
a_min = a_vec.u16s[0];
|
|
192
|
+
while (b_max < a_min && b + 32 <= b_end) {
|
|
193
|
+
b += 16, b_weights += 16;
|
|
194
|
+
b_vec.ymm = _mm256_loadu_si256((__m256i const *)b);
|
|
195
|
+
b_max = b_vec.u16s[15];
|
|
196
|
+
}
|
|
197
|
+
b_min = b_vec.u16s[0];
|
|
198
|
+
|
|
199
|
+
// Now we are likely to have some overlap, so we can intersect the registers
|
|
200
|
+
__m512i a_i32x16 = _mm512_cvtepu16_epi32(a_vec.ymm);
|
|
201
|
+
__m512i b_i32x16 = _mm512_cvtepu16_epi32(b_vec.ymm);
|
|
202
|
+
__mmask16 a_matches_any_in_b, b_matches_any_in_a;
|
|
203
|
+
_mm512_2intersect_epi32(a_i32x16, b_i32x16, &a_matches_any_in_b, &b_matches_any_in_a);
|
|
204
|
+
|
|
205
|
+
// Load and shift all the relevant weights to the start of the vector before doing the dot product
|
|
206
|
+
if (a_matches_any_in_b) {
|
|
207
|
+
__m256i a_weights_bf16x16 = _mm256_loadu_si256((__m256i const *)a_weights);
|
|
208
|
+
a_weights_bf16x16 = _mm256_maskz_compress_epi16(a_matches_any_in_b, a_weights_bf16x16);
|
|
209
|
+
__m256i b_weights_bf16x16 = _mm256_loadu_si256((__m256i const *)b_weights);
|
|
210
|
+
b_weights_bf16x16 = _mm256_maskz_compress_epi16(b_matches_any_in_a, b_weights_bf16x16);
|
|
211
|
+
product_f32x8 = _mm256_dpbf16_ps(product_f32x8, nk_m256bh_from_m256i_(a_weights_bf16x16),
|
|
212
|
+
nk_m256bh_from_m256i_(b_weights_bf16x16));
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
__m256i a_max_u16x16 = _mm256_permutexvar_epi16(last_idx, a_vec.ymm);
|
|
216
|
+
__m256i b_max_u16x16 = _mm256_permutexvar_epi16(last_idx, b_vec.ymm);
|
|
217
|
+
__mmask16 a_step_mask = _mm256_cmple_epu16_mask(a_vec.ymm, b_max_u16x16);
|
|
218
|
+
__mmask16 b_step_mask = _mm256_cmple_epu16_mask(b_vec.ymm, a_max_u16x16);
|
|
219
|
+
nk_size_t a_step = _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x10000);
|
|
220
|
+
nk_size_t b_step = _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x10000);
|
|
221
|
+
a += a_step, a_weights += a_step;
|
|
222
|
+
b += b_step, b_weights += b_step;
|
|
223
|
+
}
|
|
224
|
+
nk_f32_t tail_product = 0;
|
|
225
|
+
nk_sparse_dot_u16bf16_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, &tail_product);
|
|
226
|
+
*product = tail_product + _mm512_reduce_add_ps(_mm512_insertf32x8(_mm512_setzero_ps(), product_f32x8, 0));
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
NK_PUBLIC void nk_sparse_dot_u32f32_turin( //
|
|
230
|
+
nk_u32_t const *a, nk_u32_t const *b, //
|
|
231
|
+
nk_f32_t const *a_weights, nk_f32_t const *b_weights, //
|
|
232
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
233
|
+
nk_f64_t *product) {
|
|
234
|
+
|
|
235
|
+
#if NK_ALLOW_ISA_REDIRECT
|
|
236
|
+
// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
|
|
237
|
+
if (a_length < 32 && b_length < 32) {
|
|
238
|
+
nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
239
|
+
return;
|
|
240
|
+
}
|
|
241
|
+
#endif
|
|
242
|
+
|
|
243
|
+
// Native VP2INTERSECTD works directly on u32 - no conversion needed!
|
|
244
|
+
nk_u32_t const *const a_end = a + a_length;
|
|
245
|
+
nk_u32_t const *const b_end = b + b_length;
|
|
246
|
+
__m512d product_lower_f64x8 = _mm512_setzero_pd();
|
|
247
|
+
__m512d product_upper_f64x8 = _mm512_setzero_pd();
|
|
248
|
+
nk_b512_vec_t a_vec, b_vec;
|
|
249
|
+
|
|
250
|
+
while (a + 16 <= a_end && b + 16 <= b_end) {
|
|
251
|
+
a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
|
|
252
|
+
b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
|
|
253
|
+
|
|
254
|
+
// Avoid expensive intersection if slices don't overlap at all
|
|
255
|
+
nk_u32_t a_min;
|
|
256
|
+
nk_u32_t a_max = a_vec.u32s[15];
|
|
257
|
+
nk_u32_t b_min = b_vec.u32s[0];
|
|
258
|
+
nk_u32_t b_max = b_vec.u32s[15];
|
|
259
|
+
|
|
260
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
261
|
+
while (a_max < b_min && a + 32 <= a_end) {
|
|
262
|
+
a += 16, a_weights += 16;
|
|
263
|
+
a_vec.zmm = _mm512_loadu_si512((__m512i const *)a);
|
|
264
|
+
a_max = a_vec.u32s[15];
|
|
265
|
+
}
|
|
266
|
+
a_min = a_vec.u32s[0];
|
|
267
|
+
while (b_max < a_min && b + 32 <= b_end) {
|
|
268
|
+
b += 16, b_weights += 16;
|
|
269
|
+
b_vec.zmm = _mm512_loadu_si512((__m512i const *)b);
|
|
270
|
+
b_max = b_vec.u32s[15];
|
|
271
|
+
}
|
|
272
|
+
b_min = b_vec.u32s[0];
|
|
273
|
+
|
|
274
|
+
// Native u32 intersection - no conversion needed!
|
|
275
|
+
__mmask16 a_matches, b_matches;
|
|
276
|
+
_mm512_2intersect_epi32(a_vec.zmm, b_vec.zmm, &a_matches, &b_matches);
|
|
277
|
+
|
|
278
|
+
// Load and compress matching weights, then FMA
|
|
279
|
+
if (a_matches) {
|
|
280
|
+
__m512 a_weights_f32x16 = _mm512_loadu_ps(a_weights);
|
|
281
|
+
__m512 b_weights_f32x16 = _mm512_loadu_ps(b_weights);
|
|
282
|
+
__m512 a_matched_f32x16 = _mm512_maskz_compress_ps(a_matches, a_weights_f32x16);
|
|
283
|
+
__m512 b_matched_f32x16 = _mm512_maskz_compress_ps(b_matches, b_weights_f32x16);
|
|
284
|
+
__m256 a_matched_lower_f32x8 = _mm512_castps512_ps256(a_matched_f32x16);
|
|
285
|
+
__m256 a_matched_upper_f32x8 = _mm512_extractf32x8_ps(a_matched_f32x16, 1);
|
|
286
|
+
__m256 b_matched_lower_f32x8 = _mm512_castps512_ps256(b_matched_f32x16);
|
|
287
|
+
__m256 b_matched_upper_f32x8 = _mm512_extractf32x8_ps(b_matched_f32x16, 1);
|
|
288
|
+
|
|
289
|
+
product_lower_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_lower_f32x8),
|
|
290
|
+
_mm512_cvtps_pd(b_matched_lower_f32x8), product_lower_f64x8);
|
|
291
|
+
product_upper_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_upper_f32x8),
|
|
292
|
+
_mm512_cvtps_pd(b_matched_upper_f32x8), product_upper_f64x8);
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
__m512i a_max_u32x16 = _mm512_set1_epi32(*(int const *)&a_max);
|
|
296
|
+
__m512i b_max_u32x16 = _mm512_set1_epi32(*(int const *)&b_max);
|
|
297
|
+
__mmask16 a_step_mask = _mm512_cmple_epu32_mask(a_vec.zmm, b_max_u32x16);
|
|
298
|
+
__mmask16 b_step_mask = _mm512_cmple_epu32_mask(b_vec.zmm, a_max_u32x16);
|
|
299
|
+
nk_size_t a_step = _tzcnt_u32(~(nk_u32_t)a_step_mask | 0x10000);
|
|
300
|
+
nk_size_t b_step = _tzcnt_u32(~(nk_u32_t)b_step_mask | 0x10000);
|
|
301
|
+
a += a_step, a_weights += a_step;
|
|
302
|
+
b += b_step, b_weights += b_step;
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
nk_f64_t tail_product = 0;
|
|
306
|
+
nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, &tail_product);
|
|
307
|
+
*product = _mm512_reduce_add_pd(product_lower_f64x8) + _mm512_reduce_add_pd(product_upper_f64x8) + tail_product;
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
#if defined(__clang__)
|
|
311
|
+
#pragma clang attribute pop
|
|
312
|
+
#elif defined(__GNUC__)
|
|
313
|
+
#pragma GCC pop_options
|
|
314
|
+
#endif
|
|
315
|
+
|
|
316
|
+
#if defined(__cplusplus)
|
|
317
|
+
} // extern "C"
|
|
318
|
+
#endif
|
|
319
|
+
|
|
320
|
+
#endif // NK_TARGET_TURIN
|
|
321
|
+
#endif // NK_TARGET_X86_
|
|
322
|
+
#endif // NK_SPARSE_TURIN_H
|
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Sparse Vector Dot Products.
|
|
3
|
+
* @file include/numkong/sparse.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 21, 2024
|
|
6
|
+
*
|
|
7
|
+
* Contains:
|
|
8
|
+
*
|
|
9
|
+
* - Set intersection for sorted unique arrays → `u32` count
|
|
10
|
+
* - Sparse dot products for weighted sparse vectors
|
|
11
|
+
*
|
|
12
|
+
* For dtypes:
|
|
13
|
+
*
|
|
14
|
+
* - `u16`: indices for vocabularies under 64 thousand tokens
|
|
15
|
+
* - `u32`: indices for vocabularies under 4 billion tokens
|
|
16
|
+
* - `u64`: indices for trillion-scale combinatorics and graphs
|
|
17
|
+
* - `u16` indices + `bf16` weights → `f32` product
|
|
18
|
+
* - `u32` indices + `f32` weights → `f64` product
|
|
19
|
+
*
|
|
20
|
+
* For hardware architectures:
|
|
21
|
+
*
|
|
22
|
+
* - Arm: NEON, SVE2
|
|
23
|
+
* - x86: Ice Lake, Turin
|
|
24
|
+
*
|
|
25
|
+
* @section intersection_algorithm Intersection by Merge
|
|
26
|
+
*
|
|
27
|
+
* The core primitive is analogous to `std::set_intersection`, taking two sorted arrays
|
|
28
|
+
* of unique values and producing the intersection size:
|
|
29
|
+
*
|
|
30
|
+
* std::size_t intersection_size = 0;
|
|
31
|
+
* while (i != a_length && j != b_length) {
|
|
32
|
+
* scalar_t ai = a[i], bj = b[j];
|
|
33
|
+
* intersection_size += ai == bj;
|
|
34
|
+
* i += ai < bj;
|
|
35
|
+
* j += ai ≥ bj;
|
|
36
|
+
* }
|
|
37
|
+
*
|
|
38
|
+
* Weighted sparse dot-products follow the same merge loop, but accumulate a product
|
|
39
|
+
* for matching indices. For the `u32+f32` family the matched products are widened before
|
|
40
|
+
* accumulation, matching the widened `f64` public result.
|
|
41
|
+
*
|
|
42
|
+
* double product = 0;
|
|
43
|
+
* while (i != a_length && j != b_length) {
|
|
44
|
+
* scalar_t ai = a[i], bj = b[j];
|
|
45
|
+
* product += ai == bj ? a_weights[i] * b_weights[j] : 0;
|
|
46
|
+
* i += ai < bj;
|
|
47
|
+
* j += ai ≥ bj;
|
|
48
|
+
* }
|
|
49
|
+
*
|
|
50
|
+
* @section galloping_search Galloping vs Linear
|
|
51
|
+
*
|
|
52
|
+
* When the arrays are highly imbalanced, linear merge wastes cycles skipping elements.
|
|
53
|
+
* The serial implementation switches to a galloping search to jump over large gaps.
|
|
54
|
+
*
|
|
55
|
+
* @section x86_instructions Relevant x86 Instructions
|
|
56
|
+
*
|
|
57
|
+
* The Ice Lake kernels are shuffle/compare heavy; their throughput is often gated by port 5.
|
|
58
|
+
* On Genoa, many integer ops dual-issue on FP ports, often improving throughput despite higher latency.
|
|
59
|
+
*
|
|
60
|
+
* Intrinsic Instruction Ice Genoa
|
|
61
|
+
* _mm512_shuffle_epi32 VPSHUFD (ZMM, ZMM, I8) 1c @ p5 1c @ p123
|
|
62
|
+
* _mm512_mask_cmpneq_epi32_mask VPCMPD (K, ZMM, ZMM, I8) 3c @ p5 5c @ p01
|
|
63
|
+
* _mm512_alignr_epi32 VALIGND (ZMM, ZMM, ZMM, I8) 3c @ p5 6c @ p12
|
|
64
|
+
* _mm512_conflict_epi32 VPCONFLICTD (ZMM, ZMM) 26c @ p0/5 7c @ p01/12
|
|
65
|
+
* _mm256_maskz_compress_epi16 VPCOMPRESSW (YMM, K, YMM) 3-6c @ p5 4-8c @ p01/12
|
|
66
|
+
* _mm256_dpwssds_epi32 VPDPWSSDS (YMM, K, YMM, YMM) 4-5c @ p01 4c @ p01
|
|
67
|
+
* _mm256_dpbf16_ps VDPBF16PS (YMM, YMM, YMM) n/a 6c @ p01
|
|
68
|
+
*
|
|
69
|
+
* VP2INTERSECTD is unsupported on Ice Lake and not yet covered by uops.info for Zen5/Turin.
|
|
70
|
+
* Tiger Lake measures ~36-41c @ p5 for ZMM variants, which is why we always avoid it on Intel.
|
|
71
|
+
*
|
|
72
|
+
* @section references References
|
|
73
|
+
*
|
|
74
|
+
* - uops.info: https://uops.info/
|
|
75
|
+
* - Intel Intrinsics Guide: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
76
|
+
* - Arm Intrinsics Reference: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
77
|
+
* - vp2intersect experiments: https://github.com/mozonaut/vp2intersect
|
|
78
|
+
* - Diez-Canas "Faster-Than-Native Alternatives for x86 VP2INTERSECT Instructions":
|
|
79
|
+
* https://arxiv.org/pdf/2112.06342.pdf
|
|
80
|
+
*
|
|
81
|
+
*/
|
|
82
|
+
#ifndef NK_SPARSE_H
|
|
83
|
+
#define NK_SPARSE_H
|
|
84
|
+
|
|
85
|
+
#include "numkong/types.h"
|
|
86
|
+
|
|
87
|
+
#if defined(__cplusplus)
|
|
88
|
+
extern "C" {
|
|
89
|
+
#endif
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* @brief Set intersection between two sorted u16 arrays.
|
|
93
|
+
*
|
|
94
|
+
* @param[in] a The first sorted array of indices.
|
|
95
|
+
* @param[in] b The second sorted array of indices.
|
|
96
|
+
* @param[in] a_length The number of elements in the first array.
|
|
97
|
+
* @param[in] b_length The number of elements in the second array.
|
|
98
|
+
* @param[out] result Output buffer for intersection elements, or NULL to count only.
|
|
99
|
+
* @param[out] count The output intersection count.
|
|
100
|
+
*
|
|
101
|
+
* @note Inputs must be sorted in ascending order and contain unique elements.
|
|
102
|
+
*/
|
|
103
|
+
NK_DYNAMIC void nk_sparse_intersect_u16( //
|
|
104
|
+
nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length, nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
|
|
105
|
+
|
|
106
|
+
/**
|
|
107
|
+
* @brief Set intersection between two sorted u32 arrays.
|
|
108
|
+
*
|
|
109
|
+
* @param[in] a The first sorted array of indices.
|
|
110
|
+
* @param[in] b The second sorted array of indices.
|
|
111
|
+
* @param[in] a_length The number of elements in the first array.
|
|
112
|
+
* @param[in] b_length The number of elements in the second array.
|
|
113
|
+
* @param[out] result Output buffer for intersection elements, or NULL to count only.
|
|
114
|
+
* @param[out] count The output intersection count.
|
|
115
|
+
*
|
|
116
|
+
* @note Inputs must be sorted in ascending order and contain unique elements.
|
|
117
|
+
*/
|
|
118
|
+
NK_DYNAMIC void nk_sparse_intersect_u32( //
|
|
119
|
+
nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length, nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
|
|
120
|
+
|
|
121
|
+
/**
|
|
122
|
+
* @brief Set intersection between two sorted u64 arrays.
|
|
123
|
+
*
|
|
124
|
+
* @param[in] a The first sorted array of indices.
|
|
125
|
+
* @param[in] b The second sorted array of indices.
|
|
126
|
+
* @param[in] a_length The number of elements in the first array.
|
|
127
|
+
* @param[in] b_length The number of elements in the second array.
|
|
128
|
+
* @param[out] result Output buffer for intersection elements, or NULL to count only.
|
|
129
|
+
* @param[out] count The output intersection count.
|
|
130
|
+
*
|
|
131
|
+
* @note Inputs must be sorted in ascending order and contain unique elements.
|
|
132
|
+
*/
|
|
133
|
+
NK_DYNAMIC void nk_sparse_intersect_u64( //
|
|
134
|
+
nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length, nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
|
|
135
|
+
|
|
136
|
+
/**
|
|
137
|
+
* @brief Sparse dot-product over u16 indices with bf16 weights.
|
|
138
|
+
*
|
|
139
|
+
* @param[in] a The first sorted array of indices.
|
|
140
|
+
* @param[in] b The second sorted array of indices.
|
|
141
|
+
* @param[in] a_weights The bf16 weights for the first array.
|
|
142
|
+
* @param[in] b_weights The bf16 weights for the second array.
|
|
143
|
+
* @param[in] a_length The number of elements in the first array.
|
|
144
|
+
* @param[in] b_length The number of elements in the second array.
|
|
145
|
+
* @param[out] product The output dot product.
|
|
146
|
+
*
|
|
147
|
+
* @note Inputs must be sorted in ascending order and contain unique elements.
|
|
148
|
+
*/
|
|
149
|
+
NK_DYNAMIC void nk_sparse_dot_u16bf16( //
|
|
150
|
+
nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights, nk_bf16_t const *b_weights, nk_size_t a_length,
|
|
151
|
+
nk_size_t b_length, nk_f32_t *product);
|
|
152
|
+
|
|
153
|
+
/**
|
|
154
|
+
* @brief Sparse dot-product over u32 indices with f32 weights.
|
|
155
|
+
*
|
|
156
|
+
* @param[in] a The first sorted array of indices.
|
|
157
|
+
* @param[in] b The second sorted array of indices.
|
|
158
|
+
* @param[in] a_weights The f32 weights for the first array.
|
|
159
|
+
* @param[in] b_weights The f32 weights for the second array.
|
|
160
|
+
* @param[in] a_length The number of elements in the first array.
|
|
161
|
+
* @param[in] b_length The number of elements in the second array.
|
|
162
|
+
* @param[out] product The output dot product.
|
|
163
|
+
*
|
|
164
|
+
* @note Inputs must be sorted in ascending order and contain unique elements.
|
|
165
|
+
*/
|
|
166
|
+
NK_DYNAMIC void nk_sparse_dot_u32f32( //
|
|
167
|
+
nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights, nk_f32_t const *b_weights, nk_size_t a_length,
|
|
168
|
+
nk_size_t b_length, nk_f64_t *product);
|
|
169
|
+
|
|
170
|
+
/** @copydoc nk_sparse_intersect_u16 */
|
|
171
|
+
NK_PUBLIC void nk_sparse_intersect_u16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
|
|
172
|
+
nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
|
|
173
|
+
/** @copydoc nk_sparse_intersect_u32 */
|
|
174
|
+
NK_PUBLIC void nk_sparse_intersect_u32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
|
|
175
|
+
nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
|
|
176
|
+
/** @copydoc nk_sparse_intersect_u64 */
|
|
177
|
+
NK_PUBLIC void nk_sparse_intersect_u64_serial(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
|
|
178
|
+
nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
|
|
179
|
+
/** @copydoc nk_sparse_dot_u16bf16 */
|
|
180
|
+
NK_PUBLIC void nk_sparse_dot_u16bf16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights,
|
|
181
|
+
nk_bf16_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
182
|
+
nk_f32_t *product);
|
|
183
|
+
/** @copydoc nk_sparse_dot_u32f32 */
|
|
184
|
+
NK_PUBLIC void nk_sparse_dot_u32f32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
|
|
185
|
+
nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
186
|
+
nk_f64_t *product);
|
|
187
|
+
|
|
188
|
+
#if NK_TARGET_NEON
|
|
189
|
+
/** @copydoc nk_sparse_intersect_u16 */
|
|
190
|
+
NK_PUBLIC void nk_sparse_intersect_u16_neon(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
|
|
191
|
+
nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
|
|
192
|
+
/** @copydoc nk_sparse_intersect_u32 */
|
|
193
|
+
NK_PUBLIC void nk_sparse_intersect_u32_neon(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
|
|
194
|
+
nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
|
|
195
|
+
/** @copydoc nk_sparse_intersect_u64 */
|
|
196
|
+
NK_PUBLIC void nk_sparse_intersect_u64_neon(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
|
|
197
|
+
nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
|
|
198
|
+
#endif // NK_TARGET_NEON
|
|
199
|
+
|
|
200
|
+
#if NK_TARGET_SVE2
|
|
201
|
+
/** @copydoc nk_sparse_intersect_u16 */
|
|
202
|
+
NK_PUBLIC void nk_sparse_intersect_u16_sve2(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
|
|
203
|
+
nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
|
|
204
|
+
/** @copydoc nk_sparse_intersect_u32 */
|
|
205
|
+
NK_PUBLIC void nk_sparse_intersect_u32_sve2(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
|
|
206
|
+
nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
|
|
207
|
+
/** @copydoc nk_sparse_intersect_u64 */
|
|
208
|
+
NK_PUBLIC void nk_sparse_intersect_u64_sve2(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
|
|
209
|
+
nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
|
|
210
|
+
/** @copydoc nk_sparse_dot_u32f32 */
|
|
211
|
+
NK_PUBLIC void nk_sparse_dot_u32f32_sve2(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
|
|
212
|
+
nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
213
|
+
nk_f64_t *product);
|
|
214
|
+
#endif // NK_TARGET_SVE2
|
|
215
|
+
|
|
216
|
+
#if NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
|
|
217
|
+
/** @copydoc nk_sparse_dot_u16bf16 */
|
|
218
|
+
NK_PUBLIC void nk_sparse_dot_u16bf16_sve2(nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights,
|
|
219
|
+
nk_bf16_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
220
|
+
nk_f32_t *product);
|
|
221
|
+
#endif // NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
|
|
222
|
+
|
|
223
|
+
#if NK_TARGET_ICELAKE
|
|
224
|
+
/** @copydoc nk_sparse_intersect_u16 */
|
|
225
|
+
NK_PUBLIC void nk_sparse_intersect_u16_icelake(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
|
|
226
|
+
nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
|
|
227
|
+
/** @copydoc nk_sparse_intersect_u32 */
|
|
228
|
+
NK_PUBLIC void nk_sparse_intersect_u32_icelake(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
|
|
229
|
+
nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
|
|
230
|
+
/** @copydoc nk_sparse_intersect_u64 */
|
|
231
|
+
NK_PUBLIC void nk_sparse_intersect_u64_icelake(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
|
|
232
|
+
nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
|
|
233
|
+
/** @copydoc nk_sparse_dot_u32f32 */
|
|
234
|
+
NK_PUBLIC void nk_sparse_dot_u32f32_icelake(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
|
|
235
|
+
nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
236
|
+
nk_f64_t *product);
|
|
237
|
+
#endif // NK_TARGET_ICELAKE
|
|
238
|
+
|
|
239
|
+
#if NK_TARGET_TURIN
|
|
240
|
+
/** @copydoc nk_sparse_intersect_u16 */
|
|
241
|
+
NK_PUBLIC void nk_sparse_intersect_u16_turin(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length,
|
|
242
|
+
nk_size_t b_length, nk_u16_t *result, nk_size_t *count);
|
|
243
|
+
/** @copydoc nk_sparse_intersect_u32 */
|
|
244
|
+
NK_PUBLIC void nk_sparse_intersect_u32_turin(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length,
|
|
245
|
+
nk_size_t b_length, nk_u32_t *result, nk_size_t *count);
|
|
246
|
+
/** @copydoc nk_sparse_intersect_u64 */
|
|
247
|
+
NK_PUBLIC void nk_sparse_intersect_u64_turin(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length,
|
|
248
|
+
nk_size_t b_length, nk_u64_t *result, nk_size_t *count);
|
|
249
|
+
/** @copydoc nk_sparse_dot_u16bf16 */
|
|
250
|
+
NK_PUBLIC void nk_sparse_dot_u16bf16_turin(nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights,
|
|
251
|
+
nk_bf16_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
252
|
+
nk_f32_t *product);
|
|
253
|
+
/** @copydoc nk_sparse_dot_u32f32 */
|
|
254
|
+
NK_PUBLIC void nk_sparse_dot_u32f32_turin(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
|
|
255
|
+
nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
256
|
+
nk_f64_t *product);
|
|
257
|
+
#endif // NK_TARGET_TURIN
|
|
258
|
+
|
|
259
|
+
/**
|
|
260
|
+
* @brief Returns the output dtype for sparse dot products.
|
|
261
|
+
*/
|
|
262
|
+
NK_INTERNAL nk_dtype_t nk_sparse_dot_output_dtype(nk_dtype_t dtype) {
|
|
263
|
+
switch (dtype) {
|
|
264
|
+
case nk_f32_k: return nk_f64_k;
|
|
265
|
+
case nk_bf16_k: return nk_f32_k;
|
|
266
|
+
default: return nk_dtype_unknown_k;
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
#if defined(__cplusplus)
|
|
271
|
+
} // extern "C"
|
|
272
|
+
#endif
|
|
273
|
+
|
|
274
|
+
#include "numkong/sparse/serial.h"
|
|
275
|
+
#include "numkong/sparse/neon.h"
|
|
276
|
+
#include "numkong/sparse/sve2.h"
|
|
277
|
+
#include "numkong/sparse/icelake.h"
|
|
278
|
+
#include "numkong/sparse/turin.h"
|
|
279
|
+
|
|
280
|
+
#if defined(__cplusplus)
|
|
281
|
+
extern "C" {
|
|
282
|
+
#endif
|
|
283
|
+
|
|
284
|
+
#if !NK_DYNAMIC_DISPATCH
|
|
285
|
+
|
|
286
|
+
NK_PUBLIC void nk_sparse_intersect_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t a_length, nk_size_t b_length,
|
|
287
|
+
nk_u16_t *result, nk_size_t *count) {
|
|
288
|
+
#if NK_TARGET_SVE2
|
|
289
|
+
nk_sparse_intersect_u16_sve2(a, b, a_length, b_length, result, count);
|
|
290
|
+
#elif NK_TARGET_NEON
|
|
291
|
+
nk_sparse_intersect_u16_neon(a, b, a_length, b_length, result, count);
|
|
292
|
+
#elif NK_TARGET_TURIN
|
|
293
|
+
nk_sparse_intersect_u16_turin(a, b, a_length, b_length, result, count);
|
|
294
|
+
#elif NK_TARGET_ICELAKE
|
|
295
|
+
nk_sparse_intersect_u16_icelake(a, b, a_length, b_length, result, count);
|
|
296
|
+
#else
|
|
297
|
+
nk_sparse_intersect_u16_serial(a, b, a_length, b_length, result, count);
|
|
298
|
+
#endif
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
NK_PUBLIC void nk_sparse_intersect_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t a_length, nk_size_t b_length,
|
|
302
|
+
nk_u32_t *result, nk_size_t *count) {
|
|
303
|
+
#if NK_TARGET_SVE2
|
|
304
|
+
nk_sparse_intersect_u32_sve2(a, b, a_length, b_length, result, count);
|
|
305
|
+
#elif NK_TARGET_NEON
|
|
306
|
+
nk_sparse_intersect_u32_neon(a, b, a_length, b_length, result, count);
|
|
307
|
+
#elif NK_TARGET_TURIN
|
|
308
|
+
nk_sparse_intersect_u32_turin(a, b, a_length, b_length, result, count);
|
|
309
|
+
#elif NK_TARGET_ICELAKE
|
|
310
|
+
nk_sparse_intersect_u32_icelake(a, b, a_length, b_length, result, count);
|
|
311
|
+
#else
|
|
312
|
+
nk_sparse_intersect_u32_serial(a, b, a_length, b_length, result, count);
|
|
313
|
+
#endif
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
NK_PUBLIC void nk_sparse_intersect_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t a_length, nk_size_t b_length,
|
|
317
|
+
nk_u64_t *result, nk_size_t *count) {
|
|
318
|
+
#if NK_TARGET_SVE2
|
|
319
|
+
nk_sparse_intersect_u64_sve2(a, b, a_length, b_length, result, count);
|
|
320
|
+
#elif NK_TARGET_NEON
|
|
321
|
+
nk_sparse_intersect_u64_neon(a, b, a_length, b_length, result, count);
|
|
322
|
+
#elif NK_TARGET_TURIN
|
|
323
|
+
nk_sparse_intersect_u64_turin(a, b, a_length, b_length, result, count);
|
|
324
|
+
#elif NK_TARGET_ICELAKE
|
|
325
|
+
nk_sparse_intersect_u64_icelake(a, b, a_length, b_length, result, count);
|
|
326
|
+
#else
|
|
327
|
+
nk_sparse_intersect_u64_serial(a, b, a_length, b_length, result, count);
|
|
328
|
+
#endif
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
NK_PUBLIC void nk_sparse_dot_u16bf16(nk_u16_t const *a, nk_u16_t const *b, nk_bf16_t const *a_weights,
|
|
332
|
+
nk_bf16_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
333
|
+
nk_f32_t *product) {
|
|
334
|
+
#if NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
|
|
335
|
+
nk_sparse_dot_u16bf16_sve2(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
336
|
+
#elif NK_TARGET_TURIN
|
|
337
|
+
nk_sparse_dot_u16bf16_turin(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
338
|
+
#else
|
|
339
|
+
nk_sparse_dot_u16bf16_serial(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
340
|
+
#endif
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
NK_PUBLIC void nk_sparse_dot_u32f32(nk_u32_t const *a, nk_u32_t const *b, nk_f32_t const *a_weights,
|
|
344
|
+
nk_f32_t const *b_weights, nk_size_t a_length, nk_size_t b_length,
|
|
345
|
+
nk_f64_t *product) {
|
|
346
|
+
#if NK_TARGET_SVE2
|
|
347
|
+
nk_sparse_dot_u32f32_sve2(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
348
|
+
#elif NK_TARGET_TURIN
|
|
349
|
+
nk_sparse_dot_u32f32_turin(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
350
|
+
#elif NK_TARGET_ICELAKE
|
|
351
|
+
nk_sparse_dot_u32f32_icelake(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
352
|
+
#else
|
|
353
|
+
nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_length, b_length, product);
|
|
354
|
+
#endif
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
#endif // !NK_DYNAMIC_DISPATCH
|
|
358
|
+
|
|
359
|
+
#if defined(__cplusplus)
|
|
360
|
+
} // extern "C"
|
|
361
|
+
#endif
|
|
362
|
+
|
|
363
|
+
#endif
|