numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Set Similarity Measures for NEON.
|
|
3
|
+
* @file include/numkong/set/neon.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/set.h
|
|
8
|
+
*
|
|
9
|
+
* @section set_neon_instructions NEON Set Instructions
|
|
10
|
+
*
|
|
11
|
+
* Key NEON instructions for binary/bitwise operations (Cortex-A76 class):
|
|
12
|
+
*
|
|
13
|
+
* Intrinsic Instruction Latency Throughput
|
|
14
|
+
* vcntq_u8 CNT (V.16B, V.16B) 2cy 2/cy
|
|
15
|
+
* veorq_u8 EOR (V.16B, V.16B, V.16B) 1cy 4/cy
|
|
16
|
+
* vandq_u8 AND (V.16B, V.16B, V.16B) 1cy 4/cy
|
|
17
|
+
* vorrq_u8 ORR (V.16B, V.16B, V.16B) 1cy 4/cy
|
|
18
|
+
* vpaddlq_u8 UADDLP (V.8H, V.16B) 2cy 2/cy
|
|
19
|
+
* vaddvq_u32 ADDV (S, V.4S) 3cy 1/cy
|
|
20
|
+
*
|
|
21
|
+
* According to the available literature, the throughput for those basic integer ops is
|
|
22
|
+
* identical across most Apple, Qualcomm, and AWS Graviton chips. As long as we avoid widening
|
|
23
|
+
* operations and horizontal reductions, we won't face any reasonable bottlenecks.
|
|
24
|
+
*
|
|
25
|
+
* @section set_neon_stateful Stateful Streaming Logic
|
|
26
|
+
*
|
|
27
|
+
* To build memory-optimal tiled algorithms, this file defines:
|
|
28
|
+
*
|
|
29
|
+
* - nk_hamming_u1x128_state_neon_t for streaming Hamming distance
|
|
30
|
+
* - nk_jaccard_u1x128_state_neon_t for streaming Jaccard similarity
|
|
31
|
+
*
|
|
32
|
+
* @code{c}
|
|
33
|
+
* nk_jaccard_u1x128_state_neon_t state_first, state_second, state_third, state_fourth;
|
|
34
|
+
* nk_jaccard_u1x128_init_neon(&state_first);
|
|
35
|
+
* // ... stream through packed binary vectors ...
|
|
36
|
+
* nk_jaccard_u1x128_finalize_neon(&state_first, &state_second, &state_third, &state_fourth,
|
|
37
|
+
* query_popcount, target_popcount_a, target_popcount_b, target_popcount_c, target_popcount_d,
|
|
38
|
+
* total_dimensions, &results);
|
|
39
|
+
* @endcode
|
|
40
|
+
*/
|
|
41
|
+
#ifndef NK_SET_NEON_H
|
|
42
|
+
#define NK_SET_NEON_H
|
|
43
|
+
|
|
44
|
+
#if NK_TARGET_ARM_
|
|
45
|
+
#if NK_TARGET_NEON
|
|
46
|
+
|
|
47
|
+
#include "numkong/types.h" // `nk_u1x8_t`
|
|
48
|
+
#include "numkong/set/serial.h" // `nk_u1x8_popcount_`
|
|
49
|
+
|
|
50
|
+
#if defined(__cplusplus)
|
|
51
|
+
extern "C" {
|
|
52
|
+
#endif
|
|
53
|
+
|
|
54
|
+
#if defined(__clang__)
|
|
55
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
|
|
56
|
+
#elif defined(__GNUC__)
|
|
57
|
+
#pragma GCC push_options
|
|
58
|
+
#pragma GCC target("arch=armv8-a+simd")
|
|
59
|
+
#endif
|
|
60
|
+
|
|
61
|
+
#pragma region - Binary Sets
|
|
62
|
+
|
|
63
|
+
NK_PUBLIC void nk_hamming_u1_neon(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
64
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, NK_BITS_PER_BYTE);
|
|
65
|
+
nk_u32_t differences = 0;
|
|
66
|
+
nk_size_t i = 0;
|
|
67
|
+
// In each 8-bit word we may have up to 8 differences.
|
|
68
|
+
// So for up-to 31 cycles (31 * 16 = 496 word-dimensions = 3968 bits)
|
|
69
|
+
// we can aggregate the differences into a `uint8x16_t` vector,
|
|
70
|
+
// where each component will be up-to 255.
|
|
71
|
+
while (i + 16 <= n_bytes) {
|
|
72
|
+
uint8x16_t popcount_u8x16 = vdupq_n_u8(0);
|
|
73
|
+
for (nk_size_t cycle = 0; cycle < 31 && i + 16 <= n_bytes; ++cycle, i += 16) {
|
|
74
|
+
uint8x16_t a_u8x16 = vld1q_u8(a + i);
|
|
75
|
+
uint8x16_t b_u8x16 = vld1q_u8(b + i);
|
|
76
|
+
uint8x16_t xor_popcount_u8x16 = vcntq_u8(veorq_u8(a_u8x16, b_u8x16));
|
|
77
|
+
popcount_u8x16 = vaddq_u8(popcount_u8x16, xor_popcount_u8x16);
|
|
78
|
+
}
|
|
79
|
+
differences += (nk_u32_t)vaddlvq_u8(popcount_u8x16);
|
|
80
|
+
}
|
|
81
|
+
// Handle the tail
|
|
82
|
+
for (; i != n_bytes; ++i) differences += nk_u1x8_popcount_(a[i] ^ b[i]);
|
|
83
|
+
*result = differences;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
NK_PUBLIC void nk_jaccard_u1_neon(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
87
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, NK_BITS_PER_BYTE);
|
|
88
|
+
nk_u32_t intersection_count = 0, union_count = 0;
|
|
89
|
+
nk_size_t i = 0;
|
|
90
|
+
// In each 8-bit word we may have up to 8 intersections/unions.
|
|
91
|
+
// So for up-to 31 cycles (31 * 16 = 496 word-dimensions = 3968 bits)
|
|
92
|
+
// we can aggregate the intersections/unions into a `uint8x16_t` vector,
|
|
93
|
+
// where each component will be up-to 255.
|
|
94
|
+
while (i + 16 <= n_bytes) {
|
|
95
|
+
uint8x16_t intersection_popcount_u8x16 = vdupq_n_u8(0);
|
|
96
|
+
uint8x16_t union_popcount_u8x16 = vdupq_n_u8(0);
|
|
97
|
+
for (nk_size_t cycle = 0; cycle < 31 && i + 16 <= n_bytes; ++cycle, i += 16) {
|
|
98
|
+
uint8x16_t a_u8x16 = vld1q_u8(a + i);
|
|
99
|
+
uint8x16_t b_u8x16 = vld1q_u8(b + i);
|
|
100
|
+
intersection_popcount_u8x16 = vaddq_u8(intersection_popcount_u8x16, vcntq_u8(vandq_u8(a_u8x16, b_u8x16)));
|
|
101
|
+
union_popcount_u8x16 = vaddq_u8(union_popcount_u8x16, vcntq_u8(vorrq_u8(a_u8x16, b_u8x16)));
|
|
102
|
+
}
|
|
103
|
+
intersection_count += (nk_u32_t)vaddlvq_u8(intersection_popcount_u8x16);
|
|
104
|
+
union_count += (nk_u32_t)vaddlvq_u8(union_popcount_u8x16);
|
|
105
|
+
}
|
|
106
|
+
// Handle the tail
|
|
107
|
+
for (; i != n_bytes; ++i)
|
|
108
|
+
intersection_count += nk_u1x8_popcount_(a[i] & b[i]), union_count += nk_u1x8_popcount_(a[i] | b[i]);
|
|
109
|
+
*result = (union_count != 0) ? 1.0f - (nk_f32_t)intersection_count / (nk_f32_t)union_count : 0.0f;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
#pragma endregion - Binary Sets
|
|
113
|
+
|
|
114
|
+
#pragma region - Integer Sets
|
|
115
|
+
|
|
116
|
+
NK_PUBLIC void nk_jaccard_u32_neon(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
117
|
+
nk_u32_t intersection_count = 0;
|
|
118
|
+
nk_size_t i = 0;
|
|
119
|
+
uint32x4_t intersection_count_u32x4 = vdupq_n_u32(0);
|
|
120
|
+
for (; i + 4 <= n; i += 4) {
|
|
121
|
+
uint32x4_t a_u32x4 = vld1q_u32(a + i);
|
|
122
|
+
uint32x4_t b_u32x4 = vld1q_u32(b + i);
|
|
123
|
+
uint32x4_t equality_mask = vceqq_u32(a_u32x4, b_u32x4);
|
|
124
|
+
intersection_count_u32x4 = vaddq_u32(intersection_count_u32x4, vshrq_n_u32(equality_mask, 31));
|
|
125
|
+
}
|
|
126
|
+
intersection_count += vaddvq_u32(intersection_count_u32x4);
|
|
127
|
+
for (; i != n; ++i) intersection_count += (a[i] == b[i]);
|
|
128
|
+
*result = (n != 0) ? 1.0f - (nk_f32_t)intersection_count / (nk_f32_t)n : 0.0f;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
NK_PUBLIC void nk_hamming_u8_neon(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
132
|
+
nk_size_t i = 0;
|
|
133
|
+
uint32x4_t diff_count_u32x4 = vdupq_n_u32(0);
|
|
134
|
+
// Process 16 bytes at a time using NEON with widening adds to avoid overflow.
|
|
135
|
+
// Uses pairwise widening chain: 16 u8 → 8 u16 → 4 u32 per iteration.
|
|
136
|
+
for (; i + 16 <= n; i += 16) {
|
|
137
|
+
uint8x16_t a_u8x16 = vld1q_u8(a + i);
|
|
138
|
+
uint8x16_t b_u8x16 = vld1q_u8(b + i);
|
|
139
|
+
// vceqq_u8 returns 0xFF for equal, 0x00 for not-equal
|
|
140
|
+
// Invert to get 0xFF for not-equal, then shift right by 7 to get 1
|
|
141
|
+
uint8x16_t not_equal_u8x16 = vmvnq_u8(vceqq_u8(a_u8x16, b_u8x16));
|
|
142
|
+
uint8x16_t diff_bits_u8x16 = vshrq_n_u8(not_equal_u8x16, 7);
|
|
143
|
+
// Widen: 16 u8 → 8 u16 → 4 u32 using pairwise add and widen
|
|
144
|
+
uint16x8_t diff_u16x8 = vpaddlq_u8(diff_bits_u8x16);
|
|
145
|
+
uint32x4_t diff_u32x4 = vpaddlq_u16(diff_u16x8);
|
|
146
|
+
diff_count_u32x4 = vaddq_u32(diff_count_u32x4, diff_u32x4);
|
|
147
|
+
}
|
|
148
|
+
nk_u32_t differences = vaddvq_u32(diff_count_u32x4);
|
|
149
|
+
// Handle tail elements
|
|
150
|
+
for (; i != n; ++i) differences += (a[i] != b[i]);
|
|
151
|
+
*result = differences;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
NK_PUBLIC void nk_jaccard_u16_neon(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
155
|
+
nk_u32_t matches = 0;
|
|
156
|
+
nk_size_t i = 0;
|
|
157
|
+
uint32x4_t match_count_u32x4 = vdupq_n_u32(0);
|
|
158
|
+
// Process 8 u16 values at a time using NEON
|
|
159
|
+
for (; i + 8 <= n; i += 8) {
|
|
160
|
+
uint16x8_t a_u16x8 = vld1q_u16(a + i);
|
|
161
|
+
uint16x8_t b_u16x8 = vld1q_u16(b + i);
|
|
162
|
+
// vceqq_u16 returns 0xFFFF for equal, 0x0000 for not-equal
|
|
163
|
+
uint16x8_t equality_mask = vceqq_u16(a_u16x8, b_u16x8);
|
|
164
|
+
// Count matches by shifting right by 15 to get 1 for match, 0 for non-match
|
|
165
|
+
// Then widen and accumulate into u32
|
|
166
|
+
uint16x8_t match_bits = vshrq_n_u16(equality_mask, 15);
|
|
167
|
+
// Pairwise add and widen to u32
|
|
168
|
+
uint32x4_t match_u32x4 = vpaddlq_u16(match_bits);
|
|
169
|
+
match_count_u32x4 = vaddq_u32(match_count_u32x4, match_u32x4);
|
|
170
|
+
}
|
|
171
|
+
matches += vaddvq_u32(match_count_u32x4);
|
|
172
|
+
// Handle tail elements
|
|
173
|
+
for (; i != n; ++i) matches += (a[i] == b[i]);
|
|
174
|
+
*result = (n != 0) ? 1.0f - (nk_f32_t)matches / (nk_f32_t)n : 0.0f;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
#pragma endregion - Integer Sets
|
|
178
|
+
|
|
179
|
+
#pragma region - Stateful Streaming
|
|
180
|
+
|
|
181
|
+
typedef struct nk_hamming_u1x128_state_neon_t {
|
|
182
|
+
uint32x4_t intersection_count_u32x4;
|
|
183
|
+
} nk_hamming_u1x128_state_neon_t;
|
|
184
|
+
|
|
185
|
+
NK_INTERNAL void nk_hamming_u1x128_init_neon(nk_hamming_u1x128_state_neon_t *state) {
|
|
186
|
+
state->intersection_count_u32x4 = vdupq_n_u32(0);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
NK_INTERNAL void nk_hamming_u1x128_update_neon(nk_hamming_u1x128_state_neon_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
190
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
191
|
+
nk_unused_(depth_offset);
|
|
192
|
+
nk_unused_(active_dimensions);
|
|
193
|
+
|
|
194
|
+
// Process one 128-bit chunk (native ARM NEON register size).
|
|
195
|
+
// Uses vector accumulation - horizontal sum deferred to finalize.
|
|
196
|
+
//
|
|
197
|
+
// ARM NEON instruction characteristics:
|
|
198
|
+
// - `veorq_u8`: XOR (V.16B, V.16B, V.16B) 1cy
|
|
199
|
+
// - `vcntq_u8`: CNT (V.16B, V.16B) 1-2cy, byte popcount
|
|
200
|
+
// - `vpaddlq_u8`: UADDLP (V.8H, V.16B) 1cy, pairwise widen u8 → u16
|
|
201
|
+
// - `vpaddlq_u16`: UADDLP (V.4S, V.8H) 1cy, pairwise widen u16 → u32
|
|
202
|
+
// - `vaddq_u32`: ADD (V.4S, V.4S, V.4S) 1cy
|
|
203
|
+
// Total: ~5-6cy per 128-bit chunk (horizontal sum deferred to finalize)
|
|
204
|
+
|
|
205
|
+
// Step 1: Compute intersection bits (A XOR B)
|
|
206
|
+
uint8x16_t intersection_u8x16 = veorq_u8(a.u8x16, b.u8x16);
|
|
207
|
+
|
|
208
|
+
// Step 2: Byte-level popcount - each byte contains count of set bits (0-8)
|
|
209
|
+
uint8x16_t popcount_u8x16 = vcntq_u8(intersection_u8x16);
|
|
210
|
+
|
|
211
|
+
// Step 3: Pairwise widening reduction chain
|
|
212
|
+
// u8x16 → u16x8: pairs of adjacent bytes summed into 16-bit
|
|
213
|
+
uint16x8_t popcount_u16x8 = vpaddlq_u8(popcount_u8x16);
|
|
214
|
+
// u16x8 → u32x4: pairs of 16-bit values summed into 32-bit
|
|
215
|
+
uint32x4_t popcount_u32x4 = vpaddlq_u16(popcount_u16x8);
|
|
216
|
+
|
|
217
|
+
// Step 4: Vector accumulation (defers horizontal sum to finalize)
|
|
218
|
+
state->intersection_count_u32x4 = vaddq_u32(state->intersection_count_u32x4, popcount_u32x4);
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
NK_INTERNAL void nk_hamming_u1x128_finalize_neon( //
|
|
222
|
+
nk_hamming_u1x128_state_neon_t const *state_a, nk_hamming_u1x128_state_neon_t const *state_b,
|
|
223
|
+
nk_hamming_u1x128_state_neon_t const *state_c, nk_hamming_u1x128_state_neon_t const *state_d,
|
|
224
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
225
|
+
nk_unused_(total_dimensions);
|
|
226
|
+
|
|
227
|
+
// Horizontal sum using pairwise adds - same pattern as Jaccard.
|
|
228
|
+
// 3× ADDP to reduce 4 state vectors into [sum_a, sum_b, sum_c, sum_d].
|
|
229
|
+
uint32x4_t ab_sum_u32x4 = vpaddq_u32(state_a->intersection_count_u32x4, state_b->intersection_count_u32x4);
|
|
230
|
+
uint32x4_t cd_sum_u32x4 = vpaddq_u32(state_c->intersection_count_u32x4, state_d->intersection_count_u32x4);
|
|
231
|
+
result->u32x4 = vpaddq_u32(ab_sum_u32x4, cd_sum_u32x4);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
typedef struct nk_jaccard_u1x128_state_neon_t {
|
|
235
|
+
uint32x4_t intersection_count_u32x4;
|
|
236
|
+
} nk_jaccard_u1x128_state_neon_t;
|
|
237
|
+
|
|
238
|
+
NK_INTERNAL void nk_jaccard_u1x128_init_neon(nk_jaccard_u1x128_state_neon_t *state) {
|
|
239
|
+
state->intersection_count_u32x4 = vdupq_n_u32(0);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
NK_INTERNAL void nk_jaccard_u1x128_update_neon(nk_jaccard_u1x128_state_neon_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
243
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
244
|
+
nk_unused_(depth_offset);
|
|
245
|
+
nk_unused_(active_dimensions);
|
|
246
|
+
|
|
247
|
+
// Process one 128-bit chunk (native ARM NEON register size).
|
|
248
|
+
// Uses vector accumulation - horizontal sum deferred to finalize.
|
|
249
|
+
//
|
|
250
|
+
// ARM NEON instruction characteristics:
|
|
251
|
+
// - `vandq_u8`: AND (V.16B, V.16B, V.16B) 1cy
|
|
252
|
+
// - `vcntq_u8`: CNT (V.16B, V.16B) 1-2cy, byte popcount
|
|
253
|
+
// - `vpaddlq_u8`: UADDLP (V.8H, V.16B) 1cy, pairwise widen u8 → u16
|
|
254
|
+
// - `vpaddlq_u16`: UADDLP (V.4S, V.8H) 1cy, pairwise widen u16 → u32
|
|
255
|
+
// - `vaddq_u32`: ADD (V.4S, V.4S, V.4S) 1cy
|
|
256
|
+
// Total: ~5-6cy per 128-bit chunk (horizontal sum deferred to finalize)
|
|
257
|
+
|
|
258
|
+
// Step 1: Compute intersection bits (A AND B)
|
|
259
|
+
uint8x16_t intersection_u8x16 = vandq_u8(a.u8x16, b.u8x16);
|
|
260
|
+
|
|
261
|
+
// Step 2: Byte-level popcount - each byte contains count of set bits (0-8)
|
|
262
|
+
uint8x16_t popcount_u8x16 = vcntq_u8(intersection_u8x16);
|
|
263
|
+
|
|
264
|
+
// Step 3: Pairwise widening reduction chain
|
|
265
|
+
// u8x16 → u16x8: pairs of adjacent bytes summed into 16-bit
|
|
266
|
+
uint16x8_t popcount_u16x8 = vpaddlq_u8(popcount_u8x16);
|
|
267
|
+
// u16x8 → u32x4: pairs of 16-bit values summed into 32-bit
|
|
268
|
+
uint32x4_t popcount_u32x4 = vpaddlq_u16(popcount_u16x8);
|
|
269
|
+
|
|
270
|
+
// Step 4: Vector accumulation (defers horizontal sum to finalize)
|
|
271
|
+
state->intersection_count_u32x4 = vaddq_u32(state->intersection_count_u32x4, popcount_u32x4);
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
NK_INTERNAL void nk_jaccard_u1x128_finalize_neon( //
|
|
275
|
+
nk_jaccard_u1x128_state_neon_t const *state_a, nk_jaccard_u1x128_state_neon_t const *state_b,
|
|
276
|
+
nk_jaccard_u1x128_state_neon_t const *state_c, nk_jaccard_u1x128_state_neon_t const *state_d,
|
|
277
|
+
nk_f32_t query_popcount, nk_f32_t target_popcount_a, nk_f32_t target_popcount_b, nk_f32_t target_popcount_c,
|
|
278
|
+
nk_f32_t target_popcount_d, nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
279
|
+
nk_unused_(total_dimensions);
|
|
280
|
+
|
|
281
|
+
// Horizontal sum using pairwise adds instead of `vaddvq_u32` (ADDV).
|
|
282
|
+
// `vpaddq_u32` (ADDP) has better throughput: 2/cy vs 1/cy for ADDV on Cortex-A76.
|
|
283
|
+
// 3 ADDP instructions vs 4 ADDV + union store/load.
|
|
284
|
+
//
|
|
285
|
+
// Step 1: vpaddq_u32(A, B) = [a0+a1, a2+a3, b0+b1, b2+b3]
|
|
286
|
+
uint32x4_t ab_sum_u32x4 = vpaddq_u32(state_a->intersection_count_u32x4, state_b->intersection_count_u32x4);
|
|
287
|
+
uint32x4_t cd_sum_u32x4 = vpaddq_u32(state_c->intersection_count_u32x4, state_d->intersection_count_u32x4);
|
|
288
|
+
// Step 2: Final pairwise reduction gives [sum_a, sum_b, sum_c, sum_d]
|
|
289
|
+
uint32x4_t intersection_u32x4 = vpaddq_u32(ab_sum_u32x4, cd_sum_u32x4);
|
|
290
|
+
float32x4_t intersection_f32x4 = vcvtq_f32_u32(intersection_u32x4);
|
|
291
|
+
|
|
292
|
+
// Compute union using |A ∪ B| = |A| + |B| - |A ∩ B|
|
|
293
|
+
// Build target popcounts vector using lane insertion (avoids union store/load round-trip).
|
|
294
|
+
float32x4_t query_f32x4 = vdupq_n_f32(query_popcount);
|
|
295
|
+
float32x4_t targets_f32x4 = vdupq_n_f32(target_popcount_a);
|
|
296
|
+
targets_f32x4 = vsetq_lane_f32(target_popcount_b, targets_f32x4, 1);
|
|
297
|
+
targets_f32x4 = vsetq_lane_f32(target_popcount_c, targets_f32x4, 2);
|
|
298
|
+
targets_f32x4 = vsetq_lane_f32(target_popcount_d, targets_f32x4, 3);
|
|
299
|
+
float32x4_t union_f32x4 = vsubq_f32(vaddq_f32(query_f32x4, targets_f32x4), intersection_f32x4);
|
|
300
|
+
|
|
301
|
+
// Handle zero-union edge case (empty vectors → distance = 0.0, matching scipy convention)
|
|
302
|
+
float32x4_t one_f32x4 = vdupq_n_f32(1.0f);
|
|
303
|
+
uint32x4_t zero_union_mask = vceqq_f32(union_f32x4, vdupq_n_f32(0.0f));
|
|
304
|
+
float32x4_t safe_union_f32x4 = vbslq_f32(zero_union_mask, one_f32x4, union_f32x4);
|
|
305
|
+
|
|
306
|
+
// Fast reciprocal with Newton-Raphson refinement:
|
|
307
|
+
// - `vrecpeq_f32`: ~12-bit estimate, 1 cycle
|
|
308
|
+
// - `vrecpsq_f32`: Newton-Raphson step computes (2 - a × b), 1 cycle
|
|
309
|
+
// - `vmulq_f32`: multiply, 1 cycle
|
|
310
|
+
// One N-R iteration: ~24-bit accuracy, sufficient for f32 (23 mantissa bits).
|
|
311
|
+
// Total: ~3-4 cycles vs ~10-14 cycles for division.
|
|
312
|
+
float32x4_t union_reciprocal_f32x4 = vrecpeq_f32(safe_union_f32x4);
|
|
313
|
+
union_reciprocal_f32x4 = vmulq_f32(union_reciprocal_f32x4, vrecpsq_f32(safe_union_f32x4, union_reciprocal_f32x4));
|
|
314
|
+
|
|
315
|
+
// Compute Jaccard distance = 1 - intersection ÷ union
|
|
316
|
+
float32x4_t ratio_f32x4 = vmulq_f32(intersection_f32x4, union_reciprocal_f32x4);
|
|
317
|
+
float32x4_t jaccard_f32x4 = vsubq_f32(one_f32x4, ratio_f32x4);
|
|
318
|
+
result->f32x4 = vbslq_f32(zero_union_mask, vdupq_n_f32(0.0f), jaccard_f32x4);
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
/** @brief Hamming from_dot: computes pop_a + pop_b - 2*dot for 4 pairs (NEON). */
|
|
322
|
+
NK_INTERNAL void nk_hamming_u32x4_from_dot_neon_(nk_b128_vec_t dots, nk_u32_t query_pop, nk_b128_vec_t target_pops,
|
|
323
|
+
nk_b128_vec_t *results) {
|
|
324
|
+
uint32x4_t dots_u32x4 = dots.u32x4;
|
|
325
|
+
uint32x4_t query_u32x4 = vdupq_n_u32(query_pop);
|
|
326
|
+
uint32x4_t target_u32x4 = target_pops.u32x4;
|
|
327
|
+
results->u32x4 = vsubq_u32(vaddq_u32(query_u32x4, target_u32x4), vshlq_n_u32(dots_u32x4, 1));
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
/** @brief Jaccard from_dot: computes 1 - dot / (pop_a + pop_b - dot) for 4 pairs (NEON). */
|
|
331
|
+
NK_INTERNAL void nk_jaccard_f32x4_from_dot_neon_(nk_b128_vec_t dots, nk_u32_t query_pop, nk_b128_vec_t target_pops,
|
|
332
|
+
nk_b128_vec_t *results) {
|
|
333
|
+
float32x4_t dot_f32x4 = vcvtq_f32_u32(dots.u32x4);
|
|
334
|
+
float32x4_t query_f32x4 = vdupq_n_f32((nk_f32_t)query_pop);
|
|
335
|
+
float32x4_t target_f32x4 = vcvtq_f32_u32(target_pops.u32x4);
|
|
336
|
+
float32x4_t union_f32x4 = vsubq_f32(vaddq_f32(query_f32x4, target_f32x4), dot_f32x4);
|
|
337
|
+
|
|
338
|
+
float32x4_t one_f32x4 = vdupq_n_f32(1.0f);
|
|
339
|
+
uint32x4_t zero_union_mask = vceqq_f32(union_f32x4, vdupq_n_f32(0.0f));
|
|
340
|
+
float32x4_t safe_union_f32x4 = vbslq_f32(zero_union_mask, one_f32x4, union_f32x4);
|
|
341
|
+
|
|
342
|
+
float32x4_t union_reciprocal_f32x4 = vrecpeq_f32(safe_union_f32x4);
|
|
343
|
+
union_reciprocal_f32x4 = vmulq_f32(union_reciprocal_f32x4, vrecpsq_f32(safe_union_f32x4, union_reciprocal_f32x4));
|
|
344
|
+
|
|
345
|
+
float32x4_t ratio_f32x4 = vmulq_f32(dot_f32x4, union_reciprocal_f32x4);
|
|
346
|
+
float32x4_t jaccard_f32x4 = vsubq_f32(one_f32x4, ratio_f32x4);
|
|
347
|
+
results->f32x4 = vbslq_f32(zero_union_mask, vdupq_n_f32(0.0f), jaccard_f32x4);
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
#pragma endregion - Stateful Streaming
|
|
351
|
+
|
|
352
|
+
#if defined(__clang__)
|
|
353
|
+
#pragma clang attribute pop
|
|
354
|
+
#elif defined(__GNUC__)
|
|
355
|
+
#pragma GCC pop_options
|
|
356
|
+
#endif
|
|
357
|
+
|
|
358
|
+
#if defined(__cplusplus)
|
|
359
|
+
} // extern "C"
|
|
360
|
+
#endif
|
|
361
|
+
|
|
362
|
+
#endif // NK_TARGET_NEON
|
|
363
|
+
#endif // NK_TARGET_ARM_
|
|
364
|
+
#endif // NK_SET_NEON_H
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Set Similarity Measures for RISC-V.
|
|
3
|
+
* @file include/numkong/set/rvv.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 13, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/set.h
|
|
8
|
+
*
|
|
9
|
+
* SpacemiT K1 and similar chips implement RVA22 profile with base RVV 1.0.
|
|
10
|
+
* This does NOT include the Zvbb extension, so we lack native element-wise popcount (`vcpop.v`).
|
|
11
|
+
*
|
|
12
|
+
* @section rvv_popcount_lut Popcount via vrgather LUT
|
|
13
|
+
*
|
|
14
|
+
* We implement popcount using a 16-entry nibble lookup table with `vrgather`:
|
|
15
|
+
* - Split each byte into high and low nibbles
|
|
16
|
+
* - Use vrgather to look up popcount of each nibble (0-4)
|
|
17
|
+
* - Sum the results (0-8 per byte)
|
|
18
|
+
*
|
|
19
|
+
* This approach is efficient on SpacemiT X60 cores which have optimized vrgather
|
|
20
|
+
* for small indices (LMUL=1 with indices 0-15).
|
|
21
|
+
*
|
|
22
|
+
* @section set_rvv_instructions Key RVV Set Instructions
|
|
23
|
+
*
|
|
24
|
+
* Intrinsic Purpose
|
|
25
|
+
* vxor_vv_u8m1 XOR for Hamming difference
|
|
26
|
+
* vand_vv_u8m1 AND for Jaccard intersection
|
|
27
|
+
* vor_vv_u8m1 OR for Jaccard union
|
|
28
|
+
* vrgather_vv_u8m1 LUT lookup (16-entry nibble table)
|
|
29
|
+
* vsrl_vx_u8m1 Right shift to extract high nibble
|
|
30
|
+
* vwaddu_vx_u16m2 Widen u8 → u16 for accumulation
|
|
31
|
+
* vwredsumu_vs_u16m2_u32m1 Widening reduction sum
|
|
32
|
+
*/
|
|
33
|
+
#ifndef NK_SET_RVV_H
|
|
34
|
+
#define NK_SET_RVV_H
|
|
35
|
+
|
|
36
|
+
#if NK_TARGET_RISCV_
|
|
37
|
+
#if NK_TARGET_RVV
|
|
38
|
+
|
|
39
|
+
#include "numkong/types.h"
|
|
40
|
+
#include "numkong/set/serial.h" // `nk_u1x8_popcount_`
|
|
41
|
+
|
|
42
|
+
#if defined(__clang__)
|
|
43
|
+
#pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
|
|
44
|
+
#elif defined(__GNUC__)
|
|
45
|
+
#pragma GCC push_options
|
|
46
|
+
#pragma GCC target("arch=+v")
|
|
47
|
+
#endif
|
|
48
|
+
|
|
49
|
+
#if defined(__cplusplus)
|
|
50
|
+
extern "C" {
|
|
51
|
+
#endif
|
|
52
|
+
|
|
53
|
+
#pragma region - Binary Sets
|
|
54
|
+
|
|
55
|
+
/**
|
|
56
|
+
* @brief Compute byte-level popcount using arithmetic SWAR.
|
|
57
|
+
*
|
|
58
|
+
* Uses parallel bit counting (Hamming weight) — no vrgather, so scales
|
|
59
|
+
* linearly with LMUL unlike the nibble-LUT approach.
|
|
60
|
+
* Cost: 11 ALU instructions regardless of LMUL.
|
|
61
|
+
*
|
|
62
|
+
* @param[in] v_u8m4 Input vector of bytes
|
|
63
|
+
* @param[in] vector_length Vector length
|
|
64
|
+
* @return Vector where each byte contains its popcount (0-8)
|
|
65
|
+
*/
|
|
66
|
+
NK_INTERNAL vuint8m4_t nk_popcount_u8m4_rvv_(vuint8m4_t v_u8m4, nk_size_t vector_length) {
|
|
67
|
+
// Step 1: count pairs — v = (v & 0x55) + ((v >> 1) & 0x55)
|
|
68
|
+
vuint8m4_t t_u8m4 = __riscv_vsrl_vx_u8m4(v_u8m4, 1, vector_length);
|
|
69
|
+
t_u8m4 = __riscv_vand_vx_u8m4(t_u8m4, 0x55, vector_length);
|
|
70
|
+
v_u8m4 = __riscv_vand_vx_u8m4(v_u8m4, 0x55, vector_length);
|
|
71
|
+
v_u8m4 = __riscv_vadd_vv_u8m4(v_u8m4, t_u8m4, vector_length);
|
|
72
|
+
// Step 2: count nibbles — v = (v & 0x33) + ((v >> 2) & 0x33)
|
|
73
|
+
t_u8m4 = __riscv_vsrl_vx_u8m4(v_u8m4, 2, vector_length);
|
|
74
|
+
t_u8m4 = __riscv_vand_vx_u8m4(t_u8m4, 0x33, vector_length);
|
|
75
|
+
v_u8m4 = __riscv_vand_vx_u8m4(v_u8m4, 0x33, vector_length);
|
|
76
|
+
v_u8m4 = __riscv_vadd_vv_u8m4(v_u8m4, t_u8m4, vector_length);
|
|
77
|
+
// Step 3: count bytes — v = (v + (v >> 4)) & 0x0F
|
|
78
|
+
t_u8m4 = __riscv_vsrl_vx_u8m4(v_u8m4, 4, vector_length);
|
|
79
|
+
v_u8m4 = __riscv_vadd_vv_u8m4(v_u8m4, t_u8m4, vector_length);
|
|
80
|
+
return __riscv_vand_vx_u8m4(v_u8m4, 0x0F, vector_length);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
NK_PUBLIC void nk_hamming_u1_rvv(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
84
|
+
nk_size_t count_bytes = nk_size_divide_round_up_(n, NK_BITS_PER_BYTE);
|
|
85
|
+
|
|
86
|
+
// Accumulator for total differences
|
|
87
|
+
vuint32m1_t sum_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
88
|
+
|
|
89
|
+
nk_size_t i = 0;
|
|
90
|
+
for (nk_size_t vector_length; i + 1 <= count_bytes; i += vector_length) {
|
|
91
|
+
vector_length = __riscv_vsetvl_e8m4(count_bytes - i);
|
|
92
|
+
|
|
93
|
+
// Load and XOR to find differing bits
|
|
94
|
+
vuint8m4_t a_u8m4 = __riscv_vle8_v_u8m4(a + i, vector_length);
|
|
95
|
+
vuint8m4_t b_u8m4 = __riscv_vle8_v_u8m4(b + i, vector_length);
|
|
96
|
+
vuint8m4_t xor_u8m4 = __riscv_vxor_vv_u8m4(a_u8m4, b_u8m4, vector_length);
|
|
97
|
+
|
|
98
|
+
// Popcount each byte (0-8 per byte) using arithmetic SWAR
|
|
99
|
+
vuint8m4_t popcount_u8m4 = nk_popcount_u8m4_rvv_(xor_u8m4, vector_length);
|
|
100
|
+
|
|
101
|
+
// Widen to u16 and accumulate via widening reduction sum
|
|
102
|
+
vuint16m8_t popcount_u16m8 = __riscv_vwaddu_vx_u16m8(popcount_u8m4, 0, vector_length);
|
|
103
|
+
sum_u32m1 = __riscv_vwredsumu_vs_u16m8_u32m1(popcount_u16m8, sum_u32m1, vector_length);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
*result = __riscv_vmv_x_s_u32m1_u32(sum_u32m1);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
NK_PUBLIC void nk_jaccard_u1_rvv(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
110
|
+
nk_size_t count_bytes = nk_size_divide_round_up_(n, NK_BITS_PER_BYTE);
|
|
111
|
+
|
|
112
|
+
// Accumulators for intersection and union counts
|
|
113
|
+
vuint32m1_t intersection_sum_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
114
|
+
vuint32m1_t union_sum_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
115
|
+
|
|
116
|
+
nk_size_t i = 0;
|
|
117
|
+
for (nk_size_t vector_length; i + 1 <= count_bytes; i += vector_length) {
|
|
118
|
+
vector_length = __riscv_vsetvl_e8m4(count_bytes - i);
|
|
119
|
+
|
|
120
|
+
// Load vectors
|
|
121
|
+
vuint8m4_t a_u8m4 = __riscv_vle8_v_u8m4(a + i, vector_length);
|
|
122
|
+
vuint8m4_t b_u8m4 = __riscv_vle8_v_u8m4(b + i, vector_length);
|
|
123
|
+
|
|
124
|
+
// Compute intersection (AND) and union (OR)
|
|
125
|
+
vuint8m4_t intersection_u8m4 = __riscv_vand_vv_u8m4(a_u8m4, b_u8m4, vector_length);
|
|
126
|
+
vuint8m4_t union_u8m4 = __riscv_vor_vv_u8m4(a_u8m4, b_u8m4, vector_length);
|
|
127
|
+
|
|
128
|
+
// Popcount each using arithmetic SWAR
|
|
129
|
+
vuint8m4_t intersection_popcount_u8m4 = nk_popcount_u8m4_rvv_(intersection_u8m4, vector_length);
|
|
130
|
+
vuint8m4_t union_popcount_u8m4 = nk_popcount_u8m4_rvv_(union_u8m4, vector_length);
|
|
131
|
+
|
|
132
|
+
// Widen and accumulate
|
|
133
|
+
vuint16m8_t intersection_popcount_u16m8 = __riscv_vwaddu_vx_u16m8(intersection_popcount_u8m4, 0, vector_length);
|
|
134
|
+
vuint16m8_t union_popcount_u16m8 = __riscv_vwaddu_vx_u16m8(union_popcount_u8m4, 0, vector_length);
|
|
135
|
+
intersection_sum_u32m1 = __riscv_vwredsumu_vs_u16m8_u32m1(intersection_popcount_u16m8, intersection_sum_u32m1,
|
|
136
|
+
vector_length);
|
|
137
|
+
union_sum_u32m1 = __riscv_vwredsumu_vs_u16m8_u32m1(union_popcount_u16m8, union_sum_u32m1, vector_length);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
nk_u32_t intersection_count_u32 = __riscv_vmv_x_s_u32m1_u32(intersection_sum_u32m1);
|
|
141
|
+
nk_u32_t union_count_u32 = __riscv_vmv_x_s_u32m1_u32(union_sum_u32m1);
|
|
142
|
+
*result = (union_count_u32 != 0) ? 1.0f - (nk_f32_t)intersection_count_u32 / (nk_f32_t)union_count_u32 : 0.0f;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
#pragma endregion - Binary Sets
|
|
146
|
+
|
|
147
|
+
#pragma region - Integer Sets
|
|
148
|
+
|
|
149
|
+
NK_PUBLIC void nk_hamming_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
150
|
+
vuint32m1_t difference_count_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
|
|
151
|
+
|
|
152
|
+
nk_size_t i = 0;
|
|
153
|
+
for (nk_size_t vector_length; i + 1 <= n; i += vector_length) {
|
|
154
|
+
vector_length = __riscv_vsetvl_e8m1(n - i);
|
|
155
|
+
|
|
156
|
+
vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1(a + i, vector_length);
|
|
157
|
+
vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1(b + i, vector_length);
|
|
158
|
+
|
|
159
|
+
// Compare: mask where a != b
|
|
160
|
+
vbool8_t not_equal_mask_b8 = __riscv_vmsne_vv_u8m1_b8(a_u8m1, b_u8m1, vector_length);
|
|
161
|
+
|
|
162
|
+
// Count set bits in mask via vcpop.m (this IS available in base RVV 1.0)
|
|
163
|
+
nk_u32_t difference_count_u32 = __riscv_vcpop_m_b8(not_equal_mask_b8, vector_length);
|
|
164
|
+
|
|
165
|
+
// Accumulate (scalar addition is fine here, vcpop already reduced)
|
|
166
|
+
difference_count_u32m1 = __riscv_vadd_vx_u32m1(difference_count_u32m1, difference_count_u32, 1);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
*result = __riscv_vmv_x_s_u32m1_u32(difference_count_u32m1);
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
NK_PUBLIC void nk_jaccard_u32_rvv(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
173
|
+
nk_u32_t match_count_u32 = 0;
|
|
174
|
+
|
|
175
|
+
nk_size_t i = 0;
|
|
176
|
+
for (nk_size_t vector_length; i + 1 <= n; i += vector_length) {
|
|
177
|
+
vector_length = __riscv_vsetvl_e32m1(n - i);
|
|
178
|
+
|
|
179
|
+
vuint32m1_t a_u32m1 = __riscv_vle32_v_u32m1(a + i, vector_length);
|
|
180
|
+
vuint32m1_t b_u32m1 = __riscv_vle32_v_u32m1(b + i, vector_length);
|
|
181
|
+
|
|
182
|
+
// Compare: mask where a == b
|
|
183
|
+
vbool32_t equal_mask_b32 = __riscv_vmseq_vv_u32m1_b32(a_u32m1, b_u32m1, vector_length);
|
|
184
|
+
|
|
185
|
+
// Count matches via vcpop.m
|
|
186
|
+
match_count_u32 += __riscv_vcpop_m_b32(equal_mask_b32, vector_length);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
*result = (n != 0) ? 1.0f - (nk_f32_t)match_count_u32 / (nk_f32_t)n : 0.0f;
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
NK_PUBLIC void nk_jaccard_u16_rvv(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
193
|
+
nk_u32_t match_count_u32 = 0;
|
|
194
|
+
|
|
195
|
+
nk_size_t i = 0;
|
|
196
|
+
for (nk_size_t vector_length; i + 1 <= n; i += vector_length) {
|
|
197
|
+
vector_length = __riscv_vsetvl_e16m1(n - i);
|
|
198
|
+
|
|
199
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1(a + i, vector_length);
|
|
200
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1(b + i, vector_length);
|
|
201
|
+
|
|
202
|
+
// Compare: mask where a == b
|
|
203
|
+
vbool16_t equal_mask_b16 = __riscv_vmseq_vv_u16m1_b16(a_u16m1, b_u16m1, vector_length);
|
|
204
|
+
|
|
205
|
+
// Count matches via vcpop.m
|
|
206
|
+
match_count_u32 += __riscv_vcpop_m_b16(equal_mask_b16, vector_length);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
*result = (n != 0) ? 1.0f - (nk_f32_t)match_count_u32 / (nk_f32_t)n : 0.0f;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
#pragma endregion - Integer Sets
|
|
213
|
+
|
|
214
|
+
#if defined(__cplusplus)
|
|
215
|
+
} // extern "C"
|
|
216
|
+
#endif
|
|
217
|
+
|
|
218
|
+
#if defined(__clang__)
|
|
219
|
+
#pragma clang attribute pop
|
|
220
|
+
#elif defined(__GNUC__)
|
|
221
|
+
#pragma GCC pop_options
|
|
222
|
+
#endif
|
|
223
|
+
|
|
224
|
+
#endif // NK_TARGET_RVV
|
|
225
|
+
#endif // NK_TARGET_RISCV_
|
|
226
|
+
#endif // NK_SET_RVV_H
|