numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief NEON-accelerated Sparse Vector Operations.
|
|
3
|
+
* @file include/numkong/sparse/neon.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/sparse.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPARSE_NEON_H
|
|
10
|
+
#define NK_SPARSE_NEON_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_ARM_
|
|
13
|
+
#if NK_TARGET_NEON
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.h"
|
|
16
|
+
|
|
17
|
+
#if defined(__cplusplus)
|
|
18
|
+
extern "C" {
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#if defined(__clang__)
|
|
22
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8-a"))), apply_to = function)
|
|
23
|
+
#elif defined(__GNUC__)
|
|
24
|
+
#pragma GCC push_options
|
|
25
|
+
#pragma GCC target("arch=armv8-a")
|
|
26
|
+
#endif
|
|
27
|
+
|
|
28
|
+
NK_INTERNAL uint32x4_t nk_intersect_u32x4_neon_(uint32x4_t a, uint32x4_t b) {
|
|
29
|
+
uint32x4_t b_rot1 = vextq_u32(b, b, 1);
|
|
30
|
+
uint32x4_t b_rot2 = vextq_u32(b, b, 2);
|
|
31
|
+
uint32x4_t b_rot3 = vextq_u32(b, b, 3);
|
|
32
|
+
uint32x4_t matches_rot0 = vceqq_u32(a, b);
|
|
33
|
+
uint32x4_t matches_rot1 = vceqq_u32(a, b_rot1);
|
|
34
|
+
uint32x4_t matches_rot2 = vceqq_u32(a, b_rot2);
|
|
35
|
+
uint32x4_t matches_rot3 = vceqq_u32(a, b_rot3);
|
|
36
|
+
uint32x4_t matches = vorrq_u32(vorrq_u32(matches_rot0, matches_rot1), vorrq_u32(matches_rot2, matches_rot3));
|
|
37
|
+
return matches;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
NK_INTERNAL uint16x8_t nk_intersect_u16x8_neon_(uint16x8_t a, uint16x8_t b) {
|
|
41
|
+
uint16x8_t b_rot1 = vextq_u16(b, b, 1);
|
|
42
|
+
uint16x8_t b_rot2 = vextq_u16(b, b, 2);
|
|
43
|
+
uint16x8_t b_rot3 = vextq_u16(b, b, 3);
|
|
44
|
+
uint16x8_t b_rot4 = vextq_u16(b, b, 4);
|
|
45
|
+
uint16x8_t b_rot5 = vextq_u16(b, b, 5);
|
|
46
|
+
uint16x8_t b_rot6 = vextq_u16(b, b, 6);
|
|
47
|
+
uint16x8_t b_rot7 = vextq_u16(b, b, 7);
|
|
48
|
+
uint16x8_t matches_rot0 = vceqq_u16(a, b);
|
|
49
|
+
uint16x8_t matches_rot1 = vceqq_u16(a, b_rot1);
|
|
50
|
+
uint16x8_t matches_rot2 = vceqq_u16(a, b_rot2);
|
|
51
|
+
uint16x8_t matches_rot3 = vceqq_u16(a, b_rot3);
|
|
52
|
+
uint16x8_t matches_rot4 = vceqq_u16(a, b_rot4);
|
|
53
|
+
uint16x8_t matches_rot5 = vceqq_u16(a, b_rot5);
|
|
54
|
+
uint16x8_t matches_rot6 = vceqq_u16(a, b_rot6);
|
|
55
|
+
uint16x8_t matches_rot7 = vceqq_u16(a, b_rot7);
|
|
56
|
+
uint16x8_t matches = vorrq_u16(
|
|
57
|
+
vorrq_u16(vorrq_u16(matches_rot0, matches_rot1), vorrq_u16(matches_rot2, matches_rot3)),
|
|
58
|
+
vorrq_u16(vorrq_u16(matches_rot4, matches_rot5), vorrq_u16(matches_rot6, matches_rot7)));
|
|
59
|
+
return matches;
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
NK_PUBLIC void nk_sparse_intersect_u16_neon( //
|
|
63
|
+
nk_u16_t const *a, nk_u16_t const *b, //
|
|
64
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
65
|
+
nk_u16_t *result, nk_size_t *count) {
|
|
66
|
+
|
|
67
|
+
// NEON lacks compress-store, so fall back to serial for result output
|
|
68
|
+
if (result) {
|
|
69
|
+
nk_sparse_intersect_u16_serial(a, b, a_length, b_length, result, count);
|
|
70
|
+
return;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
#if NK_ALLOW_ISA_REDIRECT
|
|
74
|
+
// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
|
|
75
|
+
if (a_length < 32 && b_length < 32) {
|
|
76
|
+
nk_sparse_intersect_u16_serial(a, b, a_length, b_length, result, count);
|
|
77
|
+
return;
|
|
78
|
+
}
|
|
79
|
+
#endif
|
|
80
|
+
|
|
81
|
+
nk_u16_t const *const a_end = a + a_length;
|
|
82
|
+
nk_u16_t const *const b_end = b + b_length;
|
|
83
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
84
|
+
uint16x8_t c_counts_u16x8 = vdupq_n_u16(0);
|
|
85
|
+
|
|
86
|
+
while (a + 8 <= a_end && b + 8 <= b_end) {
|
|
87
|
+
a_vec.u16x8 = vld1q_u16(a);
|
|
88
|
+
b_vec.u16x8 = vld1q_u16(b);
|
|
89
|
+
|
|
90
|
+
// Intersecting registers with `nk_intersect_u16x8_neon_` involves a lot of shuffling
|
|
91
|
+
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
92
|
+
nk_u16_t a_min;
|
|
93
|
+
nk_u16_t a_max = a_vec.u16s[7];
|
|
94
|
+
nk_u16_t b_min = b_vec.u16s[0];
|
|
95
|
+
nk_u16_t b_max = b_vec.u16s[7];
|
|
96
|
+
|
|
97
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
98
|
+
while (a_max < b_min && a + 16 <= a_end) {
|
|
99
|
+
a += 8;
|
|
100
|
+
a_vec.u16x8 = vld1q_u16(a);
|
|
101
|
+
a_max = a_vec.u16s[7];
|
|
102
|
+
}
|
|
103
|
+
a_min = a_vec.u16s[0];
|
|
104
|
+
while (b_max < a_min && b + 16 <= b_end) {
|
|
105
|
+
b += 8;
|
|
106
|
+
b_vec.u16x8 = vld1q_u16(b);
|
|
107
|
+
b_max = b_vec.u16s[7];
|
|
108
|
+
}
|
|
109
|
+
b_min = b_vec.u16s[0];
|
|
110
|
+
|
|
111
|
+
// Transform match-masks into "ones", accumulate them between the cycles,
|
|
112
|
+
// and merge all together in the end.
|
|
113
|
+
uint16x8_t a_matches = nk_intersect_u16x8_neon_(a_vec.u16x8, b_vec.u16x8);
|
|
114
|
+
c_counts_u16x8 = vaddq_u16(c_counts_u16x8, vandq_u16(a_matches, vdupq_n_u16(1)));
|
|
115
|
+
|
|
116
|
+
// Use `vclz_u32` to compute leading zeros for both `a_step` and `b_step` in parallel.
|
|
117
|
+
// Narrow comparison masks from 128→64→32 bits, pack both into a `uint32x2_t`.
|
|
118
|
+
uint16x8_t a_inrange_u16x8 = vcleq_u16(a_vec.u16x8, vdupq_n_u16(b_max));
|
|
119
|
+
uint16x8_t b_inrange_u16x8 = vcleq_u16(b_vec.u16x8, vdupq_n_u16(a_max));
|
|
120
|
+
uint8x8_t a_narrow_u8x8 = vmovn_u16(a_inrange_u16x8);
|
|
121
|
+
uint8x8_t b_narrow_u8x8 = vmovn_u16(b_inrange_u16x8);
|
|
122
|
+
uint8x8_t packed_u8x8 = vshrn_n_u16(vreinterpretq_u16_u8(vcombine_u8(a_narrow_u8x8, b_narrow_u8x8)), 4);
|
|
123
|
+
uint32x2_t clz_u32x2 = vclz_u32(vreinterpret_u32_u8(packed_u8x8));
|
|
124
|
+
a += (32 - vget_lane_u32(clz_u32x2, 0)) / 4;
|
|
125
|
+
b += (32 - vget_lane_u32(clz_u32x2, 1)) / 4;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
nk_size_t tail_count = 0;
|
|
129
|
+
nk_sparse_intersect_u16_serial(a, b, a_end - a, b_end - b, 0, &tail_count);
|
|
130
|
+
*count = tail_count + (nk_size_t)vaddvq_u16(c_counts_u16x8);
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
NK_PUBLIC void nk_sparse_intersect_u32_neon( //
|
|
134
|
+
nk_u32_t const *a, nk_u32_t const *b, //
|
|
135
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
136
|
+
nk_u32_t *result, nk_size_t *count) {
|
|
137
|
+
|
|
138
|
+
// NEON lacks compress-store, so fall back to serial for result output
|
|
139
|
+
if (result) {
|
|
140
|
+
nk_sparse_intersect_u32_serial(a, b, a_length, b_length, result, count);
|
|
141
|
+
return;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
#if NK_ALLOW_ISA_REDIRECT
|
|
145
|
+
// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
|
|
146
|
+
if (a_length < 32 && b_length < 32) {
|
|
147
|
+
nk_sparse_intersect_u32_serial(a, b, a_length, b_length, result, count);
|
|
148
|
+
return;
|
|
149
|
+
}
|
|
150
|
+
#endif
|
|
151
|
+
|
|
152
|
+
nk_u32_t const *const a_end = a + a_length;
|
|
153
|
+
nk_u32_t const *const b_end = b + b_length;
|
|
154
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
155
|
+
uint32x4_t c_counts_u32x4 = vdupq_n_u32(0);
|
|
156
|
+
|
|
157
|
+
while (a + 4 <= a_end && b + 4 <= b_end) {
|
|
158
|
+
a_vec.u32x4 = vld1q_u32(a);
|
|
159
|
+
b_vec.u32x4 = vld1q_u32(b);
|
|
160
|
+
|
|
161
|
+
// Intersecting registers with `nk_intersect_u32x4_neon_` involves a lot of shuffling
|
|
162
|
+
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
163
|
+
nk_u32_t a_min;
|
|
164
|
+
nk_u32_t a_max = a_vec.u32s[3];
|
|
165
|
+
nk_u32_t b_min = b_vec.u32s[0];
|
|
166
|
+
nk_u32_t b_max = b_vec.u32s[3];
|
|
167
|
+
|
|
168
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
169
|
+
while (a_max < b_min && a + 8 <= a_end) {
|
|
170
|
+
a += 4;
|
|
171
|
+
a_vec.u32x4 = vld1q_u32(a);
|
|
172
|
+
a_max = a_vec.u32s[3];
|
|
173
|
+
}
|
|
174
|
+
a_min = a_vec.u32s[0];
|
|
175
|
+
while (b_max < a_min && b + 8 <= b_end) {
|
|
176
|
+
b += 4;
|
|
177
|
+
b_vec.u32x4 = vld1q_u32(b);
|
|
178
|
+
b_max = b_vec.u32s[3];
|
|
179
|
+
}
|
|
180
|
+
b_min = b_vec.u32s[0];
|
|
181
|
+
|
|
182
|
+
// Transform match-masks into "ones", accumulate them between the cycles,
|
|
183
|
+
// and merge all together in the end.
|
|
184
|
+
uint32x4_t a_matches = nk_intersect_u32x4_neon_(a_vec.u32x4, b_vec.u32x4);
|
|
185
|
+
c_counts_u32x4 = vaddq_u32(c_counts_u32x4, vandq_u32(a_matches, vdupq_n_u32(1)));
|
|
186
|
+
|
|
187
|
+
uint32x4_t a_inrange_u32x4 = vcleq_u32(a_vec.u32x4, vdupq_n_u32(b_max));
|
|
188
|
+
uint32x4_t b_inrange_u32x4 = vcleq_u32(b_vec.u32x4, vdupq_n_u32(a_max));
|
|
189
|
+
uint8x8_t packed_u8x8 = vmovn_u16(vcombine_u16(vmovn_u32(a_inrange_u32x4), vmovn_u32(b_inrange_u32x4)));
|
|
190
|
+
uint32x2_t clz_u32x2 = vclz_u32(vreinterpret_u32_u8(packed_u8x8));
|
|
191
|
+
a += (32 - vget_lane_u32(clz_u32x2, 0)) / 8;
|
|
192
|
+
b += (32 - vget_lane_u32(clz_u32x2, 1)) / 8;
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
nk_size_t tail_count = 0;
|
|
196
|
+
nk_sparse_intersect_u32_serial(a, b, a_end - a, b_end - b, 0, &tail_count);
|
|
197
|
+
*count = tail_count + (nk_size_t)vaddvq_u32(c_counts_u32x4);
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
NK_INTERNAL uint64x2_t nk_intersect_u64x2_neon_(uint64x2_t a, uint64x2_t b) {
|
|
201
|
+
uint64x2_t b_rot1 = vextq_u64(b, b, 1);
|
|
202
|
+
uint64x2_t matches_rot0 = vceqq_u64(a, b);
|
|
203
|
+
uint64x2_t matches_rot1 = vceqq_u64(a, b_rot1);
|
|
204
|
+
uint64x2_t matches = vorrq_u64(matches_rot0, matches_rot1);
|
|
205
|
+
return matches;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
NK_PUBLIC void nk_sparse_intersect_u64_neon( //
|
|
209
|
+
nk_u64_t const *a, nk_u64_t const *b, //
|
|
210
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
211
|
+
nk_u64_t *result, nk_size_t *count) {
|
|
212
|
+
|
|
213
|
+
// NEON lacks compress-store, so fall back to serial for result output
|
|
214
|
+
if (result) {
|
|
215
|
+
nk_sparse_intersect_u64_serial(a, b, a_length, b_length, result, count);
|
|
216
|
+
return;
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
#if NK_ALLOW_ISA_REDIRECT
|
|
220
|
+
// The baseline implementation for very small arrays (2 registers or less) can be quite simple:
|
|
221
|
+
if (a_length < 8 && b_length < 8) {
|
|
222
|
+
nk_sparse_intersect_u64_serial(a, b, a_length, b_length, result, count);
|
|
223
|
+
return;
|
|
224
|
+
}
|
|
225
|
+
#endif
|
|
226
|
+
|
|
227
|
+
nk_u64_t const *const a_end = a + a_length;
|
|
228
|
+
nk_u64_t const *const b_end = b + b_length;
|
|
229
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
230
|
+
uint64x2_t c_counts_u64x2 = vdupq_n_u64(0);
|
|
231
|
+
|
|
232
|
+
while (a + 2 <= a_end && b + 2 <= b_end) {
|
|
233
|
+
a_vec.u64x2 = vld1q_u64(a);
|
|
234
|
+
b_vec.u64x2 = vld1q_u64(b);
|
|
235
|
+
|
|
236
|
+
// Intersecting registers with `nk_intersect_u64x2_neon_` involves comparisons,
|
|
237
|
+
// so we want to avoid it if the slices don't overlap at all.
|
|
238
|
+
nk_u64_t a_min;
|
|
239
|
+
nk_u64_t a_max = a_vec.u64s[1];
|
|
240
|
+
nk_u64_t b_min = b_vec.u64s[0];
|
|
241
|
+
nk_u64_t b_max = b_vec.u64s[1];
|
|
242
|
+
|
|
243
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
244
|
+
while (a_max < b_min && a + 4 <= a_end) {
|
|
245
|
+
a += 2;
|
|
246
|
+
a_vec.u64x2 = vld1q_u64(a);
|
|
247
|
+
a_max = a_vec.u64s[1];
|
|
248
|
+
}
|
|
249
|
+
a_min = a_vec.u64s[0];
|
|
250
|
+
while (b_max < a_min && b + 4 <= b_end) {
|
|
251
|
+
b += 2;
|
|
252
|
+
b_vec.u64x2 = vld1q_u64(b);
|
|
253
|
+
b_max = b_vec.u64s[1];
|
|
254
|
+
}
|
|
255
|
+
b_min = b_vec.u64s[0];
|
|
256
|
+
|
|
257
|
+
// Now we are likely to have some overlap, so we can intersect the registers
|
|
258
|
+
// Transform match-masks into "ones", accumulate them between the cycles,
|
|
259
|
+
// and merge all together in the end.
|
|
260
|
+
uint64x2_t a_matches = nk_intersect_u64x2_neon_(a_vec.u64x2, b_vec.u64x2);
|
|
261
|
+
c_counts_u64x2 = vaddq_u64(c_counts_u64x2, vandq_u64(a_matches, vdupq_n_u64(1)));
|
|
262
|
+
|
|
263
|
+
uint64x2_t a_inrange_u64x2 = vcleq_u64(a_vec.u64x2, vdupq_n_u64(b_max));
|
|
264
|
+
uint64x2_t b_inrange_u64x2 = vcleq_u64(b_vec.u64x2, vdupq_n_u64(a_max));
|
|
265
|
+
uint16x4_t packed_u16x4 = vmovn_u32(vcombine_u32(vmovn_u64(a_inrange_u64x2), vmovn_u64(b_inrange_u64x2)));
|
|
266
|
+
uint32x2_t clz_u32x2 = vclz_u32(vreinterpret_u32_u16(packed_u16x4));
|
|
267
|
+
a += (32 - vget_lane_u32(clz_u32x2, 0)) / 16;
|
|
268
|
+
b += (32 - vget_lane_u32(clz_u32x2, 1)) / 16;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
nk_size_t tail_count = 0;
|
|
272
|
+
nk_sparse_intersect_u64_serial(a, b, a_end - a, b_end - b, 0, &tail_count);
|
|
273
|
+
*count = tail_count + (nk_size_t)vaddvq_u64(c_counts_u64x2);
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
#if defined(__clang__)
|
|
277
|
+
#pragma clang attribute pop
|
|
278
|
+
#elif defined(__GNUC__)
|
|
279
|
+
#pragma GCC pop_options
|
|
280
|
+
#endif
|
|
281
|
+
|
|
282
|
+
#if defined(__cplusplus)
|
|
283
|
+
} // extern "C"
|
|
284
|
+
#endif
|
|
285
|
+
|
|
286
|
+
#endif // NK_TARGET_NEON
|
|
287
|
+
#endif // NK_TARGET_ARM_
|
|
288
|
+
#endif // NK_SPARSE_NEON_H
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Serial Sparse Vector Operations.
|
|
3
|
+
* @file include/numkong/sparse/serial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/sparse.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPARSE_SERIAL_H
|
|
10
|
+
#define NK_SPARSE_SERIAL_H
|
|
11
|
+
|
|
12
|
+
#include "numkong/types.h"
|
|
13
|
+
#include "numkong/cast/serial.h" // `nk_bf16_to_f32_serial`, `nk_assign_from_to_`
|
|
14
|
+
|
|
15
|
+
#if defined(__cplusplus)
|
|
16
|
+
extern "C" {
|
|
17
|
+
#endif
|
|
18
|
+
|
|
19
|
+
#define nk_define_sparse_intersect_(input_type) \
|
|
20
|
+
NK_PUBLIC nk_size_t nk_sparse_intersect_##input_type##_galloping_search_( \
|
|
21
|
+
nk_##input_type##_t const *array, nk_size_t start, nk_size_t length, nk_##input_type##_t val) { \
|
|
22
|
+
nk_size_t low = start; \
|
|
23
|
+
nk_size_t high = start + 1; \
|
|
24
|
+
while (high < length && array[high] < val) { \
|
|
25
|
+
low = high; \
|
|
26
|
+
high = (2 * high < length) ? 2 * high : length; \
|
|
27
|
+
} \
|
|
28
|
+
while (low < high) { \
|
|
29
|
+
nk_size_t mid = low + (high - low) / 2; \
|
|
30
|
+
if (array[mid] < val) { low = mid + 1; } \
|
|
31
|
+
else { high = mid; } \
|
|
32
|
+
} \
|
|
33
|
+
return low; \
|
|
34
|
+
} \
|
|
35
|
+
NK_PUBLIC nk_size_t nk_sparse_intersect_##input_type##_linear_scan_( \
|
|
36
|
+
nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_size_t a_length, nk_size_t b_length, \
|
|
37
|
+
nk_##input_type##_t *result) { \
|
|
38
|
+
nk_size_t intersection_size = 0; \
|
|
39
|
+
nk_size_t i = 0, j = 0; \
|
|
40
|
+
while (i != a_length && j != b_length) { \
|
|
41
|
+
nk_##input_type##_t ai = a[i]; \
|
|
42
|
+
nk_##input_type##_t bj = b[j]; \
|
|
43
|
+
if (ai == bj) { \
|
|
44
|
+
if (result) result[intersection_size] = ai; \
|
|
45
|
+
intersection_size++; \
|
|
46
|
+
} \
|
|
47
|
+
i += ai <= bj; \
|
|
48
|
+
j += ai >= bj; \
|
|
49
|
+
} \
|
|
50
|
+
return intersection_size; \
|
|
51
|
+
} \
|
|
52
|
+
NK_PUBLIC void nk_sparse_intersect_##input_type##_serial( \
|
|
53
|
+
nk_##input_type##_t const *shorter, nk_##input_type##_t const *longer, nk_size_t shorter_length, \
|
|
54
|
+
nk_size_t longer_length, nk_##input_type##_t *result, nk_size_t *count) { \
|
|
55
|
+
/* Swap arrays if necessary, as we want "longer" to be larger than "shorter" */ \
|
|
56
|
+
if (longer_length < shorter_length) { \
|
|
57
|
+
nk_##input_type##_t const *temp = shorter; \
|
|
58
|
+
shorter = longer; \
|
|
59
|
+
longer = temp; \
|
|
60
|
+
nk_size_t temp_length = shorter_length; \
|
|
61
|
+
shorter_length = longer_length; \
|
|
62
|
+
longer_length = temp_length; \
|
|
63
|
+
} \
|
|
64
|
+
\
|
|
65
|
+
/* Use the accurate implementation if galloping is not beneficial */ \
|
|
66
|
+
if (longer_length < 64 * shorter_length) { \
|
|
67
|
+
*count = nk_sparse_intersect_##input_type##_linear_scan_(shorter, longer, shorter_length, longer_length, \
|
|
68
|
+
result); \
|
|
69
|
+
return; \
|
|
70
|
+
} \
|
|
71
|
+
\
|
|
72
|
+
/* Perform galloping, shrinking the target range */ \
|
|
73
|
+
nk_size_t intersection_size = 0; \
|
|
74
|
+
nk_size_t j = 0; \
|
|
75
|
+
for (nk_size_t i = 0; i < shorter_length; ++i) { \
|
|
76
|
+
nk_##input_type##_t shorter_i = shorter[i]; \
|
|
77
|
+
j = nk_sparse_intersect_##input_type##_galloping_search_(longer, j, longer_length, shorter_i); \
|
|
78
|
+
if (j < longer_length && longer[j] == shorter_i) { \
|
|
79
|
+
if (result) result[intersection_size] = shorter_i; \
|
|
80
|
+
intersection_size++; \
|
|
81
|
+
} \
|
|
82
|
+
} \
|
|
83
|
+
*count = intersection_size; \
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
#define nk_define_sparse_dot_(input_type, weight_type, accumulator_type, load_and_convert) \
|
|
87
|
+
NK_PUBLIC void nk_sparse_dot_##input_type##weight_type##_serial( \
|
|
88
|
+
nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_##weight_type##_t const *a_weights, \
|
|
89
|
+
nk_##weight_type##_t const *b_weights, nk_size_t a_length, nk_size_t b_length, \
|
|
90
|
+
nk_##accumulator_type##_t *product) { \
|
|
91
|
+
nk_##accumulator_type##_t weights_product = 0, awi, bwi; \
|
|
92
|
+
nk_size_t i = 0, j = 0; \
|
|
93
|
+
while (i != a_length && j != b_length) { \
|
|
94
|
+
nk_##input_type##_t ai = a[i]; \
|
|
95
|
+
nk_##input_type##_t bj = b[j]; \
|
|
96
|
+
int matches = ai == bj; \
|
|
97
|
+
load_and_convert(a_weights + i, &awi); \
|
|
98
|
+
load_and_convert(b_weights + j, &bwi); \
|
|
99
|
+
weights_product += matches * awi * bwi; \
|
|
100
|
+
i += ai < bj; \
|
|
101
|
+
j += ai >= bj; \
|
|
102
|
+
} \
|
|
103
|
+
*product = weights_product; \
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
nk_define_sparse_intersect_(u16) // nk_sparse_intersect_u16_serial
|
|
107
|
+
nk_define_sparse_intersect_(u32) // nk_sparse_intersect_u32_serial
|
|
108
|
+
nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
|
|
109
|
+
|
|
110
|
+
nk_define_sparse_dot_(u16, bf16, f32, nk_bf16_to_f32_serial) // nk_sparse_dot_u16bf16_serial
|
|
111
|
+
nk_define_sparse_dot_(u32, f32, f64, nk_assign_from_to_) // nk_sparse_dot_u32f32_serial
|
|
112
|
+
|
|
113
|
+
#if defined(__cplusplus)
|
|
114
|
+
} // extern "C"
|
|
115
|
+
#endif
|
|
116
|
+
|
|
117
|
+
#endif // NK_SPARSE_SERIAL_H
|