numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -60,31 +60,31 @@ NK_PUBLIC void nk_sparse_intersect_u16_sve2( //
|
|
|
60
60
|
|
|
61
61
|
while (a_idx < a_length && b_idx < b_length) {
|
|
62
62
|
// Load `a_member` and broadcast it, load `b_members_vec` from memory
|
|
63
|
-
svbool_t
|
|
64
|
-
svbool_t
|
|
65
|
-
svuint16_t a_u16x = svld1_u16(
|
|
66
|
-
svuint16_t b_u16x = svld1_u16(
|
|
63
|
+
svbool_t a_progress_b16x = svwhilelt_b16_u64(a_idx, a_length);
|
|
64
|
+
svbool_t b_progress_b16x = svwhilelt_b16_u64(b_idx, b_length);
|
|
65
|
+
svuint16_t a_u16x = svld1_u16(a_progress_b16x, a + a_idx);
|
|
66
|
+
svuint16_t b_u16x = svld1_u16(b_progress_b16x, b + b_idx);
|
|
67
67
|
|
|
68
68
|
// Intersecting registers with `svmatch_u16` involves a lot of shuffling
|
|
69
69
|
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
70
70
|
nk_u16_t a_min;
|
|
71
|
-
nk_u16_t a_max = svlastb(
|
|
71
|
+
nk_u16_t a_max = svlastb(a_progress_b16x, a_u16x);
|
|
72
72
|
nk_u16_t b_min = svlasta(svpfalse_b(), b_u16x);
|
|
73
|
-
nk_u16_t b_max = svlastb(
|
|
73
|
+
nk_u16_t b_max = svlastb(b_progress_b16x, b_u16x);
|
|
74
74
|
|
|
75
75
|
// If the slices don't overlap, advance the appropriate pointer
|
|
76
76
|
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
77
77
|
a_idx += register_size;
|
|
78
|
-
|
|
79
|
-
a_u16x = svld1_u16(
|
|
80
|
-
a_max = svlastb(
|
|
78
|
+
a_progress_b16x = svwhilelt_b16_u64(a_idx, a_length);
|
|
79
|
+
a_u16x = svld1_u16(a_progress_b16x, a + a_idx);
|
|
80
|
+
a_max = svlastb(a_progress_b16x, a_u16x);
|
|
81
81
|
}
|
|
82
82
|
a_min = svlasta(svpfalse_b(), a_u16x);
|
|
83
83
|
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
84
84
|
b_idx += register_size;
|
|
85
|
-
|
|
86
|
-
b_u16x = svld1_u16(
|
|
87
|
-
b_max = svlastb(
|
|
85
|
+
b_progress_b16x = svwhilelt_b16_u64(b_idx, b_length);
|
|
86
|
+
b_u16x = svld1_u16(b_progress_b16x, b + b_idx);
|
|
87
|
+
b_max = svlastb(b_progress_b16x, b_u16x);
|
|
88
88
|
}
|
|
89
89
|
b_min = svlasta(svpfalse_b(), b_u16x);
|
|
90
90
|
|
|
@@ -95,18 +95,18 @@ NK_PUBLIC void nk_sparse_intersect_u16_sve2( //
|
|
|
95
95
|
//
|
|
96
96
|
// svuint16_t a_last_broadcasted = svdup_n_u16(a_max);
|
|
97
97
|
// svuint16_t b_last_broadcasted = svdup_n_u16(b_max);
|
|
98
|
-
svbool_t
|
|
99
|
-
svbool_t
|
|
100
|
-
nk_u64_t a_step = svcntp_b16(
|
|
101
|
-
nk_u64_t b_step = svcntp_b16(
|
|
98
|
+
svbool_t a_mask_b16x = svcmple_n_u16(a_progress_b16x, a_u16x, b_max);
|
|
99
|
+
svbool_t b_mask_b16x = svcmple_n_u16(b_progress_b16x, b_u16x, a_max);
|
|
100
|
+
nk_u64_t a_step = svcntp_b16(a_progress_b16x, a_mask_b16x);
|
|
101
|
+
nk_u64_t b_step = svcntp_b16(b_progress_b16x, b_mask_b16x);
|
|
102
102
|
|
|
103
103
|
// Compare `a_u16x` with each lane of `b_u16x`
|
|
104
|
-
svbool_t
|
|
104
|
+
svbool_t equal_mask_b16x = svmatch_u16(a_progress_b16x, a_u16x, b_u16x);
|
|
105
105
|
for (nk_size_t i = 1; i < lanes_count; i++) {
|
|
106
106
|
b_u16x = svext_u16(b_u16x, b_u16x, 8);
|
|
107
|
-
|
|
107
|
+
equal_mask_b16x = svorr_z(svptrue_b16(), equal_mask_b16x, svmatch_u16(a_progress_b16x, a_u16x, b_u16x));
|
|
108
108
|
}
|
|
109
|
-
nk_size_t equal_count = svcntp_b16(svptrue_b16(),
|
|
109
|
+
nk_size_t equal_count = svcntp_b16(svptrue_b16(), equal_mask_b16x);
|
|
110
110
|
|
|
111
111
|
// Manually compact and store matching elements (svcompact_u16 is not defined)
|
|
112
112
|
if (result) {
|
|
@@ -114,7 +114,7 @@ NK_PUBLIC void nk_sparse_intersect_u16_sve2( //
|
|
|
114
114
|
nk_u16_t mask_data[16];
|
|
115
115
|
|
|
116
116
|
svst1_u16(svptrue_b16(), a_data, a_u16x);
|
|
117
|
-
svst1_u16(svptrue_b16(), mask_data, svdup_n_u16_z(
|
|
117
|
+
svst1_u16(svptrue_b16(), mask_data, svdup_n_u16_z(equal_mask_b16x, 1));
|
|
118
118
|
|
|
119
119
|
for (nk_size_t i = 0; i < svcnth(); i++)
|
|
120
120
|
if (mask_data[i]) result[c++] = a_data[i];
|
|
@@ -142,31 +142,31 @@ NK_PUBLIC void nk_sparse_intersect_u32_sve2( //
|
|
|
142
142
|
|
|
143
143
|
while (a_idx < a_length && b_idx < b_length) {
|
|
144
144
|
// Load `a_member` and broadcast it, load `b_members_vec` from memory
|
|
145
|
-
svbool_t
|
|
146
|
-
svbool_t
|
|
147
|
-
svuint32_t a_u32x = svld1_u32(
|
|
148
|
-
svuint32_t b_u32x = svld1_u32(
|
|
145
|
+
svbool_t a_progress_b32x = svwhilelt_b32_u64(a_idx, a_length);
|
|
146
|
+
svbool_t b_progress_b32x = svwhilelt_b32_u64(b_idx, b_length);
|
|
147
|
+
svuint32_t a_u32x = svld1_u32(a_progress_b32x, a + a_idx);
|
|
148
|
+
svuint32_t b_u32x = svld1_u32(b_progress_b32x, b + b_idx);
|
|
149
149
|
|
|
150
150
|
// Intersecting registers with `svmatch_u16` involves a lot of shuffling
|
|
151
151
|
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
152
152
|
nk_u32_t a_min;
|
|
153
|
-
nk_u32_t a_max = svlastb(
|
|
153
|
+
nk_u32_t a_max = svlastb(a_progress_b32x, a_u32x);
|
|
154
154
|
nk_u32_t b_min = svlasta(svpfalse_b(), b_u32x);
|
|
155
|
-
nk_u32_t b_max = svlastb(
|
|
155
|
+
nk_u32_t b_max = svlastb(b_progress_b32x, b_u32x);
|
|
156
156
|
|
|
157
157
|
// If the slices don't overlap, advance the appropriate pointer
|
|
158
158
|
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
159
159
|
a_idx += register_size;
|
|
160
|
-
|
|
161
|
-
a_u32x = svld1_u32(
|
|
162
|
-
a_max = svlastb(
|
|
160
|
+
a_progress_b32x = svwhilelt_b32_u64(a_idx, a_length);
|
|
161
|
+
a_u32x = svld1_u32(a_progress_b32x, a + a_idx);
|
|
162
|
+
a_max = svlastb(a_progress_b32x, a_u32x);
|
|
163
163
|
}
|
|
164
164
|
a_min = svlasta(svpfalse_b(), a_u32x);
|
|
165
165
|
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
166
166
|
b_idx += register_size;
|
|
167
|
-
|
|
168
|
-
b_u32x = svld1_u32(
|
|
169
|
-
b_max = svlastb(
|
|
167
|
+
b_progress_b32x = svwhilelt_b32_u64(b_idx, b_length);
|
|
168
|
+
b_u32x = svld1_u32(b_progress_b32x, b + b_idx);
|
|
169
|
+
b_max = svlastb(b_progress_b32x, b_u32x);
|
|
170
170
|
}
|
|
171
171
|
b_min = svlasta(svpfalse_b(), b_u32x);
|
|
172
172
|
|
|
@@ -177,21 +177,21 @@ NK_PUBLIC void nk_sparse_intersect_u32_sve2( //
|
|
|
177
177
|
//
|
|
178
178
|
// svuint32_t a_last_broadcasted = svdup_n_u32(a_max);
|
|
179
179
|
// svuint32_t b_last_broadcasted = svdup_n_u32(b_max);
|
|
180
|
-
svbool_t
|
|
181
|
-
svbool_t
|
|
182
|
-
nk_u64_t a_step = svcntp_b32(
|
|
183
|
-
nk_u64_t b_step = svcntp_b32(
|
|
180
|
+
svbool_t a_mask_b32x = svcmple_n_u32(a_progress_b32x, a_u32x, b_max);
|
|
181
|
+
svbool_t b_mask_b32x = svcmple_n_u32(b_progress_b32x, b_u32x, a_max);
|
|
182
|
+
nk_u64_t a_step = svcntp_b32(a_progress_b32x, a_mask_b32x);
|
|
183
|
+
nk_u64_t b_step = svcntp_b32(b_progress_b32x, b_mask_b32x);
|
|
184
184
|
|
|
185
185
|
// Comparing `a_u32x` with each lane of `b_u32x` can't be done with `svmatch`,
|
|
186
186
|
// the same way as in `nk_sparse_intersect_u16_sve2`, as that instruction is only
|
|
187
187
|
// available for 8-bit and 16-bit integers.
|
|
188
188
|
//
|
|
189
|
-
// svbool_t
|
|
189
|
+
// svbool_t equal_mask_b32x = svpfalse_b();
|
|
190
190
|
// for (nk_size_t i = 0; i < register_size; i++) {
|
|
191
|
-
//
|
|
191
|
+
// equal_mask_b32x = svorr_z(svptrue_b32(), equal_mask_b32x, svcmpeq_u32(a_progress, a_u32x, b_u32x));
|
|
192
192
|
// b_u32x = svext_u32(b_u32x, b_u32x, 1);
|
|
193
193
|
// }
|
|
194
|
-
// nk_size_t equal_count = svcntp_b32(a_progress,
|
|
194
|
+
// nk_size_t equal_count = svcntp_b32(a_progress, equal_mask_b32x);
|
|
195
195
|
//
|
|
196
196
|
// Alternatively, one can use histogram instructions, like `svhistcnt_u32_z`.
|
|
197
197
|
// They practically compute the prefix-matching count, which is equivalent to
|
|
@@ -210,19 +210,19 @@ NK_PUBLIC void nk_sparse_intersect_u32_sve2( //
|
|
|
210
210
|
// C 1 1 1 0 B 1 1 1 0
|
|
211
211
|
// D 1 1 1 1 A 1 1 1 1
|
|
212
212
|
//
|
|
213
|
-
svuint32_t
|
|
213
|
+
svuint32_t hist_low_u32x = svhistcnt_u32_z(a_progress_b32x, a_u32x, b_u32x);
|
|
214
214
|
svuint32_t a_rev_u32x = svrev_u32(a_u32x);
|
|
215
215
|
svuint32_t b_rev_u32x = svrev_u32(b_u32x);
|
|
216
|
-
svuint32_t
|
|
217
|
-
svuint32_t
|
|
218
|
-
svbool_t
|
|
219
|
-
nk_size_t equal_count = svcntp_b32(
|
|
216
|
+
svuint32_t hist_high_u32x = svrev_u32(svhistcnt_u32_z(svptrue_b32(), a_rev_u32x, b_rev_u32x));
|
|
217
|
+
svuint32_t hist_u32x = svorr_u32_x(a_progress_b32x, hist_low_u32x, hist_high_u32x);
|
|
218
|
+
svbool_t equal_mask_b32x = svcmpne_n_u32(a_progress_b32x, hist_u32x, 0);
|
|
219
|
+
nk_size_t equal_count = svcntp_b32(a_progress_b32x, equal_mask_b32x);
|
|
220
220
|
|
|
221
221
|
// Use SVE2 svcompact to compress matching elements and store to result buffer
|
|
222
222
|
if (result) {
|
|
223
|
-
svuint32_t
|
|
224
|
-
svbool_t
|
|
225
|
-
svst1_u32(
|
|
223
|
+
svuint32_t compacted_u32x = svcompact_u32(equal_mask_b32x, a_u32x);
|
|
224
|
+
svbool_t store_predicate_b32x = svwhilelt_b32_u64(0u, equal_count);
|
|
225
|
+
svst1_u32(store_predicate_b32x, result + c, compacted_u32x);
|
|
226
226
|
}
|
|
227
227
|
|
|
228
228
|
// Advance
|
|
@@ -246,56 +246,56 @@ NK_PUBLIC void nk_sparse_intersect_u64_sve2( //
|
|
|
246
246
|
|
|
247
247
|
while (a_idx < a_length && b_idx < b_length) {
|
|
248
248
|
// Load `a_member` and broadcast it, load `b_members_vec` from memory
|
|
249
|
-
svbool_t
|
|
250
|
-
svbool_t
|
|
251
|
-
svuint64_t a_u64x = svld1_u64(
|
|
252
|
-
svuint64_t b_u64x = svld1_u64(
|
|
249
|
+
svbool_t a_progress_b64x = svwhilelt_b64_u64(a_idx, a_length);
|
|
250
|
+
svbool_t b_progress_b64x = svwhilelt_b64_u64(b_idx, b_length);
|
|
251
|
+
svuint64_t a_u64x = svld1_u64(a_progress_b64x, a + a_idx);
|
|
252
|
+
svuint64_t b_u64x = svld1_u64(b_progress_b64x, b + b_idx);
|
|
253
253
|
|
|
254
254
|
// Intersecting registers involves comparisons,
|
|
255
255
|
// so we want to avoid it if the slices don't overlap at all.
|
|
256
256
|
nk_u64_t a_min;
|
|
257
|
-
nk_u64_t a_max = svlastb(
|
|
257
|
+
nk_u64_t a_max = svlastb(a_progress_b64x, a_u64x);
|
|
258
258
|
nk_u64_t b_min = svlasta(svpfalse_b(), b_u64x);
|
|
259
|
-
nk_u64_t b_max = svlastb(
|
|
259
|
+
nk_u64_t b_max = svlastb(b_progress_b64x, b_u64x);
|
|
260
260
|
|
|
261
261
|
// If the slices don't overlap, advance the appropriate pointer
|
|
262
262
|
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
263
263
|
a_idx += register_size;
|
|
264
|
-
|
|
265
|
-
a_u64x = svld1_u64(
|
|
266
|
-
a_max = svlastb(
|
|
264
|
+
a_progress_b64x = svwhilelt_b64_u64(a_idx, a_length);
|
|
265
|
+
a_u64x = svld1_u64(a_progress_b64x, a + a_idx);
|
|
266
|
+
a_max = svlastb(a_progress_b64x, a_u64x);
|
|
267
267
|
}
|
|
268
268
|
a_min = svlasta(svpfalse_b(), a_u64x);
|
|
269
269
|
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
270
270
|
b_idx += register_size;
|
|
271
|
-
|
|
272
|
-
b_u64x = svld1_u64(
|
|
273
|
-
b_max = svlastb(
|
|
271
|
+
b_progress_b64x = svwhilelt_b64_u64(b_idx, b_length);
|
|
272
|
+
b_u64x = svld1_u64(b_progress_b64x, b + b_idx);
|
|
273
|
+
b_max = svlastb(b_progress_b64x, b_u64x);
|
|
274
274
|
}
|
|
275
275
|
b_min = svlasta(svpfalse_b(), b_u64x);
|
|
276
276
|
|
|
277
277
|
// Estimate how much we will need to advance the pointers afterwards.
|
|
278
|
-
svbool_t
|
|
279
|
-
svbool_t
|
|
280
|
-
nk_u64_t a_step = svcntp_b64(
|
|
281
|
-
nk_u64_t b_step = svcntp_b64(
|
|
278
|
+
svbool_t a_mask_b64x = svcmple_n_u64(a_progress_b64x, a_u64x, b_max);
|
|
279
|
+
svbool_t b_mask_b64x = svcmple_n_u64(b_progress_b64x, b_u64x, a_max);
|
|
280
|
+
nk_u64_t a_step = svcntp_b64(a_progress_b64x, a_mask_b64x);
|
|
281
|
+
nk_u64_t b_step = svcntp_b64(b_progress_b64x, b_mask_b64x);
|
|
282
282
|
|
|
283
283
|
// Use histogram instructions like `svhistcnt_u64_z` to compute intersection.
|
|
284
284
|
// They compute the prefix-matching count, equivalent to the lower triangle
|
|
285
285
|
// of the row-major intersection matrix.
|
|
286
|
-
svuint64_t
|
|
286
|
+
svuint64_t hist_low_u64x = svhistcnt_u64_z(a_progress_b64x, a_u64x, b_u64x);
|
|
287
287
|
svuint64_t a_rev_u64x = svrev_u64(a_u64x);
|
|
288
288
|
svuint64_t b_rev_u64x = svrev_u64(b_u64x);
|
|
289
|
-
svuint64_t
|
|
290
|
-
svuint64_t
|
|
291
|
-
svbool_t
|
|
292
|
-
nk_size_t equal_count = svcntp_b64(
|
|
289
|
+
svuint64_t hist_high_u64x = svrev_u64(svhistcnt_u64_z(svptrue_b64(), a_rev_u64x, b_rev_u64x));
|
|
290
|
+
svuint64_t hist_u64x = svorr_u64_x(a_progress_b64x, hist_low_u64x, hist_high_u64x);
|
|
291
|
+
svbool_t equal_mask_b64x = svcmpne_n_u64(a_progress_b64x, hist_u64x, 0);
|
|
292
|
+
nk_size_t equal_count = svcntp_b64(a_progress_b64x, equal_mask_b64x);
|
|
293
293
|
|
|
294
294
|
// Use SVE2 svcompact to compress matching elements and store to result buffer
|
|
295
295
|
if (result) {
|
|
296
|
-
svuint64_t
|
|
297
|
-
svbool_t
|
|
298
|
-
svst1_u64(
|
|
296
|
+
svuint64_t compacted_u64x = svcompact_u64(equal_mask_b64x, a_u64x);
|
|
297
|
+
svbool_t store_predicate_b64x = svwhilelt_b64_u64(0u, equal_count);
|
|
298
|
+
svst1_u64(store_predicate_b64x, result + c, compacted_u64x);
|
|
299
299
|
}
|
|
300
300
|
|
|
301
301
|
// Advance
|
|
@@ -312,94 +312,90 @@ NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
|
|
|
312
312
|
nk_size_t a_length, nk_size_t b_length, //
|
|
313
313
|
nk_f64_t *product) {
|
|
314
314
|
|
|
315
|
-
// A single SVE lane is 128 bits wide, so one lane fits 4 values.
|
|
316
315
|
nk_size_t const register_size = svcntw();
|
|
317
316
|
nk_size_t const vector_length_f64 = svcntd();
|
|
318
|
-
nk_size_t const lanes_count = register_size / 4;
|
|
319
317
|
nk_size_t a_idx = 0, b_idx = 0;
|
|
320
|
-
svbool_t const
|
|
321
|
-
svbool_t const
|
|
318
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
319
|
+
svbool_t const predicate_all_b64x = svptrue_b64();
|
|
322
320
|
svfloat64_t product_f64x = svdup_f64(0.0);
|
|
323
321
|
|
|
324
322
|
while (a_idx < a_length && b_idx < b_length) {
|
|
325
323
|
// Load indices with progress predicates
|
|
326
|
-
svbool_t
|
|
327
|
-
svbool_t
|
|
328
|
-
svuint32_t a_u32x = svld1_u32(
|
|
329
|
-
svuint32_t b_u32x = svld1_u32(
|
|
324
|
+
svbool_t a_progress_b32x = svwhilelt_b32_u64(a_idx, a_length);
|
|
325
|
+
svbool_t b_progress_b32x = svwhilelt_b32_u64(b_idx, b_length);
|
|
326
|
+
svuint32_t a_u32x = svld1_u32(a_progress_b32x, a + a_idx);
|
|
327
|
+
svuint32_t b_u32x = svld1_u32(b_progress_b32x, b + b_idx);
|
|
330
328
|
|
|
331
329
|
// Avoid expensive intersection if slices don't overlap at all
|
|
332
330
|
nk_u32_t a_min;
|
|
333
|
-
nk_u32_t a_max = svlastb(
|
|
331
|
+
nk_u32_t a_max = svlastb(a_progress_b32x, a_u32x);
|
|
334
332
|
nk_u32_t b_min = svlasta(svpfalse_b(), b_u32x);
|
|
335
|
-
nk_u32_t b_max = svlastb(
|
|
333
|
+
nk_u32_t b_max = svlastb(b_progress_b32x, b_u32x);
|
|
336
334
|
|
|
337
335
|
// If the slices don't overlap, advance the appropriate pointer
|
|
338
336
|
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
339
337
|
a_idx += register_size;
|
|
340
|
-
|
|
341
|
-
a_u32x = svld1_u32(
|
|
342
|
-
a_max = svlastb(
|
|
338
|
+
a_progress_b32x = svwhilelt_b32_u64(a_idx, a_length);
|
|
339
|
+
a_u32x = svld1_u32(a_progress_b32x, a + a_idx);
|
|
340
|
+
a_max = svlastb(a_progress_b32x, a_u32x);
|
|
343
341
|
}
|
|
344
342
|
a_min = svlasta(svpfalse_b(), a_u32x);
|
|
345
343
|
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
346
344
|
b_idx += register_size;
|
|
347
|
-
|
|
348
|
-
b_u32x = svld1_u32(
|
|
349
|
-
b_max = svlastb(
|
|
345
|
+
b_progress_b32x = svwhilelt_b32_u64(b_idx, b_length);
|
|
346
|
+
b_u32x = svld1_u32(b_progress_b32x, b + b_idx);
|
|
347
|
+
b_max = svlastb(b_progress_b32x, b_u32x);
|
|
350
348
|
}
|
|
351
349
|
b_min = svlasta(svpfalse_b(), b_u32x);
|
|
352
350
|
|
|
353
351
|
// Calculate step sizes before modifying vectors
|
|
354
|
-
svbool_t
|
|
355
|
-
svbool_t
|
|
356
|
-
nk_u64_t a_step = svcntp_b32(
|
|
357
|
-
nk_u64_t b_step = svcntp_b32(
|
|
352
|
+
svbool_t a_mask_b32x = svcmple_n_u32(a_progress_b32x, a_u32x, b_max);
|
|
353
|
+
svbool_t b_mask_b32x = svcmple_n_u32(b_progress_b32x, b_u32x, a_max);
|
|
354
|
+
nk_u64_t a_step = svcntp_b32(a_progress_b32x, a_mask_b32x);
|
|
355
|
+
nk_u64_t b_step = svcntp_b32(b_progress_b32x, b_mask_b32x);
|
|
358
356
|
|
|
359
357
|
// Use histogram-based intersection (svmatch_u32 doesn't exist)
|
|
360
|
-
svuint32_t
|
|
358
|
+
svuint32_t hist_low_u32x = svhistcnt_u32_z(a_progress_b32x, a_u32x, b_u32x);
|
|
361
359
|
svuint32_t a_rev_u32x = svrev_u32(a_u32x);
|
|
362
360
|
svuint32_t b_rev_u32x = svrev_u32(b_u32x);
|
|
363
|
-
svuint32_t
|
|
364
|
-
svuint32_t hist_u32x = svorr_u32_x(
|
|
365
|
-
svbool_t
|
|
366
|
-
svbool_t
|
|
361
|
+
svuint32_t hist_high_u32x = svrev_u32(svhistcnt_u32_z(predicate_all_b32x, a_rev_u32x, b_rev_u32x));
|
|
362
|
+
svuint32_t hist_u32x = svorr_u32_x(a_progress_b32x, hist_low_u32x, hist_high_u32x);
|
|
363
|
+
svbool_t a_equal_mask_b32x = svcmpne_n_u32(a_progress_b32x, hist_u32x, 0);
|
|
364
|
+
svbool_t a_overlap_mask_b32x = svand_b_z(predicate_all_b32x, a_progress_b32x, a_equal_mask_b32x);
|
|
367
365
|
|
|
368
|
-
if (!svptest_any(
|
|
366
|
+
if (!svptest_any(a_progress_b32x, a_overlap_mask_b32x)) {
|
|
369
367
|
a_idx += a_step;
|
|
370
368
|
b_idx += b_step;
|
|
371
369
|
continue;
|
|
372
370
|
}
|
|
373
371
|
|
|
374
|
-
//
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
svbool_t
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
b_weights_f32x = svext_f32(b_weights_f32x, b_weights_f32x, 4);
|
|
396
|
-
}
|
|
372
|
+
// Compute b overlap mask (symmetric histogram: which b elements match something in a)
|
|
373
|
+
svuint32_t b_hist_low_u32x = svhistcnt_u32_z(b_progress_b32x, b_u32x, a_u32x);
|
|
374
|
+
svuint32_t b_hist_high_u32x = svrev_u32(svhistcnt_u32_z(predicate_all_b32x, b_rev_u32x, a_rev_u32x));
|
|
375
|
+
svuint32_t b_hist_u32x = svorr_u32_x(b_progress_b32x, b_hist_low_u32x, b_hist_high_u32x);
|
|
376
|
+
svbool_t b_overlap_mask_b32x = svand_b_z(predicate_all_b32x, b_progress_b32x,
|
|
377
|
+
svcmpne_n_u32(b_progress_b32x, b_hist_u32x, 0));
|
|
378
|
+
|
|
379
|
+
// Compact matching weights — both arrays are sorted, so svcompact
|
|
380
|
+
// preserves relative order and aligns corresponding intersection pairs.
|
|
381
|
+
svfloat32_t a_matched_f32x = svcompact_f32(a_overlap_mask_b32x, svld1_f32(a_progress_b32x, a_weights + a_idx));
|
|
382
|
+
svfloat32_t b_matched_f32x = svcompact_f32(b_overlap_mask_b32x, svld1_f32(b_progress_b32x, b_weights + b_idx));
|
|
383
|
+
|
|
384
|
+
// Widen to f64 and accumulate. svcvt_f64_f32 converts even-indexed f32
|
|
385
|
+
// elements; svcvtlt_f64_f32 converts odd-indexed f32 elements.
|
|
386
|
+
nk_size_t match_count = svcntp_b32(a_progress_b32x, a_overlap_mask_b32x);
|
|
387
|
+
svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (match_count + 1) / 2);
|
|
388
|
+
svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, match_count / 2);
|
|
389
|
+
product_f64x = svmla_f64_x(pred_even_b64x, product_f64x, svcvt_f64_f32_x(pred_even_b64x, a_matched_f32x),
|
|
390
|
+
svcvt_f64_f32_x(pred_even_b64x, b_matched_f32x));
|
|
391
|
+
product_f64x = svmla_f64_x(pred_odd_b64x, product_f64x, svcvtlt_f64_f32_x(pred_odd_b64x, a_matched_f32x),
|
|
392
|
+
svcvtlt_f64_f32_x(pred_odd_b64x, b_matched_f32x));
|
|
397
393
|
|
|
398
394
|
// Advance
|
|
399
395
|
a_idx += a_step;
|
|
400
396
|
b_idx += b_step;
|
|
401
397
|
}
|
|
402
|
-
*product = svaddv_f64(
|
|
398
|
+
*product = svaddv_f64(predicate_all_b64x, product_f64x);
|
|
403
399
|
}
|
|
404
400
|
|
|
405
401
|
#if defined(__clang__)
|
|
@@ -431,31 +427,31 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
|
|
|
431
427
|
|
|
432
428
|
while (a_idx < a_length && b_idx < b_length) {
|
|
433
429
|
// Load `a_member` and broadcast it, load `b_members_vec` from memory
|
|
434
|
-
svbool_t
|
|
435
|
-
svbool_t
|
|
436
|
-
svuint16_t a_u16x = svld1_u16(
|
|
437
|
-
svuint16_t b_u16x = svld1_u16(
|
|
430
|
+
svbool_t a_progress_b16x = svwhilelt_b16_u64(a_idx, a_length);
|
|
431
|
+
svbool_t b_progress_b16x = svwhilelt_b16_u64(b_idx, b_length);
|
|
432
|
+
svuint16_t a_u16x = svld1_u16(a_progress_b16x, a + a_idx);
|
|
433
|
+
svuint16_t b_u16x = svld1_u16(b_progress_b16x, b + b_idx);
|
|
438
434
|
|
|
439
435
|
// Intersecting registers with `svmatch_u16` involves a lot of shuffling
|
|
440
436
|
// and comparisons, so we want to avoid it if the slices don't overlap at all..
|
|
441
437
|
nk_u16_t a_min;
|
|
442
|
-
nk_u16_t a_max = svlastb(
|
|
438
|
+
nk_u16_t a_max = svlastb(a_progress_b16x, a_u16x);
|
|
443
439
|
nk_u16_t b_min = svlasta(svpfalse_b(), b_u16x);
|
|
444
|
-
nk_u16_t b_max = svlastb(
|
|
440
|
+
nk_u16_t b_max = svlastb(b_progress_b16x, b_u16x);
|
|
445
441
|
|
|
446
442
|
// If the slices don't overlap, advance the appropriate pointer
|
|
447
443
|
while (a_max < b_min && (a_idx + register_size) <= a_length) {
|
|
448
444
|
a_idx += register_size;
|
|
449
|
-
|
|
450
|
-
a_u16x = svld1_u16(
|
|
451
|
-
a_max = svlastb(
|
|
445
|
+
a_progress_b16x = svwhilelt_b16_u64(a_idx, a_length);
|
|
446
|
+
a_u16x = svld1_u16(a_progress_b16x, a + a_idx);
|
|
447
|
+
a_max = svlastb(a_progress_b16x, a_u16x);
|
|
452
448
|
}
|
|
453
449
|
a_min = svlasta(svpfalse_b(), a_u16x);
|
|
454
450
|
while (b_max < a_min && (b_idx + register_size) <= b_length) {
|
|
455
451
|
b_idx += register_size;
|
|
456
|
-
|
|
457
|
-
b_u16x = svld1_u16(
|
|
458
|
-
b_max = svlastb(
|
|
452
|
+
b_progress_b16x = svwhilelt_b16_u64(b_idx, b_length);
|
|
453
|
+
b_u16x = svld1_u16(b_progress_b16x, b + b_idx);
|
|
454
|
+
b_max = svlastb(b_progress_b16x, b_u16x);
|
|
459
455
|
}
|
|
460
456
|
b_min = svlasta(svpfalse_b(), b_u16x);
|
|
461
457
|
|
|
@@ -466,20 +462,20 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
|
|
|
466
462
|
//
|
|
467
463
|
// svuint16_t a_last_broadcasted = svdup_n_u16(a_max);
|
|
468
464
|
// svuint16_t b_last_broadcasted = svdup_n_u16(b_max);
|
|
469
|
-
svbool_t
|
|
470
|
-
svbool_t
|
|
471
|
-
nk_u64_t a_step = svcntp_b16(
|
|
472
|
-
nk_u64_t b_step = svcntp_b16(
|
|
465
|
+
svbool_t a_mask_b16x = svcmple_n_u16(a_progress_b16x, a_u16x, b_max);
|
|
466
|
+
svbool_t b_mask_b16x = svcmple_n_u16(b_progress_b16x, b_u16x, a_max);
|
|
467
|
+
nk_u64_t a_step = svcntp_b16(a_progress_b16x, a_mask_b16x);
|
|
468
|
+
nk_u64_t b_step = svcntp_b16(b_progress_b16x, b_mask_b16x);
|
|
473
469
|
|
|
474
470
|
// Compare `a_u16x` with each lane of `b_u16x`
|
|
475
|
-
svbfloat16_t a_weights_bf16x = svld1_bf16(
|
|
476
|
-
svbfloat16_t b_weights_bf16x = svld1_bf16(
|
|
471
|
+
svbfloat16_t a_weights_bf16x = svld1_bf16(a_progress_b16x, (__bf16 const *)a_weights + a_idx);
|
|
472
|
+
svbfloat16_t b_weights_bf16x = svld1_bf16(b_progress_b16x, (__bf16 const *)b_weights + b_idx);
|
|
477
473
|
for (nk_size_t i = 0; i < lanes_count; i++) {
|
|
478
|
-
svbool_t
|
|
474
|
+
svbool_t equal_mask_b16x = svmatch_u16(a_progress_b16x, a_u16x, b_u16x);
|
|
479
475
|
//! The `svsel_bf16` intrinsic is broken in many compilers, not returning the correct type.
|
|
480
476
|
//! So we reinterprete floats as integers and apply `svsel_s16`, but the `svreinterpret_s16_bs16`
|
|
481
477
|
//! and `svreinterpret_bf16_s16` are not always properly defined!
|
|
482
|
-
svint16_t b_equal_weights_s16x = svsel_s16(
|
|
478
|
+
svint16_t b_equal_weights_s16x = svsel_s16(equal_mask_b16x, svreinterpret_s16_bf16(b_weights_bf16x),
|
|
483
479
|
svdup_n_s16(0));
|
|
484
480
|
product_f32x = svbfdot_f32(product_f32x, a_weights_bf16x, svreinterpret_bf16_s16(b_equal_weights_s16x));
|
|
485
481
|
b_u16x = svext_u16(b_u16x, b_u16x, 8);
|
|
@@ -243,8 +243,8 @@ NK_PUBLIC void nk_sparse_dot_u32f32_turin( //
|
|
|
243
243
|
// Native VP2INTERSECTD works directly on u32 - no conversion needed!
|
|
244
244
|
nk_u32_t const *const a_end = a + a_length;
|
|
245
245
|
nk_u32_t const *const b_end = b + b_length;
|
|
246
|
-
__m512d
|
|
247
|
-
__m512d
|
|
246
|
+
__m512d product_low_f64x8 = _mm512_setzero_pd();
|
|
247
|
+
__m512d product_high_f64x8 = _mm512_setzero_pd();
|
|
248
248
|
nk_b512_vec_t a_vec, b_vec;
|
|
249
249
|
|
|
250
250
|
while (a + 16 <= a_end && b + 16 <= b_end) {
|
|
@@ -281,15 +281,15 @@ NK_PUBLIC void nk_sparse_dot_u32f32_turin( //
|
|
|
281
281
|
__m512 b_weights_f32x16 = _mm512_loadu_ps(b_weights);
|
|
282
282
|
__m512 a_matched_f32x16 = _mm512_maskz_compress_ps(a_matches, a_weights_f32x16);
|
|
283
283
|
__m512 b_matched_f32x16 = _mm512_maskz_compress_ps(b_matches, b_weights_f32x16);
|
|
284
|
-
__m256
|
|
285
|
-
__m256
|
|
286
|
-
__m256
|
|
287
|
-
__m256
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
284
|
+
__m256 a_matched_low_f32x8 = _mm512_castps512_ps256(a_matched_f32x16);
|
|
285
|
+
__m256 a_matched_high_f32x8 = _mm512_extractf32x8_ps(a_matched_f32x16, 1);
|
|
286
|
+
__m256 b_matched_low_f32x8 = _mm512_castps512_ps256(b_matched_f32x16);
|
|
287
|
+
__m256 b_matched_high_f32x8 = _mm512_extractf32x8_ps(b_matched_f32x16, 1);
|
|
288
|
+
|
|
289
|
+
product_low_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_low_f32x8),
|
|
290
|
+
_mm512_cvtps_pd(b_matched_low_f32x8), product_low_f64x8);
|
|
291
|
+
product_high_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(a_matched_high_f32x8),
|
|
292
|
+
_mm512_cvtps_pd(b_matched_high_f32x8), product_high_f64x8);
|
|
293
293
|
}
|
|
294
294
|
|
|
295
295
|
__m512i a_max_u32x16 = _mm512_set1_epi32(*(int const *)&a_max);
|
|
@@ -304,7 +304,7 @@ NK_PUBLIC void nk_sparse_dot_u32f32_turin( //
|
|
|
304
304
|
|
|
305
305
|
nk_f64_t tail_product = 0;
|
|
306
306
|
nk_sparse_dot_u32f32_serial(a, b, a_weights, b_weights, a_end - a, b_end - b, &tail_product);
|
|
307
|
-
*product = _mm512_reduce_add_pd(
|
|
307
|
+
*product = _mm512_reduce_add_pd(product_low_f64x8) + _mm512_reduce_add_pd(product_high_f64x8) + tail_product;
|
|
308
308
|
}
|
|
309
309
|
|
|
310
310
|
#if defined(__clang__)
|
package/include/numkong/sparse.h
CHANGED
|
@@ -57,22 +57,22 @@
|
|
|
57
57
|
* The Ice Lake kernels are shuffle/compare heavy; their throughput is often gated by port 5.
|
|
58
58
|
* On Genoa, many integer ops dual-issue on FP ports, often improving throughput despite higher latency.
|
|
59
59
|
*
|
|
60
|
-
* Intrinsic
|
|
61
|
-
* _mm512_shuffle_epi32
|
|
62
|
-
* _mm512_mask_cmpneq_epi32_mask
|
|
63
|
-
* _mm512_alignr_epi32
|
|
64
|
-
* _mm512_conflict_epi32
|
|
65
|
-
* _mm256_maskz_compress_epi16
|
|
66
|
-
* _mm256_dpwssds_epi32
|
|
67
|
-
* _mm256_dpbf16_ps
|
|
60
|
+
* Intrinsic Instruction Icelake Genoa
|
|
61
|
+
* _mm512_shuffle_epi32 VPSHUFD (ZMM, ZMM, I8) 1cy @ p5 1cy @ p123
|
|
62
|
+
* _mm512_mask_cmpneq_epi32_mask VPCMPD (K, ZMM, ZMM, I8) 3cy @ p5 5cy @ p01
|
|
63
|
+
* _mm512_alignr_epi32 VALIGND (ZMM, ZMM, ZMM, I8) 3cy @ p5 6cy @ p12
|
|
64
|
+
* _mm512_conflict_epi32 VPCONFLICTD (ZMM, ZMM) 26cy @ p0+p05+p5 7cy @ p01+p12
|
|
65
|
+
* _mm256_maskz_compress_epi16 VPCOMPRESSW (YMM, K, YMM) 3-6cy @ p5+p5 4-8cy @ p01+p12
|
|
66
|
+
* _mm256_dpwssds_epi32 VPDPWSSDS (YMM, K, YMM, YMM) 4-5cy @ p01 4cy @ p01
|
|
67
|
+
* _mm256_dpbf16_ps VDPBF16PS (YMM, YMM, YMM) n/a 6cy @ p01
|
|
68
68
|
*
|
|
69
69
|
* VP2INTERSECTD is unsupported on Ice Lake and not yet covered by uops.info for Zen5/Turin.
|
|
70
|
-
* Tiger Lake measures ~36-
|
|
70
|
+
* Tiger Lake measures ~36-41cy @ p5 for ZMM variants, which is why we always avoid it on Intel.
|
|
71
71
|
*
|
|
72
72
|
* @section references References
|
|
73
73
|
*
|
|
74
74
|
* - uops.info: https://uops.info/
|
|
75
|
-
* - Intel Intrinsics Guide: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
75
|
+
* - Intel Intrinsics Guide: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
|
|
76
76
|
* - Arm Intrinsics Reference: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
77
77
|
* - vp2intersect experiments: https://github.com/mozonaut/vp2intersect
|
|
78
78
|
* - Diez-Canas "Faster-Than-Native Alternatives for x86 VP2INTERSECT Instructions":
|