numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -25,7 +25,7 @@ extern "C" {
|
|
|
25
25
|
#endif
|
|
26
26
|
|
|
27
27
|
#if defined(__clang__)
|
|
28
|
-
#pragma clang attribute push(__attribute__((target("sme2
|
|
28
|
+
#pragma clang attribute push(__attribute__((target("sme2"))), apply_to = function)
|
|
29
29
|
#elif defined(__GNUC__)
|
|
30
30
|
#pragma GCC push_options
|
|
31
31
|
#pragma GCC target("+sme2")
|
|
@@ -50,28 +50,32 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
|
|
|
50
50
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
51
51
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
52
52
|
nk_size_t const tile_elements = tile_dim * depth_tile_size;
|
|
53
|
-
|
|
53
|
+
// BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
|
|
54
|
+
// handles one u32 (32 bits) across all row×column pairs simultaneously.
|
|
55
|
+
nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
|
|
56
|
+
nk_size_t const depth_bytes = depth_bits / 8;
|
|
54
57
|
|
|
55
58
|
nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
|
|
56
59
|
nk_u32_t const *b_norms = header->norms_offset ? (nk_u32_t const *)((char const *)b_packed + header->norms_offset)
|
|
57
60
|
: (nk_u32_t const *)0;
|
|
58
61
|
|
|
59
|
-
svbool_t const
|
|
60
|
-
|
|
61
|
-
|
|
62
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
63
|
+
// Use padded depth (depth_words * 32) for BMOPA: zero-padded bits always match in XNOR,
|
|
64
|
+
// so the effective depth for the matching→intersection conversion is the rounded-up bit count.
|
|
65
|
+
svuint32_t const depth_u32x = svdup_u32((nk_u32_t)(depth_words * 32));
|
|
62
66
|
nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
|
|
63
67
|
|
|
64
68
|
for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
|
|
65
69
|
nk_size_t const row_start_a = row_tile_a * tile_dim;
|
|
66
70
|
nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
|
|
67
71
|
: (row_count_a - row_start_a);
|
|
68
|
-
svbool_t const
|
|
72
|
+
svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_a_remaining);
|
|
69
73
|
|
|
70
74
|
// Compute A row popcounts for this tile
|
|
71
75
|
nk_u32_t a_popcounts[16];
|
|
72
76
|
for (nk_size_t r = 0; r < rows_a_remaining; r++) {
|
|
73
77
|
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)a + (row_start_a + r) * a_stride_in_bytes);
|
|
74
|
-
a_popcounts[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row,
|
|
78
|
+
a_popcounts[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_bytes);
|
|
75
79
|
}
|
|
76
80
|
|
|
77
81
|
// Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
|
|
@@ -81,21 +85,21 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
|
|
|
81
85
|
|
|
82
86
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
83
87
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
84
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
88
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
85
89
|
? depth_tile_size
|
|
86
|
-
: (
|
|
87
|
-
: 0);
|
|
90
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
88
91
|
if (u32s_this_tile == 0) break;
|
|
89
92
|
|
|
90
93
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
91
94
|
|
|
92
|
-
svbool_t const
|
|
95
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
93
96
|
|
|
97
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
94
98
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
+
nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
|
|
100
|
+
d_start_u32 * 4;
|
|
101
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
102
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
99
103
|
}
|
|
100
104
|
|
|
101
105
|
nk_u32_t const *b_tile0 = b_tiles + ((row_tile_b + 0) * depth_tile_count + d_tile) * tile_elements;
|
|
@@ -103,47 +107,47 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
|
|
|
103
107
|
nk_u32_t const *b_tile2 = b_tiles + ((row_tile_b + 2) * depth_tile_count + d_tile) * tile_elements;
|
|
104
108
|
|
|
105
109
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
106
|
-
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
107
|
-
|
|
108
|
-
svbmopa_za32_u32_m(1,
|
|
109
|
-
svld1_u32(
|
|
110
|
-
svbmopa_za32_u32_m(2,
|
|
111
|
-
svld1_u32(
|
|
112
|
-
svbmopa_za32_u32_m(3,
|
|
113
|
-
svld1_u32(
|
|
110
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
|
|
111
|
+
|
|
112
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
113
|
+
svld1_u32(predicate_all_b32x, b_tile0 + step * tile_dim));
|
|
114
|
+
svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
115
|
+
svld1_u32(predicate_all_b32x, b_tile1 + step * tile_dim));
|
|
116
|
+
svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
117
|
+
svld1_u32(predicate_all_b32x, b_tile2 + step * tile_dim));
|
|
114
118
|
}
|
|
115
119
|
}
|
|
116
120
|
|
|
117
121
|
// Extract: dot = (pop_a + pop_b - depth + matching) / 2
|
|
118
122
|
// matching = ZA[i][j]
|
|
119
|
-
svuint32_t b_pop0_u32x = svld1_u32(
|
|
120
|
-
svuint32_t b_pop1_u32x = svld1_u32(
|
|
121
|
-
svuint32_t b_pop2_u32x = svld1_u32(
|
|
123
|
+
svuint32_t b_pop0_u32x = svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 0) * tile_dim);
|
|
124
|
+
svuint32_t b_pop1_u32x = svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 1) * tile_dim);
|
|
125
|
+
svuint32_t b_pop2_u32x = svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 2) * tile_dim);
|
|
122
126
|
|
|
123
127
|
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
124
128
|
nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
125
129
|
svuint32_t pop_a_u32x = svdup_u32(a_popcounts[row]);
|
|
126
130
|
|
|
127
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
128
|
-
svuint32_t sum_pops0_u32x = svadd_u32_x(
|
|
131
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
132
|
+
svuint32_t sum_pops0_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_pop0_u32x);
|
|
129
133
|
svuint32_t numerator0_u32x = svadd_u32_x(
|
|
130
|
-
|
|
131
|
-
svst1_u32(
|
|
132
|
-
svlsr_n_u32_x(
|
|
134
|
+
predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops0_u32x, depth_u32x), za1_u32x);
|
|
135
|
+
svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 0) * tile_dim,
|
|
136
|
+
svlsr_n_u32_x(predicate_all_b32x, numerator0_u32x, 1));
|
|
133
137
|
|
|
134
|
-
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
135
|
-
svuint32_t sum_pops1_u32x = svadd_u32_x(
|
|
138
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
|
|
139
|
+
svuint32_t sum_pops1_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_pop1_u32x);
|
|
136
140
|
svuint32_t numerator1_u32x = svadd_u32_x(
|
|
137
|
-
|
|
138
|
-
svst1_u32(
|
|
139
|
-
svlsr_n_u32_x(
|
|
141
|
+
predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops1_u32x, depth_u32x), za2_u32x);
|
|
142
|
+
svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 1) * tile_dim,
|
|
143
|
+
svlsr_n_u32_x(predicate_all_b32x, numerator1_u32x, 1));
|
|
140
144
|
|
|
141
|
-
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
142
|
-
svuint32_t sum_pops2_u32x = svadd_u32_x(
|
|
145
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
|
|
146
|
+
svuint32_t sum_pops2_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_pop2_u32x);
|
|
143
147
|
svuint32_t numerator2_u32x = svadd_u32_x(
|
|
144
|
-
|
|
145
|
-
svst1_u32(
|
|
146
|
-
svlsr_n_u32_x(
|
|
148
|
+
predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops2_u32x, depth_u32x), za3_u32x);
|
|
149
|
+
svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 2) * tile_dim,
|
|
150
|
+
svlsr_n_u32_x(predicate_all_b32x, numerator2_u32x, 1));
|
|
147
151
|
}
|
|
148
152
|
}
|
|
149
153
|
|
|
@@ -152,49 +156,49 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
|
|
|
152
156
|
nk_size_t const row_start_b = row_tile_b * tile_dim;
|
|
153
157
|
nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
|
|
154
158
|
: (row_count_b - row_start_b);
|
|
155
|
-
svbool_t const
|
|
159
|
+
svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, rows_b_remaining);
|
|
156
160
|
|
|
157
161
|
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
158
162
|
|
|
159
163
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
160
164
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
161
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
165
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
162
166
|
? depth_tile_size
|
|
163
|
-
: (
|
|
164
|
-
: 0);
|
|
167
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
165
168
|
if (u32s_this_tile == 0) break;
|
|
166
169
|
|
|
167
170
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
168
171
|
|
|
169
|
-
svbool_t const
|
|
172
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
170
173
|
|
|
174
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
171
175
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
+
nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
|
|
177
|
+
d_start_u32 * 4;
|
|
178
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
179
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
176
180
|
}
|
|
177
181
|
|
|
178
182
|
nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
|
|
179
183
|
|
|
180
184
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
181
|
-
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
182
|
-
svuint32_t b_u32x = svld1_u32(
|
|
183
|
-
svbmopa_za32_u32_m(1,
|
|
185
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
|
|
186
|
+
svuint32_t b_u32x = svld1_u32(predicate_all_b32x, b_tile + step * tile_dim);
|
|
187
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_column_u32x, b_u32x);
|
|
184
188
|
}
|
|
185
189
|
}
|
|
186
190
|
|
|
187
191
|
// Extract: dot = (pop_a + pop_b - depth + matching) / 2
|
|
188
|
-
svuint32_t b_pop_u32x = svld1_u32(
|
|
192
|
+
svuint32_t b_pop_u32x = svld1_u32(predicate_all_b32x, b_norms + row_start_b);
|
|
189
193
|
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
190
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
194
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
191
195
|
svuint32_t pop_a_u32x = svdup_u32(a_popcounts[row]);
|
|
192
|
-
svuint32_t sum_pops_u32x = svadd_u32_x(
|
|
196
|
+
svuint32_t sum_pops_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_pop_u32x);
|
|
193
197
|
svuint32_t numerator_u32x = svadd_u32_x(
|
|
194
|
-
|
|
198
|
+
predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops_u32x, depth_u32x), za1_u32x);
|
|
195
199
|
nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
196
|
-
svst1_u32(
|
|
197
|
-
svlsr_n_u32_x(
|
|
200
|
+
svst1_u32(column_predicate_b32x, c_row + row_start_b,
|
|
201
|
+
svlsr_n_u32_x(predicate_all_b32x, numerator_u32x, 1));
|
|
198
202
|
}
|
|
199
203
|
}
|
|
200
204
|
}
|
|
@@ -212,39 +216,46 @@ NK_PUBLIC void nk_dots_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packe
|
|
|
212
216
|
* Same ZA transpose pattern as hammings_symmetric, but with dot extraction.
|
|
213
217
|
*/
|
|
214
218
|
__arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32_streaming_(
|
|
215
|
-
nk_u1x8_t const *vectors, nk_size_t
|
|
216
|
-
nk_size_t
|
|
219
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
220
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
217
221
|
|
|
218
222
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
219
223
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
nk_size_t const
|
|
224
|
+
// BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
|
|
225
|
+
// handles one u32 (32 bits) across all row×column pairs simultaneously.
|
|
226
|
+
nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
|
|
227
|
+
nk_size_t const depth_bytes = depth_bits / 8;
|
|
228
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_words, depth_tile_size);
|
|
223
229
|
|
|
224
|
-
svbool_t const
|
|
225
|
-
|
|
230
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
231
|
+
// Use padded depth (depth_words * 32) for BMOPA: zero-padded bits always match in XNOR,
|
|
232
|
+
// so the effective depth for the matching→intersection conversion is the rounded-up bit count.
|
|
233
|
+
svuint32_t const depth_u32x = svdup_u32((nk_u32_t)(depth_words * 32));
|
|
226
234
|
|
|
227
235
|
NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
|
|
228
236
|
|
|
229
237
|
nk_size_t const row_end = row_start + row_count;
|
|
230
|
-
nk_size_t const column_tile_count = nk_size_divide_round_up_(
|
|
238
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dim);
|
|
231
239
|
|
|
232
|
-
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start <
|
|
240
|
+
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
|
|
233
241
|
row_tile_start += tile_dim) {
|
|
234
242
|
nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
|
|
235
|
-
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <=
|
|
236
|
-
|
|
237
|
-
|
|
243
|
+
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= vectors_count)
|
|
244
|
+
? rows_remaining
|
|
245
|
+
: (vectors_count - row_tile_start);
|
|
246
|
+
svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_clamped);
|
|
238
247
|
|
|
239
248
|
// Compute A tile popcounts
|
|
240
249
|
NK_ALIGN64 nk_u32_t a_tile_pops[16];
|
|
241
250
|
for (nk_size_t r = 0; r < rows_clamped; r++) {
|
|
242
|
-
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors +
|
|
243
|
-
|
|
251
|
+
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors +
|
|
252
|
+
(row_tile_start + r) * stride_in_bytes);
|
|
253
|
+
a_tile_pops[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_bytes);
|
|
244
254
|
}
|
|
245
255
|
for (nk_size_t r = rows_clamped; r < tile_dim; r++) a_tile_pops[r] = 0;
|
|
246
256
|
|
|
247
|
-
|
|
257
|
+
// Upper triangle: start from this row tile's column
|
|
258
|
+
nk_size_t column_tile_index = row_tile_start / tile_dim;
|
|
248
259
|
|
|
249
260
|
// Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
|
|
250
261
|
for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
|
|
@@ -252,73 +263,73 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
|
|
|
252
263
|
|
|
253
264
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
254
265
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
255
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
266
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
256
267
|
? depth_tile_size
|
|
257
|
-
: (
|
|
258
|
-
: 0);
|
|
268
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
259
269
|
if (u32s_this_tile == 0) break;
|
|
260
270
|
|
|
261
271
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
262
|
-
svbool_t const
|
|
272
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
263
273
|
|
|
274
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
264
275
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
276
|
+
nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
|
|
277
|
+
d_start_u32 * 4;
|
|
278
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
279
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
269
280
|
}
|
|
270
281
|
|
|
271
282
|
// Save A columns
|
|
272
283
|
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
273
|
-
svst1_u32(
|
|
274
|
-
svread_ver_za32_u32_m(svdup_u32(0),
|
|
284
|
+
svst1_u32(predicate_all_b32x, a_buffer[s],
|
|
285
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
|
|
275
286
|
|
|
276
287
|
// B column tile 0
|
|
277
288
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
278
289
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
279
290
|
nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
|
|
280
|
-
if (col_abs <
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
291
|
+
if (col_abs < vectors_count) {
|
|
292
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
293
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
294
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
284
295
|
}
|
|
285
296
|
}
|
|
286
297
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
287
|
-
svuint32_t a_u32x = svld1_u32(
|
|
288
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
289
|
-
svbmopa_za32_u32_m(1,
|
|
298
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
299
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
300
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
290
301
|
}
|
|
291
302
|
|
|
292
303
|
// B column tile 1
|
|
293
304
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
294
305
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
295
306
|
nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
|
|
296
|
-
if (col_abs <
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
307
|
+
if (col_abs < vectors_count) {
|
|
308
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
309
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
310
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
300
311
|
}
|
|
301
312
|
}
|
|
302
313
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
303
|
-
svuint32_t a_u32x = svld1_u32(
|
|
304
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
305
|
-
svbmopa_za32_u32_m(2,
|
|
314
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
315
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
316
|
+
svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
306
317
|
}
|
|
307
318
|
|
|
308
319
|
// B column tile 2
|
|
309
320
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
310
321
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
311
322
|
nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
|
|
312
|
-
if (col_abs <
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
323
|
+
if (col_abs < vectors_count) {
|
|
324
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
325
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
326
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
316
327
|
}
|
|
317
328
|
}
|
|
318
329
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
319
|
-
svuint32_t a_u32x = svld1_u32(
|
|
320
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
321
|
-
svbmopa_za32_u32_m(3,
|
|
330
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
331
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
332
|
+
svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
322
333
|
}
|
|
323
334
|
}
|
|
324
335
|
|
|
@@ -328,88 +339,89 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
|
|
|
328
339
|
for (nk_size_t t = 0; t < 3; t++) {
|
|
329
340
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
330
341
|
nk_size_t const col_abs = (column_tile_index + t) * tile_dim + col;
|
|
331
|
-
if (col_abs <
|
|
332
|
-
nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs *
|
|
333
|
-
b_pops[t][col] = nk_sets_reduce_sumsq_u1_streaming_(b_row,
|
|
342
|
+
if (col_abs < vectors_count) {
|
|
343
|
+
nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs * stride_in_bytes);
|
|
344
|
+
b_pops[t][col] = nk_sets_reduce_sumsq_u1_streaming_(b_row, depth_bytes);
|
|
334
345
|
}
|
|
335
346
|
else { b_pops[t][col] = 0; }
|
|
336
347
|
}
|
|
337
348
|
}
|
|
338
349
|
|
|
339
350
|
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
340
|
-
nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) *
|
|
351
|
+
nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
|
|
341
352
|
svuint32_t pop_a_u32x = svdup_u32(a_tile_pops[row]);
|
|
342
353
|
|
|
343
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
344
|
-
svuint32_t b_popcount_0_u32x = svld1_u32(
|
|
345
|
-
svuint32_t sum_pops0_u32x = svadd_u32_x(
|
|
354
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
355
|
+
svuint32_t b_popcount_0_u32x = svld1_u32(predicate_all_b32x, b_pops[0]);
|
|
356
|
+
svuint32_t sum_pops0_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_popcount_0_u32x);
|
|
346
357
|
svuint32_t numerator0_u32x = svadd_u32_x(
|
|
347
|
-
|
|
348
|
-
svst1_u32(
|
|
349
|
-
svlsr_n_u32_x(
|
|
358
|
+
predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops0_u32x, depth_u32x), za1_u32x);
|
|
359
|
+
svst1_u32(predicate_all_b32x, result_row + (column_tile_index + 0) * tile_dim,
|
|
360
|
+
svlsr_n_u32_x(predicate_all_b32x, numerator0_u32x, 1));
|
|
350
361
|
|
|
351
|
-
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
352
|
-
svuint32_t b_popcount_1_u32x = svld1_u32(
|
|
353
|
-
svuint32_t sum_pops1_u32x = svadd_u32_x(
|
|
362
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
|
|
363
|
+
svuint32_t b_popcount_1_u32x = svld1_u32(predicate_all_b32x, b_pops[1]);
|
|
364
|
+
svuint32_t sum_pops1_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_popcount_1_u32x);
|
|
354
365
|
svuint32_t numerator1_u32x = svadd_u32_x(
|
|
355
|
-
|
|
356
|
-
svst1_u32(
|
|
357
|
-
svlsr_n_u32_x(
|
|
366
|
+
predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops1_u32x, depth_u32x), za2_u32x);
|
|
367
|
+
svst1_u32(predicate_all_b32x, result_row + (column_tile_index + 1) * tile_dim,
|
|
368
|
+
svlsr_n_u32_x(predicate_all_b32x, numerator1_u32x, 1));
|
|
358
369
|
|
|
359
|
-
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
360
|
-
svuint32_t b_popcount_2_u32x = svld1_u32(
|
|
361
|
-
svuint32_t sum_pops2_u32x = svadd_u32_x(
|
|
370
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
|
|
371
|
+
svuint32_t b_popcount_2_u32x = svld1_u32(predicate_all_b32x, b_pops[2]);
|
|
372
|
+
svuint32_t sum_pops2_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_popcount_2_u32x);
|
|
362
373
|
svuint32_t numerator2_u32x = svadd_u32_x(
|
|
363
|
-
|
|
364
|
-
svst1_u32(
|
|
365
|
-
svlsr_n_u32_x(
|
|
374
|
+
predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops2_u32x, depth_u32x), za3_u32x);
|
|
375
|
+
svst1_u32(predicate_all_b32x, result_row + (column_tile_index + 2) * tile_dim,
|
|
376
|
+
svlsr_n_u32_x(predicate_all_b32x, numerator2_u32x, 1));
|
|
366
377
|
}
|
|
367
378
|
}
|
|
368
379
|
|
|
369
380
|
// Remainder: 1 column tile at a time using ZA1
|
|
370
381
|
for (; column_tile_index < column_tile_count; column_tile_index++) {
|
|
371
382
|
nk_size_t const col_tile_start = column_tile_index * tile_dim;
|
|
372
|
-
nk_size_t const cols_remaining = (col_tile_start + tile_dim <=
|
|
373
|
-
|
|
374
|
-
|
|
383
|
+
nk_size_t const cols_remaining = (col_tile_start + tile_dim <= vectors_count)
|
|
384
|
+
? tile_dim
|
|
385
|
+
: (vectors_count - col_tile_start);
|
|
386
|
+
svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, cols_remaining);
|
|
375
387
|
|
|
376
388
|
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
377
389
|
|
|
378
390
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
379
391
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
380
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
392
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
381
393
|
? depth_tile_size
|
|
382
|
-
: (
|
|
383
|
-
: 0);
|
|
394
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
384
395
|
if (u32s_this_tile == 0) break;
|
|
385
396
|
|
|
386
397
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
387
|
-
svbool_t const
|
|
398
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
388
399
|
|
|
400
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
389
401
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
402
|
+
nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
|
|
403
|
+
d_start_u32 * 4;
|
|
404
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
405
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
394
406
|
}
|
|
395
407
|
|
|
396
408
|
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
397
|
-
svst1_u32(
|
|
398
|
-
svread_ver_za32_u32_m(svdup_u32(0),
|
|
409
|
+
svst1_u32(predicate_all_b32x, a_buffer[s],
|
|
410
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
|
|
399
411
|
|
|
400
412
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
401
413
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
402
414
|
nk_size_t const col_abs = col_tile_start + col;
|
|
403
|
-
if (col_abs <
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
415
|
+
if (col_abs < vectors_count) {
|
|
416
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
417
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
418
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
407
419
|
}
|
|
408
420
|
}
|
|
409
421
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
410
|
-
svuint32_t a_u32x = svld1_u32(
|
|
411
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
412
|
-
svbmopa_za32_u32_m(1,
|
|
422
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
423
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_b32x, 0, step);
|
|
424
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_u32x, b_u32x);
|
|
413
425
|
}
|
|
414
426
|
}
|
|
415
427
|
|
|
@@ -417,33 +429,34 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
|
|
|
417
429
|
NK_ALIGN64 nk_u32_t b_pops_r[16];
|
|
418
430
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
419
431
|
nk_size_t const col_abs = col_tile_start + col;
|
|
420
|
-
if (col_abs <
|
|
421
|
-
nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs *
|
|
422
|
-
b_pops_r[col] = nk_sets_reduce_sumsq_u1_streaming_(b_row,
|
|
432
|
+
if (col_abs < vectors_count) {
|
|
433
|
+
nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs * stride_in_bytes);
|
|
434
|
+
b_pops_r[col] = nk_sets_reduce_sumsq_u1_streaming_(b_row, depth_bytes);
|
|
423
435
|
}
|
|
424
436
|
else { b_pops_r[col] = 0; }
|
|
425
437
|
}
|
|
426
438
|
|
|
427
439
|
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
428
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
440
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
429
441
|
svuint32_t pop_a_u32x = svdup_u32(a_tile_pops[row]);
|
|
430
|
-
svuint32_t b_popcount_u32x = svld1_u32(
|
|
431
|
-
svuint32_t sum_pops_u32x = svadd_u32_x(
|
|
442
|
+
svuint32_t b_popcount_u32x = svld1_u32(predicate_all_b32x, b_pops_r);
|
|
443
|
+
svuint32_t sum_pops_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_popcount_u32x);
|
|
432
444
|
svuint32_t numerator_u32x = svadd_u32_x(
|
|
433
|
-
|
|
434
|
-
nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) *
|
|
435
|
-
svst1_u32(
|
|
436
|
-
svlsr_n_u32_x(
|
|
445
|
+
predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops_u32x, depth_u32x), za1_u32x);
|
|
446
|
+
nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
|
|
447
|
+
svst1_u32(column_predicate_b32x, result_row + col_tile_start,
|
|
448
|
+
svlsr_n_u32_x(predicate_all_b32x, numerator_u32x, 1));
|
|
437
449
|
}
|
|
438
450
|
}
|
|
439
451
|
}
|
|
440
452
|
}
|
|
441
453
|
|
|
442
|
-
NK_PUBLIC void nk_dots_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t
|
|
443
|
-
nk_size_t
|
|
444
|
-
nk_size_t
|
|
445
|
-
|
|
446
|
-
|
|
454
|
+
NK_PUBLIC void nk_dots_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
|
|
455
|
+
nk_size_t stride_in_bytes, nk_u32_t *result,
|
|
456
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start,
|
|
457
|
+
nk_size_t row_count) {
|
|
458
|
+
nk_dots_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
459
|
+
result_stride_in_bytes, row_start, row_count);
|
|
447
460
|
}
|
|
448
461
|
|
|
449
462
|
#if defined(__clang__)
|