numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,507 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SVE2-accelerated Sparse Vector Operations.
|
|
3
|
+
* @file include/numkong/sparse/sve2.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/sparse.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPARSE_SVE2_H
|
|
10
|
+
#define NK_SPARSE_SVE2_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_ARM_
|
|
13
|
+
|
|
14
|
+
#include "numkong/types.h"
|
|
15
|
+
|
|
16
|
+
#if defined(__cplusplus)
|
|
17
|
+
extern "C" {
|
|
18
|
+
#endif
|
|
19
|
+
|
|
20
|
+
/* SVE2 introduces many new integer-oriented instructions, extending some of the NEON functionality
|
|
21
|
+
* to variable-length SVE registers. Those include "compare multiple" intrinsics:
|
|
22
|
+
*
|
|
23
|
+
* - `svmatch[_u16]` that matches each scalar in first vector against all members of a 128-bit lane in the second.
|
|
24
|
+
* - `svhistcnt[_s32]_z` does something similar, performing an inclusive prefix scan.
|
|
25
|
+
* - `svtbx[_u16]` does extended table lookup
|
|
26
|
+
*
|
|
27
|
+
* Other notable instructions:
|
|
28
|
+
*
|
|
29
|
+
* - `DUP`: Broadcast indexed predicate element
|
|
30
|
+
* https://developer.arm.com/documentation/ddi0602/2021-06/SVE-Instructions/DUP--predicate---Broadcast-indexed-predicate-element-?lang=en
|
|
31
|
+
* - `SCLAMP` and `UCLAMP`: clamp values, i.e. combined min+max
|
|
32
|
+
* https://developer.arm.com/documentation/ddi0602/2021-06/SVE-Instructions/SCLAMP--Signed-clamp-to-minimum-maximum-vector-?lang=en
|
|
33
|
+
* https://developer.arm.com/documentation/ddi0602/2021-06/SVE-Instructions/UCLAMP--Unsigned-clamp-to-minimum-maximum-vector-?lang=en
|
|
34
|
+
* - `TBLQ`: Table lookup quadword
|
|
35
|
+
* https://developer.arm.com/documentation/ddi0602/2022-12/SVE-Instructions/TBLQ--Programmable-table-lookup-within-each-quadword-vector-segment--zeroing--?lang=en
|
|
36
|
+
*
|
|
37
|
+
* Great resources for SVE2 intrinsics:
|
|
38
|
+
*
|
|
39
|
+
* > ARM's Scalable Vector Extensions: A Critical Look at SVE2 For Integer Workloads
|
|
40
|
+
* https://gist.github.com/zingaburga/805669eb891c820bd220418ee3f0d6bd
|
|
41
|
+
*/
|
|
42
|
+
#if NK_TARGET_SVE2
|
|
43
|
+
#if defined(__clang__)
|
|
44
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+sve2"))), apply_to = function)
|
|
45
|
+
#elif defined(__GNUC__)
|
|
46
|
+
#pragma GCC push_options
|
|
47
|
+
#pragma GCC target("arch=armv8.2-a+sve+sve2")
|
|
48
|
+
#endif
|
|
49
|
+
|
|
50
|
+
NK_PUBLIC void nk_sparse_intersect_u16_sve2( //
|
|
51
|
+
nk_u16_t const *a, nk_u16_t const *b, //
|
|
52
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
53
|
+
nk_u16_t *result, nk_size_t *count) {
|
|
54
|
+
|
|
55
|
+
// A single SVE lane is 128 bits wide, so one lane fits 8 values.
|
|
56
|
+
nk_size_t const register_size = svcnth();
|
|
57
|
+
nk_size_t const lanes_count = register_size / 8;
|
|
58
|
+
nk_size_t a_idx = 0, b_idx = 0;
|
|
59
|
+
nk_size_t c = 0;
|
|
60
|
+
|
|
61
|
+
while (a_idx < a_length && b_idx < b_length) {
|
|
62
|
+
// Load `a_member` and broadcast it, load `b_members_vec` from memory
|
|
63
|
+
svbool_t a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
|
|
64
|
+
svbool_t b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
|
|
65
|
+
svuint16_t a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
|
|
66
|
+
svuint16_t b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
|
|
67
|
+
|
|
68
|
+
// Intersecting registers with `svmatch_u16` involves a lot of shuffling
|
|
69
|
+
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
70
|
+
nk_u16_t a_min;
|
|
71
|
+
nk_u16_t a_max = svlastb(a_progress_u16x, a_u16x);
|
|
72
|
+
nk_u16_t b_min = svlasta(svpfalse_b(), b_u16x);
|
|
73
|
+
nk_u16_t b_max = svlastb(b_progress_u16x, b_u16x);
|
|
74
|
+
|
|
75
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
76
|
+
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
77
|
+
a_idx += register_size;
|
|
78
|
+
a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
|
|
79
|
+
a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
|
|
80
|
+
a_max = svlastb(a_progress_u16x, a_u16x);
|
|
81
|
+
}
|
|
82
|
+
a_min = svlasta(svpfalse_b(), a_u16x);
|
|
83
|
+
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
84
|
+
b_idx += register_size;
|
|
85
|
+
b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
|
|
86
|
+
b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
|
|
87
|
+
b_max = svlastb(b_progress_u16x, b_u16x);
|
|
88
|
+
}
|
|
89
|
+
b_min = svlasta(svpfalse_b(), b_u16x);
|
|
90
|
+
|
|
91
|
+
// Before we evaluate the intersection size, obfurscating the order in `b_u16x`,
|
|
92
|
+
// let's estimate how much we will need to advance the pointers afterwards.
|
|
93
|
+
// For that, we don't even need to broadcast the values in SVE, as the whole
|
|
94
|
+
// register can be compared against a scalar:
|
|
95
|
+
//
|
|
96
|
+
// svuint16_t a_last_broadcasted = svdup_n_u16(a_max);
|
|
97
|
+
// svuint16_t b_last_broadcasted = svdup_n_u16(b_max);
|
|
98
|
+
svbool_t a_mask_u16x = svcmple_n_u16(a_progress_u16x, a_u16x, b_max);
|
|
99
|
+
svbool_t b_mask_u16x = svcmple_n_u16(b_progress_u16x, b_u16x, a_max);
|
|
100
|
+
nk_u64_t a_step = svcntp_b16(a_progress_u16x, a_mask_u16x);
|
|
101
|
+
nk_u64_t b_step = svcntp_b16(b_progress_u16x, b_mask_u16x);
|
|
102
|
+
|
|
103
|
+
// Compare `a_u16x` with each lane of `b_u16x`
|
|
104
|
+
svbool_t equal_mask = svmatch_u16(a_progress_u16x, a_u16x, b_u16x);
|
|
105
|
+
for (nk_size_t i = 1; i < lanes_count; i++) {
|
|
106
|
+
b_u16x = svext_u16(b_u16x, b_u16x, 8);
|
|
107
|
+
equal_mask = svorr_z(svptrue_b16(), equal_mask, svmatch_u16(a_progress_u16x, a_u16x, b_u16x));
|
|
108
|
+
}
|
|
109
|
+
nk_size_t equal_count = svcntp_b16(svptrue_b16(), equal_mask);
|
|
110
|
+
|
|
111
|
+
// Manually compact and store matching elements (svcompact_u16 is not defined)
|
|
112
|
+
if (result) {
|
|
113
|
+
nk_u16_t a_data[16];
|
|
114
|
+
nk_u16_t mask_data[16];
|
|
115
|
+
|
|
116
|
+
svst1_u16(svptrue_b16(), a_data, a_u16x);
|
|
117
|
+
svst1_u16(svptrue_b16(), mask_data, svdup_n_u16_z(equal_mask, 1));
|
|
118
|
+
|
|
119
|
+
for (nk_size_t i = 0; i < svcnth(); i++)
|
|
120
|
+
if (mask_data[i]) result[c++] = a_data[i];
|
|
121
|
+
c -= equal_count;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
// Advance
|
|
125
|
+
a_idx += a_step;
|
|
126
|
+
b_idx += b_step;
|
|
127
|
+
c += equal_count;
|
|
128
|
+
}
|
|
129
|
+
*count = c;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
NK_PUBLIC void nk_sparse_intersect_u32_sve2( //
|
|
133
|
+
nk_u32_t const *a, nk_u32_t const *b, //
|
|
134
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
135
|
+
nk_u32_t *result, nk_size_t *count) {
|
|
136
|
+
|
|
137
|
+
// A single SVE lane is 128 bits wide, so one lane fits 4 values.
|
|
138
|
+
nk_size_t const register_size = svcntw();
|
|
139
|
+
nk_size_t const lanes_count = register_size / 4;
|
|
140
|
+
nk_size_t a_idx = 0, b_idx = 0;
|
|
141
|
+
nk_size_t c = 0;
|
|
142
|
+
|
|
143
|
+
while (a_idx < a_length && b_idx < b_length) {
|
|
144
|
+
// Load `a_member` and broadcast it, load `b_members_vec` from memory
|
|
145
|
+
svbool_t a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
|
|
146
|
+
svbool_t b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
|
|
147
|
+
svuint32_t a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
|
|
148
|
+
svuint32_t b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
|
|
149
|
+
|
|
150
|
+
// Intersecting registers with `svmatch_u16` involves a lot of shuffling
|
|
151
|
+
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
152
|
+
nk_u32_t a_min;
|
|
153
|
+
nk_u32_t a_max = svlastb(a_progress_u32x, a_u32x);
|
|
154
|
+
nk_u32_t b_min = svlasta(svpfalse_b(), b_u32x);
|
|
155
|
+
nk_u32_t b_max = svlastb(b_progress_u32x, b_u32x);
|
|
156
|
+
|
|
157
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
158
|
+
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
159
|
+
a_idx += register_size;
|
|
160
|
+
a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
|
|
161
|
+
a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
|
|
162
|
+
a_max = svlastb(a_progress_u32x, a_u32x);
|
|
163
|
+
}
|
|
164
|
+
a_min = svlasta(svpfalse_b(), a_u32x);
|
|
165
|
+
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
166
|
+
b_idx += register_size;
|
|
167
|
+
b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
|
|
168
|
+
b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
|
|
169
|
+
b_max = svlastb(b_progress_u32x, b_u32x);
|
|
170
|
+
}
|
|
171
|
+
b_min = svlasta(svpfalse_b(), b_u32x);
|
|
172
|
+
|
|
173
|
+
// Before we evaluate the intersection size, obfurscating the order in `b_u32x`,
|
|
174
|
+
// let's estimate how much we will need to advance the pointers afterwards.
|
|
175
|
+
// For that, we don't even need to broadcast the values in SVE, as the whole
|
|
176
|
+
// register can be compared against a scalar:
|
|
177
|
+
//
|
|
178
|
+
// svuint32_t a_last_broadcasted = svdup_n_u32(a_max);
|
|
179
|
+
// svuint32_t b_last_broadcasted = svdup_n_u32(b_max);
|
|
180
|
+
svbool_t a_mask_u32x = svcmple_n_u32(a_progress_u32x, a_u32x, b_max);
|
|
181
|
+
svbool_t b_mask_u32x = svcmple_n_u32(b_progress_u32x, b_u32x, a_max);
|
|
182
|
+
nk_u64_t a_step = svcntp_b32(a_progress_u32x, a_mask_u32x);
|
|
183
|
+
nk_u64_t b_step = svcntp_b32(b_progress_u32x, b_mask_u32x);
|
|
184
|
+
|
|
185
|
+
// Comparing `a_u32x` with each lane of `b_u32x` can't be done with `svmatch`,
|
|
186
|
+
// the same way as in `nk_sparse_intersect_u16_sve2`, as that instruction is only
|
|
187
|
+
// available for 8-bit and 16-bit integers.
|
|
188
|
+
//
|
|
189
|
+
// svbool_t equal_mask = svpfalse_b();
|
|
190
|
+
// for (nk_size_t i = 0; i < register_size; i++) {
|
|
191
|
+
// equal_mask = svorr_z(svptrue_b32(), equal_mask, svcmpeq_u32(a_progress, a_u32x, b_u32x));
|
|
192
|
+
// b_u32x = svext_u32(b_u32x, b_u32x, 1);
|
|
193
|
+
// }
|
|
194
|
+
// nk_size_t equal_count = svcntp_b32(a_progress, equal_mask);
|
|
195
|
+
//
|
|
196
|
+
// Alternatively, one can use histogram instructions, like `svhistcnt_u32_z`.
|
|
197
|
+
// They practically compute the prefix-matching count, which is equivalent to
|
|
198
|
+
// the lower triangle of the row-major intersection matrix.
|
|
199
|
+
// To compute the upper triangle, we can reverse (with `svrev_b32`) the order of
|
|
200
|
+
// elements and repeat the operation, accumulating the results for top and bottom.
|
|
201
|
+
// Let's look at 4x element registers as an example:
|
|
202
|
+
//
|
|
203
|
+
// ⊐ α = {A, B, C, D}, β = {X, Y, Z, W}:
|
|
204
|
+
//
|
|
205
|
+
// hist(α, β): hist(α_rev, β_rev):
|
|
206
|
+
//
|
|
207
|
+
// X Y Z W W Z Y X
|
|
208
|
+
// A 1 0 0 0 D 1 0 0 0
|
|
209
|
+
// B 1 1 0 0 C 1 1 0 0
|
|
210
|
+
// C 1 1 1 0 B 1 1 1 0
|
|
211
|
+
// D 1 1 1 1 A 1 1 1 1
|
|
212
|
+
//
|
|
213
|
+
svuint32_t hist_lower = svhistcnt_u32_z(a_progress_u32x, a_u32x, b_u32x);
|
|
214
|
+
svuint32_t a_rev_u32x = svrev_u32(a_u32x);
|
|
215
|
+
svuint32_t b_rev_u32x = svrev_u32(b_u32x);
|
|
216
|
+
svuint32_t hist_upper = svrev_u32(svhistcnt_u32_z(svptrue_b32(), a_rev_u32x, b_rev_u32x));
|
|
217
|
+
svuint32_t hist = svorr_u32_x(a_progress_u32x, hist_lower, hist_upper);
|
|
218
|
+
svbool_t equal_mask = svcmpne_n_u32(a_progress_u32x, hist, 0);
|
|
219
|
+
nk_size_t equal_count = svcntp_b32(a_progress_u32x, equal_mask);
|
|
220
|
+
|
|
221
|
+
// Use SVE2 svcompact to compress matching elements and store to result buffer
|
|
222
|
+
if (result) {
|
|
223
|
+
svuint32_t compacted = svcompact_u32(equal_mask, a_u32x);
|
|
224
|
+
svbool_t store_predicate = svwhilelt_b32_u64(0, equal_count);
|
|
225
|
+
svst1_u32(store_predicate, result + c, compacted);
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
// Advance
|
|
229
|
+
a_idx += a_step;
|
|
230
|
+
b_idx += b_step;
|
|
231
|
+
c += equal_count;
|
|
232
|
+
}
|
|
233
|
+
*count = c;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
NK_PUBLIC void nk_sparse_intersect_u64_sve2( //
|
|
237
|
+
nk_u64_t const *a, nk_u64_t const *b, //
|
|
238
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
239
|
+
nk_u64_t *result, nk_size_t *count) {
|
|
240
|
+
|
|
241
|
+
// A single SVE lane is 128 bits wide, so one lane fits 2 values.
|
|
242
|
+
nk_size_t const register_size = svcntd();
|
|
243
|
+
nk_size_t const lanes_count = register_size / 2;
|
|
244
|
+
nk_size_t a_idx = 0, b_idx = 0;
|
|
245
|
+
nk_size_t c = 0;
|
|
246
|
+
|
|
247
|
+
while (a_idx < a_length && b_idx < b_length) {
|
|
248
|
+
// Load `a_member` and broadcast it, load `b_members_vec` from memory
|
|
249
|
+
svbool_t a_progress_u64x = svwhilelt_b64_u64(a_idx, a_length);
|
|
250
|
+
svbool_t b_progress_u64x = svwhilelt_b64_u64(b_idx, b_length);
|
|
251
|
+
svuint64_t a_u64x = svld1_u64(a_progress_u64x, a + a_idx);
|
|
252
|
+
svuint64_t b_u64x = svld1_u64(b_progress_u64x, b + b_idx);
|
|
253
|
+
|
|
254
|
+
// Intersecting registers involves comparisons,
|
|
255
|
+
// so we want to avoid it if the slices don't overlap at all.
|
|
256
|
+
nk_u64_t a_min;
|
|
257
|
+
nk_u64_t a_max = svlastb(a_progress_u64x, a_u64x);
|
|
258
|
+
nk_u64_t b_min = svlasta(svpfalse_b(), b_u64x);
|
|
259
|
+
nk_u64_t b_max = svlastb(b_progress_u64x, b_u64x);
|
|
260
|
+
|
|
261
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
262
|
+
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
263
|
+
a_idx += register_size;
|
|
264
|
+
a_progress_u64x = svwhilelt_b64_u64(a_idx, a_length);
|
|
265
|
+
a_u64x = svld1_u64(a_progress_u64x, a + a_idx);
|
|
266
|
+
a_max = svlastb(a_progress_u64x, a_u64x);
|
|
267
|
+
}
|
|
268
|
+
a_min = svlasta(svpfalse_b(), a_u64x);
|
|
269
|
+
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
270
|
+
b_idx += register_size;
|
|
271
|
+
b_progress_u64x = svwhilelt_b64_u64(b_idx, b_length);
|
|
272
|
+
b_u64x = svld1_u64(b_progress_u64x, b + b_idx);
|
|
273
|
+
b_max = svlastb(b_progress_u64x, b_u64x);
|
|
274
|
+
}
|
|
275
|
+
b_min = svlasta(svpfalse_b(), b_u64x);
|
|
276
|
+
|
|
277
|
+
// Estimate how much we will need to advance the pointers afterwards.
|
|
278
|
+
svbool_t a_mask_u64x = svcmple_n_u64(a_progress_u64x, a_u64x, b_max);
|
|
279
|
+
svbool_t b_mask_u64x = svcmple_n_u64(b_progress_u64x, b_u64x, a_max);
|
|
280
|
+
nk_u64_t a_step = svcntp_b64(a_progress_u64x, a_mask_u64x);
|
|
281
|
+
nk_u64_t b_step = svcntp_b64(b_progress_u64x, b_mask_u64x);
|
|
282
|
+
|
|
283
|
+
// Use histogram instructions like `svhistcnt_u64_z` to compute intersection.
|
|
284
|
+
// They compute the prefix-matching count, equivalent to the lower triangle
|
|
285
|
+
// of the row-major intersection matrix.
|
|
286
|
+
svuint64_t hist_lower = svhistcnt_u64_z(a_progress_u64x, a_u64x, b_u64x);
|
|
287
|
+
svuint64_t a_rev_u64x = svrev_u64(a_u64x);
|
|
288
|
+
svuint64_t b_rev_u64x = svrev_u64(b_u64x);
|
|
289
|
+
svuint64_t hist_upper = svrev_u64(svhistcnt_u64_z(svptrue_b64(), a_rev_u64x, b_rev_u64x));
|
|
290
|
+
svuint64_t hist = svorr_u64_x(a_progress_u64x, hist_lower, hist_upper);
|
|
291
|
+
svbool_t equal_mask = svcmpne_n_u64(a_progress_u64x, hist, 0);
|
|
292
|
+
nk_size_t equal_count = svcntp_b64(a_progress_u64x, equal_mask);
|
|
293
|
+
|
|
294
|
+
// Use SVE2 svcompact to compress matching elements and store to result buffer
|
|
295
|
+
if (result) {
|
|
296
|
+
svuint64_t compacted = svcompact_u64(equal_mask, a_u64x);
|
|
297
|
+
svbool_t store_predicate = svwhilelt_b64_u64(0, equal_count);
|
|
298
|
+
svst1_u64(store_predicate, result + c, compacted);
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
// Advance
|
|
302
|
+
a_idx += a_step;
|
|
303
|
+
b_idx += b_step;
|
|
304
|
+
c += equal_count;
|
|
305
|
+
}
|
|
306
|
+
*count = c;
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
|
|
310
|
+
nk_u32_t const *a, nk_u32_t const *b, //
|
|
311
|
+
nk_f32_t const *a_weights, nk_f32_t const *b_weights, //
|
|
312
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
313
|
+
nk_f64_t *product) {
|
|
314
|
+
|
|
315
|
+
// A single SVE lane is 128 bits wide, so one lane fits 4 values.
|
|
316
|
+
nk_size_t const register_size = svcntw();
|
|
317
|
+
nk_size_t const vector_length_f64 = svcntd();
|
|
318
|
+
nk_size_t const lanes_count = register_size / 4;
|
|
319
|
+
nk_size_t a_idx = 0, b_idx = 0;
|
|
320
|
+
svbool_t const predicate_all_f32x = svptrue_b32();
|
|
321
|
+
svbool_t const predicate_all_f64x = svptrue_b64();
|
|
322
|
+
svfloat64_t product_f64x = svdup_f64(0.0);
|
|
323
|
+
|
|
324
|
+
while (a_idx < a_length && b_idx < b_length) {
|
|
325
|
+
// Load indices with progress predicates
|
|
326
|
+
svbool_t a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
|
|
327
|
+
svbool_t b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
|
|
328
|
+
svuint32_t a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
|
|
329
|
+
svuint32_t b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
|
|
330
|
+
|
|
331
|
+
// Avoid expensive intersection if slices don't overlap at all
|
|
332
|
+
nk_u32_t a_min;
|
|
333
|
+
nk_u32_t a_max = svlastb(a_progress_u32x, a_u32x);
|
|
334
|
+
nk_u32_t b_min = svlasta(svpfalse_b(), b_u32x);
|
|
335
|
+
nk_u32_t b_max = svlastb(b_progress_u32x, b_u32x);
|
|
336
|
+
|
|
337
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
338
|
+
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
339
|
+
a_idx += register_size;
|
|
340
|
+
a_progress_u32x = svwhilelt_b32_u64(a_idx, a_length);
|
|
341
|
+
a_u32x = svld1_u32(a_progress_u32x, a + a_idx);
|
|
342
|
+
a_max = svlastb(a_progress_u32x, a_u32x);
|
|
343
|
+
}
|
|
344
|
+
a_min = svlasta(svpfalse_b(), a_u32x);
|
|
345
|
+
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
346
|
+
b_idx += register_size;
|
|
347
|
+
b_progress_u32x = svwhilelt_b32_u64(b_idx, b_length);
|
|
348
|
+
b_u32x = svld1_u32(b_progress_u32x, b + b_idx);
|
|
349
|
+
b_max = svlastb(b_progress_u32x, b_u32x);
|
|
350
|
+
}
|
|
351
|
+
b_min = svlasta(svpfalse_b(), b_u32x);
|
|
352
|
+
|
|
353
|
+
// Calculate step sizes before modifying vectors
|
|
354
|
+
svbool_t a_mask_u32x = svcmple_n_u32(a_progress_u32x, a_u32x, b_max);
|
|
355
|
+
svbool_t b_mask_u32x = svcmple_n_u32(b_progress_u32x, b_u32x, a_max);
|
|
356
|
+
nk_u64_t a_step = svcntp_b32(a_progress_u32x, a_mask_u32x);
|
|
357
|
+
nk_u64_t b_step = svcntp_b32(b_progress_u32x, b_mask_u32x);
|
|
358
|
+
|
|
359
|
+
// Use histogram-based intersection (svmatch_u32 doesn't exist)
|
|
360
|
+
svuint32_t hist_lower_u32x = svhistcnt_u32_z(a_progress_u32x, a_u32x, b_u32x);
|
|
361
|
+
svuint32_t a_rev_u32x = svrev_u32(a_u32x);
|
|
362
|
+
svuint32_t b_rev_u32x = svrev_u32(b_u32x);
|
|
363
|
+
svuint32_t hist_upper_u32x = svrev_u32(svhistcnt_u32_z(predicate_all_f32x, a_rev_u32x, b_rev_u32x));
|
|
364
|
+
svuint32_t hist_u32x = svorr_u32_x(a_progress_u32x, hist_lower_u32x, hist_upper_u32x);
|
|
365
|
+
svbool_t a_equal_mask_u32x = svcmpne_n_u32(a_progress_u32x, hist_u32x, 0);
|
|
366
|
+
svbool_t a_overlap_mask_u32x = svand_b_z(predicate_all_f32x, a_progress_u32x, a_equal_mask_u32x);
|
|
367
|
+
|
|
368
|
+
if (!svptest_any(a_progress_u32x, a_overlap_mask_u32x)) {
|
|
369
|
+
a_idx += a_step;
|
|
370
|
+
b_idx += b_step;
|
|
371
|
+
continue;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
// Load weights and mask by intersection
|
|
375
|
+
svfloat32_t a_weights_f32x = svsel_f32(a_overlap_mask_u32x, svld1_f32(a_progress_u32x, a_weights + a_idx),
|
|
376
|
+
svdup_f32(0.f));
|
|
377
|
+
svfloat32_t b_weights_f32x = svld1_f32(b_progress_u32x, b_weights + b_idx);
|
|
378
|
+
svbool_t predicate_low_f64x = svwhilelt_b64_u64(a_idx, a_length);
|
|
379
|
+
svbool_t predicate_high_f64x = svwhilelt_b64_u64(a_idx + vector_length_f64, a_length);
|
|
380
|
+
svfloat64_t a_low_f64x = svcvt_f64_f32_x(predicate_low_f64x, a_weights_f32x);
|
|
381
|
+
svfloat64_t a_high_f64x = svcvtlt_f64_f32_x(predicate_high_f64x, a_weights_f32x);
|
|
382
|
+
|
|
383
|
+
// For each position in a that matches something in b, we need the corresponding b weight.
|
|
384
|
+
// Use lane-by-lane matching for dot product.
|
|
385
|
+
for (nk_size_t i = 0; i < lanes_count; i++) {
|
|
386
|
+
// Check which elements of a match the current rotation of b
|
|
387
|
+
svbool_t equal_lane_u32x = svcmpeq_u32(a_progress_u32x, a_u32x, b_u32x);
|
|
388
|
+
svfloat32_t b_equal_weights_f32x = svsel_f32(equal_lane_u32x, b_weights_f32x, svdup_f32(0.f));
|
|
389
|
+
svfloat64_t b_low_f64x = svcvt_f64_f32_x(predicate_low_f64x, b_equal_weights_f32x);
|
|
390
|
+
svfloat64_t b_high_f64x = svcvtlt_f64_f32_x(predicate_high_f64x, b_equal_weights_f32x);
|
|
391
|
+
product_f64x = svmla_f64_x(predicate_low_f64x, product_f64x, a_low_f64x, b_low_f64x);
|
|
392
|
+
product_f64x = svmla_f64_x(predicate_high_f64x, product_f64x, a_high_f64x, b_high_f64x);
|
|
393
|
+
// Rotate b vectors
|
|
394
|
+
b_u32x = svext_u32(b_u32x, b_u32x, 4);
|
|
395
|
+
b_weights_f32x = svext_f32(b_weights_f32x, b_weights_f32x, 4);
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
// Advance
|
|
399
|
+
a_idx += a_step;
|
|
400
|
+
b_idx += b_step;
|
|
401
|
+
}
|
|
402
|
+
*product = svaddv_f64(predicate_all_f64x, product_f64x);
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
#if defined(__clang__)
|
|
406
|
+
#pragma clang attribute pop
|
|
407
|
+
#elif defined(__GNUC__)
|
|
408
|
+
#pragma GCC pop_options
|
|
409
|
+
#endif
|
|
410
|
+
#endif // NK_TARGET_SVE2
|
|
411
|
+
|
|
412
|
+
#if NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
|
|
413
|
+
#if defined(__clang__)
|
|
414
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.6-a+sve+sve2+bf16"))), apply_to = function)
|
|
415
|
+
#elif defined(__GNUC__)
|
|
416
|
+
#pragma GCC push_options
|
|
417
|
+
#pragma GCC target("arch=armv8.6-a+sve+sve2+bf16")
|
|
418
|
+
#endif
|
|
419
|
+
|
|
420
|
+
NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
|
|
421
|
+
nk_u16_t const *a, nk_u16_t const *b, //
|
|
422
|
+
nk_bf16_t const *a_weights, nk_bf16_t const *b_weights, //
|
|
423
|
+
nk_size_t a_length, nk_size_t b_length, //
|
|
424
|
+
nk_f32_t *product) {
|
|
425
|
+
|
|
426
|
+
// A single SVE lane is 128 bits wide, so one lane fits 8 values.
|
|
427
|
+
nk_size_t const register_size = svcnth();
|
|
428
|
+
nk_size_t const lanes_count = register_size / 8;
|
|
429
|
+
nk_size_t a_idx = 0, b_idx = 0;
|
|
430
|
+
svfloat32_t product_f32x = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
|
|
431
|
+
|
|
432
|
+
while (a_idx < a_length && b_idx < b_length) {
|
|
433
|
+
// Load `a_member` and broadcast it, load `b_members_vec` from memory
|
|
434
|
+
svbool_t a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
|
|
435
|
+
svbool_t b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
|
|
436
|
+
svuint16_t a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
|
|
437
|
+
svuint16_t b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
|
|
438
|
+
|
|
439
|
+
// Intersecting registers with `svmatch_u16` involves a lot of shuffling
|
|
440
|
+
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
441
|
+
nk_u16_t a_min;
|
|
442
|
+
nk_u16_t a_max = svlastb(a_progress_u16x, a_u16x);
|
|
443
|
+
nk_u16_t b_min = svlasta(svpfalse_b(), b_u16x);
|
|
444
|
+
nk_u16_t b_max = svlastb(b_progress_u16x, b_u16x);
|
|
445
|
+
|
|
446
|
+
// If the slices don't overlap, advance the appropriate pointer
|
|
447
|
+
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
448
|
+
a_idx += register_size;
|
|
449
|
+
a_progress_u16x = svwhilelt_b16_u64(a_idx, a_length);
|
|
450
|
+
a_u16x = svld1_u16(a_progress_u16x, a + a_idx);
|
|
451
|
+
a_max = svlastb(a_progress_u16x, a_u16x);
|
|
452
|
+
}
|
|
453
|
+
a_min = svlasta(svpfalse_b(), a_u16x);
|
|
454
|
+
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
455
|
+
b_idx += register_size;
|
|
456
|
+
b_progress_u16x = svwhilelt_b16_u64(b_idx, b_length);
|
|
457
|
+
b_u16x = svld1_u16(b_progress_u16x, b + b_idx);
|
|
458
|
+
b_max = svlastb(b_progress_u16x, b_u16x);
|
|
459
|
+
}
|
|
460
|
+
b_min = svlasta(svpfalse_b(), b_u16x);
|
|
461
|
+
|
|
462
|
+
// Before we evaluate the intersection size, obfurscating the order in `b_u16x`,
|
|
463
|
+
// let's estimate how much we will need to advance the pointers afterwards.
|
|
464
|
+
// For that, we don't even need to broadcast the values in SVE, as the whole
|
|
465
|
+
// register can be compared against a scalar:
|
|
466
|
+
//
|
|
467
|
+
// svuint16_t a_last_broadcasted = svdup_n_u16(a_max);
|
|
468
|
+
// svuint16_t b_last_broadcasted = svdup_n_u16(b_max);
|
|
469
|
+
svbool_t a_mask_u16x = svcmple_n_u16(a_progress_u16x, a_u16x, b_max);
|
|
470
|
+
svbool_t b_mask_u16x = svcmple_n_u16(b_progress_u16x, b_u16x, a_max);
|
|
471
|
+
nk_u64_t a_step = svcntp_b16(a_progress_u16x, a_mask_u16x);
|
|
472
|
+
nk_u64_t b_step = svcntp_b16(b_progress_u16x, b_mask_u16x);
|
|
473
|
+
|
|
474
|
+
// Compare `a_u16x` with each lane of `b_u16x`
|
|
475
|
+
svbfloat16_t a_weights_bf16x = svld1_bf16(a_progress_u16x, (__bf16 const *)a_weights + a_idx);
|
|
476
|
+
svbfloat16_t b_weights_bf16x = svld1_bf16(b_progress_u16x, (__bf16 const *)b_weights + b_idx);
|
|
477
|
+
for (nk_size_t i = 0; i < lanes_count; i++) {
|
|
478
|
+
svbool_t equal_mask_u16x = svmatch_u16(a_progress_u16x, a_u16x, b_u16x);
|
|
479
|
+
//! The `svsel_bf16` intrinsic is broken in many compilers, not returning the correct type.
|
|
480
|
+
//! So we reinterprete floats as integers and apply `svsel_s16`, but the `svreinterpret_s16_bs16`
|
|
481
|
+
//! and `svreinterpret_bf16_s16` are not always properly defined!
|
|
482
|
+
svint16_t b_equal_weights_s16x = svsel_s16(equal_mask_u16x, svreinterpret_s16_bf16(b_weights_bf16x),
|
|
483
|
+
svdup_n_s16(0));
|
|
484
|
+
product_f32x = svbfdot_f32(product_f32x, a_weights_bf16x, svreinterpret_bf16_s16(b_equal_weights_s16x));
|
|
485
|
+
b_u16x = svext_u16(b_u16x, b_u16x, 8);
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
// Advance
|
|
489
|
+
a_idx += a_step;
|
|
490
|
+
b_idx += b_step;
|
|
491
|
+
}
|
|
492
|
+
*product = svaddv_f32(svptrue_b32(), product_f32x);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
#if defined(__clang__)
|
|
496
|
+
#pragma clang attribute pop
|
|
497
|
+
#elif defined(__GNUC__)
|
|
498
|
+
#pragma GCC pop_options
|
|
499
|
+
#endif
|
|
500
|
+
#endif // NK_TARGET_SVE2 && NK_TARGET_SVEBFDOT
|
|
501
|
+
|
|
502
|
+
#if defined(__cplusplus)
|
|
503
|
+
} // extern "C"
|
|
504
|
+
#endif
|
|
505
|
+
|
|
506
|
+
#endif // NK_TARGET_ARM_
|
|
507
|
+
#endif // NK_SPARSE_SVE2_H
|