numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -20,226 +20,211 @@ extern "C" {
|
|
|
20
20
|
#endif
|
|
21
21
|
|
|
22
22
|
#if defined(__clang__)
|
|
23
|
-
#pragma clang attribute push(__attribute__((target("sme
|
|
23
|
+
#pragma clang attribute push(__attribute__((target("sme"))), apply_to = function)
|
|
24
24
|
#elif defined(__GNUC__)
|
|
25
25
|
#pragma GCC push_options
|
|
26
26
|
#pragma GCC target("+sme")
|
|
27
27
|
#endif
|
|
28
28
|
|
|
29
|
-
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_f16_ssve_(nk_f16_t const *data, nk_size_t count)
|
|
30
|
-
svfloat32_t
|
|
31
|
-
|
|
29
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_f16_ssve_(nk_f16_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
30
|
+
svfloat32_t accumulator_even_f32x = svdup_f32(0.0f);
|
|
31
|
+
svfloat32_t accumulator_odd_f32x = svdup_f32(0.0f);
|
|
32
|
+
nk_size_t const vector_length = svcnth();
|
|
33
|
+
nk_size_t const half_vector_length = svcntw();
|
|
32
34
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
33
|
-
svbool_t
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
35
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(i, count);
|
|
36
|
+
svfloat16_t values_f16x = svld1_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(data + i));
|
|
37
|
+
|
|
38
|
+
svbool_t predicate_even_b32x = svwhilelt_b32_u64(i, count);
|
|
39
|
+
svfloat32_t values_even_f32x = svcvt_f32_f16_x(predicate_even_b32x, values_f16x);
|
|
40
|
+
accumulator_even_f32x = svmla_f32_m(predicate_even_b32x, accumulator_even_f32x, values_even_f32x,
|
|
41
|
+
values_even_f32x);
|
|
42
|
+
|
|
43
|
+
svbool_t predicate_odd_b32x = svwhilelt_b32_u64(i + half_vector_length, count);
|
|
44
|
+
svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
|
|
45
|
+
accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
|
|
37
46
|
}
|
|
38
|
-
return svaddv_f32(svptrue_b32(),
|
|
47
|
+
return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
|
|
39
48
|
}
|
|
40
49
|
|
|
41
|
-
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_size_t count)
|
|
50
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
42
51
|
svfloat32_t accumulator_f32x = svdup_f32(0.0f);
|
|
43
|
-
nk_size_t const vector_length =
|
|
52
|
+
nk_size_t const vector_length = svcnth();
|
|
44
53
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
45
|
-
svbool_t
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
accumulator_f32x = svmla_f32_x(predicate_f32x, accumulator_f32x, values_f32x, values_f32x);
|
|
54
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(i, count);
|
|
55
|
+
svbfloat16_t values_bf16x = svld1_bf16(predicate_b16x, (nk_bf16_for_arm_simd_t const *)(data + i));
|
|
56
|
+
accumulator_f32x = svbfdot_f32(accumulator_f32x, values_bf16x, values_bf16x);
|
|
49
57
|
}
|
|
50
58
|
return svaddv_f32(svptrue_b32(), accumulator_f32x);
|
|
51
59
|
}
|
|
52
60
|
|
|
53
61
|
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e4m3_ssve_(nk_e4m3_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
54
|
-
svfloat32_t
|
|
55
|
-
svfloat32_t
|
|
56
|
-
svuint16_t subnorm_lut_u16x = svld1_u16(svwhilelt_b16(0u, 8u), nk_e4m3_subnorm_f16_lut_);
|
|
62
|
+
svfloat32_t accumulator_even_f32x = svdup_f32(0.0f);
|
|
63
|
+
svfloat32_t accumulator_odd_f32x = svdup_f32(0.0f);
|
|
57
64
|
nk_size_t const vector_length = svcnth();
|
|
58
65
|
nk_size_t const half_vector_length = svcntw();
|
|
59
66
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
60
67
|
nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
|
|
61
|
-
svbool_t
|
|
62
|
-
svbool_t
|
|
63
|
-
svuint8_t raw_u8x = svld1_u8(
|
|
64
|
-
svfloat16_t values_f16x = nk_e4m3x_to_f16x_ssve_(
|
|
68
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(0u, batch_size);
|
|
69
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(0u, batch_size);
|
|
70
|
+
svuint8_t raw_u8x = svld1_u8(predicate_b8x, (nk_u8_t const *)data + i);
|
|
71
|
+
svfloat16_t values_f16x = nk_e4m3x_to_f16x_ssve_(predicate_b16x, raw_u8x);
|
|
65
72
|
|
|
66
|
-
svbool_t
|
|
67
|
-
svfloat32_t
|
|
68
|
-
|
|
73
|
+
svbool_t predicate_even_b32x = svwhilelt_b32_u64(0u, batch_size);
|
|
74
|
+
svfloat32_t values_even_f32x = svcvt_f32_f16_x(predicate_even_b32x, values_f16x);
|
|
75
|
+
accumulator_even_f32x = svmla_f32_m(predicate_even_b32x, accumulator_even_f32x, values_even_f32x,
|
|
76
|
+
values_even_f32x);
|
|
69
77
|
|
|
70
|
-
svbool_t
|
|
71
|
-
svfloat32_t
|
|
72
|
-
|
|
78
|
+
svbool_t predicate_odd_b32x = svwhilelt_b32_u64(half_vector_length, batch_size);
|
|
79
|
+
svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
|
|
80
|
+
accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
|
|
73
81
|
}
|
|
74
|
-
return svaddv_f32(svptrue_b32(),
|
|
82
|
+
return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
|
|
75
83
|
}
|
|
76
84
|
|
|
77
85
|
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e5m2_ssve_(nk_e5m2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
78
|
-
svfloat32_t
|
|
79
|
-
svfloat32_t
|
|
86
|
+
svfloat32_t accumulator_even_f32x = svdup_f32(0.0f);
|
|
87
|
+
svfloat32_t accumulator_odd_f32x = svdup_f32(0.0f);
|
|
80
88
|
nk_size_t const vector_length = svcnth();
|
|
81
89
|
nk_size_t const half_vector_length = svcntw();
|
|
82
90
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
83
91
|
nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
|
|
84
|
-
svbool_t
|
|
85
|
-
svbool_t
|
|
86
|
-
svuint8_t raw_u8x = svld1_u8(
|
|
87
|
-
svfloat16_t values_f16x = nk_e5m2x_to_f16x_ssve_(
|
|
92
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(0u, batch_size);
|
|
93
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(0u, batch_size);
|
|
94
|
+
svuint8_t raw_u8x = svld1_u8(predicate_b8x, (nk_u8_t const *)data + i);
|
|
95
|
+
svfloat16_t values_f16x = nk_e5m2x_to_f16x_ssve_(predicate_b16x, raw_u8x);
|
|
88
96
|
|
|
89
|
-
svbool_t
|
|
90
|
-
svfloat32_t
|
|
91
|
-
|
|
97
|
+
svbool_t predicate_even_b32x = svwhilelt_b32_u64(0u, batch_size);
|
|
98
|
+
svfloat32_t values_even_f32x = svcvt_f32_f16_x(predicate_even_b32x, values_f16x);
|
|
99
|
+
accumulator_even_f32x = svmla_f32_m(predicate_even_b32x, accumulator_even_f32x, values_even_f32x,
|
|
100
|
+
values_even_f32x);
|
|
92
101
|
|
|
93
|
-
svbool_t
|
|
94
|
-
svfloat32_t
|
|
95
|
-
|
|
102
|
+
svbool_t predicate_odd_b32x = svwhilelt_b32_u64(half_vector_length, batch_size);
|
|
103
|
+
svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
|
|
104
|
+
accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
|
|
96
105
|
}
|
|
97
|
-
return svaddv_f32(svptrue_b32(),
|
|
106
|
+
return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
|
|
98
107
|
}
|
|
99
108
|
|
|
100
|
-
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_size_t count)
|
|
101
|
-
|
|
102
|
-
nk_size_t const vector_length =
|
|
109
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
110
|
+
svint32_t accumulator_i32x = svdup_s32(0);
|
|
111
|
+
nk_size_t const vector_length = svcntb();
|
|
103
112
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
104
|
-
svbool_t
|
|
105
|
-
svuint8_t raw_u8x = svld1_u8(
|
|
106
|
-
svint8_t values_i8x = nk_e2m3x_to_i8x_ssve_(
|
|
107
|
-
|
|
108
|
-
svint16_t squares_i16x = svmul_s16_z(svwhilelt_b16_u64(i, count), values_i16x, values_i16x);
|
|
109
|
-
svint64_t squares_i64x = svunpklo_s64(svunpklo_s32(squares_i16x));
|
|
110
|
-
accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, squares_i64x);
|
|
113
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, count);
|
|
114
|
+
svuint8_t raw_u8x = svld1_u8(predicate_b8x, (nk_u8_t const *)data + i);
|
|
115
|
+
svint8_t values_i8x = nk_e2m3x_to_i8x_ssve_(predicate_b8x, raw_u8x);
|
|
116
|
+
accumulator_i32x = svdot_s32(accumulator_i32x, values_i8x, values_i8x);
|
|
111
117
|
}
|
|
112
|
-
return (nk_f32_t)
|
|
118
|
+
return (nk_f32_t)svaddv_s32(svptrue_b32(), accumulator_i32x) / 256.0f;
|
|
113
119
|
}
|
|
114
120
|
|
|
115
121
|
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e3m2_ssve_(nk_e3m2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
116
|
-
svfloat32_t
|
|
117
|
-
svfloat32_t
|
|
122
|
+
svfloat32_t accumulator_even_f32x = svdup_f32(0.0f);
|
|
123
|
+
svfloat32_t accumulator_odd_f32x = svdup_f32(0.0f);
|
|
118
124
|
nk_size_t const vector_length = svcnth();
|
|
119
125
|
nk_size_t const half_vector_length = svcntw();
|
|
120
126
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
121
127
|
nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
|
|
122
|
-
svbool_t
|
|
123
|
-
svbool_t
|
|
124
|
-
svuint8_t raw_u8x = svld1_u8(
|
|
125
|
-
svfloat16_t values_f16x = nk_e3m2x_to_f16x_ssve_(
|
|
128
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(0u, batch_size);
|
|
129
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(0u, batch_size);
|
|
130
|
+
svuint8_t raw_u8x = svld1_u8(predicate_b8x, (nk_u8_t const *)data + i);
|
|
131
|
+
svfloat16_t values_f16x = nk_e3m2x_to_f16x_ssve_(predicate_b16x, raw_u8x);
|
|
126
132
|
|
|
127
|
-
svbool_t
|
|
128
|
-
svfloat32_t
|
|
129
|
-
|
|
133
|
+
svbool_t predicate_even_b32x = svwhilelt_b32_u64(0u, batch_size);
|
|
134
|
+
svfloat32_t values_even_f32x = svcvt_f32_f16_x(predicate_even_b32x, values_f16x);
|
|
135
|
+
accumulator_even_f32x = svmla_f32_m(predicate_even_b32x, accumulator_even_f32x, values_even_f32x,
|
|
136
|
+
values_even_f32x);
|
|
130
137
|
|
|
131
|
-
svbool_t
|
|
132
|
-
svfloat32_t
|
|
133
|
-
|
|
138
|
+
svbool_t predicate_odd_b32x = svwhilelt_b32_u64(half_vector_length, batch_size);
|
|
139
|
+
svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
|
|
140
|
+
accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
|
|
134
141
|
}
|
|
135
|
-
return svaddv_f32(svptrue_b32(),
|
|
142
|
+
return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
|
|
136
143
|
}
|
|
137
144
|
|
|
138
|
-
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t count)
|
|
139
|
-
|
|
140
|
-
nk_size_t const vector_length =
|
|
145
|
+
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
146
|
+
svint32_t accumulator_i32x = svdup_s32(0);
|
|
147
|
+
nk_size_t const vector_length = svcntb();
|
|
141
148
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
142
|
-
svbool_t
|
|
143
|
-
svint8_t loaded_i8x = svld1_s8(
|
|
144
|
-
|
|
145
|
-
svint16_t squares_i16x = svmul_s16_z(svwhilelt_b16_u64(i, count), values_i16x, values_i16x);
|
|
146
|
-
svint64_t squares_i64x = svunpklo_s64(svunpklo_s32(squares_i16x));
|
|
147
|
-
accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, squares_i64x);
|
|
149
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, count);
|
|
150
|
+
svint8_t loaded_i8x = svld1_s8(predicate_b8x, data + i);
|
|
151
|
+
accumulator_i32x = svdot_s32(accumulator_i32x, loaded_i8x, loaded_i8x);
|
|
148
152
|
}
|
|
149
|
-
return (nk_u32_t)
|
|
153
|
+
return (nk_u32_t)svaddv_s32(svptrue_b32(), accumulator_i32x);
|
|
150
154
|
}
|
|
151
155
|
|
|
152
|
-
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t count)
|
|
153
|
-
|
|
154
|
-
nk_size_t const vector_length =
|
|
156
|
+
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
157
|
+
svuint32_t accumulator_u32x = svdup_u32(0);
|
|
158
|
+
nk_size_t const vector_length = svcntb();
|
|
155
159
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
156
|
-
svbool_t
|
|
157
|
-
svuint8_t
|
|
158
|
-
|
|
159
|
-
svuint16_t squares_u16x = svmul_u16_z(svwhilelt_b16_u64(i, count), values_u16x, values_u16x);
|
|
160
|
-
svuint64_t squares_u64x = svunpklo_u64(svunpklo_u32(squares_u16x));
|
|
161
|
-
accumulator_u64x = svadd_u64_m(predicate_u64x, accumulator_u64x, squares_u64x);
|
|
160
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, count);
|
|
161
|
+
svuint8_t loaded_u8x = svld1_u8(predicate_b8x, data + i);
|
|
162
|
+
accumulator_u32x = svdot_u32(accumulator_u32x, loaded_u8x, loaded_u8x);
|
|
162
163
|
}
|
|
163
|
-
return (nk_u32_t)
|
|
164
|
+
return (nk_u32_t)svaddv_u32(svptrue_b32(), accumulator_u32x);
|
|
164
165
|
}
|
|
165
166
|
|
|
166
|
-
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_t count)
|
|
167
|
-
|
|
167
|
+
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
168
|
+
svint32_t accumulator_i32x = svdup_s32(0);
|
|
168
169
|
nk_u8_t const *bytes = (nk_u8_t const *)data;
|
|
169
170
|
nk_size_t const byte_count = (count + 1) / 2;
|
|
170
|
-
nk_size_t const vector_length =
|
|
171
|
+
nk_size_t const vector_length = svcntb();
|
|
171
172
|
for (nk_size_t i = 0; i < byte_count; i += vector_length) {
|
|
172
|
-
svbool_t
|
|
173
|
-
svuint8_t packed_u8x = svld1_u8(
|
|
174
|
-
svuint8_t low_u8x = svand_n_u8_x(
|
|
175
|
-
svuint8_t high_u8x = svlsr_n_u8_x(
|
|
173
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, byte_count);
|
|
174
|
+
svuint8_t packed_u8x = svld1_u8(predicate_b8x, bytes + i);
|
|
175
|
+
svuint8_t low_u8x = svand_n_u8_x(predicate_b8x, packed_u8x, 0x0F);
|
|
176
|
+
svuint8_t high_u8x = svlsr_n_u8_x(predicate_b8x, packed_u8x, 4);
|
|
176
177
|
// Sign-extend 4-bit to 8-bit: shift left 4, arithmetic shift right 4
|
|
177
|
-
svint8_t low_i8x = svasr_n_s8_x(
|
|
178
|
-
svint8_t high_i8x = svasr_n_s8_x(
|
|
178
|
+
svint8_t low_i8x = svasr_n_s8_x(predicate_b8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_b8x, low_u8x, 4)), 4);
|
|
179
|
+
svint8_t high_i8x = svasr_n_s8_x(predicate_b8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_b8x, high_u8x, 4)),
|
|
179
180
|
4);
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
svint64_t sum_i64x = svunpklo_s64(svunpklo_s32(sum_i16x));
|
|
189
|
-
accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, sum_i64x);
|
|
190
|
-
}
|
|
191
|
-
return (nk_u32_t)svaddv_s64(svptrue_b64(), accumulator_i64x);
|
|
192
|
-
}
|
|
193
|
-
|
|
194
|
-
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
|
|
195
|
-
svuint64_t accumulator_u64x = svdup_u64(0);
|
|
181
|
+
accumulator_i32x = svdot_s32(accumulator_i32x, low_i8x, low_i8x);
|
|
182
|
+
accumulator_i32x = svdot_s32(accumulator_i32x, high_i8x, high_i8x);
|
|
183
|
+
}
|
|
184
|
+
return (nk_u32_t)svaddv_s32(svptrue_b32(), accumulator_i32x);
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
188
|
+
svuint32_t accumulator_u32x = svdup_u32(0);
|
|
196
189
|
nk_u8_t const *bytes = (nk_u8_t const *)data;
|
|
197
190
|
nk_size_t const byte_count = (count + 1) / 2;
|
|
198
|
-
nk_size_t const vector_length =
|
|
191
|
+
nk_size_t const vector_length = svcntb();
|
|
199
192
|
for (nk_size_t i = 0; i < byte_count; i += vector_length) {
|
|
200
|
-
svbool_t
|
|
201
|
-
svuint8_t packed_u8x = svld1_u8(
|
|
202
|
-
svuint8_t low_u8x = svand_n_u8_x(
|
|
203
|
-
svuint8_t high_u8x = svlsr_n_u8_x(
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
svbool_t predicate_u64x = svwhilelt_b64_u64(i, byte_count);
|
|
212
|
-
svuint64_t sum_u64x = svunpklo_u64(svunpklo_u32(sum_u16x));
|
|
213
|
-
accumulator_u64x = svadd_u64_m(predicate_u64x, accumulator_u64x, sum_u64x);
|
|
214
|
-
}
|
|
215
|
-
return (nk_u32_t)svaddv_u64(svptrue_b64(), accumulator_u64x);
|
|
216
|
-
}
|
|
217
|
-
|
|
218
|
-
NK_PUBLIC svfloat32_t nk_angulars_from_dot_f32x_ssve_(svbool_t predicate_f32x, svfloat32_t dots_f32x,
|
|
193
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, byte_count);
|
|
194
|
+
svuint8_t packed_u8x = svld1_u8(predicate_b8x, bytes + i);
|
|
195
|
+
svuint8_t low_u8x = svand_n_u8_x(predicate_b8x, packed_u8x, 0x0F);
|
|
196
|
+
svuint8_t high_u8x = svlsr_n_u8_x(predicate_b8x, packed_u8x, 4);
|
|
197
|
+
accumulator_u32x = svdot_u32(accumulator_u32x, low_u8x, low_u8x);
|
|
198
|
+
accumulator_u32x = svdot_u32(accumulator_u32x, high_u8x, high_u8x);
|
|
199
|
+
}
|
|
200
|
+
return (nk_u32_t)svaddv_u32(svptrue_b32(), accumulator_u32x);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
NK_PUBLIC svfloat32_t nk_angulars_from_dot_f32x_ssve_(svbool_t predicate_b32x, svfloat32_t dots_f32x,
|
|
219
204
|
svfloat32_t query_norm_sq_f32x,
|
|
220
|
-
svfloat32_t target_norms_sq_f32x)
|
|
221
|
-
svfloat32_t norms_product_f32x = svmul_f32_x(
|
|
205
|
+
svfloat32_t target_norms_sq_f32x) NK_STREAMING_ {
|
|
206
|
+
svfloat32_t norms_product_f32x = svmul_f32_x(predicate_b32x, query_norm_sq_f32x, target_norms_sq_f32x);
|
|
222
207
|
svfloat32_t rsqrt_f32x = svrsqrte_f32(norms_product_f32x);
|
|
223
|
-
rsqrt_f32x = svmul_f32_x(
|
|
224
|
-
svrsqrts_f32(svmul_f32_x(
|
|
225
|
-
rsqrt_f32x = svmul_f32_x(
|
|
226
|
-
svrsqrts_f32(svmul_f32_x(
|
|
227
|
-
svfloat32_t angular_f32x = svsub_f32_x(
|
|
228
|
-
svmul_f32_x(
|
|
229
|
-
return svmax_f32_x(
|
|
208
|
+
rsqrt_f32x = svmul_f32_x(predicate_b32x, rsqrt_f32x,
|
|
209
|
+
svrsqrts_f32(svmul_f32_x(predicate_b32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
|
|
210
|
+
rsqrt_f32x = svmul_f32_x(predicate_b32x, rsqrt_f32x,
|
|
211
|
+
svrsqrts_f32(svmul_f32_x(predicate_b32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
|
|
212
|
+
svfloat32_t angular_f32x = svsub_f32_x(predicate_b32x, svdup_n_f32(1.0f),
|
|
213
|
+
svmul_f32_x(predicate_b32x, dots_f32x, rsqrt_f32x));
|
|
214
|
+
return svmax_f32_x(predicate_b32x, angular_f32x, svdup_n_f32(0.0f));
|
|
230
215
|
}
|
|
231
216
|
|
|
232
|
-
NK_PUBLIC svfloat32_t nk_euclideans_from_dot_f32x_ssve_(svbool_t
|
|
217
|
+
NK_PUBLIC svfloat32_t nk_euclideans_from_dot_f32x_ssve_(svbool_t predicate_b32x, svfloat32_t dots_f32x,
|
|
233
218
|
svfloat32_t query_norm_sq_f32x,
|
|
234
|
-
svfloat32_t target_norms_sq_f32x)
|
|
235
|
-
svfloat32_t sum_sq_f32x = svadd_f32_x(
|
|
236
|
-
svfloat32_t dist_sq_f32x = svsub_f32_x(
|
|
237
|
-
svmul_f32_x(
|
|
238
|
-
dist_sq_f32x = svmax_f32_x(
|
|
239
|
-
return svsqrt_f32_x(
|
|
219
|
+
svfloat32_t target_norms_sq_f32x) NK_STREAMING_ {
|
|
220
|
+
svfloat32_t sum_sq_f32x = svadd_f32_x(predicate_b32x, query_norm_sq_f32x, target_norms_sq_f32x);
|
|
221
|
+
svfloat32_t dist_sq_f32x = svsub_f32_x(predicate_b32x, sum_sq_f32x,
|
|
222
|
+
svmul_f32_x(predicate_b32x, svdup_n_f32(2.0f), dots_f32x));
|
|
223
|
+
dist_sq_f32x = svmax_f32_x(predicate_b32x, dist_sq_f32x, svdup_n_f32(0.0f));
|
|
224
|
+
return svsqrt_f32_x(predicate_b32x, dist_sq_f32x);
|
|
240
225
|
}
|
|
241
226
|
|
|
242
|
-
#pragma region
|
|
227
|
+
#pragma region F16 Floats
|
|
243
228
|
|
|
244
229
|
__arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streaming_( //
|
|
245
230
|
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -253,12 +238,12 @@ __arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streamin
|
|
|
253
238
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_ssve_(a_row, depth);
|
|
254
239
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
255
240
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
256
|
-
svbool_t
|
|
257
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
258
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
241
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
242
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
243
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
259
244
|
svst1_f32(
|
|
260
|
-
|
|
261
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
245
|
+
predicate_b32x, result_row + col_index,
|
|
246
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
262
247
|
}
|
|
263
248
|
}
|
|
264
249
|
}
|
|
@@ -286,12 +271,12 @@ __arm_locally_streaming static void nk_euclideans_packed_f16_sme_finalize_stream
|
|
|
286
271
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_ssve_(a_row, depth);
|
|
287
272
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
288
273
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
289
|
-
svbool_t
|
|
290
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
291
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
274
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
275
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
276
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
292
277
|
svst1_f32(
|
|
293
|
-
|
|
294
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
278
|
+
predicate_b32x, result_row + col_index,
|
|
279
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
295
280
|
}
|
|
296
281
|
}
|
|
297
282
|
}
|
|
@@ -307,8 +292,8 @@ NK_PUBLIC void nk_euclideans_packed_f16_sme( //
|
|
|
307
292
|
c_stride_elements);
|
|
308
293
|
}
|
|
309
294
|
|
|
310
|
-
__arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_streaming_(
|
|
311
|
-
nk_f16_t const *vectors, nk_size_t
|
|
295
|
+
__arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_streaming_( //
|
|
296
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
312
297
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
313
298
|
// Phase 1: cache row norms on diagonal
|
|
314
299
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -317,8 +302,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_strea
|
|
|
317
302
|
}
|
|
318
303
|
// Phase 2: column-first post-processing
|
|
319
304
|
nk_f32_t norms_cache[256];
|
|
320
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
321
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
305
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
306
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
322
307
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
323
308
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_ssve_(vectors + col * stride_elements, depth);
|
|
324
309
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -327,11 +312,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_strea
|
|
|
327
312
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
328
313
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
329
314
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
330
|
-
svbool_t
|
|
331
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
332
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
333
|
-
svst1_f32(
|
|
334
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
315
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
316
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
317
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
318
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
319
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
335
320
|
target_norms_sq_f32x));
|
|
336
321
|
}
|
|
337
322
|
}
|
|
@@ -341,19 +326,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_strea
|
|
|
341
326
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
342
327
|
}
|
|
343
328
|
|
|
344
|
-
NK_PUBLIC void nk_angulars_symmetric_f16_sme(
|
|
345
|
-
nk_f16_t const *vectors, nk_size_t
|
|
346
|
-
nk_f32_t *result, nk_size_t
|
|
347
|
-
nk_size_t const stride_elements =
|
|
348
|
-
nk_size_t const result_stride_elements =
|
|
349
|
-
nk_dots_symmetric_f16_sme_streaming_(vectors,
|
|
329
|
+
NK_PUBLIC void nk_angulars_symmetric_f16_sme( //
|
|
330
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
331
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
332
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
333
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
334
|
+
nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
350
335
|
row_start, row_count);
|
|
351
|
-
nk_angulars_symmetric_f16_sme_finalize_streaming_(vectors,
|
|
336
|
+
nk_angulars_symmetric_f16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
352
337
|
result_stride_elements, row_start, row_count);
|
|
353
338
|
}
|
|
354
339
|
|
|
355
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_streaming_(
|
|
356
|
-
nk_f16_t const *vectors, nk_size_t
|
|
340
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_streaming_( //
|
|
341
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
357
342
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
358
343
|
// Phase 1: cache row norms on diagonal
|
|
359
344
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -362,8 +347,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_str
|
|
|
362
347
|
}
|
|
363
348
|
// Phase 2: column-first post-processing
|
|
364
349
|
nk_f32_t norms_cache[256];
|
|
365
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
366
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
350
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
351
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
367
352
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
368
353
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_ssve_(vectors + col * stride_elements, depth);
|
|
369
354
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -372,11 +357,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_str
|
|
|
372
357
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
373
358
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
374
359
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
375
|
-
svbool_t
|
|
376
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
377
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
378
|
-
svst1_f32(
|
|
379
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
360
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
361
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
362
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
363
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
364
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
380
365
|
target_norms_sq_f32x));
|
|
381
366
|
}
|
|
382
367
|
}
|
|
@@ -386,20 +371,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_str
|
|
|
386
371
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
387
372
|
}
|
|
388
373
|
|
|
389
|
-
NK_PUBLIC void nk_euclideans_symmetric_f16_sme(
|
|
390
|
-
nk_f16_t const *vectors, nk_size_t
|
|
391
|
-
nk_f32_t *result, nk_size_t
|
|
392
|
-
nk_size_t const stride_elements =
|
|
393
|
-
nk_size_t const result_stride_elements =
|
|
394
|
-
nk_dots_symmetric_f16_sme_streaming_(vectors,
|
|
374
|
+
NK_PUBLIC void nk_euclideans_symmetric_f16_sme( //
|
|
375
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
376
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
377
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
378
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
379
|
+
nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
395
380
|
row_start, row_count);
|
|
396
|
-
nk_euclideans_symmetric_f16_sme_finalize_streaming_(vectors,
|
|
381
|
+
nk_euclideans_symmetric_f16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
397
382
|
result_stride_elements, row_start, row_count);
|
|
398
383
|
}
|
|
399
384
|
|
|
400
|
-
#pragma endregion
|
|
385
|
+
#pragma endregion F16 Floats
|
|
401
386
|
|
|
402
|
-
#pragma region
|
|
387
|
+
#pragma region BF16 Floats
|
|
403
388
|
|
|
404
389
|
__arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streaming_( //
|
|
405
390
|
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -413,12 +398,12 @@ __arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streami
|
|
|
413
398
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_ssve_(a_row, depth);
|
|
414
399
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
415
400
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
416
|
-
svbool_t
|
|
417
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
418
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
401
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
402
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
403
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
419
404
|
svst1_f32(
|
|
420
|
-
|
|
421
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
405
|
+
predicate_b32x, result_row + col_index,
|
|
406
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
422
407
|
}
|
|
423
408
|
}
|
|
424
409
|
}
|
|
@@ -446,12 +431,12 @@ __arm_locally_streaming static void nk_euclideans_packed_bf16_sme_finalize_strea
|
|
|
446
431
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_ssve_(a_row, depth);
|
|
447
432
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
448
433
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
449
|
-
svbool_t
|
|
450
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
451
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
434
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
435
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
436
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
452
437
|
svst1_f32(
|
|
453
|
-
|
|
454
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
438
|
+
predicate_b32x, result_row + col_index,
|
|
439
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
455
440
|
}
|
|
456
441
|
}
|
|
457
442
|
}
|
|
@@ -467,8 +452,8 @@ NK_PUBLIC void nk_euclideans_packed_bf16_sme( //
|
|
|
467
452
|
c_stride_elements);
|
|
468
453
|
}
|
|
469
454
|
|
|
470
|
-
__arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_streaming_(
|
|
471
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
455
|
+
__arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_streaming_( //
|
|
456
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
472
457
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
473
458
|
// Phase 1: cache row norms on diagonal
|
|
474
459
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -477,8 +462,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_stre
|
|
|
477
462
|
}
|
|
478
463
|
// Phase 2: column-first post-processing
|
|
479
464
|
nk_f32_t norms_cache[256];
|
|
480
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
481
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
465
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
466
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
482
467
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
483
468
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + col * stride_elements, depth);
|
|
484
469
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -487,11 +472,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_stre
|
|
|
487
472
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
488
473
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
489
474
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
490
|
-
svbool_t
|
|
491
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
492
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
493
|
-
svst1_f32(
|
|
494
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
475
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
476
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
477
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
478
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
479
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
495
480
|
target_norms_sq_f32x));
|
|
496
481
|
}
|
|
497
482
|
}
|
|
@@ -501,19 +486,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_stre
|
|
|
501
486
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
502
487
|
}
|
|
503
488
|
|
|
504
|
-
NK_PUBLIC void nk_angulars_symmetric_bf16_sme(
|
|
505
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
506
|
-
nk_f32_t *result, nk_size_t
|
|
507
|
-
nk_size_t const stride_elements =
|
|
508
|
-
nk_size_t const result_stride_elements =
|
|
509
|
-
nk_dots_symmetric_bf16_sme_streaming_(vectors,
|
|
510
|
-
row_start, row_count);
|
|
511
|
-
nk_angulars_symmetric_bf16_sme_finalize_streaming_(vectors,
|
|
489
|
+
NK_PUBLIC void nk_angulars_symmetric_bf16_sme( //
|
|
490
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
491
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
492
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
493
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
494
|
+
nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
495
|
+
result_stride_elements, row_start, row_count);
|
|
496
|
+
nk_angulars_symmetric_bf16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
512
497
|
result_stride_elements, row_start, row_count);
|
|
513
498
|
}
|
|
514
499
|
|
|
515
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_streaming_(
|
|
516
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
500
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_streaming_( //
|
|
501
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
517
502
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
518
503
|
// Phase 1: cache row norms on diagonal
|
|
519
504
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -522,8 +507,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_st
|
|
|
522
507
|
}
|
|
523
508
|
// Phase 2: column-first post-processing
|
|
524
509
|
nk_f32_t norms_cache[256];
|
|
525
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
526
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
510
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
511
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
527
512
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
528
513
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + col * stride_elements, depth);
|
|
529
514
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -532,11 +517,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_st
|
|
|
532
517
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
533
518
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
534
519
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
535
|
-
svbool_t
|
|
536
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
537
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
538
|
-
svst1_f32(
|
|
539
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
520
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
521
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
522
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
523
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
524
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
540
525
|
target_norms_sq_f32x));
|
|
541
526
|
}
|
|
542
527
|
}
|
|
@@ -546,20 +531,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_st
|
|
|
546
531
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
547
532
|
}
|
|
548
533
|
|
|
549
|
-
NK_PUBLIC void nk_euclideans_symmetric_bf16_sme(
|
|
550
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
551
|
-
nk_f32_t *result, nk_size_t
|
|
552
|
-
nk_size_t const stride_elements =
|
|
553
|
-
nk_size_t const result_stride_elements =
|
|
554
|
-
nk_dots_symmetric_bf16_sme_streaming_(vectors,
|
|
555
|
-
row_start, row_count);
|
|
556
|
-
nk_euclideans_symmetric_bf16_sme_finalize_streaming_(vectors,
|
|
534
|
+
NK_PUBLIC void nk_euclideans_symmetric_bf16_sme( //
|
|
535
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
536
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
537
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
538
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
539
|
+
nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
540
|
+
result_stride_elements, row_start, row_count);
|
|
541
|
+
nk_euclideans_symmetric_bf16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
557
542
|
result_stride_elements, row_start, row_count);
|
|
558
543
|
}
|
|
559
544
|
|
|
560
|
-
#pragma endregion
|
|
545
|
+
#pragma endregion BF16 Floats
|
|
561
546
|
|
|
562
|
-
#pragma region
|
|
547
|
+
#pragma region E4M3 Floats
|
|
563
548
|
|
|
564
549
|
__arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streaming_( //
|
|
565
550
|
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -573,12 +558,12 @@ __arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streami
|
|
|
573
558
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_ssve_(a_row, depth);
|
|
574
559
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
575
560
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
576
|
-
svbool_t
|
|
577
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
578
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
561
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
562
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
563
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
579
564
|
svst1_f32(
|
|
580
|
-
|
|
581
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
565
|
+
predicate_b32x, result_row + col_index,
|
|
566
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
582
567
|
}
|
|
583
568
|
}
|
|
584
569
|
}
|
|
@@ -606,12 +591,12 @@ __arm_locally_streaming static void nk_euclideans_packed_e4m3_sme_finalize_strea
|
|
|
606
591
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_ssve_(a_row, depth);
|
|
607
592
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
608
593
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
609
|
-
svbool_t
|
|
610
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
611
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
594
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
595
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
596
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
612
597
|
svst1_f32(
|
|
613
|
-
|
|
614
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
598
|
+
predicate_b32x, result_row + col_index,
|
|
599
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
615
600
|
}
|
|
616
601
|
}
|
|
617
602
|
}
|
|
@@ -627,8 +612,8 @@ NK_PUBLIC void nk_euclideans_packed_e4m3_sme( //
|
|
|
627
612
|
c_stride_elements);
|
|
628
613
|
}
|
|
629
614
|
|
|
630
|
-
__arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_streaming_(
|
|
631
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
615
|
+
__arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_streaming_( //
|
|
616
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
632
617
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
633
618
|
// Phase 1: cache row norms on diagonal
|
|
634
619
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -637,8 +622,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_stre
|
|
|
637
622
|
}
|
|
638
623
|
// Phase 2: column-first post-processing
|
|
639
624
|
nk_f32_t norms_cache[256];
|
|
640
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
641
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
625
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
626
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
642
627
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
643
628
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + col * stride_elements, depth);
|
|
644
629
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -647,11 +632,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_stre
|
|
|
647
632
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
648
633
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
649
634
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
650
|
-
svbool_t
|
|
651
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
652
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
653
|
-
svst1_f32(
|
|
654
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
635
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
636
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
637
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
638
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
639
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
655
640
|
target_norms_sq_f32x));
|
|
656
641
|
}
|
|
657
642
|
}
|
|
@@ -661,19 +646,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_stre
|
|
|
661
646
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
662
647
|
}
|
|
663
648
|
|
|
664
|
-
NK_PUBLIC void nk_angulars_symmetric_e4m3_sme(
|
|
665
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
666
|
-
nk_f32_t *result, nk_size_t
|
|
667
|
-
nk_size_t const stride_elements =
|
|
668
|
-
nk_size_t const result_stride_elements =
|
|
669
|
-
nk_dots_symmetric_e4m3_sme_streaming_(vectors,
|
|
670
|
-
row_start, row_count);
|
|
671
|
-
nk_angulars_symmetric_e4m3_sme_finalize_streaming_(vectors,
|
|
649
|
+
NK_PUBLIC void nk_angulars_symmetric_e4m3_sme( //
|
|
650
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
651
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
652
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
653
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
654
|
+
nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
655
|
+
result_stride_elements, row_start, row_count);
|
|
656
|
+
nk_angulars_symmetric_e4m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
672
657
|
result_stride_elements, row_start, row_count);
|
|
673
658
|
}
|
|
674
659
|
|
|
675
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_streaming_(
|
|
676
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
660
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_streaming_( //
|
|
661
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
677
662
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
678
663
|
// Phase 1: cache row norms on diagonal
|
|
679
664
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -682,8 +667,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_st
|
|
|
682
667
|
}
|
|
683
668
|
// Phase 2: column-first post-processing
|
|
684
669
|
nk_f32_t norms_cache[256];
|
|
685
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
686
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
670
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
671
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
687
672
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
688
673
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + col * stride_elements, depth);
|
|
689
674
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -692,11 +677,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_st
|
|
|
692
677
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
693
678
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
694
679
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
695
|
-
svbool_t
|
|
696
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
697
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
698
|
-
svst1_f32(
|
|
699
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
680
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
681
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
682
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
683
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
684
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
700
685
|
target_norms_sq_f32x));
|
|
701
686
|
}
|
|
702
687
|
}
|
|
@@ -706,20 +691,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_st
|
|
|
706
691
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
707
692
|
}
|
|
708
693
|
|
|
709
|
-
NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme(
|
|
710
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
711
|
-
nk_f32_t *result, nk_size_t
|
|
712
|
-
nk_size_t const stride_elements =
|
|
713
|
-
nk_size_t const result_stride_elements =
|
|
714
|
-
nk_dots_symmetric_e4m3_sme_streaming_(vectors,
|
|
715
|
-
row_start, row_count);
|
|
716
|
-
nk_euclideans_symmetric_e4m3_sme_finalize_streaming_(vectors,
|
|
694
|
+
NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme( //
|
|
695
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
696
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
697
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
698
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
699
|
+
nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
700
|
+
result_stride_elements, row_start, row_count);
|
|
701
|
+
nk_euclideans_symmetric_e4m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
717
702
|
result_stride_elements, row_start, row_count);
|
|
718
703
|
}
|
|
719
704
|
|
|
720
|
-
#pragma endregion
|
|
705
|
+
#pragma endregion E4M3 Floats
|
|
721
706
|
|
|
722
|
-
#pragma region
|
|
707
|
+
#pragma region E5M2 Floats
|
|
723
708
|
|
|
724
709
|
__arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streaming_( //
|
|
725
710
|
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -733,12 +718,12 @@ __arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streami
|
|
|
733
718
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_ssve_(a_row, depth);
|
|
734
719
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
735
720
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
736
|
-
svbool_t
|
|
737
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
738
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
721
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
722
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
723
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
739
724
|
svst1_f32(
|
|
740
|
-
|
|
741
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
725
|
+
predicate_b32x, result_row + col_index,
|
|
726
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
742
727
|
}
|
|
743
728
|
}
|
|
744
729
|
}
|
|
@@ -766,12 +751,12 @@ __arm_locally_streaming static void nk_euclideans_packed_e5m2_sme_finalize_strea
|
|
|
766
751
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_ssve_(a_row, depth);
|
|
767
752
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
768
753
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
769
|
-
svbool_t
|
|
770
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
771
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
754
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
755
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
756
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
772
757
|
svst1_f32(
|
|
773
|
-
|
|
774
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
758
|
+
predicate_b32x, result_row + col_index,
|
|
759
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
775
760
|
}
|
|
776
761
|
}
|
|
777
762
|
}
|
|
@@ -787,8 +772,8 @@ NK_PUBLIC void nk_euclideans_packed_e5m2_sme( //
|
|
|
787
772
|
c_stride_elements);
|
|
788
773
|
}
|
|
789
774
|
|
|
790
|
-
__arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_streaming_(
|
|
791
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
775
|
+
__arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_streaming_( //
|
|
776
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
792
777
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
793
778
|
// Phase 1: cache row norms on diagonal
|
|
794
779
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -797,8 +782,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_stre
|
|
|
797
782
|
}
|
|
798
783
|
// Phase 2: column-first post-processing
|
|
799
784
|
nk_f32_t norms_cache[256];
|
|
800
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
801
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
785
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
786
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
802
787
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
803
788
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + col * stride_elements, depth);
|
|
804
789
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -807,11 +792,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_stre
|
|
|
807
792
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
808
793
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
809
794
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
810
|
-
svbool_t
|
|
811
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
812
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
813
|
-
svst1_f32(
|
|
814
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
795
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
796
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
797
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
798
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
799
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
815
800
|
target_norms_sq_f32x));
|
|
816
801
|
}
|
|
817
802
|
}
|
|
@@ -821,19 +806,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_stre
|
|
|
821
806
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
822
807
|
}
|
|
823
808
|
|
|
824
|
-
NK_PUBLIC void nk_angulars_symmetric_e5m2_sme(
|
|
825
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
826
|
-
nk_f32_t *result, nk_size_t
|
|
827
|
-
nk_size_t const stride_elements =
|
|
828
|
-
nk_size_t const result_stride_elements =
|
|
829
|
-
nk_dots_symmetric_e5m2_sme_streaming_(vectors,
|
|
830
|
-
row_start, row_count);
|
|
831
|
-
nk_angulars_symmetric_e5m2_sme_finalize_streaming_(vectors,
|
|
809
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_sme( //
|
|
810
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
811
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
812
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
813
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
814
|
+
nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
815
|
+
result_stride_elements, row_start, row_count);
|
|
816
|
+
nk_angulars_symmetric_e5m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
832
817
|
result_stride_elements, row_start, row_count);
|
|
833
818
|
}
|
|
834
819
|
|
|
835
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_streaming_(
|
|
836
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
820
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_streaming_( //
|
|
821
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
837
822
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
838
823
|
// Phase 1: cache row norms on diagonal
|
|
839
824
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -842,8 +827,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_st
|
|
|
842
827
|
}
|
|
843
828
|
// Phase 2: column-first post-processing
|
|
844
829
|
nk_f32_t norms_cache[256];
|
|
845
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
846
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
830
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
831
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
847
832
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
848
833
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + col * stride_elements, depth);
|
|
849
834
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -852,11 +837,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_st
|
|
|
852
837
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
853
838
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
854
839
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
855
|
-
svbool_t
|
|
856
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
857
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
858
|
-
svst1_f32(
|
|
859
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
840
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
841
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
842
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
843
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
844
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
860
845
|
target_norms_sq_f32x));
|
|
861
846
|
}
|
|
862
847
|
}
|
|
@@ -866,20 +851,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_st
|
|
|
866
851
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
867
852
|
}
|
|
868
853
|
|
|
869
|
-
NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme(
|
|
870
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
871
|
-
nk_f32_t *result, nk_size_t
|
|
872
|
-
nk_size_t const stride_elements =
|
|
873
|
-
nk_size_t const result_stride_elements =
|
|
874
|
-
nk_dots_symmetric_e5m2_sme_streaming_(vectors,
|
|
875
|
-
row_start, row_count);
|
|
876
|
-
nk_euclideans_symmetric_e5m2_sme_finalize_streaming_(vectors,
|
|
854
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme( //
|
|
855
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
856
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
857
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
858
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
859
|
+
nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
860
|
+
result_stride_elements, row_start, row_count);
|
|
861
|
+
nk_euclideans_symmetric_e5m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
877
862
|
result_stride_elements, row_start, row_count);
|
|
878
863
|
}
|
|
879
864
|
|
|
880
|
-
#pragma endregion
|
|
865
|
+
#pragma endregion E5M2 Floats
|
|
881
866
|
|
|
882
|
-
#pragma region
|
|
867
|
+
#pragma region E2M3 Floats
|
|
883
868
|
|
|
884
869
|
__arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streaming_( //
|
|
885
870
|
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -893,12 +878,12 @@ __arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streami
|
|
|
893
878
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_ssve_(a_row, depth);
|
|
894
879
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
895
880
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
896
|
-
svbool_t
|
|
897
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
898
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
881
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
882
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
883
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
899
884
|
svst1_f32(
|
|
900
|
-
|
|
901
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
885
|
+
predicate_b32x, result_row + col_index,
|
|
886
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
902
887
|
}
|
|
903
888
|
}
|
|
904
889
|
}
|
|
@@ -926,12 +911,12 @@ __arm_locally_streaming static void nk_euclideans_packed_e2m3_sme_finalize_strea
|
|
|
926
911
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_ssve_(a_row, depth);
|
|
927
912
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
928
913
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
929
|
-
svbool_t
|
|
930
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
931
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
914
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
915
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
916
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
932
917
|
svst1_f32(
|
|
933
|
-
|
|
934
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
918
|
+
predicate_b32x, result_row + col_index,
|
|
919
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
935
920
|
}
|
|
936
921
|
}
|
|
937
922
|
}
|
|
@@ -947,8 +932,8 @@ NK_PUBLIC void nk_euclideans_packed_e2m3_sme( //
|
|
|
947
932
|
c_stride_elements);
|
|
948
933
|
}
|
|
949
934
|
|
|
950
|
-
__arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_streaming_(
|
|
951
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
935
|
+
__arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_streaming_( //
|
|
936
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
952
937
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
953
938
|
// Phase 1: cache row norms on diagonal
|
|
954
939
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -957,8 +942,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_stre
|
|
|
957
942
|
}
|
|
958
943
|
// Phase 2: column-first post-processing
|
|
959
944
|
nk_f32_t norms_cache[256];
|
|
960
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
961
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
945
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
946
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
962
947
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
963
948
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + col * stride_elements, depth);
|
|
964
949
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -967,11 +952,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_stre
|
|
|
967
952
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
968
953
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
969
954
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
970
|
-
svbool_t
|
|
971
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
972
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
973
|
-
svst1_f32(
|
|
974
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
955
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
956
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
957
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
958
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
959
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
975
960
|
target_norms_sq_f32x));
|
|
976
961
|
}
|
|
977
962
|
}
|
|
@@ -981,19 +966,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_stre
|
|
|
981
966
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
982
967
|
}
|
|
983
968
|
|
|
984
|
-
NK_PUBLIC void nk_angulars_symmetric_e2m3_sme(
|
|
985
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
986
|
-
nk_f32_t *result, nk_size_t
|
|
987
|
-
nk_size_t const stride_elements =
|
|
988
|
-
nk_size_t const result_stride_elements =
|
|
989
|
-
nk_dots_symmetric_e2m3_sme_streaming_(vectors,
|
|
990
|
-
row_start, row_count);
|
|
991
|
-
nk_angulars_symmetric_e2m3_sme_finalize_streaming_(vectors,
|
|
969
|
+
NK_PUBLIC void nk_angulars_symmetric_e2m3_sme( //
|
|
970
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
971
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
972
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
973
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
974
|
+
nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
975
|
+
result_stride_elements, row_start, row_count);
|
|
976
|
+
nk_angulars_symmetric_e2m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
992
977
|
result_stride_elements, row_start, row_count);
|
|
993
978
|
}
|
|
994
979
|
|
|
995
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_streaming_(
|
|
996
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
980
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_streaming_( //
|
|
981
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
997
982
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
998
983
|
// Phase 1: cache row norms on diagonal
|
|
999
984
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1002,8 +987,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_st
|
|
|
1002
987
|
}
|
|
1003
988
|
// Phase 2: column-first post-processing
|
|
1004
989
|
nk_f32_t norms_cache[256];
|
|
1005
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1006
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
990
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
991
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1007
992
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1008
993
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + col * stride_elements, depth);
|
|
1009
994
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1012,11 +997,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_st
|
|
|
1012
997
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1013
998
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
1014
999
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1015
|
-
svbool_t
|
|
1016
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
1017
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
1018
|
-
svst1_f32(
|
|
1019
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1000
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1001
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
1002
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
1003
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1004
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1020
1005
|
target_norms_sq_f32x));
|
|
1021
1006
|
}
|
|
1022
1007
|
}
|
|
@@ -1026,20 +1011,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_st
|
|
|
1026
1011
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1027
1012
|
}
|
|
1028
1013
|
|
|
1029
|
-
NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme(
|
|
1030
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
1031
|
-
nk_f32_t *result, nk_size_t
|
|
1032
|
-
nk_size_t const stride_elements =
|
|
1033
|
-
nk_size_t const result_stride_elements =
|
|
1034
|
-
nk_dots_symmetric_e2m3_sme_streaming_(vectors,
|
|
1035
|
-
row_start, row_count);
|
|
1036
|
-
nk_euclideans_symmetric_e2m3_sme_finalize_streaming_(vectors,
|
|
1014
|
+
NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme( //
|
|
1015
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1016
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1017
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
1018
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1019
|
+
nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1020
|
+
result_stride_elements, row_start, row_count);
|
|
1021
|
+
nk_euclideans_symmetric_e2m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1037
1022
|
result_stride_elements, row_start, row_count);
|
|
1038
1023
|
}
|
|
1039
1024
|
|
|
1040
|
-
#pragma endregion
|
|
1025
|
+
#pragma endregion E2M3 Floats
|
|
1041
1026
|
|
|
1042
|
-
#pragma region
|
|
1027
|
+
#pragma region E3M2 Floats
|
|
1043
1028
|
|
|
1044
1029
|
__arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streaming_( //
|
|
1045
1030
|
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -1053,12 +1038,12 @@ __arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streami
|
|
|
1053
1038
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_ssve_(a_row, depth);
|
|
1054
1039
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
1055
1040
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1056
|
-
svbool_t
|
|
1057
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
1058
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
1041
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1042
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
1043
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
1059
1044
|
svst1_f32(
|
|
1060
|
-
|
|
1061
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1045
|
+
predicate_b32x, result_row + col_index,
|
|
1046
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1062
1047
|
}
|
|
1063
1048
|
}
|
|
1064
1049
|
}
|
|
@@ -1086,12 +1071,12 @@ __arm_locally_streaming static void nk_euclideans_packed_e3m2_sme_finalize_strea
|
|
|
1086
1071
|
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_ssve_(a_row, depth);
|
|
1087
1072
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
1088
1073
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1089
|
-
svbool_t
|
|
1090
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
1091
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
1074
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1075
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
1076
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
|
|
1092
1077
|
svst1_f32(
|
|
1093
|
-
|
|
1094
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1078
|
+
predicate_b32x, result_row + col_index,
|
|
1079
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1095
1080
|
}
|
|
1096
1081
|
}
|
|
1097
1082
|
}
|
|
@@ -1107,8 +1092,8 @@ NK_PUBLIC void nk_euclideans_packed_e3m2_sme( //
|
|
|
1107
1092
|
c_stride_elements);
|
|
1108
1093
|
}
|
|
1109
1094
|
|
|
1110
|
-
__arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_streaming_(
|
|
1111
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
1095
|
+
__arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_streaming_( //
|
|
1096
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1112
1097
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1113
1098
|
// Phase 1: cache row norms on diagonal
|
|
1114
1099
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1117,8 +1102,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_stre
|
|
|
1117
1102
|
}
|
|
1118
1103
|
// Phase 2: column-first post-processing
|
|
1119
1104
|
nk_f32_t norms_cache[256];
|
|
1120
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1121
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1105
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1106
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1122
1107
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1123
1108
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + col * stride_elements, depth);
|
|
1124
1109
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1127,11 +1112,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_stre
|
|
|
1127
1112
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1128
1113
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
1129
1114
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1130
|
-
svbool_t
|
|
1131
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
1132
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
1133
|
-
svst1_f32(
|
|
1134
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1115
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1116
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
1117
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
1118
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1119
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1135
1120
|
target_norms_sq_f32x));
|
|
1136
1121
|
}
|
|
1137
1122
|
}
|
|
@@ -1141,19 +1126,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_stre
|
|
|
1141
1126
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1142
1127
|
}
|
|
1143
1128
|
|
|
1144
|
-
NK_PUBLIC void nk_angulars_symmetric_e3m2_sme(
|
|
1145
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
1146
|
-
nk_f32_t *result, nk_size_t
|
|
1147
|
-
nk_size_t const stride_elements =
|
|
1148
|
-
nk_size_t const result_stride_elements =
|
|
1149
|
-
nk_dots_symmetric_e3m2_sme_streaming_(vectors,
|
|
1150
|
-
row_start, row_count);
|
|
1151
|
-
nk_angulars_symmetric_e3m2_sme_finalize_streaming_(vectors,
|
|
1129
|
+
NK_PUBLIC void nk_angulars_symmetric_e3m2_sme( //
|
|
1130
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1131
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1132
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1133
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1134
|
+
nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1135
|
+
result_stride_elements, row_start, row_count);
|
|
1136
|
+
nk_angulars_symmetric_e3m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1152
1137
|
result_stride_elements, row_start, row_count);
|
|
1153
1138
|
}
|
|
1154
1139
|
|
|
1155
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_streaming_(
|
|
1156
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
1140
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_streaming_( //
|
|
1141
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1157
1142
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1158
1143
|
// Phase 1: cache row norms on diagonal
|
|
1159
1144
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1162,8 +1147,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_st
|
|
|
1162
1147
|
}
|
|
1163
1148
|
// Phase 2: column-first post-processing
|
|
1164
1149
|
nk_f32_t norms_cache[256];
|
|
1165
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1166
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1150
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1151
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1167
1152
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1168
1153
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + col * stride_elements, depth);
|
|
1169
1154
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1172,11 +1157,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_st
|
|
|
1172
1157
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1173
1158
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
1174
1159
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1175
|
-
svbool_t
|
|
1176
|
-
svfloat32_t dots_f32x = svld1_f32(
|
|
1177
|
-
svfloat32_t target_norms_sq_f32x = svld1_f32(
|
|
1178
|
-
svst1_f32(
|
|
1179
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1160
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1161
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
|
|
1162
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
|
|
1163
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1164
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1180
1165
|
target_norms_sq_f32x));
|
|
1181
1166
|
}
|
|
1182
1167
|
}
|
|
@@ -1186,19 +1171,19 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_st
|
|
|
1186
1171
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1187
1172
|
}
|
|
1188
1173
|
|
|
1189
|
-
NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme(
|
|
1190
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
1191
|
-
nk_f32_t *result, nk_size_t
|
|
1192
|
-
nk_size_t const stride_elements =
|
|
1193
|
-
nk_size_t const result_stride_elements =
|
|
1194
|
-
nk_dots_symmetric_e3m2_sme_streaming_(vectors,
|
|
1195
|
-
row_start, row_count);
|
|
1196
|
-
nk_euclideans_symmetric_e3m2_sme_finalize_streaming_(vectors,
|
|
1174
|
+
NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme( //
|
|
1175
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1176
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1177
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1178
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1179
|
+
nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1180
|
+
result_stride_elements, row_start, row_count);
|
|
1181
|
+
nk_euclideans_symmetric_e3m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1197
1182
|
result_stride_elements, row_start, row_count);
|
|
1198
1183
|
}
|
|
1199
1184
|
|
|
1200
|
-
#pragma endregion
|
|
1201
|
-
#pragma region
|
|
1185
|
+
#pragma endregion E3M2 Floats
|
|
1186
|
+
#pragma region I8 Integers
|
|
1202
1187
|
|
|
1203
1188
|
__arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming_( //
|
|
1204
1189
|
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -1212,14 +1197,14 @@ __arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming
|
|
|
1212
1197
|
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i8_ssve_(a_row, depth);
|
|
1213
1198
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1214
1199
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1215
|
-
svbool_t
|
|
1200
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1216
1201
|
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1217
|
-
|
|
1218
|
-
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1219
|
-
svld1_u32(
|
|
1202
|
+
predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t const *)(result_row + col_index)));
|
|
1203
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
|
|
1204
|
+
svld1_u32(predicate_b32x, b_norms + col_index));
|
|
1220
1205
|
svst1_f32(
|
|
1221
|
-
|
|
1222
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1206
|
+
predicate_b32x, result_row + col_index,
|
|
1207
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1223
1208
|
}
|
|
1224
1209
|
}
|
|
1225
1210
|
}
|
|
@@ -1248,14 +1233,14 @@ __arm_locally_streaming static void nk_euclideans_packed_i8_sme_finalize_streami
|
|
|
1248
1233
|
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i8_ssve_(a_row, depth);
|
|
1249
1234
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1250
1235
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1251
|
-
svbool_t
|
|
1236
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1252
1237
|
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1253
|
-
|
|
1254
|
-
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1255
|
-
svld1_u32(
|
|
1238
|
+
predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t const *)(result_row + col_index)));
|
|
1239
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
|
|
1240
|
+
svld1_u32(predicate_b32x, b_norms + col_index));
|
|
1256
1241
|
svst1_f32(
|
|
1257
|
-
|
|
1258
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1242
|
+
predicate_b32x, result_row + col_index,
|
|
1243
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1259
1244
|
}
|
|
1260
1245
|
}
|
|
1261
1246
|
}
|
|
@@ -1272,8 +1257,8 @@ NK_PUBLIC void nk_euclideans_packed_i8_sme( //
|
|
|
1272
1257
|
c_stride_elements);
|
|
1273
1258
|
}
|
|
1274
1259
|
|
|
1275
|
-
__arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_streaming_(
|
|
1276
|
-
nk_i8_t const *vectors, nk_size_t
|
|
1260
|
+
__arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_streaming_( //
|
|
1261
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1277
1262
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1278
1263
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1279
1264
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1282,8 +1267,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_stream
|
|
|
1282
1267
|
}
|
|
1283
1268
|
// Phase 2: column-first post-processing
|
|
1284
1269
|
nk_u32_t norms_cache[256];
|
|
1285
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1286
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1270
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1271
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1287
1272
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1288
1273
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_ssve_(vectors + col * stride_elements, depth);
|
|
1289
1274
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1293,13 +1278,13 @@ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_stream
|
|
|
1293
1278
|
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1294
1279
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1295
1280
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1296
|
-
svbool_t
|
|
1281
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1297
1282
|
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1298
|
-
|
|
1283
|
+
predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t *)(result_row + col_index)));
|
|
1299
1284
|
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1300
|
-
|
|
1301
|
-
svst1_f32(
|
|
1302
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1285
|
+
predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
|
|
1286
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1287
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1303
1288
|
target_norms_sq_f32x));
|
|
1304
1289
|
}
|
|
1305
1290
|
}
|
|
@@ -1309,19 +1294,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_stream
|
|
|
1309
1294
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1310
1295
|
}
|
|
1311
1296
|
|
|
1312
|
-
NK_PUBLIC void nk_angulars_symmetric_i8_sme(
|
|
1313
|
-
nk_i8_t const *vectors, nk_size_t
|
|
1314
|
-
nk_f32_t *result, nk_size_t
|
|
1315
|
-
nk_size_t const stride_elements =
|
|
1316
|
-
nk_size_t const result_stride_elements =
|
|
1317
|
-
nk_dots_symmetric_i8_sme_streaming_(vectors,
|
|
1297
|
+
NK_PUBLIC void nk_angulars_symmetric_i8_sme( //
|
|
1298
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1299
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1300
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
1301
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1302
|
+
nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
|
|
1318
1303
|
result_stride_elements, row_start, row_count);
|
|
1319
|
-
nk_angulars_symmetric_i8_sme_finalize_streaming_(vectors,
|
|
1304
|
+
nk_angulars_symmetric_i8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1320
1305
|
result_stride_elements, row_start, row_count);
|
|
1321
1306
|
}
|
|
1322
1307
|
|
|
1323
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_streaming_(
|
|
1324
|
-
nk_i8_t const *vectors, nk_size_t
|
|
1308
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_streaming_( //
|
|
1309
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1325
1310
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1326
1311
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1327
1312
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1330,8 +1315,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_stre
|
|
|
1330
1315
|
}
|
|
1331
1316
|
// Phase 2: column-first post-processing
|
|
1332
1317
|
nk_u32_t norms_cache[256];
|
|
1333
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1334
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1318
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1319
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1335
1320
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1336
1321
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_ssve_(vectors + col * stride_elements, depth);
|
|
1337
1322
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1341,13 +1326,13 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_stre
|
|
|
1341
1326
|
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1342
1327
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1343
1328
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1344
|
-
svbool_t
|
|
1329
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1345
1330
|
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1346
|
-
|
|
1331
|
+
predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t *)(result_row + col_index)));
|
|
1347
1332
|
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1348
|
-
|
|
1349
|
-
svst1_f32(
|
|
1350
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1333
|
+
predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
|
|
1334
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1335
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1351
1336
|
target_norms_sq_f32x));
|
|
1352
1337
|
}
|
|
1353
1338
|
}
|
|
@@ -1357,20 +1342,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_stre
|
|
|
1357
1342
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1358
1343
|
}
|
|
1359
1344
|
|
|
1360
|
-
NK_PUBLIC void nk_euclideans_symmetric_i8_sme(
|
|
1361
|
-
nk_i8_t const *vectors, nk_size_t
|
|
1362
|
-
nk_f32_t *result, nk_size_t
|
|
1363
|
-
nk_size_t const stride_elements =
|
|
1364
|
-
nk_size_t const result_stride_elements =
|
|
1365
|
-
nk_dots_symmetric_i8_sme_streaming_(vectors,
|
|
1345
|
+
NK_PUBLIC void nk_euclideans_symmetric_i8_sme( //
|
|
1346
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1347
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1348
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
1349
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1350
|
+
nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
|
|
1366
1351
|
result_stride_elements, row_start, row_count);
|
|
1367
|
-
nk_euclideans_symmetric_i8_sme_finalize_streaming_(vectors,
|
|
1352
|
+
nk_euclideans_symmetric_i8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1368
1353
|
result_stride_elements, row_start, row_count);
|
|
1369
1354
|
}
|
|
1370
1355
|
|
|
1371
|
-
#pragma endregion
|
|
1356
|
+
#pragma endregion I8 Integers
|
|
1372
1357
|
|
|
1373
|
-
#pragma region
|
|
1358
|
+
#pragma region U8 Integers
|
|
1374
1359
|
|
|
1375
1360
|
__arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming_( //
|
|
1376
1361
|
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -1384,14 +1369,14 @@ __arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming
|
|
|
1384
1369
|
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u8_ssve_(a_row, depth);
|
|
1385
1370
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1386
1371
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1387
|
-
svbool_t
|
|
1372
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1388
1373
|
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1389
|
-
|
|
1390
|
-
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1391
|
-
svld1_u32(
|
|
1374
|
+
predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t const *)(result_row + col_index)));
|
|
1375
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
|
|
1376
|
+
svld1_u32(predicate_b32x, b_norms + col_index));
|
|
1392
1377
|
svst1_f32(
|
|
1393
|
-
|
|
1394
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1378
|
+
predicate_b32x, result_row + col_index,
|
|
1379
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1395
1380
|
}
|
|
1396
1381
|
}
|
|
1397
1382
|
}
|
|
@@ -1420,14 +1405,14 @@ __arm_locally_streaming static void nk_euclideans_packed_u8_sme_finalize_streami
|
|
|
1420
1405
|
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u8_ssve_(a_row, depth);
|
|
1421
1406
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1422
1407
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1423
|
-
svbool_t
|
|
1408
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1424
1409
|
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1425
|
-
|
|
1426
|
-
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1427
|
-
svld1_u32(
|
|
1410
|
+
predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t const *)(result_row + col_index)));
|
|
1411
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
|
|
1412
|
+
svld1_u32(predicate_b32x, b_norms + col_index));
|
|
1428
1413
|
svst1_f32(
|
|
1429
|
-
|
|
1430
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1414
|
+
predicate_b32x, result_row + col_index,
|
|
1415
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1431
1416
|
}
|
|
1432
1417
|
}
|
|
1433
1418
|
}
|
|
@@ -1444,8 +1429,8 @@ NK_PUBLIC void nk_euclideans_packed_u8_sme( //
|
|
|
1444
1429
|
c_stride_elements);
|
|
1445
1430
|
}
|
|
1446
1431
|
|
|
1447
|
-
__arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_streaming_(
|
|
1448
|
-
nk_u8_t const *vectors, nk_size_t
|
|
1432
|
+
__arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_streaming_( //
|
|
1433
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1449
1434
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1450
1435
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1451
1436
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1454,8 +1439,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_stream
|
|
|
1454
1439
|
}
|
|
1455
1440
|
// Phase 2: column-first post-processing
|
|
1456
1441
|
nk_u32_t norms_cache[256];
|
|
1457
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1458
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1442
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1443
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1459
1444
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1460
1445
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_ssve_(vectors + col * stride_elements, depth);
|
|
1461
1446
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1465,13 +1450,13 @@ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_stream
|
|
|
1465
1450
|
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1466
1451
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1467
1452
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1468
|
-
svbool_t
|
|
1453
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1469
1454
|
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1470
|
-
|
|
1455
|
+
predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t *)(result_row + col_index)));
|
|
1471
1456
|
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1472
|
-
|
|
1473
|
-
svst1_f32(
|
|
1474
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1457
|
+
predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
|
|
1458
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1459
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1475
1460
|
target_norms_sq_f32x));
|
|
1476
1461
|
}
|
|
1477
1462
|
}
|
|
@@ -1481,19 +1466,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_stream
|
|
|
1481
1466
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1482
1467
|
}
|
|
1483
1468
|
|
|
1484
|
-
NK_PUBLIC void nk_angulars_symmetric_u8_sme(
|
|
1485
|
-
nk_u8_t const *vectors, nk_size_t
|
|
1486
|
-
nk_f32_t *result, nk_size_t
|
|
1487
|
-
nk_size_t const stride_elements =
|
|
1488
|
-
nk_size_t const result_stride_elements =
|
|
1489
|
-
nk_dots_symmetric_u8_sme_streaming_(vectors,
|
|
1469
|
+
NK_PUBLIC void nk_angulars_symmetric_u8_sme( //
|
|
1470
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1471
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1472
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
1473
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1474
|
+
nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
|
|
1490
1475
|
result_stride_elements, row_start, row_count);
|
|
1491
|
-
nk_angulars_symmetric_u8_sme_finalize_streaming_(vectors,
|
|
1476
|
+
nk_angulars_symmetric_u8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1492
1477
|
result_stride_elements, row_start, row_count);
|
|
1493
1478
|
}
|
|
1494
1479
|
|
|
1495
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_streaming_(
|
|
1496
|
-
nk_u8_t const *vectors, nk_size_t
|
|
1480
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_streaming_( //
|
|
1481
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1497
1482
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1498
1483
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1499
1484
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1502,8 +1487,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_stre
|
|
|
1502
1487
|
}
|
|
1503
1488
|
// Phase 2: column-first post-processing
|
|
1504
1489
|
nk_u32_t norms_cache[256];
|
|
1505
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1506
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1490
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1491
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1507
1492
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1508
1493
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_ssve_(vectors + col * stride_elements, depth);
|
|
1509
1494
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1513,13 +1498,13 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_stre
|
|
|
1513
1498
|
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1514
1499
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1515
1500
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1516
|
-
svbool_t
|
|
1501
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1517
1502
|
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1518
|
-
|
|
1503
|
+
predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t *)(result_row + col_index)));
|
|
1519
1504
|
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1520
|
-
|
|
1521
|
-
svst1_f32(
|
|
1522
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1505
|
+
predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
|
|
1506
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1507
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1523
1508
|
target_norms_sq_f32x));
|
|
1524
1509
|
}
|
|
1525
1510
|
}
|
|
@@ -1529,20 +1514,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_stre
|
|
|
1529
1514
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1530
1515
|
}
|
|
1531
1516
|
|
|
1532
|
-
NK_PUBLIC void nk_euclideans_symmetric_u8_sme(
|
|
1533
|
-
nk_u8_t const *vectors, nk_size_t
|
|
1534
|
-
nk_f32_t *result, nk_size_t
|
|
1535
|
-
nk_size_t const stride_elements =
|
|
1536
|
-
nk_size_t const result_stride_elements =
|
|
1537
|
-
nk_dots_symmetric_u8_sme_streaming_(vectors,
|
|
1517
|
+
NK_PUBLIC void nk_euclideans_symmetric_u8_sme( //
|
|
1518
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1519
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1520
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
1521
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1522
|
+
nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
|
|
1538
1523
|
result_stride_elements, row_start, row_count);
|
|
1539
|
-
nk_euclideans_symmetric_u8_sme_finalize_streaming_(vectors,
|
|
1524
|
+
nk_euclideans_symmetric_u8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1540
1525
|
result_stride_elements, row_start, row_count);
|
|
1541
1526
|
}
|
|
1542
1527
|
|
|
1543
|
-
#pragma endregion
|
|
1528
|
+
#pragma endregion U8 Integers
|
|
1544
1529
|
|
|
1545
|
-
#pragma region
|
|
1530
|
+
#pragma region I4 Integers
|
|
1546
1531
|
|
|
1547
1532
|
__arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming_( //
|
|
1548
1533
|
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -1556,14 +1541,14 @@ __arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming
|
|
|
1556
1541
|
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i4_ssve_(a_row, depth);
|
|
1557
1542
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1558
1543
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1559
|
-
svbool_t
|
|
1544
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1560
1545
|
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1561
|
-
|
|
1562
|
-
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1563
|
-
svld1_u32(
|
|
1546
|
+
predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t const *)(result_row + col_index)));
|
|
1547
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
|
|
1548
|
+
svld1_u32(predicate_b32x, b_norms + col_index));
|
|
1564
1549
|
svst1_f32(
|
|
1565
|
-
|
|
1566
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1550
|
+
predicate_b32x, result_row + col_index,
|
|
1551
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1567
1552
|
}
|
|
1568
1553
|
}
|
|
1569
1554
|
}
|
|
@@ -1592,14 +1577,14 @@ __arm_locally_streaming static void nk_euclideans_packed_i4_sme_finalize_streami
|
|
|
1592
1577
|
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i4_ssve_(a_row, depth);
|
|
1593
1578
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1594
1579
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1595
|
-
svbool_t
|
|
1580
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1596
1581
|
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1597
|
-
|
|
1598
|
-
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1599
|
-
svld1_u32(
|
|
1582
|
+
predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t const *)(result_row + col_index)));
|
|
1583
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
|
|
1584
|
+
svld1_u32(predicate_b32x, b_norms + col_index));
|
|
1600
1585
|
svst1_f32(
|
|
1601
|
-
|
|
1602
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1586
|
+
predicate_b32x, result_row + col_index,
|
|
1587
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1603
1588
|
}
|
|
1604
1589
|
}
|
|
1605
1590
|
}
|
|
@@ -1616,8 +1601,8 @@ NK_PUBLIC void nk_euclideans_packed_i4_sme( //
|
|
|
1616
1601
|
c_stride_elements);
|
|
1617
1602
|
}
|
|
1618
1603
|
|
|
1619
|
-
__arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_streaming_(
|
|
1620
|
-
nk_i4x2_t const *vectors, nk_size_t
|
|
1604
|
+
__arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_streaming_( //
|
|
1605
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1621
1606
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1622
1607
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1623
1608
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1626,8 +1611,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_stream
|
|
|
1626
1611
|
}
|
|
1627
1612
|
// Phase 2: column-first post-processing
|
|
1628
1613
|
nk_u32_t norms_cache[256];
|
|
1629
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1630
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1614
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1615
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1631
1616
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1632
1617
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i4_ssve_(vectors + col * stride_elements, depth);
|
|
1633
1618
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1637,13 +1622,13 @@ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_stream
|
|
|
1637
1622
|
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1638
1623
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1639
1624
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1640
|
-
svbool_t
|
|
1625
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1641
1626
|
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1642
|
-
|
|
1627
|
+
predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t *)(result_row + col_index)));
|
|
1643
1628
|
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1644
|
-
|
|
1645
|
-
svst1_f32(
|
|
1646
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1629
|
+
predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
|
|
1630
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1631
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1647
1632
|
target_norms_sq_f32x));
|
|
1648
1633
|
}
|
|
1649
1634
|
}
|
|
@@ -1653,19 +1638,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_stream
|
|
|
1653
1638
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1654
1639
|
}
|
|
1655
1640
|
|
|
1656
|
-
NK_PUBLIC void nk_angulars_symmetric_i4_sme(
|
|
1657
|
-
nk_i4x2_t const *vectors, nk_size_t
|
|
1658
|
-
nk_f32_t *result, nk_size_t
|
|
1659
|
-
nk_size_t const stride_elements =
|
|
1660
|
-
nk_size_t const result_stride_elements =
|
|
1661
|
-
nk_dots_symmetric_i4_sme_streaming_(vectors,
|
|
1641
|
+
NK_PUBLIC void nk_angulars_symmetric_i4_sme( //
|
|
1642
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1643
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1644
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
|
|
1645
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1646
|
+
nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
|
|
1662
1647
|
result_stride_elements, row_start, row_count);
|
|
1663
|
-
nk_angulars_symmetric_i4_sme_finalize_streaming_(vectors,
|
|
1648
|
+
nk_angulars_symmetric_i4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1664
1649
|
result_stride_elements, row_start, row_count);
|
|
1665
1650
|
}
|
|
1666
1651
|
|
|
1667
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_streaming_(
|
|
1668
|
-
nk_i4x2_t const *vectors, nk_size_t
|
|
1652
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_streaming_( //
|
|
1653
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1669
1654
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1670
1655
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1671
1656
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1674,8 +1659,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_stre
|
|
|
1674
1659
|
}
|
|
1675
1660
|
// Phase 2: column-first post-processing
|
|
1676
1661
|
nk_u32_t norms_cache[256];
|
|
1677
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1678
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1662
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1663
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1679
1664
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1680
1665
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i4_ssve_(vectors + col * stride_elements, depth);
|
|
1681
1666
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1685,13 +1670,13 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_stre
|
|
|
1685
1670
|
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1686
1671
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1687
1672
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1688
|
-
svbool_t
|
|
1673
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1689
1674
|
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1690
|
-
|
|
1675
|
+
predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t *)(result_row + col_index)));
|
|
1691
1676
|
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1692
|
-
|
|
1693
|
-
svst1_f32(
|
|
1694
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1677
|
+
predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
|
|
1678
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1679
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1695
1680
|
target_norms_sq_f32x));
|
|
1696
1681
|
}
|
|
1697
1682
|
}
|
|
@@ -1701,20 +1686,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_stre
|
|
|
1701
1686
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1702
1687
|
}
|
|
1703
1688
|
|
|
1704
|
-
NK_PUBLIC void nk_euclideans_symmetric_i4_sme(
|
|
1705
|
-
nk_i4x2_t const *vectors, nk_size_t
|
|
1706
|
-
nk_f32_t *result, nk_size_t
|
|
1707
|
-
nk_size_t const stride_elements =
|
|
1708
|
-
nk_size_t const result_stride_elements =
|
|
1709
|
-
nk_dots_symmetric_i4_sme_streaming_(vectors,
|
|
1689
|
+
NK_PUBLIC void nk_euclideans_symmetric_i4_sme( //
|
|
1690
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1691
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1692
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
|
|
1693
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1694
|
+
nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
|
|
1710
1695
|
result_stride_elements, row_start, row_count);
|
|
1711
|
-
nk_euclideans_symmetric_i4_sme_finalize_streaming_(vectors,
|
|
1696
|
+
nk_euclideans_symmetric_i4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1712
1697
|
result_stride_elements, row_start, row_count);
|
|
1713
1698
|
}
|
|
1714
1699
|
|
|
1715
|
-
#pragma endregion
|
|
1700
|
+
#pragma endregion Signed Integers
|
|
1716
1701
|
|
|
1717
|
-
#pragma region
|
|
1702
|
+
#pragma region U4 Integers
|
|
1718
1703
|
|
|
1719
1704
|
__arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming_( //
|
|
1720
1705
|
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
@@ -1728,14 +1713,14 @@ __arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming
|
|
|
1728
1713
|
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u4_ssve_(a_row, depth);
|
|
1729
1714
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1730
1715
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1731
|
-
svbool_t
|
|
1716
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1732
1717
|
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1733
|
-
|
|
1734
|
-
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1735
|
-
svld1_u32(
|
|
1718
|
+
predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t const *)(result_row + col_index)));
|
|
1719
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
|
|
1720
|
+
svld1_u32(predicate_b32x, b_norms + col_index));
|
|
1736
1721
|
svst1_f32(
|
|
1737
|
-
|
|
1738
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1722
|
+
predicate_b32x, result_row + col_index,
|
|
1723
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1739
1724
|
}
|
|
1740
1725
|
}
|
|
1741
1726
|
}
|
|
@@ -1764,14 +1749,14 @@ __arm_locally_streaming static void nk_euclideans_packed_u4_sme_finalize_streami
|
|
|
1764
1749
|
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u4_ssve_(a_row, depth);
|
|
1765
1750
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1766
1751
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1767
|
-
svbool_t
|
|
1752
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
|
|
1768
1753
|
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1769
|
-
|
|
1770
|
-
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1771
|
-
svld1_u32(
|
|
1754
|
+
predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t const *)(result_row + col_index)));
|
|
1755
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
|
|
1756
|
+
svld1_u32(predicate_b32x, b_norms + col_index));
|
|
1772
1757
|
svst1_f32(
|
|
1773
|
-
|
|
1774
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1758
|
+
predicate_b32x, result_row + col_index,
|
|
1759
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1775
1760
|
}
|
|
1776
1761
|
}
|
|
1777
1762
|
}
|
|
@@ -1788,8 +1773,8 @@ NK_PUBLIC void nk_euclideans_packed_u4_sme( //
|
|
|
1788
1773
|
c_stride_elements);
|
|
1789
1774
|
}
|
|
1790
1775
|
|
|
1791
|
-
__arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_streaming_(
|
|
1792
|
-
nk_u4x2_t const *vectors, nk_size_t
|
|
1776
|
+
__arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_streaming_( //
|
|
1777
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1793
1778
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1794
1779
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1795
1780
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1798,8 +1783,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_stream
|
|
|
1798
1783
|
}
|
|
1799
1784
|
// Phase 2: column-first post-processing
|
|
1800
1785
|
nk_u32_t norms_cache[256];
|
|
1801
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1802
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1786
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1787
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1803
1788
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1804
1789
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u4_ssve_(vectors + col * stride_elements, depth);
|
|
1805
1790
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1809,13 +1794,13 @@ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_stream
|
|
|
1809
1794
|
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1810
1795
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1811
1796
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1812
|
-
svbool_t
|
|
1797
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1813
1798
|
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1814
|
-
|
|
1799
|
+
predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t *)(result_row + col_index)));
|
|
1815
1800
|
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1816
|
-
|
|
1817
|
-
svst1_f32(
|
|
1818
|
-
nk_angulars_from_dot_f32x_ssve_(
|
|
1801
|
+
predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
|
|
1802
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1803
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1819
1804
|
target_norms_sq_f32x));
|
|
1820
1805
|
}
|
|
1821
1806
|
}
|
|
@@ -1825,19 +1810,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_stream
|
|
|
1825
1810
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1826
1811
|
}
|
|
1827
1812
|
|
|
1828
|
-
NK_PUBLIC void nk_angulars_symmetric_u4_sme(
|
|
1829
|
-
nk_u4x2_t const *vectors, nk_size_t
|
|
1830
|
-
nk_f32_t *result, nk_size_t
|
|
1831
|
-
nk_size_t const stride_elements =
|
|
1832
|
-
nk_size_t const result_stride_elements =
|
|
1833
|
-
nk_dots_symmetric_u4_sme_streaming_(vectors,
|
|
1813
|
+
NK_PUBLIC void nk_angulars_symmetric_u4_sme( //
|
|
1814
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1815
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1816
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
|
|
1817
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1818
|
+
nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
|
|
1834
1819
|
result_stride_elements, row_start, row_count);
|
|
1835
|
-
nk_angulars_symmetric_u4_sme_finalize_streaming_(vectors,
|
|
1820
|
+
nk_angulars_symmetric_u4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1836
1821
|
result_stride_elements, row_start, row_count);
|
|
1837
1822
|
}
|
|
1838
1823
|
|
|
1839
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_streaming_(
|
|
1840
|
-
nk_u4x2_t const *vectors, nk_size_t
|
|
1824
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_streaming_( //
|
|
1825
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
1841
1826
|
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1842
1827
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1843
1828
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1846,8 +1831,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_stre
|
|
|
1846
1831
|
}
|
|
1847
1832
|
// Phase 2: column-first post-processing
|
|
1848
1833
|
nk_u32_t norms_cache[256];
|
|
1849
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1850
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1834
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1835
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1851
1836
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1852
1837
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u4_ssve_(vectors + col * stride_elements, depth);
|
|
1853
1838
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1857,13 +1842,13 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_stre
|
|
|
1857
1842
|
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1858
1843
|
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1859
1844
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1860
|
-
svbool_t
|
|
1845
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1861
1846
|
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1862
|
-
|
|
1847
|
+
predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t *)(result_row + col_index)));
|
|
1863
1848
|
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1864
|
-
|
|
1865
|
-
svst1_f32(
|
|
1866
|
-
nk_euclideans_from_dot_f32x_ssve_(
|
|
1849
|
+
predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
|
|
1850
|
+
svst1_f32(predicate_b32x, result_row + col_index,
|
|
1851
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
|
|
1867
1852
|
target_norms_sq_f32x));
|
|
1868
1853
|
}
|
|
1869
1854
|
}
|
|
@@ -1873,18 +1858,18 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_stre
|
|
|
1873
1858
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1874
1859
|
}
|
|
1875
1860
|
|
|
1876
|
-
NK_PUBLIC void nk_euclideans_symmetric_u4_sme(
|
|
1877
|
-
nk_u4x2_t const *vectors, nk_size_t
|
|
1878
|
-
nk_f32_t *result, nk_size_t
|
|
1879
|
-
nk_size_t const stride_elements =
|
|
1880
|
-
nk_size_t const result_stride_elements =
|
|
1881
|
-
nk_dots_symmetric_u4_sme_streaming_(vectors,
|
|
1861
|
+
NK_PUBLIC void nk_euclideans_symmetric_u4_sme( //
|
|
1862
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1863
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1864
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
|
|
1865
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1866
|
+
nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
|
|
1882
1867
|
result_stride_elements, row_start, row_count);
|
|
1883
|
-
nk_euclideans_symmetric_u4_sme_finalize_streaming_(vectors,
|
|
1868
|
+
nk_euclideans_symmetric_u4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1884
1869
|
result_stride_elements, row_start, row_count);
|
|
1885
1870
|
}
|
|
1886
1871
|
|
|
1887
|
-
#pragma endregion
|
|
1872
|
+
#pragma endregion Unsigned Integers
|
|
1888
1873
|
|
|
1889
1874
|
#if defined(__clang__)
|
|
1890
1875
|
#pragma clang attribute pop
|