numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1099 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Batched Set Distances for SME.
|
|
3
|
+
* @file include/numkong/sets/smebi32.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
* @sa include/numkong/sets.h
|
|
7
|
+
*
|
|
8
|
+
* Uses ARM Scalable Matrix Extension (SME) for efficient binary set operations.
|
|
9
|
+
* Leverages streaming mode's wider vectors (512-bit on Apple M4) for fast
|
|
10
|
+
* XOR+POPCNT operations on binary vectors.
|
|
11
|
+
*
|
|
12
|
+
* @section smebi32_math Mathematical Foundation
|
|
13
|
+
*
|
|
14
|
+
* Hamming distance: popcount(a XOR b) = number of differing bits
|
|
15
|
+
*
|
|
16
|
+
* Jaccard distance using intersection:
|
|
17
|
+
* intersection = popcount(a AND b)
|
|
18
|
+
* union = popcount(a) + popcount(b) - intersection
|
|
19
|
+
* jaccard = 1 - intersection / union
|
|
20
|
+
*
|
|
21
|
+
* @section smebi32_tiles SME Dimensions (512-bit SVL)
|
|
22
|
+
*
|
|
23
|
+
* - svcntw(): 16 (number of 32-bit elements per vector)
|
|
24
|
+
* - svcntb(): 64 (number of bytes per SVE vector)
|
|
25
|
+
* - Tile blocking: 16x16 output tiles for cache efficiency
|
|
26
|
+
* - Depth processing: 64 bytes (512 bits) per iteration
|
|
27
|
+
*
|
|
28
|
+
* @section smebi32_perf Performance Characteristics (Apple M4)
|
|
29
|
+
*
|
|
30
|
+
* - SVL: 512 bits (64 bytes)
|
|
31
|
+
* - Streaming mode provides dedicated register file
|
|
32
|
+
* - Streaming mode overhead: ~50-100 cycles for SMSTART/SMSTOP
|
|
33
|
+
*/
|
|
34
|
+
|
|
35
|
+
#ifndef NK_SETS_SMEBI32_H
|
|
36
|
+
#define NK_SETS_SMEBI32_H
|
|
37
|
+
|
|
38
|
+
#if NK_TARGET_ARM_
|
|
39
|
+
#if NK_TARGET_SMEBI32
|
|
40
|
+
|
|
41
|
+
#include "numkong/types.h"
|
|
42
|
+
#include "numkong/set/serial.h"
|
|
43
|
+
#include "numkong/sets/serial.h"
|
|
44
|
+
#include "numkong/dots/sme.h" // `nk_sme_zero_za32_*` constants
|
|
45
|
+
#include "numkong/reduce.h" // `nk_reduce_moments_u1`
|
|
46
|
+
|
|
47
|
+
#if defined(__cplusplus)
|
|
48
|
+
extern "C" {
|
|
49
|
+
#endif
|
|
50
|
+
|
|
51
|
+
/*
|
|
52
|
+
* Binary set operations using SME BMOPA instruction.
|
|
53
|
+
*
|
|
54
|
+
* BMOPA computes: ZA[i,j] += popcount(~(Zn[i] ^ Zm[j])) = popcount(XNOR)
|
|
55
|
+
* This counts matching bits. Hamming = depth_bits - matching.
|
|
56
|
+
*
|
|
57
|
+
* Tile layout (SVL=512, Apple M4):
|
|
58
|
+
* - ZA32 output tile: 16 × 16 u32 elements (1 KB)
|
|
59
|
+
* - Input vectors: 16 u32 elements (SVL/32)
|
|
60
|
+
* - Each BMOPA processes 32 bits (one u32) across 16×16 pairs
|
|
61
|
+
* - BMOPA predicates: b32 (u32 input granularity)
|
|
62
|
+
* - Packed kernel: 4-tile path (ZA0-ZA3) for 4 B-column tiles simultaneously
|
|
63
|
+
* - Unpacked kernel: ZA transpose (ZA0.S=staging, ZA1-3.S=accumulation, 3-tile fast path)
|
|
64
|
+
* - Packed format: column-major u32 within each tile
|
|
65
|
+
*/
|
|
66
|
+
|
|
67
|
+
#if defined(__clang__)
|
|
68
|
+
#pragma clang attribute push(__attribute__((target("sme2,sve2"))), apply_to = function)
|
|
69
|
+
#elif defined(__GNUC__)
|
|
70
|
+
#pragma GCC push_options
|
|
71
|
+
#pragma GCC target("+sme2")
|
|
72
|
+
#endif
|
|
73
|
+
|
|
74
|
+
/* Read SVL in bytes from non-streaming context using RDSVL instruction. */
|
|
75
|
+
NK_INTERNAL nk_size_t nk_smebi32_svl_bytes_(void) {
|
|
76
|
+
nk_size_t svl_bytes;
|
|
77
|
+
__asm__ volatile("rdsvl %0, #1" : "=r"(svl_bytes));
|
|
78
|
+
return svl_bytes;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
/* Get ZA32 tile dimension (number of f32/u32 elements per row). */
|
|
82
|
+
NK_INTERNAL nk_size_t nk_smebi32_tile_dim_(void) { return nk_smebi32_svl_bytes_() / sizeof(nk_u32_t); }
|
|
83
|
+
|
|
84
|
+
typedef struct {
|
|
85
|
+
nk_u32_t row_tile_count; // ceiling(rows / tile_dim)
|
|
86
|
+
nk_u32_t depth_tile_count; // ceiling(depth_bits / depth_tile_bits)
|
|
87
|
+
nk_u32_t rows; // actual row count
|
|
88
|
+
nk_u32_t depth_bits; // actual depth in bits
|
|
89
|
+
nk_u32_t svl_bytes; // SVL at pack time for validation
|
|
90
|
+
nk_u32_t norms_offset; // byte offset to norms (0 if none)
|
|
91
|
+
nk_u32_t reserved[10]; // padding to 64 bytes
|
|
92
|
+
} nk_sets_smebi32_packed_header_t;
|
|
93
|
+
|
|
94
|
+
/** Count total set bits across a byte vector using streaming SVE.
|
|
95
|
+
* Accumulates per-byte popcounts into u32 lanes via svdot; single horizontal reduction at end. */
|
|
96
|
+
NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data,
|
|
97
|
+
nk_size_t n_bytes) NK_STREAMING_COMPATIBLE_ {
|
|
98
|
+
svuint32_t acc_u32x = svdup_u32(0);
|
|
99
|
+
svuint8_t const ones_u8x = svdup_u8(1);
|
|
100
|
+
for (nk_size_t offset = 0; offset < n_bytes; offset += svcntb()) {
|
|
101
|
+
svbool_t predicate_u8x = svwhilelt_b8_u64(offset, n_bytes);
|
|
102
|
+
acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_u8x, svld1_u8(predicate_u8x, data + offset)), ones_u8x);
|
|
103
|
+
}
|
|
104
|
+
return (nk_u32_t)svaddv_u32(svptrue_b32(), acc_u32x);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
#pragma region Hamming Distance
|
|
108
|
+
|
|
109
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u1_smebi32(nk_size_t row_count, nk_size_t depth_bits) {
|
|
110
|
+
nk_size_t const tile_dim = nk_smebi32_tile_dim_(); // 16 rows per tile
|
|
111
|
+
nk_size_t const depth_tile_size = nk_smebi32_tile_dim_(); // 16 u32 per depth tile = 512 bits
|
|
112
|
+
|
|
113
|
+
nk_size_t const depth_u32 = nk_size_divide_round_up_(depth_bits, 32);
|
|
114
|
+
nk_size_t const row_tile_count = nk_size_divide_round_up_(row_count, tile_dim);
|
|
115
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32, depth_tile_size);
|
|
116
|
+
|
|
117
|
+
nk_size_t const tile_elements = tile_dim * depth_tile_size; // 256 u32 per tile
|
|
118
|
+
nk_size_t size = sizeof(nk_sets_smebi32_packed_header_t);
|
|
119
|
+
size += row_tile_count * depth_tile_count * tile_elements * sizeof(nk_u32_t);
|
|
120
|
+
size += row_count * sizeof(nk_u32_t); // per-row population counts
|
|
121
|
+
|
|
122
|
+
return size;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count, nk_size_t depth_bits,
|
|
126
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
127
|
+
nk_size_t const svl_bytes = nk_smebi32_svl_bytes_();
|
|
128
|
+
nk_size_t const tile_dim = nk_smebi32_tile_dim_(); // 16 rows per tile
|
|
129
|
+
nk_size_t const depth_tile_size = nk_smebi32_tile_dim_(); // 16 u32 per depth tile
|
|
130
|
+
nk_size_t const tile_elements = tile_dim * depth_tile_size;
|
|
131
|
+
nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, NK_BITS_PER_BYTE);
|
|
132
|
+
|
|
133
|
+
nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
|
|
134
|
+
nk_size_t const row_tile_count = nk_size_divide_round_up_(row_count, tile_dim);
|
|
135
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
|
|
136
|
+
nk_size_t const total_tiles = row_tile_count * depth_tile_count;
|
|
137
|
+
nk_size_t const data_size = total_tiles * tile_elements * sizeof(nk_u32_t);
|
|
138
|
+
|
|
139
|
+
nk_sets_smebi32_packed_header_t *header = (nk_sets_smebi32_packed_header_t *)b_packed;
|
|
140
|
+
header->row_tile_count = (nk_u32_t)row_tile_count;
|
|
141
|
+
header->depth_tile_count = (nk_u32_t)depth_tile_count;
|
|
142
|
+
header->rows = (nk_u32_t)row_count;
|
|
143
|
+
header->depth_bits = (nk_u32_t)depth_bits;
|
|
144
|
+
header->svl_bytes = (nk_u32_t)svl_bytes;
|
|
145
|
+
header->norms_offset = (nk_u32_t)(sizeof(nk_sets_smebi32_packed_header_t) + data_size);
|
|
146
|
+
|
|
147
|
+
nk_u32_t *tiles_ptr = (nk_u32_t *)((char *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
|
|
148
|
+
nk_u32_t *norms_ptr = (nk_u32_t *)((char *)b_packed + header->norms_offset);
|
|
149
|
+
|
|
150
|
+
// Zero-initialize all tiles (partial tiles stay zero-padded for predicated loads)
|
|
151
|
+
for (nk_size_t i = 0; i < total_tiles * tile_elements; i++) tiles_ptr[i] = 0;
|
|
152
|
+
|
|
153
|
+
// Pack tiles: column-major u32 within each tile for efficient SVE loads
|
|
154
|
+
for (nk_size_t row_tile = 0; row_tile < row_tile_count; row_tile++) {
|
|
155
|
+
for (nk_size_t depth_tile = 0; depth_tile < depth_tile_count; depth_tile++) {
|
|
156
|
+
nk_size_t const tile_index = row_tile * depth_tile_count + depth_tile;
|
|
157
|
+
nk_u32_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
158
|
+
|
|
159
|
+
nk_size_t const src_row_start = row_tile * tile_dim;
|
|
160
|
+
nk_size_t const src_u32_start = depth_tile * depth_tile_size;
|
|
161
|
+
nk_size_t const rows_to_pack = (src_row_start + tile_dim <= row_count) ? tile_dim
|
|
162
|
+
: (row_count - src_row_start);
|
|
163
|
+
nk_size_t const u32s_to_pack = (src_u32_start + depth_tile_size <= depth_u32_total)
|
|
164
|
+
? depth_tile_size
|
|
165
|
+
: (depth_u32_total > src_u32_start ? depth_u32_total - src_u32_start
|
|
166
|
+
: 0);
|
|
167
|
+
|
|
168
|
+
// Column-major packing: tile_output[col * tile_dim + row]
|
|
169
|
+
for (nk_size_t row = 0; row < rows_to_pack; row++) {
|
|
170
|
+
nk_u32_t const *src_row = (nk_u32_t const *)((char const *)b +
|
|
171
|
+
(src_row_start + row) * b_stride_in_bytes);
|
|
172
|
+
for (nk_size_t col = 0; col < u32s_to_pack; col++) {
|
|
173
|
+
nk_size_t const dst_idx = col * tile_dim + row; // Column-major!
|
|
174
|
+
tile_output[dst_idx] = src_row[src_u32_start + col];
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
// Compute per-row population counts
|
|
181
|
+
for (nk_size_t row = 0; row < row_count; row++) {
|
|
182
|
+
nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
|
|
183
|
+
{
|
|
184
|
+
nk_u64_t nk_local_sum_, nk_local_sumsq_;
|
|
185
|
+
nk_reduce_moments_u1(src_row, depth_in_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
|
|
186
|
+
norms_ptr[row] = (nk_u32_t)nk_local_sum_;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
/**
|
|
192
|
+
* SME Hamming kernel using ZA transpose for unpacked A.
|
|
193
|
+
* ZA0.S = staging (A rows loaded horizontally, read vertically for BMOPA).
|
|
194
|
+
* ZA1-3.S = BMOPA accumulation (3 B column tiles in fast path).
|
|
195
|
+
*
|
|
196
|
+
* Each ZA0.S batch covers 16 depth u32 steps (one full depth tile).
|
|
197
|
+
* BMOPA expansion=1 for u32: each u32 contributes 32 bits via XNOR+POPCNT.
|
|
198
|
+
*/
|
|
199
|
+
__arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_(
|
|
200
|
+
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
201
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
202
|
+
|
|
203
|
+
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
204
|
+
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
205
|
+
nk_size_t const depth_tile_count = header->depth_tile_count;
|
|
206
|
+
|
|
207
|
+
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
208
|
+
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
209
|
+
nk_size_t const tile_elements = tile_dim * depth_tile_size;
|
|
210
|
+
nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
|
|
211
|
+
|
|
212
|
+
nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
|
|
213
|
+
|
|
214
|
+
svbool_t const predicate_all_u32x = svptrue_b32();
|
|
215
|
+
svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
|
|
216
|
+
nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
|
|
217
|
+
|
|
218
|
+
for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
|
|
219
|
+
nk_size_t const row_start_a = row_tile_a * tile_dim;
|
|
220
|
+
nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
|
|
221
|
+
: (row_count_a - row_start_a);
|
|
222
|
+
svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_a_remaining);
|
|
223
|
+
|
|
224
|
+
// Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
|
|
225
|
+
nk_size_t row_tile_b = 0;
|
|
226
|
+
for (; row_tile_b + 3 <= row_tile_count_b; row_tile_b += 3) {
|
|
227
|
+
svzero_mask_za(nk_sme_zero_za32_tiles_123_);
|
|
228
|
+
|
|
229
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
230
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
231
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
232
|
+
? depth_tile_size
|
|
233
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
234
|
+
: 0);
|
|
235
|
+
if (u32s_this_tile == 0) break;
|
|
236
|
+
|
|
237
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
238
|
+
|
|
239
|
+
svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
240
|
+
|
|
241
|
+
// Load A rows into ZA0.S horizontally as u32 words
|
|
242
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
243
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
|
|
244
|
+
(row_start_a + row_in_tile) * a_stride_in_bytes) +
|
|
245
|
+
d_start_u32;
|
|
246
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
// B tile pointers for 3 column tiles
|
|
250
|
+
nk_u32_t const *b_tile0 = b_tiles + ((row_tile_b + 0) * depth_tile_count + d_tile) * tile_elements;
|
|
251
|
+
nk_u32_t const *b_tile1 = b_tiles + ((row_tile_b + 1) * depth_tile_count + d_tile) * tile_elements;
|
|
252
|
+
nk_u32_t const *b_tile2 = b_tiles + ((row_tile_b + 2) * depth_tile_count + d_tile) * tile_elements;
|
|
253
|
+
|
|
254
|
+
// Vertical read + BMOPA for each depth step
|
|
255
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
256
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
|
|
257
|
+
|
|
258
|
+
svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
|
|
259
|
+
svld1_u32(predicate_all_u32x, b_tile0 + step * tile_dim));
|
|
260
|
+
svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
|
|
261
|
+
svld1_u32(predicate_all_u32x, b_tile1 + step * tile_dim));
|
|
262
|
+
svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
|
|
263
|
+
svld1_u32(predicate_all_u32x, b_tile2 + step * tile_dim));
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
// Extract from ZA1-3: Hamming = depth_bits - matching_bits
|
|
268
|
+
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
269
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
270
|
+
|
|
271
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
|
|
272
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
|
|
273
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
|
|
274
|
+
|
|
275
|
+
svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 0) * tile_dim,
|
|
276
|
+
svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x));
|
|
277
|
+
svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 1) * tile_dim,
|
|
278
|
+
svsub_u32_x(predicate_all_u32x, depth_u32x, za2_u32x));
|
|
279
|
+
svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 2) * tile_dim,
|
|
280
|
+
svsub_u32_x(predicate_all_u32x, depth_u32x, za3_u32x));
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
// Remainder: 1 B column tile at a time using ZA1
|
|
285
|
+
for (; row_tile_b < row_tile_count_b; row_tile_b++) {
|
|
286
|
+
nk_size_t const row_start_b = row_tile_b * tile_dim;
|
|
287
|
+
nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
|
|
288
|
+
: (row_count_b - row_start_b);
|
|
289
|
+
svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, rows_b_remaining);
|
|
290
|
+
|
|
291
|
+
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
292
|
+
|
|
293
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
294
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
295
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
296
|
+
? depth_tile_size
|
|
297
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
298
|
+
: 0);
|
|
299
|
+
if (u32s_this_tile == 0) break;
|
|
300
|
+
|
|
301
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
302
|
+
|
|
303
|
+
svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
304
|
+
|
|
305
|
+
// Load A rows into ZA0.S horizontally
|
|
306
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
307
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
|
|
308
|
+
(row_start_a + row_in_tile) * a_stride_in_bytes) +
|
|
309
|
+
d_start_u32;
|
|
310
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
|
|
314
|
+
|
|
315
|
+
// Vertical read + BMOPA
|
|
316
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
317
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
|
|
318
|
+
svuint32_t b_u32x = svld1_u32(predicate_all_u32x, b_tile + step * tile_dim);
|
|
319
|
+
svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_column_u32x, b_u32x);
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
// Extract from ZA1: Hamming = depth_bits - matching_bits
|
|
324
|
+
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
325
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
|
|
326
|
+
svuint32_t hamming_u32x = svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x);
|
|
327
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
328
|
+
svst1_u32(column_predicate_u32x, c_row + row_start_b, hamming_u32x);
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c,
|
|
335
|
+
nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
|
|
336
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
337
|
+
nk_hammings_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
338
|
+
c_stride_in_bytes);
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
/**
|
|
342
|
+
* Symmetric Hamming using ZA0 time-sharing + 3-tile fast path.
|
|
343
|
+
* ZA0.S = staging (A rows loaded horizontally, read vertically for BMOPA).
|
|
344
|
+
* ZA1-3.S = BMOPA accumulators (3 B column tiles in fast path).
|
|
345
|
+
* Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
|
|
346
|
+
*/
|
|
347
|
+
__arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_(
|
|
348
|
+
nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits, nk_size_t stride, nk_u32_t *result,
|
|
349
|
+
nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
350
|
+
|
|
351
|
+
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
352
|
+
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
353
|
+
nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
|
|
354
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
|
|
355
|
+
|
|
356
|
+
svbool_t const predicate_all_u32x = svptrue_b32();
|
|
357
|
+
svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
|
|
358
|
+
|
|
359
|
+
NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
|
|
360
|
+
|
|
361
|
+
nk_size_t const row_end = row_start + row_count;
|
|
362
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dim);
|
|
363
|
+
|
|
364
|
+
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
|
|
365
|
+
row_tile_start += tile_dim) {
|
|
366
|
+
nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
|
|
367
|
+
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
|
|
368
|
+
: (n_vectors - row_tile_start);
|
|
369
|
+
svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_clamped);
|
|
370
|
+
|
|
371
|
+
nk_size_t column_tile_index = 0;
|
|
372
|
+
|
|
373
|
+
// Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
|
|
374
|
+
for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
|
|
375
|
+
svzero_mask_za(nk_sme_zero_za32_tiles_123_);
|
|
376
|
+
|
|
377
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
378
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
379
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
380
|
+
? depth_tile_size
|
|
381
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
382
|
+
: 0);
|
|
383
|
+
if (u32s_this_tile == 0) break;
|
|
384
|
+
|
|
385
|
+
// Load A rows into ZA0 horizontally
|
|
386
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
387
|
+
svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
388
|
+
|
|
389
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
390
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
|
|
391
|
+
(row_tile_start + row_in_tile) * stride) +
|
|
392
|
+
d_start_u32;
|
|
393
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
// Save A columns from ZA0 to stack buffer
|
|
397
|
+
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
398
|
+
svst1_u32(predicate_all_u32x, a_buffer[s],
|
|
399
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
|
|
400
|
+
|
|
401
|
+
// B column tile 0
|
|
402
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
403
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
404
|
+
nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
|
|
405
|
+
if (col_abs < n_vectors) {
|
|
406
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
407
|
+
d_start_u32;
|
|
408
|
+
svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
412
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
|
|
413
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
|
|
414
|
+
svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
// B column tile 1
|
|
418
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
419
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
420
|
+
nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
|
|
421
|
+
if (col_abs < n_vectors) {
|
|
422
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
423
|
+
d_start_u32;
|
|
424
|
+
svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
|
|
425
|
+
}
|
|
426
|
+
}
|
|
427
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
428
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
|
|
429
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
|
|
430
|
+
svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
// B column tile 2
|
|
434
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
435
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
436
|
+
nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
|
|
437
|
+
if (col_abs < n_vectors) {
|
|
438
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
439
|
+
d_start_u32;
|
|
440
|
+
svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
|
|
441
|
+
}
|
|
442
|
+
}
|
|
443
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
444
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
|
|
445
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
|
|
446
|
+
svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
|
|
447
|
+
}
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
// Extract ZA1-3: hamming = depth_bits - ZA[i][j]
|
|
451
|
+
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
452
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
|
|
453
|
+
|
|
454
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
|
|
455
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
|
|
456
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
|
|
457
|
+
|
|
458
|
+
svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 0) * tile_dim,
|
|
459
|
+
svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x));
|
|
460
|
+
svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 1) * tile_dim,
|
|
461
|
+
svsub_u32_x(predicate_all_u32x, depth_u32x, za2_u32x));
|
|
462
|
+
svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 2) * tile_dim,
|
|
463
|
+
svsub_u32_x(predicate_all_u32x, depth_u32x, za3_u32x));
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
// Remainder: 1 column tile at a time using ZA1
|
|
468
|
+
for (; column_tile_index < column_tile_count; column_tile_index++) {
|
|
469
|
+
nk_size_t const col_tile_start = column_tile_index * tile_dim;
|
|
470
|
+
nk_size_t const cols_remaining = (col_tile_start + tile_dim <= n_vectors) ? tile_dim
|
|
471
|
+
: (n_vectors - col_tile_start);
|
|
472
|
+
svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, cols_remaining);
|
|
473
|
+
|
|
474
|
+
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
475
|
+
|
|
476
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
477
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
478
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
479
|
+
? depth_tile_size
|
|
480
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
481
|
+
: 0);
|
|
482
|
+
if (u32s_this_tile == 0) break;
|
|
483
|
+
|
|
484
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
485
|
+
svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
486
|
+
|
|
487
|
+
// Load A rows into ZA0 horizontally
|
|
488
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
489
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
|
|
490
|
+
(row_tile_start + row_in_tile) * stride) +
|
|
491
|
+
d_start_u32;
|
|
492
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
// Save A columns from ZA0 to stack buffer
|
|
496
|
+
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
497
|
+
svst1_u32(predicate_all_u32x, a_buffer[s],
|
|
498
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
|
|
499
|
+
|
|
500
|
+
// Load B column tile into ZA0
|
|
501
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
502
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
503
|
+
nk_size_t const col_abs = col_tile_start + col;
|
|
504
|
+
if (col_abs < n_vectors) {
|
|
505
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
506
|
+
d_start_u32;
|
|
507
|
+
svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
|
|
508
|
+
}
|
|
509
|
+
}
|
|
510
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
511
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
|
|
512
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_u32x, 0, step);
|
|
513
|
+
svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_u32x, b_u32x);
|
|
514
|
+
}
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
518
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
|
|
519
|
+
svuint32_t hamming_u32x = svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x);
|
|
520
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
|
|
521
|
+
svst1_u32(column_predicate_u32x, c_row + col_tile_start, hamming_u32x);
|
|
522
|
+
}
|
|
523
|
+
}
|
|
524
|
+
}
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits,
|
|
528
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
529
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
530
|
+
nk_hammings_symmetric_u1_smebi32_streaming_(vectors, n_vectors, depth_bits, stride, result, result_stride,
|
|
531
|
+
row_start, row_count);
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
#pragma endregion // Hamming Distance
|
|
535
|
+
|
|
536
|
+
/*
|
|
537
|
+
* Jaccard distance via BMOPA matching counts + algebraic normalization.
|
|
538
|
+
*
|
|
539
|
+
* BMOPA gives: matching = popcount(XNOR(a,b))
|
|
540
|
+
* Then:
|
|
541
|
+
* hamming = depth_bits - matching
|
|
542
|
+
* intersection = (norm_a + norm_b - hamming) / 2 = (norm_a + norm_b - depth_bits + matching) / 2
|
|
543
|
+
* union = (norm_a + norm_b + hamming) / 2 = sum_norms - intersection
|
|
544
|
+
* jaccard = 1 - intersection / union (1.0 when union == 0)
|
|
545
|
+
*
|
|
546
|
+
* Inner BMOPA loop is identical to Hamming; only the extraction phase differs.
|
|
547
|
+
* Packed format shares the Hamming tile layout for B operand, plus per-row norms.
|
|
548
|
+
*/
|
|
549
|
+
|
|
550
|
+
#pragma region Jaccard Distance
|
|
551
|
+
|
|
552
|
+
/**
|
|
553
|
+
* SME Jaccard kernel using BMOPA for matching-bit counts.
|
|
554
|
+
* Mirrors nk_hammings_packed_u1_smebi32_streaming_ exactly in structure,
|
|
555
|
+
* but derives intersection/union algebraically from the matching counts:
|
|
556
|
+
* matching = popcount(XNOR(a,b)) (from BMOPA)
|
|
557
|
+
* hamming = depth_bits - matching
|
|
558
|
+
* intersection = (norm_a + norm_b - hamming) / 2
|
|
559
|
+
* union = (norm_a + norm_b + hamming) / 2
|
|
560
|
+
* jaccard = 1 - intersection / union (1.0 when union == 0)
|
|
561
|
+
*/
|
|
562
|
+
__arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_(
|
|
563
|
+
nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
564
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
565
|
+
|
|
566
|
+
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
567
|
+
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
568
|
+
nk_size_t const depth_tile_count = header->depth_tile_count;
|
|
569
|
+
|
|
570
|
+
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
571
|
+
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
572
|
+
nk_size_t const tile_elements = tile_dim * depth_tile_size;
|
|
573
|
+
nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
|
|
574
|
+
|
|
575
|
+
nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
|
|
576
|
+
nk_u32_t const *b_norms = header->norms_offset ? (nk_u32_t const *)((char const *)b_packed + header->norms_offset)
|
|
577
|
+
: (nk_u32_t const *)0;
|
|
578
|
+
|
|
579
|
+
svbool_t const predicate_all_f32x = svptrue_b32();
|
|
580
|
+
svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)depth_bits);
|
|
581
|
+
svfloat32_t const half_f32x = svdup_f32(0.5f);
|
|
582
|
+
svfloat32_t const one_f32x = svdup_f32(1.0f);
|
|
583
|
+
svfloat32_t const zero_f32x = svdup_f32(0.0f);
|
|
584
|
+
nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, 8);
|
|
585
|
+
nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
|
|
586
|
+
|
|
587
|
+
for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
|
|
588
|
+
nk_size_t const row_start_a = row_tile_a * tile_dim;
|
|
589
|
+
nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
|
|
590
|
+
: (row_count_a - row_start_a);
|
|
591
|
+
svbool_t const row_predicate_f32x = svwhilelt_b32_u64(0u, rows_a_remaining);
|
|
592
|
+
|
|
593
|
+
// Compute A tile norms using streaming SVE popcount
|
|
594
|
+
NK_ALIGN64 nk_f32_t a_tile_norms[16];
|
|
595
|
+
for (nk_size_t r = 0; r < rows_a_remaining; r++) {
|
|
596
|
+
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)a + (row_start_a + r) * a_stride_in_bytes);
|
|
597
|
+
a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
// Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
|
|
601
|
+
nk_size_t row_tile_b = 0;
|
|
602
|
+
for (; row_tile_b + 3 <= row_tile_count_b; row_tile_b += 3) {
|
|
603
|
+
svzero_mask_za(nk_sme_zero_za32_tiles_123_);
|
|
604
|
+
|
|
605
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
606
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
607
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
608
|
+
? depth_tile_size
|
|
609
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
610
|
+
: 0);
|
|
611
|
+
if (u32s_this_tile == 0) break;
|
|
612
|
+
|
|
613
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
614
|
+
|
|
615
|
+
svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
616
|
+
|
|
617
|
+
// Load A rows into ZA0.S horizontally as u32 words
|
|
618
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
619
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
|
|
620
|
+
(row_start_a + row_in_tile) * a_stride_in_bytes) +
|
|
621
|
+
d_start_u32;
|
|
622
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
// B tile pointers for 3 column tiles
|
|
626
|
+
nk_u32_t const *b_tile0 = b_tiles + ((row_tile_b + 0) * depth_tile_count + d_tile) * tile_elements;
|
|
627
|
+
nk_u32_t const *b_tile1 = b_tiles + ((row_tile_b + 1) * depth_tile_count + d_tile) * tile_elements;
|
|
628
|
+
nk_u32_t const *b_tile2 = b_tiles + ((row_tile_b + 2) * depth_tile_count + d_tile) * tile_elements;
|
|
629
|
+
|
|
630
|
+
// Vertical read + BMOPA for each depth step
|
|
631
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
632
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, step);
|
|
633
|
+
|
|
634
|
+
svbmopa_za32_u32_m(1, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
|
|
635
|
+
svld1_u32(predicate_all_f32x, b_tile0 + step * tile_dim));
|
|
636
|
+
svbmopa_za32_u32_m(2, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
|
|
637
|
+
svld1_u32(predicate_all_f32x, b_tile1 + step * tile_dim));
|
|
638
|
+
svbmopa_za32_u32_m(3, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
|
|
639
|
+
svld1_u32(predicate_all_f32x, b_tile2 + step * tile_dim));
|
|
640
|
+
}
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
// Extract from ZA1-3: Jaccard normalization via streaming SVE
|
|
644
|
+
// Hoist B norms outside row loop (same for all A rows in this tile-pair)
|
|
645
|
+
svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(
|
|
646
|
+
predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 0) * tile_dim));
|
|
647
|
+
svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(
|
|
648
|
+
predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 1) * tile_dim));
|
|
649
|
+
svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(
|
|
650
|
+
predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 2) * tile_dim));
|
|
651
|
+
|
|
652
|
+
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
653
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
654
|
+
svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
|
|
655
|
+
|
|
656
|
+
// ZA1
|
|
657
|
+
{
|
|
658
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
|
|
659
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
|
|
660
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_0_f32x);
|
|
661
|
+
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
662
|
+
predicate_all_f32x,
|
|
663
|
+
svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
|
|
664
|
+
matching_f32x),
|
|
665
|
+
half_f32x);
|
|
666
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
|
|
667
|
+
svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
|
|
668
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
|
|
669
|
+
svfloat32_t jaccard_f32x = svsel_f32(
|
|
670
|
+
nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
|
|
671
|
+
svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 0) * tile_dim, jaccard_f32x);
|
|
672
|
+
}
|
|
673
|
+
// ZA2
|
|
674
|
+
{
|
|
675
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 2, row);
|
|
676
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za2_u32x);
|
|
677
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_1_f32x);
|
|
678
|
+
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
679
|
+
predicate_all_f32x,
|
|
680
|
+
svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
|
|
681
|
+
matching_f32x),
|
|
682
|
+
half_f32x);
|
|
683
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
|
|
684
|
+
svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
|
|
685
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
|
|
686
|
+
svfloat32_t jaccard_f32x = svsel_f32(
|
|
687
|
+
nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
|
|
688
|
+
svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 1) * tile_dim, jaccard_f32x);
|
|
689
|
+
}
|
|
690
|
+
// ZA3
|
|
691
|
+
{
|
|
692
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 3, row);
|
|
693
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za3_u32x);
|
|
694
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_2_f32x);
|
|
695
|
+
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
696
|
+
predicate_all_f32x,
|
|
697
|
+
svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
|
|
698
|
+
matching_f32x),
|
|
699
|
+
half_f32x);
|
|
700
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
|
|
701
|
+
svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
|
|
702
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
|
|
703
|
+
svfloat32_t jaccard_f32x = svsel_f32(
|
|
704
|
+
nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
|
|
705
|
+
svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 2) * tile_dim, jaccard_f32x);
|
|
706
|
+
}
|
|
707
|
+
}
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
// Remainder: 1 B column tile at a time using ZA1
|
|
711
|
+
for (; row_tile_b < row_tile_count_b; row_tile_b++) {
|
|
712
|
+
nk_size_t const row_start_b = row_tile_b * tile_dim;
|
|
713
|
+
nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
|
|
714
|
+
: (row_count_b - row_start_b);
|
|
715
|
+
svbool_t const column_predicate_f32x = svwhilelt_b32_u64(0u, rows_b_remaining);
|
|
716
|
+
|
|
717
|
+
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
718
|
+
|
|
719
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
720
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
721
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
722
|
+
? depth_tile_size
|
|
723
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
724
|
+
: 0);
|
|
725
|
+
if (u32s_this_tile == 0) break;
|
|
726
|
+
|
|
727
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
728
|
+
|
|
729
|
+
svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
730
|
+
|
|
731
|
+
// Load A rows into ZA0.S horizontally
|
|
732
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
733
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
|
|
734
|
+
(row_start_a + row_in_tile) * a_stride_in_bytes) +
|
|
735
|
+
d_start_u32;
|
|
736
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
|
|
740
|
+
|
|
741
|
+
// Vertical read + BMOPA
|
|
742
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
743
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, step);
|
|
744
|
+
svuint32_t b_u32x = svld1_u32(predicate_all_f32x, b_tile + step * tile_dim);
|
|
745
|
+
svbmopa_za32_u32_m(1, row_predicate_f32x, column_predicate_f32x, a_column_u32x, b_u32x);
|
|
746
|
+
}
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
// Extract from ZA1: Jaccard normalization
|
|
750
|
+
svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_f32x,
|
|
751
|
+
svld1_u32(predicate_all_f32x, b_norms + row_start_b));
|
|
752
|
+
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
753
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
|
|
754
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
|
|
755
|
+
svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
|
|
756
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_f32x);
|
|
757
|
+
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
758
|
+
predicate_all_f32x,
|
|
759
|
+
svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
|
|
760
|
+
matching_f32x),
|
|
761
|
+
half_f32x);
|
|
762
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
|
|
763
|
+
svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
|
|
764
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
|
|
765
|
+
svfloat32_t jaccard_f32x = svsel_f32(nonzero_f32x,
|
|
766
|
+
svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
|
|
767
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
768
|
+
svst1_f32(column_predicate_f32x, c_row + row_start_b, jaccard_f32x);
|
|
769
|
+
}
|
|
770
|
+
}
|
|
771
|
+
}
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
775
|
+
nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
|
|
776
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
777
|
+
nk_jaccards_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
778
|
+
c_stride_in_bytes);
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
/**
|
|
782
|
+
* Symmetric Jaccard kernel using ZA0 time-sharing + 3-tile fast path.
|
|
783
|
+
* Fills upper triangle only (column_tile >= row_tile); caller sees result[i][j] for j >= i.
|
|
784
|
+
* Norms computed on-the-fly using streaming SVE popcount.
|
|
785
|
+
*/
|
|
786
|
+
__arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_(
|
|
787
|
+
nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits, nk_size_t stride, nk_f32_t *result,
|
|
788
|
+
nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
789
|
+
|
|
790
|
+
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
791
|
+
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
792
|
+
nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
|
|
793
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
|
|
794
|
+
nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, NK_BITS_PER_BYTE);
|
|
795
|
+
|
|
796
|
+
svbool_t const predicate_all_f32x = svptrue_b32();
|
|
797
|
+
svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)depth_bits);
|
|
798
|
+
svfloat32_t const half_f32x = svdup_f32(0.5f);
|
|
799
|
+
svfloat32_t const one_f32x = svdup_f32(1.0f);
|
|
800
|
+
svfloat32_t const zero_f32x = svdup_f32(0.0f);
|
|
801
|
+
|
|
802
|
+
NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
|
|
803
|
+
|
|
804
|
+
nk_size_t const row_end = row_start + row_count;
|
|
805
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dim);
|
|
806
|
+
|
|
807
|
+
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
|
|
808
|
+
row_tile_start += tile_dim) {
|
|
809
|
+
nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
|
|
810
|
+
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
|
|
811
|
+
: (n_vectors - row_tile_start);
|
|
812
|
+
svbool_t const row_predicate_f32x = svwhilelt_b32_u64(0u, rows_clamped);
|
|
813
|
+
|
|
814
|
+
// Compute A tile norms
|
|
815
|
+
NK_ALIGN64 nk_f32_t a_tile_norms[16];
|
|
816
|
+
for (nk_size_t r = 0; r < rows_clamped; r++) {
|
|
817
|
+
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors + (row_tile_start + r) * stride);
|
|
818
|
+
a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
|
|
819
|
+
}
|
|
820
|
+
for (nk_size_t r = rows_clamped; r < tile_dim; r++) a_tile_norms[r] = 0.0f;
|
|
821
|
+
|
|
822
|
+
// Upper triangle: start from this row tile's column
|
|
823
|
+
nk_size_t column_tile_index = row_tile_start / tile_dim;
|
|
824
|
+
|
|
825
|
+
// Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
|
|
826
|
+
for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
|
|
827
|
+
svzero_mask_za(nk_sme_zero_za32_tiles_123_);
|
|
828
|
+
|
|
829
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
830
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
831
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
832
|
+
? depth_tile_size
|
|
833
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
834
|
+
: 0);
|
|
835
|
+
if (u32s_this_tile == 0) break;
|
|
836
|
+
|
|
837
|
+
// Load A rows into ZA0 horizontally
|
|
838
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
839
|
+
svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
840
|
+
|
|
841
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
842
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
|
|
843
|
+
(row_tile_start + row_in_tile) * stride) +
|
|
844
|
+
d_start_u32;
|
|
845
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
// Save A columns from ZA0 to stack buffer
|
|
849
|
+
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
850
|
+
svst1_u32(predicate_all_f32x, a_buffer[s],
|
|
851
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, s));
|
|
852
|
+
|
|
853
|
+
// B column tile 0
|
|
854
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
855
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
856
|
+
nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
|
|
857
|
+
if (col_abs < n_vectors) {
|
|
858
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
859
|
+
d_start_u32;
|
|
860
|
+
svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
|
|
861
|
+
}
|
|
862
|
+
}
|
|
863
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
864
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
|
|
865
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
|
|
866
|
+
svbmopa_za32_u32_m(1, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
// B column tile 1
|
|
870
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
871
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
872
|
+
nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
|
|
873
|
+
if (col_abs < n_vectors) {
|
|
874
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
875
|
+
d_start_u32;
|
|
876
|
+
svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
|
|
877
|
+
}
|
|
878
|
+
}
|
|
879
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
880
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
|
|
881
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
|
|
882
|
+
svbmopa_za32_u32_m(2, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
|
|
883
|
+
}
|
|
884
|
+
|
|
885
|
+
// B column tile 2
|
|
886
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
887
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
888
|
+
nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
|
|
889
|
+
if (col_abs < n_vectors) {
|
|
890
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
891
|
+
d_start_u32;
|
|
892
|
+
svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
|
|
893
|
+
}
|
|
894
|
+
}
|
|
895
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
896
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
|
|
897
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
|
|
898
|
+
svbmopa_za32_u32_m(3, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
|
|
899
|
+
}
|
|
900
|
+
}
|
|
901
|
+
|
|
902
|
+
// Compute B tile norms for 3 column tiles
|
|
903
|
+
NK_ALIGN64 nk_u32_t b_tile_norms_0[16];
|
|
904
|
+
NK_ALIGN64 nk_u32_t b_tile_norms_1[16];
|
|
905
|
+
NK_ALIGN64 nk_u32_t b_tile_norms_2[16];
|
|
906
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
907
|
+
nk_size_t const col_abs_0 = (column_tile_index + 0) * tile_dim + col;
|
|
908
|
+
nk_size_t const col_abs_1 = (column_tile_index + 1) * tile_dim + col;
|
|
909
|
+
nk_size_t const col_abs_2 = (column_tile_index + 2) * tile_dim + col;
|
|
910
|
+
b_tile_norms_0[col] = (col_abs_0 < n_vectors)
|
|
911
|
+
? nk_sets_reduce_sumsq_u1_streaming_(
|
|
912
|
+
(nk_u1x8_t const *)((char const *)vectors + col_abs_0 * stride),
|
|
913
|
+
depth_in_bytes)
|
|
914
|
+
: 0;
|
|
915
|
+
b_tile_norms_1[col] = (col_abs_1 < n_vectors)
|
|
916
|
+
? nk_sets_reduce_sumsq_u1_streaming_(
|
|
917
|
+
(nk_u1x8_t const *)((char const *)vectors + col_abs_1 * stride),
|
|
918
|
+
depth_in_bytes)
|
|
919
|
+
: 0;
|
|
920
|
+
b_tile_norms_2[col] = (col_abs_2 < n_vectors)
|
|
921
|
+
? nk_sets_reduce_sumsq_u1_streaming_(
|
|
922
|
+
(nk_u1x8_t const *)((char const *)vectors + col_abs_2 * stride),
|
|
923
|
+
depth_in_bytes)
|
|
924
|
+
: 0;
|
|
925
|
+
}
|
|
926
|
+
|
|
927
|
+
// Extract ZA1-3: Jaccard normalization
|
|
928
|
+
svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(predicate_all_f32x,
|
|
929
|
+
svld1_u32(predicate_all_f32x, b_tile_norms_0));
|
|
930
|
+
svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(predicate_all_f32x,
|
|
931
|
+
svld1_u32(predicate_all_f32x, b_tile_norms_1));
|
|
932
|
+
svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(predicate_all_f32x,
|
|
933
|
+
svld1_u32(predicate_all_f32x, b_tile_norms_2));
|
|
934
|
+
|
|
935
|
+
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
936
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride);
|
|
937
|
+
svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
|
|
938
|
+
|
|
939
|
+
// ZA1
|
|
940
|
+
{
|
|
941
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
|
|
942
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
|
|
943
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_0_f32x);
|
|
944
|
+
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
945
|
+
predicate_all_f32x,
|
|
946
|
+
svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
|
|
947
|
+
matching_f32x),
|
|
948
|
+
half_f32x);
|
|
949
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
|
|
950
|
+
svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
|
|
951
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
|
|
952
|
+
svfloat32_t jaccard_f32x = svsel_f32(
|
|
953
|
+
nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
|
|
954
|
+
svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 0) * tile_dim, jaccard_f32x);
|
|
955
|
+
}
|
|
956
|
+
// ZA2
|
|
957
|
+
{
|
|
958
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 2, row);
|
|
959
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za2_u32x);
|
|
960
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_1_f32x);
|
|
961
|
+
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
962
|
+
predicate_all_f32x,
|
|
963
|
+
svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
|
|
964
|
+
matching_f32x),
|
|
965
|
+
half_f32x);
|
|
966
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
|
|
967
|
+
svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
|
|
968
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
|
|
969
|
+
svfloat32_t jaccard_f32x = svsel_f32(
|
|
970
|
+
nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
|
|
971
|
+
svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 1) * tile_dim, jaccard_f32x);
|
|
972
|
+
}
|
|
973
|
+
// ZA3
|
|
974
|
+
{
|
|
975
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 3, row);
|
|
976
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za3_u32x);
|
|
977
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_2_f32x);
|
|
978
|
+
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
979
|
+
predicate_all_f32x,
|
|
980
|
+
svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
|
|
981
|
+
matching_f32x),
|
|
982
|
+
half_f32x);
|
|
983
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
|
|
984
|
+
svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
|
|
985
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
|
|
986
|
+
svfloat32_t jaccard_f32x = svsel_f32(
|
|
987
|
+
nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
|
|
988
|
+
svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 2) * tile_dim, jaccard_f32x);
|
|
989
|
+
}
|
|
990
|
+
}
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
// Remainder: 1 column tile at a time using ZA1
|
|
994
|
+
for (; column_tile_index < column_tile_count; column_tile_index++) {
|
|
995
|
+
nk_size_t const col_tile_start = column_tile_index * tile_dim;
|
|
996
|
+
nk_size_t const cols_remaining = (col_tile_start + tile_dim <= n_vectors) ? tile_dim
|
|
997
|
+
: (n_vectors - col_tile_start);
|
|
998
|
+
svbool_t const column_predicate_f32x = svwhilelt_b32_u64(0u, cols_remaining);
|
|
999
|
+
|
|
1000
|
+
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
1001
|
+
|
|
1002
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
1003
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
1004
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
1005
|
+
? depth_tile_size
|
|
1006
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
1007
|
+
: 0);
|
|
1008
|
+
if (u32s_this_tile == 0) break;
|
|
1009
|
+
|
|
1010
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1011
|
+
svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
1012
|
+
|
|
1013
|
+
// Load A rows into ZA0 horizontally
|
|
1014
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
1015
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
|
|
1016
|
+
(row_tile_start + row_in_tile) * stride) +
|
|
1017
|
+
d_start_u32;
|
|
1018
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
|
|
1019
|
+
}
|
|
1020
|
+
|
|
1021
|
+
// Save A columns from ZA0 to stack buffer
|
|
1022
|
+
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
1023
|
+
svst1_u32(predicate_all_f32x, a_buffer[s],
|
|
1024
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, s));
|
|
1025
|
+
|
|
1026
|
+
// Load B column tile into ZA0
|
|
1027
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
1028
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
1029
|
+
nk_size_t const col_abs = col_tile_start + col;
|
|
1030
|
+
if (col_abs < n_vectors) {
|
|
1031
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
1032
|
+
d_start_u32;
|
|
1033
|
+
svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
|
|
1034
|
+
}
|
|
1035
|
+
}
|
|
1036
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
1037
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
|
|
1038
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_f32x, 0, step);
|
|
1039
|
+
svbmopa_za32_u32_m(1, row_predicate_f32x, column_predicate_f32x, a_u32x, b_u32x);
|
|
1040
|
+
}
|
|
1041
|
+
}
|
|
1042
|
+
|
|
1043
|
+
// Compute B tile norms for remainder tile
|
|
1044
|
+
NK_ALIGN64 nk_u32_t b_tile_norms[16];
|
|
1045
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
1046
|
+
nk_size_t const col_abs = col_tile_start + col;
|
|
1047
|
+
b_tile_norms[col] = (col_abs < n_vectors)
|
|
1048
|
+
? nk_sets_reduce_sumsq_u1_streaming_(
|
|
1049
|
+
(nk_u1x8_t const *)((char const *)vectors + col_abs * stride),
|
|
1050
|
+
depth_in_bytes)
|
|
1051
|
+
: 0;
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_f32x, svld1_u32(predicate_all_f32x, b_tile_norms));
|
|
1055
|
+
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
1056
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
|
|
1057
|
+
svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
|
|
1058
|
+
svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
|
|
1059
|
+
svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_f32x);
|
|
1060
|
+
svfloat32_t intersection_f32x = svmul_f32_x(
|
|
1061
|
+
predicate_all_f32x,
|
|
1062
|
+
svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
|
|
1063
|
+
matching_f32x),
|
|
1064
|
+
half_f32x);
|
|
1065
|
+
svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
|
|
1066
|
+
svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
|
|
1067
|
+
svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
|
|
1068
|
+
svfloat32_t jaccard_f32x = svsel_f32(nonzero_f32x,
|
|
1069
|
+
svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
|
|
1070
|
+
nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride);
|
|
1071
|
+
svst1_f32(column_predicate_f32x, c_row + col_tile_start, jaccard_f32x);
|
|
1072
|
+
}
|
|
1073
|
+
}
|
|
1074
|
+
}
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits,
|
|
1078
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1079
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1080
|
+
nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, n_vectors, depth_bits, stride, result, result_stride,
|
|
1081
|
+
row_start, row_count);
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
#pragma endregion // Jaccard Distance
|
|
1085
|
+
|
|
1086
|
+
#if defined(__clang__)
|
|
1087
|
+
#pragma clang attribute pop
|
|
1088
|
+
#elif defined(__GNUC__)
|
|
1089
|
+
#pragma GCC pop_options
|
|
1090
|
+
#endif
|
|
1091
|
+
|
|
1092
|
+
#if defined(__cplusplus)
|
|
1093
|
+
} // extern "C"
|
|
1094
|
+
#endif
|
|
1095
|
+
|
|
1096
|
+
#endif // NK_TARGET_SMEBI32
|
|
1097
|
+
#endif // NK_TARGET_ARM_
|
|
1098
|
+
|
|
1099
|
+
#endif // NK_SETS_SMEBI32_H
|