numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -65,7 +65,7 @@ extern "C" {
|
|
|
65
65
|
*/
|
|
66
66
|
|
|
67
67
|
#if defined(__clang__)
|
|
68
|
-
#pragma clang attribute push(__attribute__((target("sme2
|
|
68
|
+
#pragma clang attribute push(__attribute__((target("sme2"))), apply_to = function)
|
|
69
69
|
#elif defined(__GNUC__)
|
|
70
70
|
#pragma GCC push_options
|
|
71
71
|
#pragma GCC target("+sme2")
|
|
@@ -93,13 +93,12 @@ typedef struct {
|
|
|
93
93
|
|
|
94
94
|
/** Count total set bits across a byte vector using streaming SVE.
|
|
95
95
|
* Accumulates per-byte popcounts into u32 lanes via svdot; single horizontal reduction at end. */
|
|
96
|
-
NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data,
|
|
97
|
-
nk_size_t n_bytes) NK_STREAMING_COMPATIBLE_ {
|
|
96
|
+
NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data, nk_size_t n_bytes) NK_STREAMING_ {
|
|
98
97
|
svuint32_t acc_u32x = svdup_u32(0);
|
|
99
98
|
svuint8_t const ones_u8x = svdup_u8(1);
|
|
100
99
|
for (nk_size_t offset = 0; offset < n_bytes; offset += svcntb()) {
|
|
101
|
-
svbool_t
|
|
102
|
-
acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(
|
|
100
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(offset, n_bytes);
|
|
101
|
+
acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_b8x, svld1_u8(predicate_b8x, data + offset)), ones_u8x);
|
|
103
102
|
}
|
|
104
103
|
return (nk_u32_t)svaddv_u32(svptrue_b32(), acc_u32x);
|
|
105
104
|
}
|
|
@@ -128,11 +127,13 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
|
|
|
128
127
|
nk_size_t const tile_dim = nk_smebi32_tile_dim_(); // 16 rows per tile
|
|
129
128
|
nk_size_t const depth_tile_size = nk_smebi32_tile_dim_(); // 16 u32 per depth tile
|
|
130
129
|
nk_size_t const tile_elements = tile_dim * depth_tile_size;
|
|
131
|
-
nk_size_t const
|
|
130
|
+
nk_size_t const depth_bytes = depth_bits / 8;
|
|
132
131
|
|
|
133
|
-
|
|
132
|
+
// BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
|
|
133
|
+
// handles one u32 (32 bits) across all row×column pairs simultaneously.
|
|
134
|
+
nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
|
|
134
135
|
nk_size_t const row_tile_count = nk_size_divide_round_up_(row_count, tile_dim);
|
|
135
|
-
nk_size_t const depth_tile_count = nk_size_divide_round_up_(
|
|
136
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_words, depth_tile_size);
|
|
136
137
|
nk_size_t const total_tiles = row_tile_count * depth_tile_count;
|
|
137
138
|
nk_size_t const data_size = total_tiles * tile_elements * sizeof(nk_u32_t);
|
|
138
139
|
|
|
@@ -160,18 +161,24 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
|
|
|
160
161
|
nk_size_t const src_u32_start = depth_tile * depth_tile_size;
|
|
161
162
|
nk_size_t const rows_to_pack = (src_row_start + tile_dim <= row_count) ? tile_dim
|
|
162
163
|
: (row_count - src_row_start);
|
|
163
|
-
nk_size_t const u32s_to_pack = (src_u32_start + depth_tile_size <=
|
|
164
|
+
nk_size_t const u32s_to_pack = (src_u32_start + depth_tile_size <= depth_words)
|
|
164
165
|
? depth_tile_size
|
|
165
|
-
: (
|
|
166
|
-
: 0);
|
|
166
|
+
: (depth_words > src_u32_start ? depth_words - src_u32_start : 0);
|
|
167
167
|
|
|
168
168
|
// Column-major packing: tile_output[col * tile_dim + row]
|
|
169
|
+
// Copy byte-by-byte for the last u32 to avoid garbage bits when depth_bits % 32 != 0
|
|
170
|
+
nk_size_t const tail_bytes = depth_bytes % 4;
|
|
171
|
+
nk_size_t const last_col = u32s_to_pack > 0 ? u32s_to_pack - 1 : 0;
|
|
172
|
+
nk_size_t const is_last_depth_tile = (src_u32_start + u32s_to_pack >= depth_words);
|
|
169
173
|
for (nk_size_t row = 0; row < rows_to_pack; row++) {
|
|
170
174
|
nk_u32_t const *src_row = (nk_u32_t const *)((char const *)b +
|
|
171
175
|
(src_row_start + row) * b_stride_in_bytes);
|
|
172
176
|
for (nk_size_t col = 0; col < u32s_to_pack; col++) {
|
|
173
177
|
nk_size_t const dst_idx = col * tile_dim + row; // Column-major!
|
|
174
|
-
|
|
178
|
+
if (tail_bytes && is_last_depth_tile && col == last_col) {
|
|
179
|
+
nk_copy_bytes_(&tile_output[dst_idx], &src_row[src_u32_start + col], tail_bytes);
|
|
180
|
+
}
|
|
181
|
+
else { tile_output[dst_idx] = src_row[src_u32_start + col]; }
|
|
175
182
|
}
|
|
176
183
|
}
|
|
177
184
|
}
|
|
@@ -182,7 +189,7 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
|
|
|
182
189
|
nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
|
|
183
190
|
{
|
|
184
191
|
nk_u64_t nk_local_sum_, nk_local_sumsq_;
|
|
185
|
-
nk_reduce_moments_u1(src_row,
|
|
192
|
+
nk_reduce_moments_u1(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
|
|
186
193
|
norms_ptr[row] = (nk_u32_t)nk_local_sum_;
|
|
187
194
|
}
|
|
188
195
|
}
|
|
@@ -207,19 +214,24 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
|
|
|
207
214
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
208
215
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
209
216
|
nk_size_t const tile_elements = tile_dim * depth_tile_size;
|
|
210
|
-
|
|
217
|
+
// BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
|
|
218
|
+
// handles one u32 (32 bits) across all row×column pairs simultaneously.
|
|
219
|
+
nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
|
|
220
|
+
nk_size_t const depth_bytes = depth_bits / 8;
|
|
211
221
|
|
|
212
222
|
nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
|
|
213
223
|
|
|
214
|
-
svbool_t const
|
|
215
|
-
|
|
224
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
225
|
+
// Use padded depth (depth_words * 32) for BMOPA: zero-padded bits always match in XNOR,
|
|
226
|
+
// so the effective depth for the matching→hamming conversion is the rounded-up bit count.
|
|
227
|
+
svuint32_t const depth_u32x = svdup_u32((nk_u32_t)(depth_words * 32));
|
|
216
228
|
nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
|
|
217
229
|
|
|
218
230
|
for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
|
|
219
231
|
nk_size_t const row_start_a = row_tile_a * tile_dim;
|
|
220
232
|
nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
|
|
221
233
|
: (row_count_a - row_start_a);
|
|
222
|
-
svbool_t const
|
|
234
|
+
svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_a_remaining);
|
|
223
235
|
|
|
224
236
|
// Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
|
|
225
237
|
nk_size_t row_tile_b = 0;
|
|
@@ -228,22 +240,23 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
|
|
|
228
240
|
|
|
229
241
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
230
242
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
231
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
243
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
232
244
|
? depth_tile_size
|
|
233
|
-
: (
|
|
234
|
-
: 0);
|
|
245
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
235
246
|
if (u32s_this_tile == 0) break;
|
|
236
247
|
|
|
237
248
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
238
249
|
|
|
239
|
-
svbool_t const
|
|
250
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
251
|
+
|
|
252
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
240
253
|
|
|
241
|
-
// Load A rows into ZA0.S
|
|
254
|
+
// Load A rows into ZA0.S, byte-predicated to zero garbage bits
|
|
242
255
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
256
|
+
nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
|
|
257
|
+
d_start_u32 * 4;
|
|
258
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
259
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
247
260
|
}
|
|
248
261
|
|
|
249
262
|
// B tile pointers for 3 column tiles
|
|
@@ -253,14 +266,14 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
|
|
|
253
266
|
|
|
254
267
|
// Vertical read + BMOPA for each depth step
|
|
255
268
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
256
|
-
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
257
|
-
|
|
258
|
-
svbmopa_za32_u32_m(1,
|
|
259
|
-
svld1_u32(
|
|
260
|
-
svbmopa_za32_u32_m(2,
|
|
261
|
-
svld1_u32(
|
|
262
|
-
svbmopa_za32_u32_m(3,
|
|
263
|
-
svld1_u32(
|
|
269
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
|
|
270
|
+
|
|
271
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
272
|
+
svld1_u32(predicate_all_b32x, b_tile0 + step * tile_dim));
|
|
273
|
+
svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
274
|
+
svld1_u32(predicate_all_b32x, b_tile1 + step * tile_dim));
|
|
275
|
+
svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
276
|
+
svld1_u32(predicate_all_b32x, b_tile2 + step * tile_dim));
|
|
264
277
|
}
|
|
265
278
|
}
|
|
266
279
|
|
|
@@ -268,16 +281,16 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
|
|
|
268
281
|
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
269
282
|
nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
270
283
|
|
|
271
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
272
|
-
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
273
|
-
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
284
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
285
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
|
|
286
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
|
|
274
287
|
|
|
275
|
-
svst1_u32(
|
|
276
|
-
svsub_u32_x(
|
|
277
|
-
svst1_u32(
|
|
278
|
-
svsub_u32_x(
|
|
279
|
-
svst1_u32(
|
|
280
|
-
svsub_u32_x(
|
|
288
|
+
svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 0) * tile_dim,
|
|
289
|
+
svsub_u32_x(predicate_all_b32x, depth_u32x, za1_u32x));
|
|
290
|
+
svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 1) * tile_dim,
|
|
291
|
+
svsub_u32_x(predicate_all_b32x, depth_u32x, za2_u32x));
|
|
292
|
+
svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 2) * tile_dim,
|
|
293
|
+
svsub_u32_x(predicate_all_b32x, depth_u32x, za3_u32x));
|
|
281
294
|
}
|
|
282
295
|
}
|
|
283
296
|
|
|
@@ -286,46 +299,46 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
|
|
|
286
299
|
nk_size_t const row_start_b = row_tile_b * tile_dim;
|
|
287
300
|
nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
|
|
288
301
|
: (row_count_b - row_start_b);
|
|
289
|
-
svbool_t const
|
|
302
|
+
svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, rows_b_remaining);
|
|
290
303
|
|
|
291
304
|
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
292
305
|
|
|
293
306
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
294
307
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
295
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
308
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
296
309
|
? depth_tile_size
|
|
297
|
-
: (
|
|
298
|
-
: 0);
|
|
310
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
299
311
|
if (u32s_this_tile == 0) break;
|
|
300
312
|
|
|
301
313
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
302
314
|
|
|
303
|
-
svbool_t const
|
|
315
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
304
316
|
|
|
305
317
|
// Load A rows into ZA0.S horizontally
|
|
318
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
306
319
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
320
|
+
nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
|
|
321
|
+
d_start_u32 * 4;
|
|
322
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
323
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
311
324
|
}
|
|
312
325
|
|
|
313
326
|
nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
|
|
314
327
|
|
|
315
328
|
// Vertical read + BMOPA
|
|
316
329
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
317
|
-
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
318
|
-
svuint32_t b_u32x = svld1_u32(
|
|
319
|
-
svbmopa_za32_u32_m(1,
|
|
330
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
|
|
331
|
+
svuint32_t b_u32x = svld1_u32(predicate_all_b32x, b_tile + step * tile_dim);
|
|
332
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_column_u32x, b_u32x);
|
|
320
333
|
}
|
|
321
334
|
}
|
|
322
335
|
|
|
323
336
|
// Extract from ZA1: Hamming = depth_bits - matching_bits
|
|
324
337
|
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
325
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
326
|
-
svuint32_t hamming_u32x = svsub_u32_x(
|
|
338
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
339
|
+
svuint32_t hamming_u32x = svsub_u32_x(predicate_all_b32x, depth_u32x, za1_u32x);
|
|
327
340
|
nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
328
|
-
svst1_u32(
|
|
341
|
+
svst1_u32(column_predicate_b32x, c_row + row_start_b, hamming_u32x);
|
|
329
342
|
}
|
|
330
343
|
}
|
|
331
344
|
}
|
|
@@ -345,30 +358,37 @@ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
|
|
|
345
358
|
* Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
|
|
346
359
|
*/
|
|
347
360
|
__arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_(
|
|
348
|
-
nk_u1x8_t const *vectors, nk_size_t
|
|
349
|
-
nk_size_t
|
|
361
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
362
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
350
363
|
|
|
351
364
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
352
365
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
353
|
-
|
|
354
|
-
|
|
366
|
+
// BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
|
|
367
|
+
// handles one u32 (32 bits) across all row×column pairs simultaneously.
|
|
368
|
+
nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
|
|
369
|
+
nk_size_t const depth_bytes = depth_bits / 8;
|
|
370
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_words, depth_tile_size);
|
|
355
371
|
|
|
356
|
-
svbool_t const
|
|
357
|
-
|
|
372
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
373
|
+
// Use padded depth (depth_words * 32) for BMOPA: zero-padded bits always match in XNOR,
|
|
374
|
+
// so the effective depth for the matching→hamming conversion is the rounded-up bit count.
|
|
375
|
+
svuint32_t const depth_u32x = svdup_u32((nk_u32_t)(depth_words * 32));
|
|
358
376
|
|
|
359
377
|
NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
|
|
360
378
|
|
|
361
379
|
nk_size_t const row_end = row_start + row_count;
|
|
362
|
-
nk_size_t const column_tile_count = nk_size_divide_round_up_(
|
|
380
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dim);
|
|
363
381
|
|
|
364
|
-
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start <
|
|
382
|
+
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
|
|
365
383
|
row_tile_start += tile_dim) {
|
|
366
384
|
nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
|
|
367
|
-
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <=
|
|
368
|
-
|
|
369
|
-
|
|
385
|
+
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= vectors_count)
|
|
386
|
+
? rows_remaining
|
|
387
|
+
: (vectors_count - row_tile_start);
|
|
388
|
+
svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_clamped);
|
|
370
389
|
|
|
371
|
-
|
|
390
|
+
// Upper triangle: start from this row tile's column
|
|
391
|
+
nk_size_t column_tile_index = row_tile_start / tile_dim;
|
|
372
392
|
|
|
373
393
|
// Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
|
|
374
394
|
for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
|
|
@@ -376,162 +396,164 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_sme
|
|
|
376
396
|
|
|
377
397
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
378
398
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
379
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
399
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
380
400
|
? depth_tile_size
|
|
381
|
-
: (
|
|
382
|
-
: 0);
|
|
401
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
383
402
|
if (u32s_this_tile == 0) break;
|
|
384
403
|
|
|
385
404
|
// Load A rows into ZA0 horizontally
|
|
386
405
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
387
|
-
svbool_t const
|
|
406
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
388
407
|
|
|
408
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
389
409
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
410
|
+
nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
|
|
411
|
+
d_start_u32 * 4;
|
|
412
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
413
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
394
414
|
}
|
|
395
415
|
|
|
396
416
|
// Save A columns from ZA0 to stack buffer
|
|
397
417
|
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
398
|
-
svst1_u32(
|
|
399
|
-
svread_ver_za32_u32_m(svdup_u32(0),
|
|
418
|
+
svst1_u32(predicate_all_b32x, a_buffer[s],
|
|
419
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
|
|
400
420
|
|
|
401
421
|
// B column tile 0
|
|
402
422
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
403
423
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
404
424
|
nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
|
|
405
|
-
if (col_abs <
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
425
|
+
if (col_abs < vectors_count) {
|
|
426
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
427
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
428
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
409
429
|
}
|
|
410
430
|
}
|
|
411
431
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
412
|
-
svuint32_t a_u32x = svld1_u32(
|
|
413
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
414
|
-
svbmopa_za32_u32_m(1,
|
|
432
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
433
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
434
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
415
435
|
}
|
|
416
436
|
|
|
417
437
|
// B column tile 1
|
|
418
438
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
419
439
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
420
440
|
nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
|
|
421
|
-
if (col_abs <
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
441
|
+
if (col_abs < vectors_count) {
|
|
442
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
443
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
444
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
425
445
|
}
|
|
426
446
|
}
|
|
427
447
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
428
|
-
svuint32_t a_u32x = svld1_u32(
|
|
429
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
430
|
-
svbmopa_za32_u32_m(2,
|
|
448
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
449
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
450
|
+
svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
431
451
|
}
|
|
432
452
|
|
|
433
453
|
// B column tile 2
|
|
434
454
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
435
455
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
436
456
|
nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
|
|
437
|
-
if (col_abs <
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
457
|
+
if (col_abs < vectors_count) {
|
|
458
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
459
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
460
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
441
461
|
}
|
|
442
462
|
}
|
|
443
463
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
444
|
-
svuint32_t a_u32x = svld1_u32(
|
|
445
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
446
|
-
svbmopa_za32_u32_m(3,
|
|
464
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
465
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
466
|
+
svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
447
467
|
}
|
|
448
468
|
}
|
|
449
469
|
|
|
450
470
|
// Extract ZA1-3: hamming = depth_bits - ZA[i][j]
|
|
451
471
|
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
452
|
-
nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) *
|
|
453
|
-
|
|
454
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
455
|
-
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
456
|
-
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
457
|
-
|
|
458
|
-
svst1_u32(
|
|
459
|
-
svsub_u32_x(
|
|
460
|
-
svst1_u32(
|
|
461
|
-
svsub_u32_x(
|
|
462
|
-
svst1_u32(
|
|
463
|
-
svsub_u32_x(
|
|
472
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
|
|
473
|
+
|
|
474
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
475
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
|
|
476
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
|
|
477
|
+
|
|
478
|
+
svst1_u32(predicate_all_b32x, c_row + (column_tile_index + 0) * tile_dim,
|
|
479
|
+
svsub_u32_x(predicate_all_b32x, depth_u32x, za1_u32x));
|
|
480
|
+
svst1_u32(predicate_all_b32x, c_row + (column_tile_index + 1) * tile_dim,
|
|
481
|
+
svsub_u32_x(predicate_all_b32x, depth_u32x, za2_u32x));
|
|
482
|
+
svst1_u32(predicate_all_b32x, c_row + (column_tile_index + 2) * tile_dim,
|
|
483
|
+
svsub_u32_x(predicate_all_b32x, depth_u32x, za3_u32x));
|
|
464
484
|
}
|
|
465
485
|
}
|
|
466
486
|
|
|
467
487
|
// Remainder: 1 column tile at a time using ZA1
|
|
468
488
|
for (; column_tile_index < column_tile_count; column_tile_index++) {
|
|
469
489
|
nk_size_t const col_tile_start = column_tile_index * tile_dim;
|
|
470
|
-
nk_size_t const cols_remaining = (col_tile_start + tile_dim <=
|
|
471
|
-
|
|
472
|
-
|
|
490
|
+
nk_size_t const cols_remaining = (col_tile_start + tile_dim <= vectors_count)
|
|
491
|
+
? tile_dim
|
|
492
|
+
: (vectors_count - col_tile_start);
|
|
493
|
+
svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, cols_remaining);
|
|
473
494
|
|
|
474
495
|
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
475
496
|
|
|
476
497
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
477
498
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
478
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
499
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
479
500
|
? depth_tile_size
|
|
480
|
-
: (
|
|
481
|
-
: 0);
|
|
501
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
482
502
|
if (u32s_this_tile == 0) break;
|
|
483
503
|
|
|
484
504
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
485
|
-
svbool_t const
|
|
505
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
486
506
|
|
|
487
507
|
// Load A rows into ZA0 horizontally
|
|
508
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
488
509
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
510
|
+
nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
|
|
511
|
+
d_start_u32 * 4;
|
|
512
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
513
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
493
514
|
}
|
|
494
515
|
|
|
495
516
|
// Save A columns from ZA0 to stack buffer
|
|
496
517
|
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
497
|
-
svst1_u32(
|
|
498
|
-
svread_ver_za32_u32_m(svdup_u32(0),
|
|
518
|
+
svst1_u32(predicate_all_b32x, a_buffer[s],
|
|
519
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
|
|
499
520
|
|
|
500
521
|
// Load B column tile into ZA0
|
|
501
522
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
502
523
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
503
524
|
nk_size_t const col_abs = col_tile_start + col;
|
|
504
|
-
if (col_abs <
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
525
|
+
if (col_abs < vectors_count) {
|
|
526
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
527
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
528
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
508
529
|
}
|
|
509
530
|
}
|
|
510
531
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
511
|
-
svuint32_t a_u32x = svld1_u32(
|
|
512
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
513
|
-
svbmopa_za32_u32_m(1,
|
|
532
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
533
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_b32x, 0, step);
|
|
534
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_u32x, b_u32x);
|
|
514
535
|
}
|
|
515
536
|
}
|
|
516
537
|
|
|
517
538
|
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
518
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
519
|
-
svuint32_t hamming_u32x = svsub_u32_x(
|
|
520
|
-
nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) *
|
|
521
|
-
svst1_u32(
|
|
539
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
540
|
+
svuint32_t hamming_u32x = svsub_u32_x(predicate_all_b32x, depth_u32x, za1_u32x);
|
|
541
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
|
|
542
|
+
svst1_u32(column_predicate_b32x, c_row + col_tile_start, hamming_u32x);
|
|
522
543
|
}
|
|
523
544
|
}
|
|
524
545
|
}
|
|
525
546
|
}
|
|
526
547
|
|
|
527
|
-
NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t
|
|
528
|
-
nk_size_t
|
|
529
|
-
nk_size_t
|
|
530
|
-
|
|
531
|
-
|
|
548
|
+
NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
|
|
549
|
+
nk_size_t stride_in_bytes, nk_u32_t *result,
|
|
550
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start,
|
|
551
|
+
nk_size_t row_count) {
|
|
552
|
+
nk_hammings_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
553
|
+
result_stride_in_bytes, row_start, row_count);
|
|
532
554
|
}
|
|
533
555
|
|
|
534
|
-
#pragma endregion
|
|
556
|
+
#pragma endregion Hamming Distance
|
|
535
557
|
|
|
536
558
|
/*
|
|
537
559
|
* Jaccard distance via BMOPA matching counts + algebraic normalization.
|
|
@@ -570,31 +592,33 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
|
|
|
570
592
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
571
593
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
572
594
|
nk_size_t const tile_elements = tile_dim * depth_tile_size;
|
|
573
|
-
|
|
595
|
+
// BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
|
|
596
|
+
// handles one u32 (32 bits) across all row×column pairs simultaneously.
|
|
597
|
+
nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
|
|
598
|
+
nk_size_t const depth_bytes = depth_bits / 8;
|
|
574
599
|
|
|
575
600
|
nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
|
|
576
601
|
nk_u32_t const *b_norms = header->norms_offset ? (nk_u32_t const *)((char const *)b_packed + header->norms_offset)
|
|
577
602
|
: (nk_u32_t const *)0;
|
|
578
603
|
|
|
579
|
-
svbool_t const
|
|
580
|
-
svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)
|
|
604
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
605
|
+
svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)(depth_words * 32));
|
|
581
606
|
svfloat32_t const half_f32x = svdup_f32(0.5f);
|
|
582
607
|
svfloat32_t const one_f32x = svdup_f32(1.0f);
|
|
583
608
|
svfloat32_t const zero_f32x = svdup_f32(0.0f);
|
|
584
|
-
nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, 8);
|
|
585
609
|
nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
|
|
586
610
|
|
|
587
611
|
for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
|
|
588
612
|
nk_size_t const row_start_a = row_tile_a * tile_dim;
|
|
589
613
|
nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
|
|
590
614
|
: (row_count_a - row_start_a);
|
|
591
|
-
svbool_t const
|
|
615
|
+
svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_a_remaining);
|
|
592
616
|
|
|
593
617
|
// Compute A tile norms using streaming SVE popcount
|
|
594
618
|
NK_ALIGN64 nk_f32_t a_tile_norms[16];
|
|
595
619
|
for (nk_size_t r = 0; r < rows_a_remaining; r++) {
|
|
596
620
|
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)a + (row_start_a + r) * a_stride_in_bytes);
|
|
597
|
-
a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row,
|
|
621
|
+
a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_bytes);
|
|
598
622
|
}
|
|
599
623
|
|
|
600
624
|
// Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
|
|
@@ -604,22 +628,23 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
|
|
|
604
628
|
|
|
605
629
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
606
630
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
607
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
631
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
608
632
|
? depth_tile_size
|
|
609
|
-
: (
|
|
610
|
-
: 0);
|
|
633
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
611
634
|
if (u32s_this_tile == 0) break;
|
|
612
635
|
|
|
613
636
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
614
637
|
|
|
615
|
-
svbool_t const
|
|
638
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
639
|
+
|
|
640
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
616
641
|
|
|
617
|
-
// Load A rows into ZA0.S
|
|
642
|
+
// Load A rows into ZA0.S, byte-predicated to zero garbage bits
|
|
618
643
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
644
|
+
nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
|
|
645
|
+
d_start_u32 * 4;
|
|
646
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
647
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
623
648
|
}
|
|
624
649
|
|
|
625
650
|
// B tile pointers for 3 column tiles
|
|
@@ -629,25 +654,25 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
|
|
|
629
654
|
|
|
630
655
|
// Vertical read + BMOPA for each depth step
|
|
631
656
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
632
|
-
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
633
|
-
|
|
634
|
-
svbmopa_za32_u32_m(1,
|
|
635
|
-
svld1_u32(
|
|
636
|
-
svbmopa_za32_u32_m(2,
|
|
637
|
-
svld1_u32(
|
|
638
|
-
svbmopa_za32_u32_m(3,
|
|
639
|
-
svld1_u32(
|
|
657
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
|
|
658
|
+
|
|
659
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
660
|
+
svld1_u32(predicate_all_b32x, b_tile0 + step * tile_dim));
|
|
661
|
+
svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
662
|
+
svld1_u32(predicate_all_b32x, b_tile1 + step * tile_dim));
|
|
663
|
+
svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
|
|
664
|
+
svld1_u32(predicate_all_b32x, b_tile2 + step * tile_dim));
|
|
640
665
|
}
|
|
641
666
|
}
|
|
642
667
|
|
|
643
668
|
// Extract from ZA1-3: Jaccard normalization via streaming SVE
|
|
644
669
|
// Hoist B norms outside row loop (same for all A rows in this tile-pair)
|
|
645
670
|
svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(
|
|
646
|
-
|
|
671
|
+
predicate_all_b32x, svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 0) * tile_dim));
|
|
647
672
|
svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(
|
|
648
|
-
|
|
673
|
+
predicate_all_b32x, svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 1) * tile_dim));
|
|
649
674
|
svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(
|
|
650
|
-
|
|
675
|
+
predicate_all_b32x, svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 2) * tile_dim));
|
|
651
676
|
|
|
652
677
|
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
653
678
|
nk_f32_t *c_row = (nk_f32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
@@ -655,54 +680,54 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
|
|
|
655
680
|
|
|
656
681
|
// ZA1
|
|
657
682
|
{
|
|
658
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
659
|
-
svfloat32_t matching_f32x = svcvt_f32_u32_x(
|
|
660
|
-
svfloat32_t sum_norms_f32x = svadd_f32_x(
|
|
683
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
684
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za1_u32x);
|
|
685
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_0_f32x);
|
|
661
686
|
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
662
|
-
|
|
663
|
-
svadd_f32_x(
|
|
687
|
+
predicate_all_b32x,
|
|
688
|
+
svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
|
|
664
689
|
matching_f32x),
|
|
665
690
|
half_f32x);
|
|
666
|
-
svfloat32_t union_val_f32x = svsub_f32_x(
|
|
667
|
-
svbool_t
|
|
668
|
-
svfloat32_t ratio_f32x = svdiv_f32_x(
|
|
691
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
|
|
692
|
+
svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
|
|
693
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
|
|
669
694
|
svfloat32_t jaccard_f32x = svsel_f32(
|
|
670
|
-
|
|
671
|
-
svst1_f32(
|
|
695
|
+
nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
|
|
696
|
+
svst1_f32(predicate_all_b32x, c_row + (row_tile_b + 0) * tile_dim, jaccard_f32x);
|
|
672
697
|
}
|
|
673
698
|
// ZA2
|
|
674
699
|
{
|
|
675
|
-
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
676
|
-
svfloat32_t matching_f32x = svcvt_f32_u32_x(
|
|
677
|
-
svfloat32_t sum_norms_f32x = svadd_f32_x(
|
|
700
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
|
|
701
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za2_u32x);
|
|
702
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_1_f32x);
|
|
678
703
|
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
679
|
-
|
|
680
|
-
svadd_f32_x(
|
|
704
|
+
predicate_all_b32x,
|
|
705
|
+
svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
|
|
681
706
|
matching_f32x),
|
|
682
707
|
half_f32x);
|
|
683
|
-
svfloat32_t union_val_f32x = svsub_f32_x(
|
|
684
|
-
svbool_t
|
|
685
|
-
svfloat32_t ratio_f32x = svdiv_f32_x(
|
|
708
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
|
|
709
|
+
svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
|
|
710
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
|
|
686
711
|
svfloat32_t jaccard_f32x = svsel_f32(
|
|
687
|
-
|
|
688
|
-
svst1_f32(
|
|
712
|
+
nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
|
|
713
|
+
svst1_f32(predicate_all_b32x, c_row + (row_tile_b + 1) * tile_dim, jaccard_f32x);
|
|
689
714
|
}
|
|
690
715
|
// ZA3
|
|
691
716
|
{
|
|
692
|
-
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
693
|
-
svfloat32_t matching_f32x = svcvt_f32_u32_x(
|
|
694
|
-
svfloat32_t sum_norms_f32x = svadd_f32_x(
|
|
717
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
|
|
718
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za3_u32x);
|
|
719
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_2_f32x);
|
|
695
720
|
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
696
|
-
|
|
697
|
-
svadd_f32_x(
|
|
721
|
+
predicate_all_b32x,
|
|
722
|
+
svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
|
|
698
723
|
matching_f32x),
|
|
699
724
|
half_f32x);
|
|
700
|
-
svfloat32_t union_val_f32x = svsub_f32_x(
|
|
701
|
-
svbool_t
|
|
702
|
-
svfloat32_t ratio_f32x = svdiv_f32_x(
|
|
725
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
|
|
726
|
+
svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
|
|
727
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
|
|
703
728
|
svfloat32_t jaccard_f32x = svsel_f32(
|
|
704
|
-
|
|
705
|
-
svst1_f32(
|
|
729
|
+
nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
|
|
730
|
+
svst1_f32(predicate_all_b32x, c_row + (row_tile_b + 2) * tile_dim, jaccard_f32x);
|
|
706
731
|
}
|
|
707
732
|
}
|
|
708
733
|
}
|
|
@@ -712,60 +737,60 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
|
|
|
712
737
|
nk_size_t const row_start_b = row_tile_b * tile_dim;
|
|
713
738
|
nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
|
|
714
739
|
: (row_count_b - row_start_b);
|
|
715
|
-
svbool_t const
|
|
740
|
+
svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, rows_b_remaining);
|
|
716
741
|
|
|
717
742
|
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
718
743
|
|
|
719
744
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
720
745
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
721
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
746
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
722
747
|
? depth_tile_size
|
|
723
|
-
: (
|
|
724
|
-
: 0);
|
|
748
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
725
749
|
if (u32s_this_tile == 0) break;
|
|
726
750
|
|
|
727
751
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
728
752
|
|
|
729
|
-
svbool_t const
|
|
753
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
730
754
|
|
|
731
755
|
// Load A rows into ZA0.S horizontally
|
|
756
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
732
757
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
758
|
+
nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
|
|
759
|
+
d_start_u32 * 4;
|
|
760
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
761
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
737
762
|
}
|
|
738
763
|
|
|
739
764
|
nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
|
|
740
765
|
|
|
741
766
|
// Vertical read + BMOPA
|
|
742
767
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
743
|
-
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
744
|
-
svuint32_t b_u32x = svld1_u32(
|
|
745
|
-
svbmopa_za32_u32_m(1,
|
|
768
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
|
|
769
|
+
svuint32_t b_u32x = svld1_u32(predicate_all_b32x, b_tile + step * tile_dim);
|
|
770
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_column_u32x, b_u32x);
|
|
746
771
|
}
|
|
747
772
|
}
|
|
748
773
|
|
|
749
774
|
// Extract from ZA1: Jaccard normalization
|
|
750
|
-
svfloat32_t b_norms_f32x = svcvt_f32_u32_x(
|
|
751
|
-
svld1_u32(
|
|
775
|
+
svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_b32x,
|
|
776
|
+
svld1_u32(predicate_all_b32x, b_norms + row_start_b));
|
|
752
777
|
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
753
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
754
|
-
svfloat32_t matching_f32x = svcvt_f32_u32_x(
|
|
778
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
779
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za1_u32x);
|
|
755
780
|
svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
|
|
756
|
-
svfloat32_t sum_norms_f32x = svadd_f32_x(
|
|
781
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_f32x);
|
|
757
782
|
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
758
|
-
|
|
759
|
-
svadd_f32_x(
|
|
783
|
+
predicate_all_b32x,
|
|
784
|
+
svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
|
|
760
785
|
matching_f32x),
|
|
761
786
|
half_f32x);
|
|
762
|
-
svfloat32_t union_val_f32x = svsub_f32_x(
|
|
763
|
-
svbool_t
|
|
764
|
-
svfloat32_t ratio_f32x = svdiv_f32_x(
|
|
765
|
-
svfloat32_t jaccard_f32x = svsel_f32(
|
|
766
|
-
svsub_f32_x(
|
|
787
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
|
|
788
|
+
svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
|
|
789
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
|
|
790
|
+
svfloat32_t jaccard_f32x = svsel_f32(nonzero_b32x,
|
|
791
|
+
svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
|
|
767
792
|
nk_f32_t *c_row = (nk_f32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
768
|
-
svst1_f32(
|
|
793
|
+
svst1_f32(column_predicate_b32x, c_row + row_start_b, jaccard_f32x);
|
|
769
794
|
}
|
|
770
795
|
}
|
|
771
796
|
}
|
|
@@ -784,17 +809,19 @@ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
|
|
|
784
809
|
* Norms computed on-the-fly using streaming SVE popcount.
|
|
785
810
|
*/
|
|
786
811
|
__arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_(
|
|
787
|
-
nk_u1x8_t const *vectors, nk_size_t
|
|
788
|
-
nk_size_t
|
|
812
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
813
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
789
814
|
|
|
790
815
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
791
816
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
nk_size_t const
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
817
|
+
// BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
|
|
818
|
+
// handles one u32 (32 bits) across all row×column pairs simultaneously.
|
|
819
|
+
nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
|
|
820
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_words, depth_tile_size);
|
|
821
|
+
nk_size_t const depth_bytes = depth_bits / 8;
|
|
822
|
+
|
|
823
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
824
|
+
svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)(depth_words * 32));
|
|
798
825
|
svfloat32_t const half_f32x = svdup_f32(0.5f);
|
|
799
826
|
svfloat32_t const one_f32x = svdup_f32(1.0f);
|
|
800
827
|
svfloat32_t const zero_f32x = svdup_f32(0.0f);
|
|
@@ -802,20 +829,22 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
|
|
|
802
829
|
NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
|
|
803
830
|
|
|
804
831
|
nk_size_t const row_end = row_start + row_count;
|
|
805
|
-
nk_size_t const column_tile_count = nk_size_divide_round_up_(
|
|
832
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dim);
|
|
806
833
|
|
|
807
|
-
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start <
|
|
834
|
+
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
|
|
808
835
|
row_tile_start += tile_dim) {
|
|
809
836
|
nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
|
|
810
|
-
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <=
|
|
811
|
-
|
|
812
|
-
|
|
837
|
+
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= vectors_count)
|
|
838
|
+
? rows_remaining
|
|
839
|
+
: (vectors_count - row_tile_start);
|
|
840
|
+
svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_clamped);
|
|
813
841
|
|
|
814
842
|
// Compute A tile norms
|
|
815
843
|
NK_ALIGN64 nk_f32_t a_tile_norms[16];
|
|
816
844
|
for (nk_size_t r = 0; r < rows_clamped; r++) {
|
|
817
|
-
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors +
|
|
818
|
-
|
|
845
|
+
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors +
|
|
846
|
+
(row_tile_start + r) * stride_in_bytes);
|
|
847
|
+
a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_bytes);
|
|
819
848
|
}
|
|
820
849
|
for (nk_size_t r = rows_clamped; r < tile_dim; r++) a_tile_norms[r] = 0.0f;
|
|
821
850
|
|
|
@@ -828,74 +857,74 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
|
|
|
828
857
|
|
|
829
858
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
830
859
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
831
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
860
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
832
861
|
? depth_tile_size
|
|
833
|
-
: (
|
|
834
|
-
: 0);
|
|
862
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
835
863
|
if (u32s_this_tile == 0) break;
|
|
836
864
|
|
|
837
865
|
// Load A rows into ZA0 horizontally
|
|
838
866
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
839
|
-
svbool_t const
|
|
867
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
840
868
|
|
|
869
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
841
870
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
871
|
+
nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
|
|
872
|
+
d_start_u32 * 4;
|
|
873
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
874
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
846
875
|
}
|
|
847
876
|
|
|
848
877
|
// Save A columns from ZA0 to stack buffer
|
|
849
878
|
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
850
|
-
svst1_u32(
|
|
851
|
-
svread_ver_za32_u32_m(svdup_u32(0),
|
|
879
|
+
svst1_u32(predicate_all_b32x, a_buffer[s],
|
|
880
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
|
|
852
881
|
|
|
853
882
|
// B column tile 0
|
|
854
883
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
855
884
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
856
885
|
nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
|
|
857
|
-
if (col_abs <
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
886
|
+
if (col_abs < vectors_count) {
|
|
887
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
888
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
889
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
861
890
|
}
|
|
862
891
|
}
|
|
863
892
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
864
|
-
svuint32_t a_u32x = svld1_u32(
|
|
865
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
866
|
-
svbmopa_za32_u32_m(1,
|
|
893
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
894
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
895
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
867
896
|
}
|
|
868
897
|
|
|
869
898
|
// B column tile 1
|
|
870
899
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
871
900
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
872
901
|
nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
|
|
873
|
-
if (col_abs <
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
902
|
+
if (col_abs < vectors_count) {
|
|
903
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
904
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
905
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
877
906
|
}
|
|
878
907
|
}
|
|
879
908
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
880
|
-
svuint32_t a_u32x = svld1_u32(
|
|
881
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
882
|
-
svbmopa_za32_u32_m(2,
|
|
909
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
910
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
911
|
+
svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
883
912
|
}
|
|
884
913
|
|
|
885
914
|
// B column tile 2
|
|
886
915
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
887
916
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
888
917
|
nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
|
|
889
|
-
if (col_abs <
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
918
|
+
if (col_abs < vectors_count) {
|
|
919
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
920
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
921
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
893
922
|
}
|
|
894
923
|
}
|
|
895
924
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
896
|
-
svuint32_t a_u32x = svld1_u32(
|
|
897
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
898
|
-
svbmopa_za32_u32_m(3,
|
|
925
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
926
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
|
|
927
|
+
svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
|
|
899
928
|
}
|
|
900
929
|
}
|
|
901
930
|
|
|
@@ -907,85 +936,85 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
|
|
|
907
936
|
nk_size_t const col_abs_0 = (column_tile_index + 0) * tile_dim + col;
|
|
908
937
|
nk_size_t const col_abs_1 = (column_tile_index + 1) * tile_dim + col;
|
|
909
938
|
nk_size_t const col_abs_2 = (column_tile_index + 2) * tile_dim + col;
|
|
910
|
-
b_tile_norms_0[col] =
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
b_tile_norms_1[col] =
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
b_tile_norms_2[col] =
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
939
|
+
b_tile_norms_0[col] =
|
|
940
|
+
(col_abs_0 < vectors_count)
|
|
941
|
+
? nk_sets_reduce_sumsq_u1_streaming_(
|
|
942
|
+
(nk_u1x8_t const *)((char const *)vectors + col_abs_0 * stride_in_bytes), depth_bytes)
|
|
943
|
+
: 0;
|
|
944
|
+
b_tile_norms_1[col] =
|
|
945
|
+
(col_abs_1 < vectors_count)
|
|
946
|
+
? nk_sets_reduce_sumsq_u1_streaming_(
|
|
947
|
+
(nk_u1x8_t const *)((char const *)vectors + col_abs_1 * stride_in_bytes), depth_bytes)
|
|
948
|
+
: 0;
|
|
949
|
+
b_tile_norms_2[col] =
|
|
950
|
+
(col_abs_2 < vectors_count)
|
|
951
|
+
? nk_sets_reduce_sumsq_u1_streaming_(
|
|
952
|
+
(nk_u1x8_t const *)((char const *)vectors + col_abs_2 * stride_in_bytes), depth_bytes)
|
|
953
|
+
: 0;
|
|
925
954
|
}
|
|
926
955
|
|
|
927
956
|
// Extract ZA1-3: Jaccard normalization
|
|
928
|
-
svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(
|
|
929
|
-
svld1_u32(
|
|
930
|
-
svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(
|
|
931
|
-
svld1_u32(
|
|
932
|
-
svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(
|
|
933
|
-
svld1_u32(
|
|
957
|
+
svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(predicate_all_b32x,
|
|
958
|
+
svld1_u32(predicate_all_b32x, b_tile_norms_0));
|
|
959
|
+
svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(predicate_all_b32x,
|
|
960
|
+
svld1_u32(predicate_all_b32x, b_tile_norms_1));
|
|
961
|
+
svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(predicate_all_b32x,
|
|
962
|
+
svld1_u32(predicate_all_b32x, b_tile_norms_2));
|
|
934
963
|
|
|
935
964
|
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
936
|
-
nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) *
|
|
965
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
|
|
937
966
|
svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
|
|
938
967
|
|
|
939
968
|
// ZA1
|
|
940
969
|
{
|
|
941
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
942
|
-
svfloat32_t matching_f32x = svcvt_f32_u32_x(
|
|
943
|
-
svfloat32_t sum_norms_f32x = svadd_f32_x(
|
|
970
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
971
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za1_u32x);
|
|
972
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_0_f32x);
|
|
944
973
|
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
945
|
-
|
|
946
|
-
svadd_f32_x(
|
|
974
|
+
predicate_all_b32x,
|
|
975
|
+
svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
|
|
947
976
|
matching_f32x),
|
|
948
977
|
half_f32x);
|
|
949
|
-
svfloat32_t union_val_f32x = svsub_f32_x(
|
|
950
|
-
svbool_t
|
|
951
|
-
svfloat32_t ratio_f32x = svdiv_f32_x(
|
|
978
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
|
|
979
|
+
svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
|
|
980
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
|
|
952
981
|
svfloat32_t jaccard_f32x = svsel_f32(
|
|
953
|
-
|
|
954
|
-
svst1_f32(
|
|
982
|
+
nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
|
|
983
|
+
svst1_f32(predicate_all_b32x, c_row + (column_tile_index + 0) * tile_dim, jaccard_f32x);
|
|
955
984
|
}
|
|
956
985
|
// ZA2
|
|
957
986
|
{
|
|
958
|
-
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
959
|
-
svfloat32_t matching_f32x = svcvt_f32_u32_x(
|
|
960
|
-
svfloat32_t sum_norms_f32x = svadd_f32_x(
|
|
987
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
|
|
988
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za2_u32x);
|
|
989
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_1_f32x);
|
|
961
990
|
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
962
|
-
|
|
963
|
-
svadd_f32_x(
|
|
991
|
+
predicate_all_b32x,
|
|
992
|
+
svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
|
|
964
993
|
matching_f32x),
|
|
965
994
|
half_f32x);
|
|
966
|
-
svfloat32_t union_val_f32x = svsub_f32_x(
|
|
967
|
-
svbool_t
|
|
968
|
-
svfloat32_t ratio_f32x = svdiv_f32_x(
|
|
995
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
|
|
996
|
+
svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
|
|
997
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
|
|
969
998
|
svfloat32_t jaccard_f32x = svsel_f32(
|
|
970
|
-
|
|
971
|
-
svst1_f32(
|
|
999
|
+
nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
|
|
1000
|
+
svst1_f32(predicate_all_b32x, c_row + (column_tile_index + 1) * tile_dim, jaccard_f32x);
|
|
972
1001
|
}
|
|
973
1002
|
// ZA3
|
|
974
1003
|
{
|
|
975
|
-
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
976
|
-
svfloat32_t matching_f32x = svcvt_f32_u32_x(
|
|
977
|
-
svfloat32_t sum_norms_f32x = svadd_f32_x(
|
|
1004
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
|
|
1005
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za3_u32x);
|
|
1006
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_2_f32x);
|
|
978
1007
|
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
979
|
-
|
|
980
|
-
svadd_f32_x(
|
|
1008
|
+
predicate_all_b32x,
|
|
1009
|
+
svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
|
|
981
1010
|
matching_f32x),
|
|
982
1011
|
half_f32x);
|
|
983
|
-
svfloat32_t union_val_f32x = svsub_f32_x(
|
|
984
|
-
svbool_t
|
|
985
|
-
svfloat32_t ratio_f32x = svdiv_f32_x(
|
|
1012
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
|
|
1013
|
+
svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
|
|
1014
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
|
|
986
1015
|
svfloat32_t jaccard_f32x = svsel_f32(
|
|
987
|
-
|
|
988
|
-
svst1_f32(
|
|
1016
|
+
nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
|
|
1017
|
+
svst1_f32(predicate_all_b32x, c_row + (column_tile_index + 2) * tile_dim, jaccard_f32x);
|
|
989
1018
|
}
|
|
990
1019
|
}
|
|
991
1020
|
}
|
|
@@ -993,50 +1022,51 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
|
|
|
993
1022
|
// Remainder: 1 column tile at a time using ZA1
|
|
994
1023
|
for (; column_tile_index < column_tile_count; column_tile_index++) {
|
|
995
1024
|
nk_size_t const col_tile_start = column_tile_index * tile_dim;
|
|
996
|
-
nk_size_t const cols_remaining = (col_tile_start + tile_dim <=
|
|
997
|
-
|
|
998
|
-
|
|
1025
|
+
nk_size_t const cols_remaining = (col_tile_start + tile_dim <= vectors_count)
|
|
1026
|
+
? tile_dim
|
|
1027
|
+
: (vectors_count - col_tile_start);
|
|
1028
|
+
svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, cols_remaining);
|
|
999
1029
|
|
|
1000
1030
|
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
1001
1031
|
|
|
1002
1032
|
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
1003
1033
|
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
1004
|
-
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <=
|
|
1034
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
|
|
1005
1035
|
? depth_tile_size
|
|
1006
|
-
: (
|
|
1007
|
-
: 0);
|
|
1036
|
+
: (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
|
|
1008
1037
|
if (u32s_this_tile == 0) break;
|
|
1009
1038
|
|
|
1010
1039
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1011
|
-
svbool_t const
|
|
1040
|
+
svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
1012
1041
|
|
|
1013
1042
|
// Load A rows into ZA0 horizontally
|
|
1043
|
+
svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
|
|
1014
1044
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1045
|
+
nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
|
|
1046
|
+
d_start_u32 * 4;
|
|
1047
|
+
svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
|
|
1048
|
+
svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
|
|
1019
1049
|
}
|
|
1020
1050
|
|
|
1021
1051
|
// Save A columns from ZA0 to stack buffer
|
|
1022
1052
|
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
1023
|
-
svst1_u32(
|
|
1024
|
-
svread_ver_za32_u32_m(svdup_u32(0),
|
|
1053
|
+
svst1_u32(predicate_all_b32x, a_buffer[s],
|
|
1054
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
|
|
1025
1055
|
|
|
1026
1056
|
// Load B column tile into ZA0
|
|
1027
1057
|
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1028
1058
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
1029
1059
|
nk_size_t const col_abs = col_tile_start + col;
|
|
1030
|
-
if (col_abs <
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1060
|
+
if (col_abs < vectors_count) {
|
|
1061
|
+
nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
|
|
1062
|
+
svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
|
|
1063
|
+
svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
|
|
1034
1064
|
}
|
|
1035
1065
|
}
|
|
1036
1066
|
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
1037
|
-
svuint32_t a_u32x = svld1_u32(
|
|
1038
|
-
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0),
|
|
1039
|
-
svbmopa_za32_u32_m(1,
|
|
1067
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
|
|
1068
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_b32x, 0, step);
|
|
1069
|
+
svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_u32x, b_u32x);
|
|
1040
1070
|
}
|
|
1041
1071
|
}
|
|
1042
1072
|
|
|
@@ -1044,44 +1074,45 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
|
|
|
1044
1074
|
NK_ALIGN64 nk_u32_t b_tile_norms[16];
|
|
1045
1075
|
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
1046
1076
|
nk_size_t const col_abs = col_tile_start + col;
|
|
1047
|
-
b_tile_norms[col] = (col_abs <
|
|
1077
|
+
b_tile_norms[col] = (col_abs < vectors_count)
|
|
1048
1078
|
? nk_sets_reduce_sumsq_u1_streaming_(
|
|
1049
|
-
(nk_u1x8_t const *)((char const *)vectors + col_abs *
|
|
1050
|
-
|
|
1079
|
+
(nk_u1x8_t const *)((char const *)vectors + col_abs * stride_in_bytes),
|
|
1080
|
+
depth_bytes)
|
|
1051
1081
|
: 0;
|
|
1052
1082
|
}
|
|
1053
1083
|
|
|
1054
|
-
svfloat32_t b_norms_f32x = svcvt_f32_u32_x(
|
|
1084
|
+
svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_b32x, svld1_u32(predicate_all_b32x, b_tile_norms));
|
|
1055
1085
|
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
1056
|
-
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0),
|
|
1057
|
-
svfloat32_t matching_f32x = svcvt_f32_u32_x(
|
|
1086
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
|
|
1087
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za1_u32x);
|
|
1058
1088
|
svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
|
|
1059
|
-
svfloat32_t sum_norms_f32x = svadd_f32_x(
|
|
1089
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_f32x);
|
|
1060
1090
|
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
1061
|
-
|
|
1062
|
-
svadd_f32_x(
|
|
1091
|
+
predicate_all_b32x,
|
|
1092
|
+
svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
|
|
1063
1093
|
matching_f32x),
|
|
1064
1094
|
half_f32x);
|
|
1065
|
-
svfloat32_t union_val_f32x = svsub_f32_x(
|
|
1066
|
-
svbool_t
|
|
1067
|
-
svfloat32_t ratio_f32x = svdiv_f32_x(
|
|
1068
|
-
svfloat32_t jaccard_f32x = svsel_f32(
|
|
1069
|
-
svsub_f32_x(
|
|
1070
|
-
nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) *
|
|
1071
|
-
svst1_f32(
|
|
1095
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
|
|
1096
|
+
svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
|
|
1097
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
|
|
1098
|
+
svfloat32_t jaccard_f32x = svsel_f32(nonzero_b32x,
|
|
1099
|
+
svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
|
|
1100
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
|
|
1101
|
+
svst1_f32(column_predicate_b32x, c_row + col_tile_start, jaccard_f32x);
|
|
1072
1102
|
}
|
|
1073
1103
|
}
|
|
1074
1104
|
}
|
|
1075
1105
|
}
|
|
1076
1106
|
|
|
1077
|
-
NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t
|
|
1078
|
-
nk_size_t
|
|
1079
|
-
nk_size_t
|
|
1080
|
-
|
|
1081
|
-
|
|
1107
|
+
NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
|
|
1108
|
+
nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1109
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start,
|
|
1110
|
+
nk_size_t row_count) {
|
|
1111
|
+
nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
1112
|
+
result_stride_in_bytes, row_start, row_count);
|
|
1082
1113
|
}
|
|
1083
1114
|
|
|
1084
|
-
#pragma endregion
|
|
1115
|
+
#pragma endregion Jaccard Distance
|
|
1085
1116
|
|
|
1086
1117
|
#if defined(__clang__)
|
|
1087
1118
|
#pragma clang attribute pop
|