numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -53,7 +53,7 @@ extern "C" {
|
|
|
53
53
|
#endif
|
|
54
54
|
|
|
55
55
|
#if defined(__clang__)
|
|
56
|
-
#pragma clang attribute push(__attribute__((target("sme
|
|
56
|
+
#pragma clang attribute push(__attribute__((target("sme"))), apply_to = function)
|
|
57
57
|
#elif defined(__GNUC__)
|
|
58
58
|
#pragma GCC push_options
|
|
59
59
|
#pragma GCC target("+sme")
|
|
@@ -112,8 +112,8 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
112
112
|
nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
|
|
113
113
|
document_header->norms_offset);
|
|
114
114
|
|
|
115
|
-
svbool_t const
|
|
116
|
-
svbool_t const
|
|
115
|
+
svbool_t const predicate_all_b16x = svptrue_b16();
|
|
116
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
117
117
|
|
|
118
118
|
nk_f32_t total_angular_distance = 0.0f;
|
|
119
119
|
|
|
@@ -121,10 +121,10 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
121
121
|
nk_size_t const row_start = row_tile_index * tile_dimension;
|
|
122
122
|
nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
|
|
123
123
|
: (query_count - row_start);
|
|
124
|
-
svbool_t const
|
|
124
|
+
svbool_t const row_predicate_b16x = (rows_remaining == tile_dimension)
|
|
125
125
|
? svptrue_b16()
|
|
126
126
|
: svwhilelt_b16_u64(0u, rows_remaining * 2);
|
|
127
|
-
svbool_t const
|
|
127
|
+
svbool_t const row_predicate_b32x = (rows_remaining == tile_dimension) ? svptrue_b32()
|
|
128
128
|
: svwhilelt_b32_u64(0u, rows_remaining);
|
|
129
129
|
|
|
130
130
|
// Running max + argmax vectors for angular distance finalization
|
|
@@ -140,29 +140,29 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
140
140
|
// Accumulate: for each depth step, load Q vector and 4 D vectors, issue 4 FMOPAs
|
|
141
141
|
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
142
142
|
svfloat16_t query_packed_f16x = svld1_f16(
|
|
143
|
-
|
|
143
|
+
row_predicate_b16x,
|
|
144
144
|
(float16_t const *)(query_vecs +
|
|
145
145
|
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
146
146
|
svfloat16_t document_packed_0_f16x = svld1_f16(
|
|
147
|
-
|
|
147
|
+
predicate_all_b16x,
|
|
148
148
|
(float16_t const *)(document_vecs +
|
|
149
149
|
((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
|
|
150
150
|
svfloat16_t document_packed_1_f16x = svld1_f16(
|
|
151
|
-
|
|
151
|
+
predicate_all_b16x,
|
|
152
152
|
(float16_t const *)(document_vecs +
|
|
153
153
|
((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
|
|
154
154
|
svfloat16_t document_packed_2_f16x = svld1_f16(
|
|
155
|
-
|
|
155
|
+
predicate_all_b16x,
|
|
156
156
|
(float16_t const *)(document_vecs +
|
|
157
157
|
((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
|
|
158
158
|
svfloat16_t document_packed_3_f16x = svld1_f16(
|
|
159
|
-
|
|
159
|
+
predicate_all_b16x,
|
|
160
160
|
(float16_t const *)(document_vecs +
|
|
161
161
|
((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
|
|
162
|
-
svmopa_za32_f16_m(0,
|
|
163
|
-
svmopa_za32_f16_m(1,
|
|
164
|
-
svmopa_za32_f16_m(2,
|
|
165
|
-
svmopa_za32_f16_m(3,
|
|
162
|
+
svmopa_za32_f16_m(0, row_predicate_b16x, predicate_all_b16x, query_packed_f16x, document_packed_0_f16x);
|
|
163
|
+
svmopa_za32_f16_m(1, row_predicate_b16x, predicate_all_b16x, query_packed_f16x, document_packed_1_f16x);
|
|
164
|
+
svmopa_za32_f16_m(2, row_predicate_b16x, predicate_all_b16x, query_packed_f16x, document_packed_2_f16x);
|
|
165
|
+
svmopa_za32_f16_m(3, row_predicate_b16x, predicate_all_b16x, query_packed_f16x, document_packed_3_f16x);
|
|
166
166
|
}
|
|
167
167
|
|
|
168
168
|
// Vertical column extraction + argmax update (manually unrolled over 4 tiles)
|
|
@@ -170,36 +170,36 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
170
170
|
// Tile 0
|
|
171
171
|
{
|
|
172
172
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
|
|
173
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
173
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 0,
|
|
174
174
|
column_within_tile);
|
|
175
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
175
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
176
176
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
177
177
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
178
178
|
}
|
|
179
179
|
// Tile 1
|
|
180
180
|
{
|
|
181
181
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
|
|
182
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
182
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 1,
|
|
183
183
|
column_within_tile);
|
|
184
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
184
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
185
185
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
186
186
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
187
187
|
}
|
|
188
188
|
// Tile 2
|
|
189
189
|
{
|
|
190
190
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
|
|
191
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
191
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 2,
|
|
192
192
|
column_within_tile);
|
|
193
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
193
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
194
194
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
195
195
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
196
196
|
}
|
|
197
197
|
// Tile 3
|
|
198
198
|
{
|
|
199
199
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
|
|
200
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
200
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 3,
|
|
201
201
|
column_within_tile);
|
|
202
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
202
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
203
203
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
204
204
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
205
205
|
}
|
|
@@ -212,7 +212,7 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
212
212
|
nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
|
|
213
213
|
? tile_dimension
|
|
214
214
|
: (document_count - col_start);
|
|
215
|
-
svbool_t const
|
|
215
|
+
svbool_t const column_predicate_b16x = (cols_remaining == tile_dimension)
|
|
216
216
|
? svptrue_b16()
|
|
217
217
|
: svwhilelt_b16_u64(0u, cols_remaining * 2);
|
|
218
218
|
|
|
@@ -220,23 +220,23 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
220
220
|
|
|
221
221
|
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
222
222
|
svfloat16_t query_packed_f16x = svld1_f16(
|
|
223
|
-
|
|
223
|
+
row_predicate_b16x,
|
|
224
224
|
(float16_t const *)(query_vecs +
|
|
225
225
|
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
226
226
|
svfloat16_t document_packed_f16x = svld1_f16(
|
|
227
|
-
|
|
227
|
+
column_predicate_b16x,
|
|
228
228
|
(float16_t const *)(document_vecs +
|
|
229
229
|
(column_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
230
|
-
svmopa_za32_f16_m(0,
|
|
230
|
+
svmopa_za32_f16_m(0, row_predicate_b16x, column_predicate_b16x, query_packed_f16x,
|
|
231
231
|
document_packed_f16x);
|
|
232
232
|
}
|
|
233
233
|
|
|
234
234
|
// Vertical column extraction from ZA0 + argmax update
|
|
235
235
|
for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
|
|
236
236
|
nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
|
|
237
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
237
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 0,
|
|
238
238
|
column_within_tile);
|
|
239
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
239
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
240
240
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
241
241
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
242
242
|
}
|
|
@@ -246,19 +246,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
246
246
|
// Gather document inverse norms via argmax indices (no SVE gather in streaming mode)
|
|
247
247
|
nk_u32_t best_document_indices[64];
|
|
248
248
|
nk_f32_t document_inverse_norms_gathered[64];
|
|
249
|
-
svst1_u32(
|
|
249
|
+
svst1_u32(row_predicate_b32x, best_document_indices, running_argmax_u32x);
|
|
250
250
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++)
|
|
251
251
|
document_inverse_norms_gathered[row_in_tile] = document_inverse_norms[best_document_indices[row_in_tile]];
|
|
252
252
|
|
|
253
253
|
// SVE-width: cosine = dot * inv_norm_q * inv_norm_d, angular = max(1 - cosine, 0)
|
|
254
|
-
svfloat32_t query_inverse_norms_f32x = svld1_f32(
|
|
255
|
-
svfloat32_t document_inverse_norms_f32x = svld1_f32(
|
|
254
|
+
svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_b32x, query_inverse_norms + row_start);
|
|
255
|
+
svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_b32x, document_inverse_norms_gathered);
|
|
256
256
|
svfloat32_t cosine_f32x = svmul_f32_x(
|
|
257
|
-
|
|
257
|
+
row_predicate_b32x, svmul_f32_x(row_predicate_b32x, running_maximum_f32x, query_inverse_norms_f32x),
|
|
258
258
|
document_inverse_norms_f32x);
|
|
259
259
|
svfloat32_t angular_distance_f32x = svmax_f32_x(
|
|
260
|
-
|
|
261
|
-
total_angular_distance += svaddv_f32(
|
|
260
|
+
row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
|
|
261
|
+
total_angular_distance += svaddv_f32(row_predicate_b32x, angular_distance_f32x);
|
|
262
262
|
}
|
|
263
263
|
|
|
264
264
|
*result = total_angular_distance;
|
|
@@ -304,8 +304,8 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
304
304
|
nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
|
|
305
305
|
document_header->norms_offset);
|
|
306
306
|
|
|
307
|
-
svbool_t const
|
|
308
|
-
svbool_t const
|
|
307
|
+
svbool_t const predicate_all_b16x = svptrue_b16();
|
|
308
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
309
309
|
|
|
310
310
|
nk_f32_t total_angular_distance = 0.0f;
|
|
311
311
|
|
|
@@ -313,10 +313,10 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
313
313
|
nk_size_t const row_start = row_tile_index * tile_dimension;
|
|
314
314
|
nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
|
|
315
315
|
: (query_count - row_start);
|
|
316
|
-
svbool_t const
|
|
316
|
+
svbool_t const row_predicate_b16x = (rows_remaining == tile_dimension)
|
|
317
317
|
? svptrue_b16()
|
|
318
318
|
: svwhilelt_b16_u64(0u, rows_remaining * 2);
|
|
319
|
-
svbool_t const
|
|
319
|
+
svbool_t const row_predicate_b32x = (rows_remaining == tile_dimension) ? svptrue_b32()
|
|
320
320
|
: svwhilelt_b32_u64(0u, rows_remaining);
|
|
321
321
|
|
|
322
322
|
// Running max + argmax vectors for angular distance finalization
|
|
@@ -332,32 +332,32 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
332
332
|
// Accumulate: for each depth step, load Q vector and 4 D vectors, issue 4 BFMOPAs
|
|
333
333
|
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
334
334
|
svbfloat16_t query_packed_bf16x = svld1_bf16(
|
|
335
|
-
|
|
335
|
+
row_predicate_b16x,
|
|
336
336
|
(bfloat16_t const *)(query_vecs +
|
|
337
337
|
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
338
338
|
svbfloat16_t document_packed_0_bf16x = svld1_bf16(
|
|
339
|
-
|
|
339
|
+
predicate_all_b16x,
|
|
340
340
|
(bfloat16_t const *)(document_vecs +
|
|
341
341
|
((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
|
|
342
342
|
svbfloat16_t document_packed_1_bf16x = svld1_bf16(
|
|
343
|
-
|
|
343
|
+
predicate_all_b16x,
|
|
344
344
|
(bfloat16_t const *)(document_vecs +
|
|
345
345
|
((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
|
|
346
346
|
svbfloat16_t document_packed_2_bf16x = svld1_bf16(
|
|
347
|
-
|
|
347
|
+
predicate_all_b16x,
|
|
348
348
|
(bfloat16_t const *)(document_vecs +
|
|
349
349
|
((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
|
|
350
350
|
svbfloat16_t document_packed_3_bf16x = svld1_bf16(
|
|
351
|
-
|
|
351
|
+
predicate_all_b16x,
|
|
352
352
|
(bfloat16_t const *)(document_vecs +
|
|
353
353
|
((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
|
|
354
|
-
svmopa_za32_bf16_m(0,
|
|
354
|
+
svmopa_za32_bf16_m(0, row_predicate_b16x, predicate_all_b16x, query_packed_bf16x,
|
|
355
355
|
document_packed_0_bf16x);
|
|
356
|
-
svmopa_za32_bf16_m(1,
|
|
356
|
+
svmopa_za32_bf16_m(1, row_predicate_b16x, predicate_all_b16x, query_packed_bf16x,
|
|
357
357
|
document_packed_1_bf16x);
|
|
358
|
-
svmopa_za32_bf16_m(2,
|
|
358
|
+
svmopa_za32_bf16_m(2, row_predicate_b16x, predicate_all_b16x, query_packed_bf16x,
|
|
359
359
|
document_packed_2_bf16x);
|
|
360
|
-
svmopa_za32_bf16_m(3,
|
|
360
|
+
svmopa_za32_bf16_m(3, row_predicate_b16x, predicate_all_b16x, query_packed_bf16x,
|
|
361
361
|
document_packed_3_bf16x);
|
|
362
362
|
}
|
|
363
363
|
|
|
@@ -366,36 +366,36 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
366
366
|
// Tile 0
|
|
367
367
|
{
|
|
368
368
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
|
|
369
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
369
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 0,
|
|
370
370
|
column_within_tile);
|
|
371
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
371
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
372
372
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
373
373
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
374
374
|
}
|
|
375
375
|
// Tile 1
|
|
376
376
|
{
|
|
377
377
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
|
|
378
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
378
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 1,
|
|
379
379
|
column_within_tile);
|
|
380
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
380
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
381
381
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
382
382
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
383
383
|
}
|
|
384
384
|
// Tile 2
|
|
385
385
|
{
|
|
386
386
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
|
|
387
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
387
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 2,
|
|
388
388
|
column_within_tile);
|
|
389
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
389
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
390
390
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
391
391
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
392
392
|
}
|
|
393
393
|
// Tile 3
|
|
394
394
|
{
|
|
395
395
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
|
|
396
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
396
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 3,
|
|
397
397
|
column_within_tile);
|
|
398
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
398
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
399
399
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
400
400
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
401
401
|
}
|
|
@@ -408,7 +408,7 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
408
408
|
nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
|
|
409
409
|
? tile_dimension
|
|
410
410
|
: (document_count - col_start);
|
|
411
|
-
svbool_t const
|
|
411
|
+
svbool_t const column_predicate_b16x = (cols_remaining == tile_dimension)
|
|
412
412
|
? svptrue_b16()
|
|
413
413
|
: svwhilelt_b16_u64(0u, cols_remaining * 2);
|
|
414
414
|
|
|
@@ -416,23 +416,23 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
416
416
|
|
|
417
417
|
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
418
418
|
svbfloat16_t query_packed_bf16x = svld1_bf16(
|
|
419
|
-
|
|
419
|
+
row_predicate_b16x,
|
|
420
420
|
(bfloat16_t const *)(query_vecs +
|
|
421
421
|
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
422
422
|
svbfloat16_t document_packed_bf16x = svld1_bf16(
|
|
423
|
-
|
|
423
|
+
column_predicate_b16x,
|
|
424
424
|
(bfloat16_t const *)(document_vecs +
|
|
425
425
|
(column_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
426
|
-
svmopa_za32_bf16_m(0,
|
|
426
|
+
svmopa_za32_bf16_m(0, row_predicate_b16x, column_predicate_b16x, query_packed_bf16x,
|
|
427
427
|
document_packed_bf16x);
|
|
428
428
|
}
|
|
429
429
|
|
|
430
430
|
// Vertical column extraction from ZA0 + argmax update
|
|
431
431
|
for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
|
|
432
432
|
nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
|
|
433
|
-
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN),
|
|
433
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 0,
|
|
434
434
|
column_within_tile);
|
|
435
|
-
svbool_t is_better_bx = svcmpgt_f32(
|
|
435
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
|
|
436
436
|
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
437
437
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
438
438
|
}
|
|
@@ -442,19 +442,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
442
442
|
// Gather document inverse norms via argmax indices (no SVE gather in streaming mode)
|
|
443
443
|
nk_u32_t best_document_indices[64];
|
|
444
444
|
nk_f32_t document_inverse_norms_gathered[64];
|
|
445
|
-
svst1_u32(
|
|
445
|
+
svst1_u32(row_predicate_b32x, best_document_indices, running_argmax_u32x);
|
|
446
446
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++)
|
|
447
447
|
document_inverse_norms_gathered[row_in_tile] = document_inverse_norms[best_document_indices[row_in_tile]];
|
|
448
448
|
|
|
449
449
|
// SVE-width: cosine = dot * inv_norm_q * inv_norm_d, angular = max(1 - cosine, 0)
|
|
450
|
-
svfloat32_t query_inverse_norms_f32x = svld1_f32(
|
|
451
|
-
svfloat32_t document_inverse_norms_f32x = svld1_f32(
|
|
450
|
+
svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_b32x, query_inverse_norms + row_start);
|
|
451
|
+
svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_b32x, document_inverse_norms_gathered);
|
|
452
452
|
svfloat32_t cosine_f32x = svmul_f32_x(
|
|
453
|
-
|
|
453
|
+
row_predicate_b32x, svmul_f32_x(row_predicate_b32x, running_maximum_f32x, query_inverse_norms_f32x),
|
|
454
454
|
document_inverse_norms_f32x);
|
|
455
455
|
svfloat32_t angular_distance_f32x = svmax_f32_x(
|
|
456
|
-
|
|
457
|
-
total_angular_distance += svaddv_f32(
|
|
456
|
+
row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
|
|
457
|
+
total_angular_distance += svaddv_f32(row_predicate_b32x, angular_distance_f32x);
|
|
458
458
|
}
|
|
459
459
|
|
|
460
460
|
*result = total_angular_distance;
|
|
@@ -468,20 +468,20 @@ NK_PUBLIC void nk_maxsim_packed_bf16_sme( //
|
|
|
468
468
|
nk_maxsim_packed_bf16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
469
469
|
}
|
|
470
470
|
|
|
471
|
-
NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t
|
|
472
|
-
return nk_dots_packed_size_bf16_sme(
|
|
471
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t columns, nk_size_t depth) { //
|
|
472
|
+
return nk_dots_packed_size_bf16_sme(columns, depth);
|
|
473
473
|
}
|
|
474
474
|
|
|
475
|
-
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sme(nk_size_t
|
|
476
|
-
return nk_dots_packed_size_f16_sme(
|
|
475
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sme(nk_size_t columns, nk_size_t depth) { //
|
|
476
|
+
return nk_dots_packed_size_f16_sme(columns, depth);
|
|
477
477
|
}
|
|
478
478
|
|
|
479
|
-
NK_PUBLIC void nk_maxsim_pack_bf16_sme(
|
|
480
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
479
|
+
NK_PUBLIC void nk_maxsim_pack_bf16_sme( //
|
|
480
|
+
nk_bf16_t const *vectors, nk_size_t columns, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) { //
|
|
481
481
|
|
|
482
482
|
// Delegate tile interleaving and squared norms computation to dots pack.
|
|
483
483
|
// Both headers are 64 bytes with identical layout for the first 6 fields.
|
|
484
|
-
nk_dots_pack_bf16_sme(vectors,
|
|
484
|
+
nk_dots_pack_bf16_sme(vectors, columns, depth, stride_in_bytes, packed);
|
|
485
485
|
|
|
486
486
|
// Set maxsim-specific header fields (overlaps dots reserved area)
|
|
487
487
|
nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
|
|
@@ -491,18 +491,18 @@ NK_PUBLIC void nk_maxsim_pack_bf16_sme(
|
|
|
491
491
|
|
|
492
492
|
// Convert squared norms → inverse norms in-place
|
|
493
493
|
nk_f32_t *norms = (nk_f32_t *)((char *)packed + header->norms_offset);
|
|
494
|
-
for (nk_size_t i = 0; i <
|
|
494
|
+
for (nk_size_t i = 0; i < columns; i++) {
|
|
495
495
|
nk_f32_t norm_sq = norms[i];
|
|
496
496
|
norms[i] = (norm_sq > 0.0f) ? (nk_f32_t)nk_f64_rsqrt_neon((nk_f64_t)norm_sq) : 0.0f;
|
|
497
497
|
}
|
|
498
498
|
}
|
|
499
499
|
|
|
500
|
-
NK_PUBLIC void nk_maxsim_pack_f16_sme(
|
|
501
|
-
nk_f16_t const *vectors, nk_size_t
|
|
500
|
+
NK_PUBLIC void nk_maxsim_pack_f16_sme( //
|
|
501
|
+
nk_f16_t const *vectors, nk_size_t columns, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) { //
|
|
502
502
|
|
|
503
503
|
// Delegate tile interleaving and squared norms computation to dots pack.
|
|
504
504
|
// Both headers are 64 bytes with identical layout for the first 6 fields.
|
|
505
|
-
nk_dots_pack_f16_sme(vectors,
|
|
505
|
+
nk_dots_pack_f16_sme(vectors, columns, depth, stride_in_bytes, packed);
|
|
506
506
|
|
|
507
507
|
// Set maxsim-specific header fields (overlaps dots reserved area)
|
|
508
508
|
nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
|
|
@@ -512,7 +512,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_sme(
|
|
|
512
512
|
|
|
513
513
|
// Convert squared norms → inverse norms in-place
|
|
514
514
|
nk_f32_t *norms = (nk_f32_t *)((char *)packed + header->norms_offset);
|
|
515
|
-
for (nk_size_t i = 0; i <
|
|
515
|
+
for (nk_size_t i = 0; i < columns; i++) {
|
|
516
516
|
nk_f32_t norm_sq = norms[i];
|
|
517
517
|
norms[i] = (norm_sq > 0.0f) ? (nk_f32_t)nk_f64_rsqrt_neon((nk_f64_t)norm_sq) : 0.0f;
|
|
518
518
|
}
|
|
@@ -527,45 +527,45 @@ NK_PUBLIC void nk_maxsim_pack_f16_sme(
|
|
|
527
527
|
* Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
|
|
528
528
|
* Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
|
|
529
529
|
*/
|
|
530
|
-
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sme(nk_size_t
|
|
531
|
-
nk_size_t const expansion = 4;
|
|
532
|
-
nk_size_t const tile_dimension =
|
|
533
|
-
nk_size_t const vector_elements =
|
|
534
|
-
nk_size_t const column_tile_count = nk_size_divide_round_up_(
|
|
535
|
-
nk_size_t const depth_step_count = nk_size_divide_round_up_(
|
|
536
|
-
nk_size_t const original_stride = nk_size_round_up_to_multiple_(
|
|
530
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sme(nk_size_t columns, nk_size_t depth) { //
|
|
531
|
+
nk_size_t const expansion = 4; // i8->i32 SMOPA
|
|
532
|
+
nk_size_t const tile_dimension = nk_sme_cntw_(); // 16 for SVL=512
|
|
533
|
+
nk_size_t const vector_elements = nk_sme_cntb_(); // 64 for SVL=512
|
|
534
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
|
|
535
|
+
nk_size_t const depth_step_count = nk_size_divide_round_up_(depth, expansion);
|
|
536
|
+
nk_size_t const original_stride = nk_size_round_up_to_multiple_(depth * sizeof(nk_f32_t), 64);
|
|
537
537
|
|
|
538
538
|
nk_size_t size = sizeof(nk_maxsim_sme_packed_header_t); // 64 B header
|
|
539
539
|
size += column_tile_count * depth_step_count * vector_elements; // i8 tiles
|
|
540
|
-
size +=
|
|
541
|
-
size +=
|
|
540
|
+
size += columns * sizeof(nk_f32_t); // f32 squared norms
|
|
541
|
+
size += columns * original_stride; // f32 originals
|
|
542
542
|
return size;
|
|
543
543
|
}
|
|
544
544
|
|
|
545
|
-
NK_PUBLIC void nk_maxsim_pack_f32_sme(
|
|
546
|
-
nk_f32_t const *vectors, nk_size_t
|
|
545
|
+
NK_PUBLIC void nk_maxsim_pack_f32_sme( //
|
|
546
|
+
nk_f32_t const *vectors, nk_size_t columns, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) { //
|
|
547
547
|
|
|
548
|
-
nk_size_t const expansion = 4;
|
|
549
|
-
nk_size_t const tile_dimension =
|
|
550
|
-
nk_size_t const vector_elements =
|
|
551
|
-
nk_size_t const stride_elements =
|
|
548
|
+
nk_size_t const expansion = 4; // i8->i32 SMOPA
|
|
549
|
+
nk_size_t const tile_dimension = nk_sme_cntw_(); // 16 for SVL=512
|
|
550
|
+
nk_size_t const vector_elements = nk_sme_cntb_(); // 64 for SVL=512
|
|
551
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
552
552
|
|
|
553
|
-
nk_size_t const column_tile_count = nk_size_divide_round_up_(
|
|
554
|
-
nk_size_t const depth_step_count = nk_size_divide_round_up_(
|
|
553
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
|
|
554
|
+
nk_size_t const depth_step_count = nk_size_divide_round_up_(depth, expansion);
|
|
555
555
|
nk_size_t const total_vectors = column_tile_count * depth_step_count;
|
|
556
|
-
nk_size_t const original_stride = nk_size_round_up_to_multiple_(
|
|
556
|
+
nk_size_t const original_stride = nk_size_round_up_to_multiple_(depth * sizeof(nk_f32_t), 64);
|
|
557
557
|
|
|
558
558
|
// Set up header
|
|
559
559
|
nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
|
|
560
560
|
header->column_tile_count = (nk_u32_t)column_tile_count;
|
|
561
561
|
header->depth_tile_count = (nk_u32_t)depth_step_count;
|
|
562
|
-
header->columns = (nk_u32_t)
|
|
563
|
-
header->depth = (nk_u32_t)
|
|
564
|
-
header->svl_bytes = (nk_u32_t)(
|
|
562
|
+
header->columns = (nk_u32_t)columns;
|
|
563
|
+
header->depth = (nk_u32_t)depth;
|
|
564
|
+
header->svl_bytes = (nk_u32_t)(tile_dimension * sizeof(nk_f32_t));
|
|
565
565
|
|
|
566
566
|
nk_size_t const tiles_size = total_vectors * vector_elements;
|
|
567
567
|
nk_size_t const norms_offset = sizeof(nk_maxsim_sme_packed_header_t) + tiles_size;
|
|
568
|
-
nk_size_t const originals_offset = norms_offset +
|
|
568
|
+
nk_size_t const originals_offset = norms_offset + columns * sizeof(nk_f32_t);
|
|
569
569
|
|
|
570
570
|
header->norms_offset = (nk_u32_t)norms_offset;
|
|
571
571
|
header->originals_offset = (nk_u32_t)originals_offset;
|
|
@@ -580,13 +580,13 @@ NK_PUBLIC void nk_maxsim_pack_f32_sme(
|
|
|
580
580
|
for (nk_size_t i = 0; i < tiles_size; i++) tiles[i] = 0;
|
|
581
581
|
|
|
582
582
|
// For each vector: quantize metadata, quantize+interleave into tiles, copy originals
|
|
583
|
-
for (nk_size_t vector_index = 0; vector_index <
|
|
584
|
-
nk_f32_t const *source = (nk_f32_t const *)((char const *)vectors + vector_index *
|
|
583
|
+
for (nk_size_t vector_index = 0; vector_index < columns; vector_index++) {
|
|
584
|
+
nk_f32_t const *source = (nk_f32_t const *)((char const *)vectors + vector_index * stride_in_bytes);
|
|
585
585
|
|
|
586
586
|
// Pass 1: Compute absmax and norm_sq simultaneously
|
|
587
587
|
nk_f32_t absmax = 0.0f;
|
|
588
588
|
nk_f32_t norm_sq = 0.0f;
|
|
589
|
-
for (nk_size_t dim = 0; dim <
|
|
589
|
+
for (nk_size_t dim = 0; dim < depth; dim++) {
|
|
590
590
|
nk_f32_t val = source[dim];
|
|
591
591
|
nk_f32_t abs_val = nk_f32_abs_(val);
|
|
592
592
|
if (abs_val > absmax) absmax = abs_val;
|
|
@@ -601,7 +601,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_sme(
|
|
|
601
601
|
nk_size_t const column_tile = vector_index / tile_dimension;
|
|
602
602
|
nk_size_t const column_in_tile = vector_index % tile_dimension;
|
|
603
603
|
|
|
604
|
-
for (nk_size_t dim = 0; dim <
|
|
604
|
+
for (nk_size_t dim = 0; dim < depth; dim++) {
|
|
605
605
|
nk_size_t const depth_step = dim / expansion;
|
|
606
606
|
nk_size_t const sub_element = dim % expansion;
|
|
607
607
|
nk_size_t const vec_index = column_tile * depth_step_count + depth_step;
|
|
@@ -619,8 +619,8 @@ NK_PUBLIC void nk_maxsim_pack_f32_sme(
|
|
|
619
619
|
|
|
620
620
|
// Pass 3: Copy originals (64B-aligned stride, zero-pad tail)
|
|
621
621
|
char *dest_original = originals + vector_index * original_stride;
|
|
622
|
-
nk_copy_bytes_(dest_original, source,
|
|
623
|
-
for (nk_size_t byte =
|
|
622
|
+
nk_copy_bytes_(dest_original, source, depth * sizeof(nk_f32_t));
|
|
623
|
+
for (nk_size_t byte = depth * sizeof(nk_f32_t); byte < original_stride; byte++) dest_original[byte] = 0;
|
|
624
624
|
}
|
|
625
625
|
}
|
|
626
626
|
|
|
@@ -628,16 +628,28 @@ NK_PUBLIC void nk_maxsim_pack_f32_sme(
|
|
|
628
628
|
* Streaming-compatible f32 dot product with f64 accumulation.
|
|
629
629
|
* Follows the svcntd()-stride + svcvt_f64_f32_x pattern from nk_dots_reduce_sumsq_f32_ssve_.
|
|
630
630
|
*/
|
|
631
|
-
NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_(
|
|
632
|
-
nk_f32_t const *a, nk_f32_t const *b, nk_size_t count)
|
|
633
|
-
svfloat64_t
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
631
|
+
NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
|
|
632
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_size_t count) NK_STREAMING_ { //
|
|
633
|
+
svfloat64_t accumulator_even_f64x = svdup_f64(0.0);
|
|
634
|
+
svfloat64_t accumulator_odd_f64x = svdup_f64(0.0);
|
|
635
|
+
nk_size_t const vector_length = svcntw();
|
|
636
|
+
nk_size_t const half_vector_length = svcntd();
|
|
637
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
638
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(i, count);
|
|
639
|
+
svfloat32_t a_f32x = svld1_f32(predicate_b32x, a + i);
|
|
640
|
+
svfloat32_t b_f32x = svld1_f32(predicate_b32x, b + i);
|
|
641
|
+
|
|
642
|
+
svbool_t predicate_even_b64x = svwhilelt_b64_u64(i, count);
|
|
643
|
+
svfloat64_t a_even_f64x = svcvt_f64_f32_x(predicate_even_b64x, a_f32x);
|
|
644
|
+
svfloat64_t b_even_f64x = svcvt_f64_f32_x(predicate_even_b64x, b_f32x);
|
|
645
|
+
accumulator_even_f64x = svmla_f64_m(predicate_even_b64x, accumulator_even_f64x, a_even_f64x, b_even_f64x);
|
|
646
|
+
|
|
647
|
+
svbool_t predicate_odd_b64x = svwhilelt_b64_u64(i + half_vector_length, count);
|
|
648
|
+
svfloat64_t a_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, a_f32x);
|
|
649
|
+
svfloat64_t b_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, b_f32x);
|
|
650
|
+
accumulator_odd_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_odd_f64x, a_odd_f64x, b_odd_f64x);
|
|
639
651
|
}
|
|
640
|
-
return svaddv_f64(svptrue_b64(),
|
|
652
|
+
return svaddv_f64(svptrue_b64(), accumulator_even_f64x) + svaddv_f64(svptrue_b64(), accumulator_odd_f64x);
|
|
641
653
|
}
|
|
642
654
|
|
|
643
655
|
/**
|
|
@@ -680,8 +692,8 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
680
692
|
|
|
681
693
|
nk_size_t const expansion = 4; // i8->i32 SMOPA
|
|
682
694
|
|
|
683
|
-
svbool_t const
|
|
684
|
-
svbool_t const
|
|
695
|
+
svbool_t const predicate_all_b8x = svptrue_b8();
|
|
696
|
+
svbool_t const predicate_all_b32x = svptrue_b32();
|
|
685
697
|
|
|
686
698
|
nk_f64_t total_angular_distance_f64 = 0.0;
|
|
687
699
|
|
|
@@ -689,10 +701,10 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
689
701
|
nk_size_t const row_start = row_tile_index * tile_dimension;
|
|
690
702
|
nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
|
|
691
703
|
: (query_count - row_start);
|
|
692
|
-
svbool_t const
|
|
704
|
+
svbool_t const row_predicate_b8x = (rows_remaining == tile_dimension)
|
|
693
705
|
? svptrue_b8()
|
|
694
706
|
: svwhilelt_b8_u64(0u, rows_remaining * expansion);
|
|
695
|
-
svbool_t const
|
|
707
|
+
svbool_t const row_predicate_b32x = (rows_remaining == tile_dimension) ? svptrue_b32()
|
|
696
708
|
: svwhilelt_b32_u64(0u, rows_remaining);
|
|
697
709
|
|
|
698
710
|
svint32_t running_max_i32x = svdup_s32(NK_I32_MIN);
|
|
@@ -706,28 +718,29 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
706
718
|
|
|
707
719
|
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
708
720
|
svint8_t query_packed_i8x = svld1_s8(
|
|
709
|
-
|
|
710
|
-
(
|
|
721
|
+
row_predicate_b8x,
|
|
722
|
+
(nk_i8_t const *)(query_tiles +
|
|
723
|
+
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
711
724
|
svint8_t document_packed_0_i8x = svld1_s8(
|
|
712
|
-
|
|
713
|
-
(
|
|
714
|
-
|
|
725
|
+
predicate_all_b8x,
|
|
726
|
+
(nk_i8_t const *)(document_tiles +
|
|
727
|
+
((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
|
|
715
728
|
svint8_t document_packed_1_i8x = svld1_s8(
|
|
716
|
-
|
|
717
|
-
(
|
|
718
|
-
|
|
729
|
+
predicate_all_b8x,
|
|
730
|
+
(nk_i8_t const *)(document_tiles +
|
|
731
|
+
((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
|
|
719
732
|
svint8_t document_packed_2_i8x = svld1_s8(
|
|
720
|
-
|
|
721
|
-
(
|
|
722
|
-
|
|
733
|
+
predicate_all_b8x,
|
|
734
|
+
(nk_i8_t const *)(document_tiles +
|
|
735
|
+
((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
|
|
723
736
|
svint8_t document_packed_3_i8x = svld1_s8(
|
|
724
|
-
|
|
725
|
-
(
|
|
726
|
-
|
|
727
|
-
svmopa_za32_s8_m(0,
|
|
728
|
-
svmopa_za32_s8_m(1,
|
|
729
|
-
svmopa_za32_s8_m(2,
|
|
730
|
-
svmopa_za32_s8_m(3,
|
|
737
|
+
predicate_all_b8x,
|
|
738
|
+
(nk_i8_t const *)(document_tiles +
|
|
739
|
+
((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
|
|
740
|
+
svmopa_za32_s8_m(0, row_predicate_b8x, predicate_all_b8x, query_packed_i8x, document_packed_0_i8x);
|
|
741
|
+
svmopa_za32_s8_m(1, row_predicate_b8x, predicate_all_b8x, query_packed_i8x, document_packed_1_i8x);
|
|
742
|
+
svmopa_za32_s8_m(2, row_predicate_b8x, predicate_all_b8x, query_packed_i8x, document_packed_2_i8x);
|
|
743
|
+
svmopa_za32_s8_m(3, row_predicate_b8x, predicate_all_b8x, query_packed_i8x, document_packed_3_i8x);
|
|
731
744
|
}
|
|
732
745
|
|
|
733
746
|
// Vertical column extraction + argmax update (manually unrolled over 4 tiles)
|
|
@@ -735,36 +748,36 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
735
748
|
// Tile 0
|
|
736
749
|
{
|
|
737
750
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
|
|
738
|
-
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN),
|
|
751
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 0,
|
|
739
752
|
column_within_tile);
|
|
740
|
-
svbool_t is_better_bx = svcmpgt_s32(
|
|
753
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
|
|
741
754
|
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
742
755
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
743
756
|
}
|
|
744
757
|
// Tile 1
|
|
745
758
|
{
|
|
746
759
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
|
|
747
|
-
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN),
|
|
760
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 1,
|
|
748
761
|
column_within_tile);
|
|
749
|
-
svbool_t is_better_bx = svcmpgt_s32(
|
|
762
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
|
|
750
763
|
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
751
764
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
752
765
|
}
|
|
753
766
|
// Tile 2
|
|
754
767
|
{
|
|
755
768
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
|
|
756
|
-
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN),
|
|
769
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 2,
|
|
757
770
|
column_within_tile);
|
|
758
|
-
svbool_t is_better_bx = svcmpgt_s32(
|
|
771
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
|
|
759
772
|
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
760
773
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
761
774
|
}
|
|
762
775
|
// Tile 3
|
|
763
776
|
{
|
|
764
777
|
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
|
|
765
|
-
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN),
|
|
778
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 3,
|
|
766
779
|
column_within_tile);
|
|
767
|
-
svbool_t is_better_bx = svcmpgt_s32(
|
|
780
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
|
|
768
781
|
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
769
782
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
770
783
|
}
|
|
@@ -777,7 +790,7 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
777
790
|
nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
|
|
778
791
|
? tile_dimension
|
|
779
792
|
: (document_count - col_start);
|
|
780
|
-
svbool_t const
|
|
793
|
+
svbool_t const column_predicate_b8x = (cols_remaining == tile_dimension)
|
|
781
794
|
? svptrue_b8()
|
|
782
795
|
: svwhilelt_b8_u64(0u, cols_remaining * expansion);
|
|
783
796
|
|
|
@@ -785,20 +798,21 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
785
798
|
|
|
786
799
|
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
787
800
|
svint8_t query_packed_i8x = svld1_s8(
|
|
788
|
-
|
|
789
|
-
(
|
|
801
|
+
row_predicate_b8x,
|
|
802
|
+
(nk_i8_t const *)(query_tiles +
|
|
803
|
+
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
790
804
|
svint8_t document_packed_i8x = svld1_s8(
|
|
791
|
-
|
|
792
|
-
(
|
|
793
|
-
|
|
794
|
-
svmopa_za32_s8_m(0,
|
|
805
|
+
column_predicate_b8x,
|
|
806
|
+
(nk_i8_t const *)(document_tiles +
|
|
807
|
+
(column_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
808
|
+
svmopa_za32_s8_m(0, row_predicate_b8x, column_predicate_b8x, query_packed_i8x, document_packed_i8x);
|
|
795
809
|
}
|
|
796
810
|
|
|
797
811
|
for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
|
|
798
812
|
nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
|
|
799
|
-
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN),
|
|
813
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 0,
|
|
800
814
|
column_within_tile);
|
|
801
|
-
svbool_t is_better_bx = svcmpgt_s32(
|
|
815
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
|
|
802
816
|
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
803
817
|
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
804
818
|
}
|
|
@@ -806,7 +820,7 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
806
820
|
|
|
807
821
|
// Refinement: tile-wide interleaved f64 dot products
|
|
808
822
|
nk_u32_t best_document_indices[64]; // max tile_dimension across all SVL values
|
|
809
|
-
svst1_u32(
|
|
823
|
+
svst1_u32(row_predicate_b32x, best_document_indices, running_argmax_u32x);
|
|
810
824
|
|
|
811
825
|
// Pointer setup: one (query, document) pair per row in the tile
|
|
812
826
|
nk_f32_t const *query_original_ptrs[64];
|
|
@@ -828,46 +842,57 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
828
842
|
svfloat64_t accumulator_1_f64x = svdup_f64(0.0);
|
|
829
843
|
svfloat64_t accumulator_2_f64x = svdup_f64(0.0);
|
|
830
844
|
svfloat64_t accumulator_3_f64x = svdup_f64(0.0);
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
845
|
+
nk_size_t const depth_vector_length = svcntw();
|
|
846
|
+
nk_size_t const depth_half_length = svcntd();
|
|
847
|
+
|
|
848
|
+
for (nk_size_t depth_index = 0; depth_index < depth; depth_index += depth_vector_length) {
|
|
849
|
+
svbool_t predicate_depth_b32x = svwhilelt_b32_u64(depth_index, depth);
|
|
850
|
+
svbool_t predicate_even_b64x = svwhilelt_b64_u64(depth_index, depth);
|
|
851
|
+
svbool_t predicate_odd_b64x = svwhilelt_b64_u64(depth_index + depth_half_length, depth);
|
|
852
|
+
|
|
853
|
+
svfloat32_t query_values_0_f32x = svld1_f32(predicate_depth_b32x,
|
|
854
|
+
query_original_ptrs[row_batch_start + 0] + depth_index);
|
|
855
|
+
svfloat32_t document_values_0_f32x = svld1_f32(
|
|
856
|
+
predicate_depth_b32x, document_original_ptrs[row_batch_start + 0] + depth_index);
|
|
857
|
+
accumulator_0_f64x = svmla_f64_m(predicate_even_b64x, accumulator_0_f64x,
|
|
858
|
+
svcvt_f64_f32_x(predicate_even_b64x, query_values_0_f32x),
|
|
859
|
+
svcvt_f64_f32_x(predicate_even_b64x, document_values_0_f32x));
|
|
860
|
+
accumulator_0_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_0_f64x,
|
|
861
|
+
svcvtlt_f64_f32_x(predicate_odd_b64x, query_values_0_f32x),
|
|
862
|
+
svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_0_f32x));
|
|
863
|
+
|
|
864
|
+
svfloat32_t query_values_1_f32x = svld1_f32(predicate_depth_b32x,
|
|
865
|
+
query_original_ptrs[row_batch_start + 1] + depth_index);
|
|
866
|
+
svfloat32_t document_values_1_f32x = svld1_f32(
|
|
867
|
+
predicate_depth_b32x, document_original_ptrs[row_batch_start + 1] + depth_index);
|
|
868
|
+
accumulator_1_f64x = svmla_f64_m(predicate_even_b64x, accumulator_1_f64x,
|
|
869
|
+
svcvt_f64_f32_x(predicate_even_b64x, query_values_1_f32x),
|
|
870
|
+
svcvt_f64_f32_x(predicate_even_b64x, document_values_1_f32x));
|
|
871
|
+
accumulator_1_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_1_f64x,
|
|
872
|
+
svcvtlt_f64_f32_x(predicate_odd_b64x, query_values_1_f32x),
|
|
873
|
+
svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_1_f32x));
|
|
874
|
+
|
|
875
|
+
svfloat32_t query_values_2_f32x = svld1_f32(predicate_depth_b32x,
|
|
876
|
+
query_original_ptrs[row_batch_start + 2] + depth_index);
|
|
877
|
+
svfloat32_t document_values_2_f32x = svld1_f32(
|
|
878
|
+
predicate_depth_b32x, document_original_ptrs[row_batch_start + 2] + depth_index);
|
|
879
|
+
accumulator_2_f64x = svmla_f64_m(predicate_even_b64x, accumulator_2_f64x,
|
|
880
|
+
svcvt_f64_f32_x(predicate_even_b64x, query_values_2_f32x),
|
|
881
|
+
svcvt_f64_f32_x(predicate_even_b64x, document_values_2_f32x));
|
|
882
|
+
accumulator_2_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_2_f64x,
|
|
883
|
+
svcvtlt_f64_f32_x(predicate_odd_b64x, query_values_2_f32x),
|
|
884
|
+
svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_2_f32x));
|
|
885
|
+
|
|
886
|
+
svfloat32_t query_values_3_f32x = svld1_f32(predicate_depth_b32x,
|
|
887
|
+
query_original_ptrs[row_batch_start + 3] + depth_index);
|
|
888
|
+
svfloat32_t document_values_3_f32x = svld1_f32(
|
|
889
|
+
predicate_depth_b32x, document_original_ptrs[row_batch_start + 3] + depth_index);
|
|
890
|
+
accumulator_3_f64x = svmla_f64_m(predicate_even_b64x, accumulator_3_f64x,
|
|
891
|
+
svcvt_f64_f32_x(predicate_even_b64x, query_values_3_f32x),
|
|
892
|
+
svcvt_f64_f32_x(predicate_even_b64x, document_values_3_f32x));
|
|
893
|
+
accumulator_3_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_3_f64x,
|
|
894
|
+
svcvtlt_f64_f32_x(predicate_odd_b64x, query_values_3_f32x),
|
|
895
|
+
svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_3_f32x));
|
|
871
896
|
}
|
|
872
897
|
|
|
873
898
|
// Reduce accumulators and compute angular distance per row
|