numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,929 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated MaxSim (ColBERT late-interaction) for SME.
|
|
3
|
+
* @file include/numkong/maxsim/sme.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 10, 2026
|
|
6
|
+
*
|
|
7
|
+
* Computes MaxSim(Q, D) = Σᵢ maxⱼ dot(qᵢ, dⱼ) using ARM SME outer products.
|
|
8
|
+
*
|
|
9
|
+
* Both Q and D are pre-packed with `nk_dots_pack_bf16_sme` from `dots/sme.h`.
|
|
10
|
+
* This frees all 4 ZA tiles for accumulation (vs 3 with A-side staging).
|
|
11
|
+
*
|
|
12
|
+
* Key optimization: vertical column reads for max reduction.
|
|
13
|
+
* Traditional extraction reads tile rows then calls `svmaxv` (horizontal max, ~8cy).
|
|
14
|
+
* Our approach reads tile columns with `svread_ver_za32_f32_m`:
|
|
15
|
+
*
|
|
16
|
+
* - Each column read gives dot products of all query tokens vs one doc token.
|
|
17
|
+
* - Element-wise `svmax` (~1cy) updates a running max vector across doc tokens.
|
|
18
|
+
* - Only `svaddv` at the very end: ⌈n_q/16⌉ = 2 horizontal reductions total.
|
|
19
|
+
*
|
|
20
|
+
* This is ~100x fewer horizontal reductions for typical ColBERT dimensions.
|
|
21
|
+
*
|
|
22
|
+
* ZA tile layout after BFMOPA accumulation (16x16 f32):
|
|
23
|
+
*
|
|
24
|
+
* - Row i, Column j = dot(q_{tile_row_start + i}, d_{tile_col_start + j})
|
|
25
|
+
* - Vertical column read of column j → similarities of all 16 q tokens to doc token j
|
|
26
|
+
* - Element-wise max across columns → per-query-token max over doc tokens in this tile group
|
|
27
|
+
*
|
|
28
|
+
* Benchmark results (Apple M4, SVL=512):
|
|
29
|
+
*
|
|
30
|
+
* Dimensions dots_packed GEMM maxsim fused GEMM speedup End-to-end speedup
|
|
31
|
+
* 32×128×128 (ColBERT) 840 GFLOPS 1516 GFLOPS 1.81× 5.10×
|
|
32
|
+
* 32×256×128 1037 GFLOPS 1591 GFLOPS 1.53× 5.17×
|
|
33
|
+
* 64×512×128 1016 GFLOPS 1651 GFLOPS 1.62× 5.42×
|
|
34
|
+
* 32×128×256 859 GFLOPS 1725 GFLOPS 2.01× 4.06×
|
|
35
|
+
* 32×1024×768 (BERT) 1124 GFLOPS 1932 GFLOPS 1.72× 2.61×
|
|
36
|
+
*
|
|
37
|
+
* Speedup sources:
|
|
38
|
+
*
|
|
39
|
+
* 1. Pre-packing both sides → 4 ZA tiles for accumulation (vs 3 with A-staging): +33% MOPA throughput
|
|
40
|
+
* 2. No output matrix materialization → eliminates M×N f32 memory round-trip
|
|
41
|
+
* 3. Vertical column reads → ~128 element-wise svmax (1cy) vs ~256 svmaxv horizontal reductions (8cy)
|
|
42
|
+
*/
|
|
43
|
+
#ifndef NK_MAXSIM_SME_H
|
|
44
|
+
#define NK_MAXSIM_SME_H
|
|
45
|
+
|
|
46
|
+
#if NK_TARGET_ARM_
|
|
47
|
+
#if NK_TARGET_SME
|
|
48
|
+
|
|
49
|
+
#include "numkong/dots/sme.h" // nk_dots_sme_packed_header_t, nk_dots_pack_{f16,bf16}_sme, nk_dots_packed_size_{f16,bf16}_sme
|
|
50
|
+
|
|
51
|
+
#if defined(__cplusplus)
|
|
52
|
+
extern "C" {
|
|
53
|
+
#endif
|
|
54
|
+
|
|
55
|
+
#if defined(__clang__)
|
|
56
|
+
#pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
|
|
57
|
+
#elif defined(__GNUC__)
|
|
58
|
+
#pragma GCC push_options
|
|
59
|
+
#pragma GCC target("+sme")
|
|
60
|
+
#endif
|
|
61
|
+
|
|
62
|
+
/**
|
|
63
|
+
* Packed header for MaxSim SME kernels. Used by f32 (i8 screening + f32 refinement)
|
|
64
|
+
* and bf16/f16 (BFMOPA/FMOPA + angular normalization) kernels.
|
|
65
|
+
*
|
|
66
|
+
* For f32: stores i8 tile-interleaved data, f32 squared norms, AND f32 originals.
|
|
67
|
+
* For bf16/f16: stores tile-interleaved data and f32 inverse norms (1/||v||).
|
|
68
|
+
* originals_offset and original_stride are 0 (unused).
|
|
69
|
+
*/
|
|
70
|
+
typedef struct {
|
|
71
|
+
nk_u32_t column_tile_count; // ceil(n / tile_dimension)
|
|
72
|
+
nk_u32_t depth_tile_count; // ceil(depth / expansion)
|
|
73
|
+
nk_u32_t columns; // actual vector count (for predicates)
|
|
74
|
+
nk_u32_t depth; // actual depth
|
|
75
|
+
nk_u32_t svl_bytes; // SVL in bytes at pack time (validation)
|
|
76
|
+
nk_u32_t norms_offset; // byte offset -> per-vector norms (squared for f32, inverse for bf16/f16)
|
|
77
|
+
nk_u32_t originals_offset; // byte offset -> f32 original vectors (0 for bf16/f16)
|
|
78
|
+
nk_u32_t original_stride; // row stride in bytes for originals (64B-aligned, 0 for bf16/f16)
|
|
79
|
+
nk_u32_t reserved[8]; // padding to 64 bytes
|
|
80
|
+
} nk_maxsim_sme_packed_header_t;
|
|
81
|
+
|
|
82
|
+
NK_STATIC_ASSERT(sizeof(nk_maxsim_sme_packed_header_t) == 64, nk_maxsim_sme_packed_header_must_be_64_bytes);
|
|
83
|
+
|
|
84
|
+
/**
|
|
85
|
+
* MaxSim f16 kernel: both Q and D pre-packed, vertical column read extraction.
|
|
86
|
+
*
|
|
87
|
+
* 4-tile fast path: processes 4 doc column tiles simultaneously using ZA0-ZA3.
|
|
88
|
+
* Inner loop per depth_step: 1 Q load + 4 D loads + 4 FMOPA = 9 ops.
|
|
89
|
+
* Extraction per 4-tile group: 4×16 = 64 vertical reads + 64 svmax = ~128 cycles.
|
|
90
|
+
*
|
|
91
|
+
* 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
|
|
92
|
+
*/
|
|
93
|
+
__arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streaming_( //
|
|
94
|
+
void const *query_packed, void const *document_packed, //
|
|
95
|
+
nk_size_t query_count, nk_size_t document_count, //
|
|
96
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
97
|
+
|
|
98
|
+
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
99
|
+
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
100
|
+
nk_size_t const depth_step_count = query_header->depth_tile_count;
|
|
101
|
+
nk_size_t const query_row_tiles = query_header->column_tile_count;
|
|
102
|
+
nk_size_t const document_col_tiles = document_header->column_tile_count;
|
|
103
|
+
|
|
104
|
+
nk_size_t const tile_dimension = svcntw(); // 16: ZA32 tile dimension
|
|
105
|
+
nk_size_t const vector_elements = svcnth(); // 32: f16 elements per SVE vector
|
|
106
|
+
|
|
107
|
+
nk_f16_t const *query_vecs = (nk_f16_t const *)((char const *)query_packed + sizeof(nk_maxsim_sme_packed_header_t));
|
|
108
|
+
nk_f16_t const *document_vecs = (nk_f16_t const *)((char const *)document_packed +
|
|
109
|
+
sizeof(nk_maxsim_sme_packed_header_t));
|
|
110
|
+
|
|
111
|
+
nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
|
|
112
|
+
nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
|
|
113
|
+
document_header->norms_offset);
|
|
114
|
+
|
|
115
|
+
svbool_t const predicate_all_f16x = svptrue_b16();
|
|
116
|
+
svbool_t const predicate_all_f32x = svptrue_b32();
|
|
117
|
+
|
|
118
|
+
nk_f32_t total_angular_distance = 0.0f;
|
|
119
|
+
|
|
120
|
+
for (nk_size_t row_tile_index = 0; row_tile_index < query_row_tiles; row_tile_index++) {
|
|
121
|
+
nk_size_t const row_start = row_tile_index * tile_dimension;
|
|
122
|
+
nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
|
|
123
|
+
: (query_count - row_start);
|
|
124
|
+
svbool_t const row_predicate_f16x = (rows_remaining == tile_dimension)
|
|
125
|
+
? svptrue_b16()
|
|
126
|
+
: svwhilelt_b16_u64(0u, rows_remaining * 2);
|
|
127
|
+
svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
|
|
128
|
+
: svwhilelt_b32_u64(0u, rows_remaining);
|
|
129
|
+
|
|
130
|
+
// Running max + argmax vectors for angular distance finalization
|
|
131
|
+
svfloat32_t running_maximum_f32x = svdup_f32(NK_F32_MIN);
|
|
132
|
+
svuint32_t running_argmax_u32x = svdup_u32(0);
|
|
133
|
+
|
|
134
|
+
nk_size_t column_tile_index = 0;
|
|
135
|
+
|
|
136
|
+
// Fast path: 4 doc column tiles at a time using ZA0-ZA3
|
|
137
|
+
for (; column_tile_index + 4 <= document_col_tiles; column_tile_index += 4) {
|
|
138
|
+
svzero_za(); // Zero all 4 tiles
|
|
139
|
+
|
|
140
|
+
// Accumulate: for each depth step, load Q vector and 4 D vectors, issue 4 FMOPAs
|
|
141
|
+
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
142
|
+
svfloat16_t query_packed_f16x = svld1_f16(
|
|
143
|
+
row_predicate_f16x,
|
|
144
|
+
(float16_t const *)(query_vecs +
|
|
145
|
+
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
146
|
+
svfloat16_t document_packed_0_f16x = svld1_f16(
|
|
147
|
+
predicate_all_f16x,
|
|
148
|
+
(float16_t const *)(document_vecs +
|
|
149
|
+
((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
|
|
150
|
+
svfloat16_t document_packed_1_f16x = svld1_f16(
|
|
151
|
+
predicate_all_f16x,
|
|
152
|
+
(float16_t const *)(document_vecs +
|
|
153
|
+
((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
|
|
154
|
+
svfloat16_t document_packed_2_f16x = svld1_f16(
|
|
155
|
+
predicate_all_f16x,
|
|
156
|
+
(float16_t const *)(document_vecs +
|
|
157
|
+
((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
|
|
158
|
+
svfloat16_t document_packed_3_f16x = svld1_f16(
|
|
159
|
+
predicate_all_f16x,
|
|
160
|
+
(float16_t const *)(document_vecs +
|
|
161
|
+
((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
|
|
162
|
+
svmopa_za32_f16_m(0, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_0_f16x);
|
|
163
|
+
svmopa_za32_f16_m(1, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_1_f16x);
|
|
164
|
+
svmopa_za32_f16_m(2, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_2_f16x);
|
|
165
|
+
svmopa_za32_f16_m(3, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_3_f16x);
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
// Vertical column extraction + argmax update (manually unrolled over 4 tiles)
|
|
169
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < tile_dimension; column_within_tile++) {
|
|
170
|
+
// Tile 0
|
|
171
|
+
{
|
|
172
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
|
|
173
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
|
|
174
|
+
column_within_tile);
|
|
175
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
176
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
177
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
178
|
+
}
|
|
179
|
+
// Tile 1
|
|
180
|
+
{
|
|
181
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
|
|
182
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 1,
|
|
183
|
+
column_within_tile);
|
|
184
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
185
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
186
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
187
|
+
}
|
|
188
|
+
// Tile 2
|
|
189
|
+
{
|
|
190
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
|
|
191
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 2,
|
|
192
|
+
column_within_tile);
|
|
193
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
194
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
195
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
196
|
+
}
|
|
197
|
+
// Tile 3
|
|
198
|
+
{
|
|
199
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
|
|
200
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 3,
|
|
201
|
+
column_within_tile);
|
|
202
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
203
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
204
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
// Remainder: 1 doc column tile at a time using ZA0 only
|
|
210
|
+
for (; column_tile_index < document_col_tiles; column_tile_index++) {
|
|
211
|
+
nk_size_t const col_start = column_tile_index * tile_dimension;
|
|
212
|
+
nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
|
|
213
|
+
? tile_dimension
|
|
214
|
+
: (document_count - col_start);
|
|
215
|
+
svbool_t const column_predicate_f16x = (cols_remaining == tile_dimension)
|
|
216
|
+
? svptrue_b16()
|
|
217
|
+
: svwhilelt_b16_u64(0u, cols_remaining * 2);
|
|
218
|
+
|
|
219
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_); // Zero ZA0 only
|
|
220
|
+
|
|
221
|
+
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
222
|
+
svfloat16_t query_packed_f16x = svld1_f16(
|
|
223
|
+
row_predicate_f16x,
|
|
224
|
+
(float16_t const *)(query_vecs +
|
|
225
|
+
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
226
|
+
svfloat16_t document_packed_f16x = svld1_f16(
|
|
227
|
+
column_predicate_f16x,
|
|
228
|
+
(float16_t const *)(document_vecs +
|
|
229
|
+
(column_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
230
|
+
svmopa_za32_f16_m(0, row_predicate_f16x, column_predicate_f16x, query_packed_f16x,
|
|
231
|
+
document_packed_f16x);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// Vertical column extraction from ZA0 + argmax update
|
|
235
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
|
|
236
|
+
nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
|
|
237
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
|
|
238
|
+
column_within_tile);
|
|
239
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
240
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
241
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
// Angular distance finalization — SVE-width vector ops
|
|
246
|
+
// Gather document inverse norms via argmax indices (no SVE gather in streaming mode)
|
|
247
|
+
nk_u32_t best_document_indices[64];
|
|
248
|
+
nk_f32_t document_inverse_norms_gathered[64];
|
|
249
|
+
svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
|
|
250
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++)
|
|
251
|
+
document_inverse_norms_gathered[row_in_tile] = document_inverse_norms[best_document_indices[row_in_tile]];
|
|
252
|
+
|
|
253
|
+
// SVE-width: cosine = dot * inv_norm_q * inv_norm_d, angular = max(1 - cosine, 0)
|
|
254
|
+
svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_f32x, query_inverse_norms + row_start);
|
|
255
|
+
svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_f32x, document_inverse_norms_gathered);
|
|
256
|
+
svfloat32_t cosine_f32x = svmul_f32_x(
|
|
257
|
+
row_predicate_f32x, svmul_f32_x(row_predicate_f32x, running_maximum_f32x, query_inverse_norms_f32x),
|
|
258
|
+
document_inverse_norms_f32x);
|
|
259
|
+
svfloat32_t angular_distance_f32x = svmax_f32_x(
|
|
260
|
+
row_predicate_f32x, svsub_f32_x(row_predicate_f32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
|
|
261
|
+
total_angular_distance += svaddv_f32(row_predicate_f32x, angular_distance_f32x);
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
*result = total_angular_distance;
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
NK_PUBLIC void nk_maxsim_packed_f16_sme( //
|
|
268
|
+
void const *query_packed, void const *document_packed, //
|
|
269
|
+
nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
|
|
270
|
+
nk_f32_t *result) { //
|
|
271
|
+
|
|
272
|
+
nk_maxsim_packed_f16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
/**
|
|
276
|
+
* MaxSim bf16 kernel: both Q and D pre-packed, vertical column read extraction.
|
|
277
|
+
*
|
|
278
|
+
* 4-tile fast path: processes 4 doc column tiles simultaneously using ZA0-ZA3.
|
|
279
|
+
* Inner loop per depth_step: 1 Q load + 4 D loads + 4 BFMOPA = 9 ops.
|
|
280
|
+
* Extraction per 4-tile group: 4×16 = 64 vertical reads + 64 svmax = ~128 cycles.
|
|
281
|
+
*
|
|
282
|
+
* 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
|
|
283
|
+
*/
|
|
284
|
+
__arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_streaming_( //
|
|
285
|
+
void const *query_packed, void const *document_packed, //
|
|
286
|
+
nk_size_t query_count, nk_size_t document_count, //
|
|
287
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
288
|
+
|
|
289
|
+
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
290
|
+
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
291
|
+
nk_size_t const depth_step_count = query_header->depth_tile_count;
|
|
292
|
+
nk_size_t const query_row_tiles = query_header->column_tile_count;
|
|
293
|
+
nk_size_t const document_col_tiles = document_header->column_tile_count;
|
|
294
|
+
|
|
295
|
+
nk_size_t const tile_dimension = svcntw(); // 16: ZA32 tile dimension
|
|
296
|
+
nk_size_t const vector_elements = svcnth(); // 32: bf16 elements per SVE vector
|
|
297
|
+
|
|
298
|
+
nk_bf16_t const *query_vecs = (nk_bf16_t const *)((char const *)query_packed +
|
|
299
|
+
sizeof(nk_maxsim_sme_packed_header_t));
|
|
300
|
+
nk_bf16_t const *document_vecs = (nk_bf16_t const *)((char const *)document_packed +
|
|
301
|
+
sizeof(nk_maxsim_sme_packed_header_t));
|
|
302
|
+
|
|
303
|
+
nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
|
|
304
|
+
nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
|
|
305
|
+
document_header->norms_offset);
|
|
306
|
+
|
|
307
|
+
svbool_t const predicate_all_f16x = svptrue_b16();
|
|
308
|
+
svbool_t const predicate_all_f32x = svptrue_b32();
|
|
309
|
+
|
|
310
|
+
nk_f32_t total_angular_distance = 0.0f;
|
|
311
|
+
|
|
312
|
+
for (nk_size_t row_tile_index = 0; row_tile_index < query_row_tiles; row_tile_index++) {
|
|
313
|
+
nk_size_t const row_start = row_tile_index * tile_dimension;
|
|
314
|
+
nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
|
|
315
|
+
: (query_count - row_start);
|
|
316
|
+
svbool_t const row_predicate_f16x = (rows_remaining == tile_dimension)
|
|
317
|
+
? svptrue_b16()
|
|
318
|
+
: svwhilelt_b16_u64(0u, rows_remaining * 2);
|
|
319
|
+
svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
|
|
320
|
+
: svwhilelt_b32_u64(0u, rows_remaining);
|
|
321
|
+
|
|
322
|
+
// Running max + argmax vectors for angular distance finalization
|
|
323
|
+
svfloat32_t running_maximum_f32x = svdup_f32(NK_F32_MIN);
|
|
324
|
+
svuint32_t running_argmax_u32x = svdup_u32(0);
|
|
325
|
+
|
|
326
|
+
nk_size_t column_tile_index = 0;
|
|
327
|
+
|
|
328
|
+
// Fast path: 4 doc column tiles at a time using ZA0-ZA3
|
|
329
|
+
for (; column_tile_index + 4 <= document_col_tiles; column_tile_index += 4) {
|
|
330
|
+
svzero_za(); // Zero all 4 tiles
|
|
331
|
+
|
|
332
|
+
// Accumulate: for each depth step, load Q vector and 4 D vectors, issue 4 BFMOPAs
|
|
333
|
+
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
334
|
+
svbfloat16_t query_packed_bf16x = svld1_bf16(
|
|
335
|
+
row_predicate_f16x,
|
|
336
|
+
(bfloat16_t const *)(query_vecs +
|
|
337
|
+
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
338
|
+
svbfloat16_t document_packed_0_bf16x = svld1_bf16(
|
|
339
|
+
predicate_all_f16x,
|
|
340
|
+
(bfloat16_t const *)(document_vecs +
|
|
341
|
+
((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
|
|
342
|
+
svbfloat16_t document_packed_1_bf16x = svld1_bf16(
|
|
343
|
+
predicate_all_f16x,
|
|
344
|
+
(bfloat16_t const *)(document_vecs +
|
|
345
|
+
((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
|
|
346
|
+
svbfloat16_t document_packed_2_bf16x = svld1_bf16(
|
|
347
|
+
predicate_all_f16x,
|
|
348
|
+
(bfloat16_t const *)(document_vecs +
|
|
349
|
+
((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
|
|
350
|
+
svbfloat16_t document_packed_3_bf16x = svld1_bf16(
|
|
351
|
+
predicate_all_f16x,
|
|
352
|
+
(bfloat16_t const *)(document_vecs +
|
|
353
|
+
((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
|
|
354
|
+
svmopa_za32_bf16_m(0, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
|
|
355
|
+
document_packed_0_bf16x);
|
|
356
|
+
svmopa_za32_bf16_m(1, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
|
|
357
|
+
document_packed_1_bf16x);
|
|
358
|
+
svmopa_za32_bf16_m(2, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
|
|
359
|
+
document_packed_2_bf16x);
|
|
360
|
+
svmopa_za32_bf16_m(3, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
|
|
361
|
+
document_packed_3_bf16x);
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// Vertical column extraction + argmax update (manually unrolled over 4 tiles)
|
|
365
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < tile_dimension; column_within_tile++) {
|
|
366
|
+
// Tile 0
|
|
367
|
+
{
|
|
368
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
|
|
369
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
|
|
370
|
+
column_within_tile);
|
|
371
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
372
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
373
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
374
|
+
}
|
|
375
|
+
// Tile 1
|
|
376
|
+
{
|
|
377
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
|
|
378
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 1,
|
|
379
|
+
column_within_tile);
|
|
380
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
381
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
382
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
383
|
+
}
|
|
384
|
+
// Tile 2
|
|
385
|
+
{
|
|
386
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
|
|
387
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 2,
|
|
388
|
+
column_within_tile);
|
|
389
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
390
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
391
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
392
|
+
}
|
|
393
|
+
// Tile 3
|
|
394
|
+
{
|
|
395
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
|
|
396
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 3,
|
|
397
|
+
column_within_tile);
|
|
398
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
399
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
400
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
// Remainder: 1 doc column tile at a time using ZA0 only
|
|
406
|
+
for (; column_tile_index < document_col_tiles; column_tile_index++) {
|
|
407
|
+
nk_size_t const col_start = column_tile_index * tile_dimension;
|
|
408
|
+
nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
|
|
409
|
+
? tile_dimension
|
|
410
|
+
: (document_count - col_start);
|
|
411
|
+
svbool_t const column_predicate_f16x = (cols_remaining == tile_dimension)
|
|
412
|
+
? svptrue_b16()
|
|
413
|
+
: svwhilelt_b16_u64(0u, cols_remaining * 2);
|
|
414
|
+
|
|
415
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_); // Zero ZA0 only
|
|
416
|
+
|
|
417
|
+
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
418
|
+
svbfloat16_t query_packed_bf16x = svld1_bf16(
|
|
419
|
+
row_predicate_f16x,
|
|
420
|
+
(bfloat16_t const *)(query_vecs +
|
|
421
|
+
(row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
422
|
+
svbfloat16_t document_packed_bf16x = svld1_bf16(
|
|
423
|
+
column_predicate_f16x,
|
|
424
|
+
(bfloat16_t const *)(document_vecs +
|
|
425
|
+
(column_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
426
|
+
svmopa_za32_bf16_m(0, row_predicate_f16x, column_predicate_f16x, query_packed_bf16x,
|
|
427
|
+
document_packed_bf16x);
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
// Vertical column extraction from ZA0 + argmax update
|
|
431
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
|
|
432
|
+
nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
|
|
433
|
+
svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
|
|
434
|
+
column_within_tile);
|
|
435
|
+
svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
|
|
436
|
+
running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
|
|
437
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
// Angular distance finalization — SVE-width vector ops
|
|
442
|
+
// Gather document inverse norms via argmax indices (no SVE gather in streaming mode)
|
|
443
|
+
nk_u32_t best_document_indices[64];
|
|
444
|
+
nk_f32_t document_inverse_norms_gathered[64];
|
|
445
|
+
svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
|
|
446
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++)
|
|
447
|
+
document_inverse_norms_gathered[row_in_tile] = document_inverse_norms[best_document_indices[row_in_tile]];
|
|
448
|
+
|
|
449
|
+
// SVE-width: cosine = dot * inv_norm_q * inv_norm_d, angular = max(1 - cosine, 0)
|
|
450
|
+
svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_f32x, query_inverse_norms + row_start);
|
|
451
|
+
svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_f32x, document_inverse_norms_gathered);
|
|
452
|
+
svfloat32_t cosine_f32x = svmul_f32_x(
|
|
453
|
+
row_predicate_f32x, svmul_f32_x(row_predicate_f32x, running_maximum_f32x, query_inverse_norms_f32x),
|
|
454
|
+
document_inverse_norms_f32x);
|
|
455
|
+
svfloat32_t angular_distance_f32x = svmax_f32_x(
|
|
456
|
+
row_predicate_f32x, svsub_f32_x(row_predicate_f32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
|
|
457
|
+
total_angular_distance += svaddv_f32(row_predicate_f32x, angular_distance_f32x);
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
*result = total_angular_distance;
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
NK_PUBLIC void nk_maxsim_packed_bf16_sme( //
|
|
464
|
+
void const *query_packed, void const *document_packed, //
|
|
465
|
+
nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
|
|
466
|
+
nk_f32_t *result) { //
|
|
467
|
+
|
|
468
|
+
nk_maxsim_packed_bf16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t n, nk_size_t k) { //
|
|
472
|
+
return nk_dots_packed_size_bf16_sme(n, k);
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sme(nk_size_t n, nk_size_t k) { //
|
|
476
|
+
return nk_dots_packed_size_f16_sme(n, k);
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
NK_PUBLIC void nk_maxsim_pack_bf16_sme( //
|
|
480
|
+
nk_bf16_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
|
|
481
|
+
|
|
482
|
+
// Delegate tile interleaving and squared norms computation to dots pack.
|
|
483
|
+
// Both headers are 64 bytes with identical layout for the first 6 fields.
|
|
484
|
+
nk_dots_pack_bf16_sme(vectors, n, k, stride, packed);
|
|
485
|
+
|
|
486
|
+
// Set maxsim-specific header fields (overlaps dots reserved area)
|
|
487
|
+
nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
|
|
488
|
+
header->originals_offset = 0; // not used for bf16
|
|
489
|
+
header->original_stride = 0; // not used for bf16
|
|
490
|
+
for (nk_size_t i = 0; i < 8; i++) header->reserved[i] = 0;
|
|
491
|
+
|
|
492
|
+
// Convert squared norms → inverse norms in-place
|
|
493
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)packed + header->norms_offset);
|
|
494
|
+
for (nk_size_t i = 0; i < n; i++) {
|
|
495
|
+
nk_f32_t norm_sq = norms[i];
|
|
496
|
+
norms[i] = (norm_sq > 0.0f) ? (nk_f32_t)nk_f64_rsqrt_neon((nk_f64_t)norm_sq) : 0.0f;
|
|
497
|
+
}
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
NK_PUBLIC void nk_maxsim_pack_f16_sme( //
|
|
501
|
+
nk_f16_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
|
|
502
|
+
|
|
503
|
+
// Delegate tile interleaving and squared norms computation to dots pack.
|
|
504
|
+
// Both headers are 64 bytes with identical layout for the first 6 fields.
|
|
505
|
+
nk_dots_pack_f16_sme(vectors, n, k, stride, packed);
|
|
506
|
+
|
|
507
|
+
// Set maxsim-specific header fields (overlaps dots reserved area)
|
|
508
|
+
nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
|
|
509
|
+
header->originals_offset = 0; // not used for f16
|
|
510
|
+
header->original_stride = 0; // not used for f16
|
|
511
|
+
for (nk_size_t i = 0; i < 8; i++) header->reserved[i] = 0;
|
|
512
|
+
|
|
513
|
+
// Convert squared norms → inverse norms in-place
|
|
514
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)packed + header->norms_offset);
|
|
515
|
+
for (nk_size_t i = 0; i < n; i++) {
|
|
516
|
+
nk_f32_t norm_sq = norms[i];
|
|
517
|
+
norms[i] = (norm_sq > 0.0f) ? (nk_f32_t)nk_f64_rsqrt_neon((nk_f64_t)norm_sq) : 0.0f;
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
/**
|
|
522
|
+
* MaxSim f32 kernel: i8 SMOPA screening + f32/f64 refinement + angular distance.
|
|
523
|
+
*
|
|
524
|
+
* Screening: i8 SMOPA has expansion=4, processing 4x more depth per instruction than f32 FMOPA.
|
|
525
|
+
* With 4 ZA tiles the fast path processes 64 document columns per iteration.
|
|
526
|
+
*
|
|
527
|
+
* Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
|
|
528
|
+
* Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
|
|
529
|
+
*/
|
|
530
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sme(nk_size_t n, nk_size_t k) { //
|
|
531
|
+
nk_size_t const expansion = 4; // i8->i32 SMOPA
|
|
532
|
+
nk_size_t const tile_dimension = svcntsw(); // 16 for SVL=512
|
|
533
|
+
nk_size_t const vector_elements = svcntsb(); // 64 for SVL=512
|
|
534
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(n, tile_dimension);
|
|
535
|
+
nk_size_t const depth_step_count = nk_size_divide_round_up_(k, expansion);
|
|
536
|
+
nk_size_t const original_stride = nk_size_round_up_to_multiple_(k * sizeof(nk_f32_t), 64);
|
|
537
|
+
|
|
538
|
+
nk_size_t size = sizeof(nk_maxsim_sme_packed_header_t); // 64 B header
|
|
539
|
+
size += column_tile_count * depth_step_count * vector_elements; // i8 tiles
|
|
540
|
+
size += n * sizeof(nk_f32_t); // f32 squared norms
|
|
541
|
+
size += n * original_stride; // f32 originals
|
|
542
|
+
return size;
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
NK_PUBLIC void nk_maxsim_pack_f32_sme( //
|
|
546
|
+
nk_f32_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
|
|
547
|
+
|
|
548
|
+
nk_size_t const expansion = 4; // i8->i32 SMOPA
|
|
549
|
+
nk_size_t const tile_dimension = svcntsw(); // 16 for SVL=512
|
|
550
|
+
nk_size_t const vector_elements = svcntsb(); // 64 for SVL=512
|
|
551
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
|
|
552
|
+
|
|
553
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(n, tile_dimension);
|
|
554
|
+
nk_size_t const depth_step_count = nk_size_divide_round_up_(k, expansion);
|
|
555
|
+
nk_size_t const total_vectors = column_tile_count * depth_step_count;
|
|
556
|
+
nk_size_t const original_stride = nk_size_round_up_to_multiple_(k * sizeof(nk_f32_t), 64);
|
|
557
|
+
|
|
558
|
+
// Set up header
|
|
559
|
+
nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
|
|
560
|
+
header->column_tile_count = (nk_u32_t)column_tile_count;
|
|
561
|
+
header->depth_tile_count = (nk_u32_t)depth_step_count;
|
|
562
|
+
header->columns = (nk_u32_t)n;
|
|
563
|
+
header->depth = (nk_u32_t)k;
|
|
564
|
+
header->svl_bytes = (nk_u32_t)(svcntsw() * sizeof(nk_f32_t));
|
|
565
|
+
|
|
566
|
+
nk_size_t const tiles_size = total_vectors * vector_elements;
|
|
567
|
+
nk_size_t const norms_offset = sizeof(nk_maxsim_sme_packed_header_t) + tiles_size;
|
|
568
|
+
nk_size_t const originals_offset = norms_offset + n * sizeof(nk_f32_t);
|
|
569
|
+
|
|
570
|
+
header->norms_offset = (nk_u32_t)norms_offset;
|
|
571
|
+
header->originals_offset = (nk_u32_t)originals_offset;
|
|
572
|
+
header->original_stride = (nk_u32_t)original_stride;
|
|
573
|
+
for (nk_size_t i = 0; i < 8; i++) header->reserved[i] = 0;
|
|
574
|
+
|
|
575
|
+
nk_i8_t *tiles = (nk_i8_t *)((char *)packed + sizeof(nk_maxsim_sme_packed_header_t));
|
|
576
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)packed + norms_offset);
|
|
577
|
+
char *originals = (char *)packed + originals_offset;
|
|
578
|
+
|
|
579
|
+
// Zero-initialize tile data (partial vectors stay zero-padded)
|
|
580
|
+
for (nk_size_t i = 0; i < tiles_size; i++) tiles[i] = 0;
|
|
581
|
+
|
|
582
|
+
// For each vector: quantize metadata, quantize+interleave into tiles, copy originals
|
|
583
|
+
for (nk_size_t vector_index = 0; vector_index < n; vector_index++) {
|
|
584
|
+
nk_f32_t const *source = (nk_f32_t const *)((char const *)vectors + vector_index * stride);
|
|
585
|
+
|
|
586
|
+
// Pass 1: Compute absmax and norm_sq simultaneously
|
|
587
|
+
nk_f32_t absmax = 0.0f;
|
|
588
|
+
nk_f32_t norm_sq = 0.0f;
|
|
589
|
+
for (nk_size_t dim = 0; dim < k; dim++) {
|
|
590
|
+
nk_f32_t val = source[dim];
|
|
591
|
+
nk_f32_t abs_val = nk_f32_abs_(val);
|
|
592
|
+
if (abs_val > absmax) absmax = abs_val;
|
|
593
|
+
norm_sq += val * val;
|
|
594
|
+
}
|
|
595
|
+
norms[vector_index] = norm_sq;
|
|
596
|
+
|
|
597
|
+
nk_f32_t scale = absmax / 127.0f;
|
|
598
|
+
if (scale == 0.0f) scale = 1.0f;
|
|
599
|
+
|
|
600
|
+
// Pass 2: Quantize and scatter into tile-interleaved positions
|
|
601
|
+
nk_size_t const column_tile = vector_index / tile_dimension;
|
|
602
|
+
nk_size_t const column_in_tile = vector_index % tile_dimension;
|
|
603
|
+
|
|
604
|
+
for (nk_size_t dim = 0; dim < k; dim++) {
|
|
605
|
+
nk_size_t const depth_step = dim / expansion;
|
|
606
|
+
nk_size_t const sub_element = dim % expansion;
|
|
607
|
+
nk_size_t const vec_index = column_tile * depth_step_count + depth_step;
|
|
608
|
+
nk_size_t const offset = vec_index * vector_elements + expansion * column_in_tile + sub_element;
|
|
609
|
+
|
|
610
|
+
nk_f32_t scaled = source[dim] / scale;
|
|
611
|
+
nk_i32_t quantized;
|
|
612
|
+
if (scaled >= 0.0f) quantized = (nk_i32_t)(scaled + 0.5f);
|
|
613
|
+
else quantized = (nk_i32_t)(scaled - 0.5f);
|
|
614
|
+
if (quantized > 127) quantized = 127;
|
|
615
|
+
if (quantized < -127) quantized = -127;
|
|
616
|
+
|
|
617
|
+
tiles[offset] = (nk_i8_t)quantized;
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
// Pass 3: Copy originals (64B-aligned stride, zero-pad tail)
|
|
621
|
+
char *dest_original = originals + vector_index * original_stride;
|
|
622
|
+
nk_copy_bytes_(dest_original, source, k * sizeof(nk_f32_t));
|
|
623
|
+
for (nk_size_t byte = k * sizeof(nk_f32_t); byte < original_stride; byte++) dest_original[byte] = 0;
|
|
624
|
+
}
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
/**
|
|
628
|
+
* Streaming-compatible f32 dot product with f64 accumulation.
|
|
629
|
+
* Follows the svcntd()-stride + svcvt_f64_f32_x pattern from nk_dots_reduce_sumsq_f32_ssve_.
|
|
630
|
+
*/
|
|
631
|
+
NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
|
|
632
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_size_t count) NK_STREAMING_COMPATIBLE_ { //
|
|
633
|
+
svfloat64_t accumulator_f64x = svdup_f64(0.0);
|
|
634
|
+
for (nk_size_t i = 0; i < count; i += svcntd()) {
|
|
635
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(i, count);
|
|
636
|
+
svfloat64_t a_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(svwhilelt_b32_u64(i, count), a + i));
|
|
637
|
+
svfloat64_t b_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(svwhilelt_b32_u64(i, count), b + i));
|
|
638
|
+
accumulator_f64x = svmla_f64_x(predicate_f64x, accumulator_f64x, a_f64x, b_f64x);
|
|
639
|
+
}
|
|
640
|
+
return svaddv_f64(svptrue_b64(), accumulator_f64x);
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
/**
|
|
644
|
+
* MaxSim f32 kernel: i8 SMOPA screening + f32/f64 refinement + angular distance.
|
|
645
|
+
*
|
|
646
|
+
* Screening: i8 SMOPA has expansion=4, processing 4x more depth per instruction than f32 FMOPA.
|
|
647
|
+
* With 4 ZA tiles the fast path processes 64 document columns per iteration.
|
|
648
|
+
*
|
|
649
|
+
* Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
|
|
650
|
+
* Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
|
|
651
|
+
*/
|
|
652
|
+
__arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streaming_( //
|
|
653
|
+
void const *query_packed, void const *document_packed, //
|
|
654
|
+
nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
|
|
655
|
+
nk_f64_t *result) {
|
|
656
|
+
|
|
657
|
+
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
658
|
+
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
659
|
+
|
|
660
|
+
nk_size_t const depth_step_count = query_header->depth_tile_count;
|
|
661
|
+
nk_size_t const query_row_tiles = query_header->column_tile_count;
|
|
662
|
+
nk_size_t const document_col_tiles = document_header->column_tile_count;
|
|
663
|
+
|
|
664
|
+
nk_size_t const tile_dimension = svcntw(); // 16: ZA32 tile dimension
|
|
665
|
+
nk_size_t const vector_elements = svcntb(); // 64: i8 elements per SVE vector
|
|
666
|
+
|
|
667
|
+
// Tile data pointers (i8)
|
|
668
|
+
nk_i8_t const *query_tiles = (nk_i8_t const *)((char const *)query_packed + sizeof(nk_maxsim_sme_packed_header_t));
|
|
669
|
+
nk_i8_t const *document_tiles = (nk_i8_t const *)((char const *)document_packed +
|
|
670
|
+
sizeof(nk_maxsim_sme_packed_header_t));
|
|
671
|
+
|
|
672
|
+
// Norms and originals pointers
|
|
673
|
+
nk_f32_t const *query_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
|
|
674
|
+
nk_f32_t const *document_norms = (nk_f32_t const *)((char const *)document_packed + document_header->norms_offset);
|
|
675
|
+
nk_f32_t const *query_originals = (nk_f32_t const *)((char const *)query_packed + query_header->originals_offset);
|
|
676
|
+
nk_f32_t const *document_originals = (nk_f32_t const *)((char const *)document_packed +
|
|
677
|
+
document_header->originals_offset);
|
|
678
|
+
nk_size_t const query_original_stride_elements = query_header->original_stride / sizeof(nk_f32_t);
|
|
679
|
+
nk_size_t const document_original_stride_elements = document_header->original_stride / sizeof(nk_f32_t);
|
|
680
|
+
|
|
681
|
+
nk_size_t const expansion = 4; // i8->i32 SMOPA
|
|
682
|
+
|
|
683
|
+
svbool_t const predicate_all_i8x = svptrue_b8();
|
|
684
|
+
svbool_t const predicate_all_f32x = svptrue_b32();
|
|
685
|
+
|
|
686
|
+
nk_f64_t total_angular_distance_f64 = 0.0;
|
|
687
|
+
|
|
688
|
+
for (nk_size_t row_tile_index = 0; row_tile_index < query_row_tiles; row_tile_index++) {
|
|
689
|
+
nk_size_t const row_start = row_tile_index * tile_dimension;
|
|
690
|
+
nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
|
|
691
|
+
: (query_count - row_start);
|
|
692
|
+
svbool_t const row_predicate_i8x = (rows_remaining == tile_dimension)
|
|
693
|
+
? svptrue_b8()
|
|
694
|
+
: svwhilelt_b8_u64(0u, rows_remaining * expansion);
|
|
695
|
+
svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
|
|
696
|
+
: svwhilelt_b32_u64(0u, rows_remaining);
|
|
697
|
+
|
|
698
|
+
svint32_t running_max_i32x = svdup_s32(NK_I32_MIN);
|
|
699
|
+
svuint32_t running_argmax_u32x = svdup_u32(0);
|
|
700
|
+
|
|
701
|
+
nk_size_t column_tile_index = 0;
|
|
702
|
+
|
|
703
|
+
// 4-tile fast path: ZA0-ZA3 process 4 document column tiles simultaneously
|
|
704
|
+
for (; column_tile_index + 4 <= document_col_tiles; column_tile_index += 4) {
|
|
705
|
+
svzero_za();
|
|
706
|
+
|
|
707
|
+
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
708
|
+
svint8_t query_packed_i8x = svld1_s8(
|
|
709
|
+
row_predicate_i8x,
|
|
710
|
+
(int8_t const *)(query_tiles + (row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
711
|
+
svint8_t document_packed_0_i8x = svld1_s8(
|
|
712
|
+
predicate_all_i8x,
|
|
713
|
+
(int8_t const *)(document_tiles +
|
|
714
|
+
((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
|
|
715
|
+
svint8_t document_packed_1_i8x = svld1_s8(
|
|
716
|
+
predicate_all_i8x,
|
|
717
|
+
(int8_t const *)(document_tiles +
|
|
718
|
+
((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
|
|
719
|
+
svint8_t document_packed_2_i8x = svld1_s8(
|
|
720
|
+
predicate_all_i8x,
|
|
721
|
+
(int8_t const *)(document_tiles +
|
|
722
|
+
((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
|
|
723
|
+
svint8_t document_packed_3_i8x = svld1_s8(
|
|
724
|
+
predicate_all_i8x,
|
|
725
|
+
(int8_t const *)(document_tiles +
|
|
726
|
+
((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
|
|
727
|
+
svmopa_za32_s8_m(0, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_0_i8x);
|
|
728
|
+
svmopa_za32_s8_m(1, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_1_i8x);
|
|
729
|
+
svmopa_za32_s8_m(2, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_2_i8x);
|
|
730
|
+
svmopa_za32_s8_m(3, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_3_i8x);
|
|
731
|
+
}
|
|
732
|
+
|
|
733
|
+
// Vertical column extraction + argmax update (manually unrolled over 4 tiles)
|
|
734
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < tile_dimension; column_within_tile++) {
|
|
735
|
+
// Tile 0
|
|
736
|
+
{
|
|
737
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
|
|
738
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 0,
|
|
739
|
+
column_within_tile);
|
|
740
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
|
|
741
|
+
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
742
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
743
|
+
}
|
|
744
|
+
// Tile 1
|
|
745
|
+
{
|
|
746
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
|
|
747
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 1,
|
|
748
|
+
column_within_tile);
|
|
749
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
|
|
750
|
+
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
751
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
752
|
+
}
|
|
753
|
+
// Tile 2
|
|
754
|
+
{
|
|
755
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
|
|
756
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 2,
|
|
757
|
+
column_within_tile);
|
|
758
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
|
|
759
|
+
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
760
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
761
|
+
}
|
|
762
|
+
// Tile 3
|
|
763
|
+
{
|
|
764
|
+
nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
|
|
765
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 3,
|
|
766
|
+
column_within_tile);
|
|
767
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
|
|
768
|
+
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
769
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
770
|
+
}
|
|
771
|
+
}
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
// 1-tile remainder: ZA0 only
|
|
775
|
+
for (; column_tile_index < document_col_tiles; column_tile_index++) {
|
|
776
|
+
nk_size_t const col_start = column_tile_index * tile_dimension;
|
|
777
|
+
nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
|
|
778
|
+
? tile_dimension
|
|
779
|
+
: (document_count - col_start);
|
|
780
|
+
svbool_t const column_predicate_i8x = (cols_remaining == tile_dimension)
|
|
781
|
+
? svptrue_b8()
|
|
782
|
+
: svwhilelt_b8_u64(0u, cols_remaining * expansion);
|
|
783
|
+
|
|
784
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
785
|
+
|
|
786
|
+
for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
|
|
787
|
+
svint8_t query_packed_i8x = svld1_s8(
|
|
788
|
+
row_predicate_i8x,
|
|
789
|
+
(int8_t const *)(query_tiles + (row_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
790
|
+
svint8_t document_packed_i8x = svld1_s8(
|
|
791
|
+
column_predicate_i8x,
|
|
792
|
+
(int8_t const *)(document_tiles +
|
|
793
|
+
(column_tile_index * depth_step_count + depth_step) * vector_elements));
|
|
794
|
+
svmopa_za32_s8_m(0, row_predicate_i8x, column_predicate_i8x, query_packed_i8x, document_packed_i8x);
|
|
795
|
+
}
|
|
796
|
+
|
|
797
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
|
|
798
|
+
nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
|
|
799
|
+
svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 0,
|
|
800
|
+
column_within_tile);
|
|
801
|
+
svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
|
|
802
|
+
running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
|
|
803
|
+
running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
|
|
804
|
+
}
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
// Refinement: tile-wide interleaved f64 dot products
|
|
808
|
+
nk_u32_t best_document_indices[64]; // max tile_dimension across all SVL values
|
|
809
|
+
svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
|
|
810
|
+
|
|
811
|
+
// Pointer setup: one (query, document) pair per row in the tile
|
|
812
|
+
nk_f32_t const *query_original_ptrs[64];
|
|
813
|
+
nk_f32_t const *document_original_ptrs[64];
|
|
814
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
|
|
815
|
+
nk_size_t query_index = row_start + row_in_tile;
|
|
816
|
+
nk_u32_t best_document_index = best_document_indices[row_in_tile];
|
|
817
|
+
query_original_ptrs[row_in_tile] = query_originals + query_index * query_original_stride_elements;
|
|
818
|
+
document_original_ptrs[row_in_tile] = document_originals +
|
|
819
|
+
best_document_index * document_original_stride_elements;
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
// Interleaved f64 dot products in batches of 4 (hides MLA 4-cycle latency)
|
|
823
|
+
nk_size_t row_batch_start = 0;
|
|
824
|
+
|
|
825
|
+
// Fast path: 4-wide batches
|
|
826
|
+
for (; row_batch_start + 4 <= rows_remaining; row_batch_start += 4) {
|
|
827
|
+
svfloat64_t accumulator_0_f64x = svdup_f64(0.0);
|
|
828
|
+
svfloat64_t accumulator_1_f64x = svdup_f64(0.0);
|
|
829
|
+
svfloat64_t accumulator_2_f64x = svdup_f64(0.0);
|
|
830
|
+
svfloat64_t accumulator_3_f64x = svdup_f64(0.0);
|
|
831
|
+
|
|
832
|
+
for (nk_size_t depth_index = 0; depth_index < depth; depth_index += svcntd()) {
|
|
833
|
+
svbool_t predicate_depth_f64x = svwhilelt_b64_u64(depth_index, depth);
|
|
834
|
+
svbool_t predicate_depth_f32x = svwhilelt_b32_u64(depth_index, depth);
|
|
835
|
+
|
|
836
|
+
svfloat64_t query_values_0_f64x = svcvt_f64_f32_x(
|
|
837
|
+
predicate_depth_f64x,
|
|
838
|
+
svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 0] + depth_index));
|
|
839
|
+
svfloat64_t document_values_0_f64x = svcvt_f64_f32_x(
|
|
840
|
+
predicate_depth_f64x,
|
|
841
|
+
svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 0] + depth_index));
|
|
842
|
+
accumulator_0_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_0_f64x, query_values_0_f64x,
|
|
843
|
+
document_values_0_f64x);
|
|
844
|
+
|
|
845
|
+
svfloat64_t query_values_1_f64x = svcvt_f64_f32_x(
|
|
846
|
+
predicate_depth_f64x,
|
|
847
|
+
svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 1] + depth_index));
|
|
848
|
+
svfloat64_t document_values_1_f64x = svcvt_f64_f32_x(
|
|
849
|
+
predicate_depth_f64x,
|
|
850
|
+
svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 1] + depth_index));
|
|
851
|
+
accumulator_1_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_1_f64x, query_values_1_f64x,
|
|
852
|
+
document_values_1_f64x);
|
|
853
|
+
|
|
854
|
+
svfloat64_t query_values_2_f64x = svcvt_f64_f32_x(
|
|
855
|
+
predicate_depth_f64x,
|
|
856
|
+
svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 2] + depth_index));
|
|
857
|
+
svfloat64_t document_values_2_f64x = svcvt_f64_f32_x(
|
|
858
|
+
predicate_depth_f64x,
|
|
859
|
+
svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 2] + depth_index));
|
|
860
|
+
accumulator_2_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_2_f64x, query_values_2_f64x,
|
|
861
|
+
document_values_2_f64x);
|
|
862
|
+
|
|
863
|
+
svfloat64_t query_values_3_f64x = svcvt_f64_f32_x(
|
|
864
|
+
predicate_depth_f64x,
|
|
865
|
+
svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 3] + depth_index));
|
|
866
|
+
svfloat64_t document_values_3_f64x = svcvt_f64_f32_x(
|
|
867
|
+
predicate_depth_f64x,
|
|
868
|
+
svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 3] + depth_index));
|
|
869
|
+
accumulator_3_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_3_f64x, query_values_3_f64x,
|
|
870
|
+
document_values_3_f64x);
|
|
871
|
+
}
|
|
872
|
+
|
|
873
|
+
// Reduce accumulators and compute angular distance per row
|
|
874
|
+
svfloat64_t *batch_accumulators[] = {&accumulator_0_f64x, &accumulator_1_f64x, &accumulator_2_f64x,
|
|
875
|
+
&accumulator_3_f64x};
|
|
876
|
+
for (nk_size_t batch_index = 0; batch_index < 4; batch_index++) {
|
|
877
|
+
nk_size_t query_index = row_start + row_batch_start + batch_index;
|
|
878
|
+
nk_u32_t best_document_index = best_document_indices[row_batch_start + batch_index];
|
|
879
|
+
nk_f64_t dot_product_f64 = svaddv_f64(svptrue_b64(), *batch_accumulators[batch_index]);
|
|
880
|
+
nk_f64_t norm_product_f64 = (nk_f64_t)query_norms[query_index] *
|
|
881
|
+
(nk_f64_t)document_norms[best_document_index];
|
|
882
|
+
nk_f64_t cosine_f64 = (norm_product_f64 > 0.0) ? dot_product_f64 * nk_f64_rsqrt_serial(norm_product_f64)
|
|
883
|
+
: 0.0;
|
|
884
|
+
nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
|
|
885
|
+
if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
|
|
886
|
+
total_angular_distance_f64 += angular_distance_f64;
|
|
887
|
+
}
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
// Remainder: 1 row at a time
|
|
891
|
+
for (; row_batch_start < rows_remaining; row_batch_start++) {
|
|
892
|
+
nk_size_t query_index = row_start + row_batch_start;
|
|
893
|
+
nk_u32_t best_document_index = best_document_indices[row_batch_start];
|
|
894
|
+
nk_f64_t dot_product_f64 = nk_maxsim_reduce_dot_f32_ssve_(query_original_ptrs[row_batch_start],
|
|
895
|
+
document_original_ptrs[row_batch_start], depth);
|
|
896
|
+
nk_f64_t norm_product_f64 = (nk_f64_t)query_norms[query_index] *
|
|
897
|
+
(nk_f64_t)document_norms[best_document_index];
|
|
898
|
+
nk_f64_t cosine_f64 = (norm_product_f64 > 0.0) ? dot_product_f64 * nk_f64_rsqrt_serial(norm_product_f64)
|
|
899
|
+
: 0.0;
|
|
900
|
+
nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
|
|
901
|
+
if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
|
|
902
|
+
total_angular_distance_f64 += angular_distance_f64;
|
|
903
|
+
}
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
*result = total_angular_distance_f64;
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
NK_PUBLIC void nk_maxsim_packed_f32_sme( //
|
|
910
|
+
void const *query_packed, void const *document_packed, //
|
|
911
|
+
nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
|
|
912
|
+
nk_f64_t *result) { //
|
|
913
|
+
|
|
914
|
+
nk_maxsim_packed_f32_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
#if defined(__clang__)
|
|
918
|
+
#pragma clang attribute pop
|
|
919
|
+
#elif defined(__GNUC__)
|
|
920
|
+
#pragma GCC pop_options
|
|
921
|
+
#endif
|
|
922
|
+
|
|
923
|
+
#if defined(__cplusplus)
|
|
924
|
+
} // extern "C"
|
|
925
|
+
#endif
|
|
926
|
+
|
|
927
|
+
#endif // NK_TARGET_SME
|
|
928
|
+
#endif // NK_TARGET_ARM_
|
|
929
|
+
#endif // NK_MAXSIM_SME_H
|