numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Batched Dot Products for SME (u1 binary vectors).
|
|
3
|
+
* @file include/numkong/dots/smebi32.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 24, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dots.h
|
|
8
|
+
*
|
|
9
|
+
* Uses ARM SME BMOPA instruction for binary dot products:
|
|
10
|
+
* BMOPA gives: matching = popcount(XNOR(a,b))
|
|
11
|
+
* dot(a,b) = popcount(a AND b) = (pop_a + pop_b - depth + matching) / 2
|
|
12
|
+
*/
|
|
13
|
+
#ifndef NK_DOTS_SMEBI32_H
|
|
14
|
+
#define NK_DOTS_SMEBI32_H
|
|
15
|
+
|
|
16
|
+
#if NK_TARGET_ARM_
|
|
17
|
+
#if NK_TARGET_SMEBI32
|
|
18
|
+
|
|
19
|
+
#include "numkong/types.h"
|
|
20
|
+
#include "numkong/dots/sme.h" // nk_sme_zero_za32_* constants
|
|
21
|
+
#include "numkong/sets/smebi32.h" // nk_dots_packed_size_u1_smebi32, nk_dots_pack_u1_smebi32
|
|
22
|
+
|
|
23
|
+
#if defined(__cplusplus)
|
|
24
|
+
extern "C" {
|
|
25
|
+
#endif
|
|
26
|
+
|
|
27
|
+
#if defined(__clang__)
|
|
28
|
+
#pragma clang attribute push(__attribute__((target("sme2,sve2"))), apply_to = function)
|
|
29
|
+
#elif defined(__GNUC__)
|
|
30
|
+
#pragma GCC push_options
|
|
31
|
+
#pragma GCC target("+sme2")
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
/**
|
|
35
|
+
* SME u1 dot-product kernel using ZA transpose for unpacked A.
|
|
36
|
+
* ZA0.S = staging (A rows loaded horizontally, read vertically for BMOPA).
|
|
37
|
+
* ZA1-3.S = BMOPA accumulation (3 B column tiles in fast path).
|
|
38
|
+
*
|
|
39
|
+
* BMOPA gives matching = popcount(XNOR(a,b)).
|
|
40
|
+
* dot(a,b) = popcount(a AND b) = (pop_a + pop_b - depth_bits + matching) / 2
|
|
41
|
+
*/
|
|
42
|
+
__arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_streaming_(
|
|
43
|
+
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
44
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
45
|
+
|
|
46
|
+
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
47
|
+
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
48
|
+
nk_size_t const depth_tile_count = header->depth_tile_count;
|
|
49
|
+
|
|
50
|
+
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
51
|
+
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
52
|
+
nk_size_t const tile_elements = tile_dim * depth_tile_size;
|
|
53
|
+
nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
|
|
54
|
+
|
|
55
|
+
nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
|
|
56
|
+
nk_u32_t const *b_norms = header->norms_offset ? (nk_u32_t const *)((char const *)b_packed + header->norms_offset)
|
|
57
|
+
: (nk_u32_t const *)0;
|
|
58
|
+
|
|
59
|
+
svbool_t const predicate_all_u32x = svptrue_b32();
|
|
60
|
+
svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
|
|
61
|
+
nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, 8);
|
|
62
|
+
nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
|
|
63
|
+
|
|
64
|
+
for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
|
|
65
|
+
nk_size_t const row_start_a = row_tile_a * tile_dim;
|
|
66
|
+
nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
|
|
67
|
+
: (row_count_a - row_start_a);
|
|
68
|
+
svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_a_remaining);
|
|
69
|
+
|
|
70
|
+
// Compute A row popcounts for this tile
|
|
71
|
+
nk_u32_t a_popcounts[16];
|
|
72
|
+
for (nk_size_t r = 0; r < rows_a_remaining; r++) {
|
|
73
|
+
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)a + (row_start_a + r) * a_stride_in_bytes);
|
|
74
|
+
a_popcounts[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
|
|
78
|
+
nk_size_t row_tile_b = 0;
|
|
79
|
+
for (; row_tile_b + 3 <= row_tile_count_b; row_tile_b += 3) {
|
|
80
|
+
svzero_mask_za(nk_sme_zero_za32_tiles_123_);
|
|
81
|
+
|
|
82
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
83
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
84
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
85
|
+
? depth_tile_size
|
|
86
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
87
|
+
: 0);
|
|
88
|
+
if (u32s_this_tile == 0) break;
|
|
89
|
+
|
|
90
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
91
|
+
|
|
92
|
+
svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
93
|
+
|
|
94
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
95
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
|
|
96
|
+
(row_start_a + row_in_tile) * a_stride_in_bytes) +
|
|
97
|
+
d_start_u32;
|
|
98
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
nk_u32_t const *b_tile0 = b_tiles + ((row_tile_b + 0) * depth_tile_count + d_tile) * tile_elements;
|
|
102
|
+
nk_u32_t const *b_tile1 = b_tiles + ((row_tile_b + 1) * depth_tile_count + d_tile) * tile_elements;
|
|
103
|
+
nk_u32_t const *b_tile2 = b_tiles + ((row_tile_b + 2) * depth_tile_count + d_tile) * tile_elements;
|
|
104
|
+
|
|
105
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
106
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
|
|
107
|
+
|
|
108
|
+
svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
|
|
109
|
+
svld1_u32(predicate_all_u32x, b_tile0 + step * tile_dim));
|
|
110
|
+
svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
|
|
111
|
+
svld1_u32(predicate_all_u32x, b_tile1 + step * tile_dim));
|
|
112
|
+
svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
|
|
113
|
+
svld1_u32(predicate_all_u32x, b_tile2 + step * tile_dim));
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// Extract: dot = (pop_a + pop_b - depth + matching) / 2
|
|
118
|
+
// matching = ZA[i][j]
|
|
119
|
+
svuint32_t b_pop0_u32x = svld1_u32(predicate_all_u32x, b_norms + (row_tile_b + 0) * tile_dim);
|
|
120
|
+
svuint32_t b_pop1_u32x = svld1_u32(predicate_all_u32x, b_norms + (row_tile_b + 1) * tile_dim);
|
|
121
|
+
svuint32_t b_pop2_u32x = svld1_u32(predicate_all_u32x, b_norms + (row_tile_b + 2) * tile_dim);
|
|
122
|
+
|
|
123
|
+
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
124
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
125
|
+
svuint32_t pop_a_u32x = svdup_u32(a_popcounts[row]);
|
|
126
|
+
|
|
127
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
|
|
128
|
+
svuint32_t sum_pops0_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_pop0_u32x);
|
|
129
|
+
svuint32_t numerator0_u32x = svadd_u32_x(
|
|
130
|
+
predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops0_u32x, depth_u32x), za1_u32x);
|
|
131
|
+
svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 0) * tile_dim,
|
|
132
|
+
svlsr_n_u32_x(predicate_all_u32x, numerator0_u32x, 1));
|
|
133
|
+
|
|
134
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
|
|
135
|
+
svuint32_t sum_pops1_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_pop1_u32x);
|
|
136
|
+
svuint32_t numerator1_u32x = svadd_u32_x(
|
|
137
|
+
predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops1_u32x, depth_u32x), za2_u32x);
|
|
138
|
+
svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 1) * tile_dim,
|
|
139
|
+
svlsr_n_u32_x(predicate_all_u32x, numerator1_u32x, 1));
|
|
140
|
+
|
|
141
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
|
|
142
|
+
svuint32_t sum_pops2_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_pop2_u32x);
|
|
143
|
+
svuint32_t numerator2_u32x = svadd_u32_x(
|
|
144
|
+
predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops2_u32x, depth_u32x), za3_u32x);
|
|
145
|
+
svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 2) * tile_dim,
|
|
146
|
+
svlsr_n_u32_x(predicate_all_u32x, numerator2_u32x, 1));
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// Remainder: 1 B column tile at a time using ZA1
|
|
151
|
+
for (; row_tile_b < row_tile_count_b; row_tile_b++) {
|
|
152
|
+
nk_size_t const row_start_b = row_tile_b * tile_dim;
|
|
153
|
+
nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
|
|
154
|
+
: (row_count_b - row_start_b);
|
|
155
|
+
svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, rows_b_remaining);
|
|
156
|
+
|
|
157
|
+
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
158
|
+
|
|
159
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
160
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
161
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
162
|
+
? depth_tile_size
|
|
163
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
164
|
+
: 0);
|
|
165
|
+
if (u32s_this_tile == 0) break;
|
|
166
|
+
|
|
167
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
168
|
+
|
|
169
|
+
svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
170
|
+
|
|
171
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
|
|
172
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
|
|
173
|
+
(row_start_a + row_in_tile) * a_stride_in_bytes) +
|
|
174
|
+
d_start_u32;
|
|
175
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
|
|
179
|
+
|
|
180
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
181
|
+
svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
|
|
182
|
+
svuint32_t b_u32x = svld1_u32(predicate_all_u32x, b_tile + step * tile_dim);
|
|
183
|
+
svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_column_u32x, b_u32x);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
// Extract: dot = (pop_a + pop_b - depth + matching) / 2
|
|
188
|
+
svuint32_t b_pop_u32x = svld1_u32(predicate_all_u32x, b_norms + row_start_b);
|
|
189
|
+
for (nk_size_t row = 0; row < rows_a_remaining; row++) {
|
|
190
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
|
|
191
|
+
svuint32_t pop_a_u32x = svdup_u32(a_popcounts[row]);
|
|
192
|
+
svuint32_t sum_pops_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_pop_u32x);
|
|
193
|
+
svuint32_t numerator_u32x = svadd_u32_x(
|
|
194
|
+
predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops_u32x, depth_u32x), za1_u32x);
|
|
195
|
+
nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
|
|
196
|
+
svst1_u32(column_predicate_u32x, c_row + row_start_b,
|
|
197
|
+
svlsr_n_u32_x(predicate_all_u32x, numerator_u32x, 1));
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
NK_PUBLIC void nk_dots_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a,
|
|
204
|
+
nk_size_t row_count_b, nk_size_t depth_bits, nk_size_t a_stride_in_bytes,
|
|
205
|
+
nk_size_t c_stride_in_bytes) {
|
|
206
|
+
nk_dots_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
207
|
+
c_stride_in_bytes);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
/**
|
|
211
|
+
* Symmetric u1 dot-product using ZA0 time-sharing + 3-tile fast path.
|
|
212
|
+
* Same ZA transpose pattern as hammings_symmetric, but with dot extraction.
|
|
213
|
+
*/
|
|
214
|
+
__arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32_streaming_(
|
|
215
|
+
nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits, nk_size_t stride, nk_u32_t *result,
|
|
216
|
+
nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
217
|
+
|
|
218
|
+
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
219
|
+
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
220
|
+
nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
|
|
221
|
+
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
|
|
222
|
+
nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, 8);
|
|
223
|
+
|
|
224
|
+
svbool_t const predicate_all_u32x = svptrue_b32();
|
|
225
|
+
svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
|
|
226
|
+
|
|
227
|
+
NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
|
|
228
|
+
|
|
229
|
+
nk_size_t const row_end = row_start + row_count;
|
|
230
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dim);
|
|
231
|
+
|
|
232
|
+
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
|
|
233
|
+
row_tile_start += tile_dim) {
|
|
234
|
+
nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
|
|
235
|
+
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
|
|
236
|
+
: (n_vectors - row_tile_start);
|
|
237
|
+
svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_clamped);
|
|
238
|
+
|
|
239
|
+
// Compute A tile popcounts
|
|
240
|
+
NK_ALIGN64 nk_u32_t a_tile_pops[16];
|
|
241
|
+
for (nk_size_t r = 0; r < rows_clamped; r++) {
|
|
242
|
+
nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors + (row_tile_start + r) * stride);
|
|
243
|
+
a_tile_pops[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
|
|
244
|
+
}
|
|
245
|
+
for (nk_size_t r = rows_clamped; r < tile_dim; r++) a_tile_pops[r] = 0;
|
|
246
|
+
|
|
247
|
+
nk_size_t column_tile_index = 0;
|
|
248
|
+
|
|
249
|
+
// Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
|
|
250
|
+
for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
|
|
251
|
+
svzero_mask_za(nk_sme_zero_za32_tiles_123_);
|
|
252
|
+
|
|
253
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
254
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
255
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
256
|
+
? depth_tile_size
|
|
257
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
258
|
+
: 0);
|
|
259
|
+
if (u32s_this_tile == 0) break;
|
|
260
|
+
|
|
261
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
262
|
+
svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
263
|
+
|
|
264
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
265
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
|
|
266
|
+
(row_tile_start + row_in_tile) * stride) +
|
|
267
|
+
d_start_u32;
|
|
268
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
// Save A columns
|
|
272
|
+
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
273
|
+
svst1_u32(predicate_all_u32x, a_buffer[s],
|
|
274
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
|
|
275
|
+
|
|
276
|
+
// B column tile 0
|
|
277
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
278
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
279
|
+
nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
|
|
280
|
+
if (col_abs < n_vectors) {
|
|
281
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
282
|
+
d_start_u32;
|
|
283
|
+
svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
287
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
|
|
288
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
|
|
289
|
+
svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
// B column tile 1
|
|
293
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
294
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
295
|
+
nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
|
|
296
|
+
if (col_abs < n_vectors) {
|
|
297
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
298
|
+
d_start_u32;
|
|
299
|
+
svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
303
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
|
|
304
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
|
|
305
|
+
svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
// B column tile 2
|
|
309
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
310
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
311
|
+
nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
|
|
312
|
+
if (col_abs < n_vectors) {
|
|
313
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
314
|
+
d_start_u32;
|
|
315
|
+
svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
319
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
|
|
320
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
|
|
321
|
+
svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
// Extract: dot = (pop_a + pop_b - depth + matching) / 2
|
|
326
|
+
// Compute B tile popcounts
|
|
327
|
+
NK_ALIGN64 nk_u32_t b_pops[3][16];
|
|
328
|
+
for (nk_size_t t = 0; t < 3; t++) {
|
|
329
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
330
|
+
nk_size_t const col_abs = (column_tile_index + t) * tile_dim + col;
|
|
331
|
+
if (col_abs < n_vectors) {
|
|
332
|
+
nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs * stride);
|
|
333
|
+
b_pops[t][col] = nk_sets_reduce_sumsq_u1_streaming_(b_row, depth_in_bytes);
|
|
334
|
+
}
|
|
335
|
+
else { b_pops[t][col] = 0; }
|
|
336
|
+
}
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
340
|
+
nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
|
|
341
|
+
svuint32_t pop_a_u32x = svdup_u32(a_tile_pops[row]);
|
|
342
|
+
|
|
343
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
|
|
344
|
+
svuint32_t b_popcount_0_u32x = svld1_u32(predicate_all_u32x, b_pops[0]);
|
|
345
|
+
svuint32_t sum_pops0_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_popcount_0_u32x);
|
|
346
|
+
svuint32_t numerator0_u32x = svadd_u32_x(
|
|
347
|
+
predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops0_u32x, depth_u32x), za1_u32x);
|
|
348
|
+
svst1_u32(predicate_all_u32x, result_row + (column_tile_index + 0) * tile_dim,
|
|
349
|
+
svlsr_n_u32_x(predicate_all_u32x, numerator0_u32x, 1));
|
|
350
|
+
|
|
351
|
+
svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
|
|
352
|
+
svuint32_t b_popcount_1_u32x = svld1_u32(predicate_all_u32x, b_pops[1]);
|
|
353
|
+
svuint32_t sum_pops1_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_popcount_1_u32x);
|
|
354
|
+
svuint32_t numerator1_u32x = svadd_u32_x(
|
|
355
|
+
predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops1_u32x, depth_u32x), za2_u32x);
|
|
356
|
+
svst1_u32(predicate_all_u32x, result_row + (column_tile_index + 1) * tile_dim,
|
|
357
|
+
svlsr_n_u32_x(predicate_all_u32x, numerator1_u32x, 1));
|
|
358
|
+
|
|
359
|
+
svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
|
|
360
|
+
svuint32_t b_popcount_2_u32x = svld1_u32(predicate_all_u32x, b_pops[2]);
|
|
361
|
+
svuint32_t sum_pops2_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_popcount_2_u32x);
|
|
362
|
+
svuint32_t numerator2_u32x = svadd_u32_x(
|
|
363
|
+
predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops2_u32x, depth_u32x), za3_u32x);
|
|
364
|
+
svst1_u32(predicate_all_u32x, result_row + (column_tile_index + 2) * tile_dim,
|
|
365
|
+
svlsr_n_u32_x(predicate_all_u32x, numerator2_u32x, 1));
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
// Remainder: 1 column tile at a time using ZA1
|
|
370
|
+
for (; column_tile_index < column_tile_count; column_tile_index++) {
|
|
371
|
+
nk_size_t const col_tile_start = column_tile_index * tile_dim;
|
|
372
|
+
nk_size_t const cols_remaining = (col_tile_start + tile_dim <= n_vectors) ? tile_dim
|
|
373
|
+
: (n_vectors - col_tile_start);
|
|
374
|
+
svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, cols_remaining);
|
|
375
|
+
|
|
376
|
+
svzero_mask_za(nk_sme_zero_za32_tile_1_);
|
|
377
|
+
|
|
378
|
+
for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
|
|
379
|
+
nk_size_t const d_start_u32 = d_tile * depth_tile_size;
|
|
380
|
+
nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
|
|
381
|
+
? depth_tile_size
|
|
382
|
+
: (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
|
|
383
|
+
: 0);
|
|
384
|
+
if (u32s_this_tile == 0) break;
|
|
385
|
+
|
|
386
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
387
|
+
svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
|
|
388
|
+
|
|
389
|
+
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
390
|
+
nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
|
|
391
|
+
(row_tile_start + row_in_tile) * stride) +
|
|
392
|
+
d_start_u32;
|
|
393
|
+
svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
for (nk_size_t s = 0; s < u32s_this_tile; s++)
|
|
397
|
+
svst1_u32(predicate_all_u32x, a_buffer[s],
|
|
398
|
+
svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
|
|
399
|
+
|
|
400
|
+
svzero_mask_za(nk_sme_zero_za32_tile_0_);
|
|
401
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
402
|
+
nk_size_t const col_abs = col_tile_start + col;
|
|
403
|
+
if (col_abs < n_vectors) {
|
|
404
|
+
nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
|
|
405
|
+
d_start_u32;
|
|
406
|
+
svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
|
|
407
|
+
}
|
|
408
|
+
}
|
|
409
|
+
for (nk_size_t step = 0; step < u32s_this_tile; step++) {
|
|
410
|
+
svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
|
|
411
|
+
svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_u32x, 0, step);
|
|
412
|
+
svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_u32x, b_u32x);
|
|
413
|
+
}
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
// Compute B tile popcounts for remainder
|
|
417
|
+
NK_ALIGN64 nk_u32_t b_pops_r[16];
|
|
418
|
+
for (nk_size_t col = 0; col < tile_dim; col++) {
|
|
419
|
+
nk_size_t const col_abs = col_tile_start + col;
|
|
420
|
+
if (col_abs < n_vectors) {
|
|
421
|
+
nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs * stride);
|
|
422
|
+
b_pops_r[col] = nk_sets_reduce_sumsq_u1_streaming_(b_row, depth_in_bytes);
|
|
423
|
+
}
|
|
424
|
+
else { b_pops_r[col] = 0; }
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
428
|
+
svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
|
|
429
|
+
svuint32_t pop_a_u32x = svdup_u32(a_tile_pops[row]);
|
|
430
|
+
svuint32_t b_popcount_u32x = svld1_u32(predicate_all_u32x, b_pops_r);
|
|
431
|
+
svuint32_t sum_pops_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_popcount_u32x);
|
|
432
|
+
svuint32_t numerator_u32x = svadd_u32_x(
|
|
433
|
+
predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops_u32x, depth_u32x), za1_u32x);
|
|
434
|
+
nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
|
|
435
|
+
svst1_u32(column_predicate_u32x, result_row + col_tile_start,
|
|
436
|
+
svlsr_n_u32_x(predicate_all_u32x, numerator_u32x, 1));
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
NK_PUBLIC void nk_dots_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits,
|
|
443
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
444
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
445
|
+
nk_dots_symmetric_u1_smebi32_streaming_(vectors, n_vectors, depth_bits, stride, result, result_stride, row_start,
|
|
446
|
+
row_count);
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
#if defined(__clang__)
|
|
450
|
+
#pragma clang attribute pop
|
|
451
|
+
#elif defined(__GNUC__)
|
|
452
|
+
#pragma GCC pop_options
|
|
453
|
+
#endif
|
|
454
|
+
|
|
455
|
+
#if defined(__cplusplus)
|
|
456
|
+
} // extern "C"
|
|
457
|
+
#endif
|
|
458
|
+
|
|
459
|
+
#endif // NK_TARGET_SMEBI32
|
|
460
|
+
#endif // NK_TARGET_ARM_
|
|
461
|
+
#endif // NK_DOTS_SMEBI32_H
|