numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1901 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Batched Spatial Distances for ARM SME.
|
|
3
|
+
* @file include/numkong/spatials/sme.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatials.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPATIALS_SME_H
|
|
10
|
+
#define NK_SPATIALS_SME_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_ARM_
|
|
13
|
+
#if NK_TARGET_SME
|
|
14
|
+
|
|
15
|
+
#include "numkong/dots/serial.h"
|
|
16
|
+
#include "numkong/dots/sme.h"
|
|
17
|
+
|
|
18
|
+
#if defined(__cplusplus)
|
|
19
|
+
extern "C" {
|
|
20
|
+
#endif
|
|
21
|
+
|
|
22
|
+
#if defined(__clang__)
|
|
23
|
+
#pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
|
|
24
|
+
#elif defined(__GNUC__)
|
|
25
|
+
#pragma GCC push_options
|
|
26
|
+
#pragma GCC target("+sme")
|
|
27
|
+
#endif
|
|
28
|
+
|
|
29
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_f16_ssve_(nk_f16_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
|
|
30
|
+
svfloat32_t accumulator_f32x = svdup_f32(0.0f);
|
|
31
|
+
nk_size_t const vector_length = svcntw();
|
|
32
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
33
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(i, count);
|
|
34
|
+
svfloat32_t values_f32x = svcvt_f32_f16_x(
|
|
35
|
+
predicate_f32x, svld1_f16(svwhilelt_b16_u64(i, count), (nk_f16_for_arm_simd_t const *)(data + i)));
|
|
36
|
+
accumulator_f32x = svmla_f32_x(predicate_f32x, accumulator_f32x, values_f32x, values_f32x);
|
|
37
|
+
}
|
|
38
|
+
return svaddv_f32(svptrue_b32(), accumulator_f32x);
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
|
|
42
|
+
svfloat32_t accumulator_f32x = svdup_f32(0.0f);
|
|
43
|
+
nk_size_t const vector_length = svcntw();
|
|
44
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
45
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(i, count);
|
|
46
|
+
svuint16_t raw_u16x = svld1_u16(svwhilelt_b16_u64(i, count), (nk_u16_t const *)data + i);
|
|
47
|
+
svfloat32_t values_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_f32x, svunpklo_u32(raw_u16x), 16));
|
|
48
|
+
accumulator_f32x = svmla_f32_x(predicate_f32x, accumulator_f32x, values_f32x, values_f32x);
|
|
49
|
+
}
|
|
50
|
+
return svaddv_f32(svptrue_b32(), accumulator_f32x);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e4m3_ssve_(nk_e4m3_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
54
|
+
svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
|
|
55
|
+
svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
|
|
56
|
+
svuint16_t subnorm_lut_u16x = svld1_u16(svwhilelt_b16(0u, 8u), nk_e4m3_subnorm_f16_lut_);
|
|
57
|
+
nk_size_t const vector_length = svcnth();
|
|
58
|
+
nk_size_t const half_vector_length = svcntw();
|
|
59
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
60
|
+
nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
|
|
61
|
+
svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
|
|
62
|
+
svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
|
|
63
|
+
svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
|
|
64
|
+
svfloat16_t values_f16x = nk_e4m3x_to_f16x_ssve_(predicate_f16x, raw_u8x, subnorm_lut_u16x);
|
|
65
|
+
|
|
66
|
+
svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
|
|
67
|
+
svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
|
|
68
|
+
accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
|
|
69
|
+
|
|
70
|
+
svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
|
|
71
|
+
svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
|
|
72
|
+
accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
|
|
73
|
+
}
|
|
74
|
+
return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e5m2_ssve_(nk_e5m2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
78
|
+
svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
|
|
79
|
+
svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
|
|
80
|
+
nk_size_t const vector_length = svcnth();
|
|
81
|
+
nk_size_t const half_vector_length = svcntw();
|
|
82
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
83
|
+
nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
|
|
84
|
+
svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
|
|
85
|
+
svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
|
|
86
|
+
svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
|
|
87
|
+
svfloat16_t values_f16x = nk_e5m2x_to_f16x_ssve_(predicate_f16x, raw_u8x);
|
|
88
|
+
|
|
89
|
+
svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
|
|
90
|
+
svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
|
|
91
|
+
accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
|
|
92
|
+
|
|
93
|
+
svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
|
|
94
|
+
svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
|
|
95
|
+
accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
|
|
96
|
+
}
|
|
97
|
+
return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
|
|
101
|
+
svint64_t accumulator_i64x = svdup_s64(0);
|
|
102
|
+
nk_size_t const vector_length = svcntd();
|
|
103
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
104
|
+
svbool_t predicate_i64x = svwhilelt_b64_u64(i, count);
|
|
105
|
+
svuint8_t raw_u8x = svld1_u8(svwhilelt_b8_u64(i, count), (nk_u8_t const *)data + i);
|
|
106
|
+
svint8_t values_i8x = nk_e2m3x_to_i8x_ssve_(svwhilelt_b8_u64(i, count), raw_u8x);
|
|
107
|
+
svint16_t values_i16x = svunpklo_s16(values_i8x);
|
|
108
|
+
svint16_t squares_i16x = svmul_s16_z(svwhilelt_b16_u64(i, count), values_i16x, values_i16x);
|
|
109
|
+
svint64_t squares_i64x = svunpklo_s64(svunpklo_s32(squares_i16x));
|
|
110
|
+
accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, squares_i64x);
|
|
111
|
+
}
|
|
112
|
+
return (nk_f32_t)svaddv_s64(svptrue_b64(), accumulator_i64x) / 256.0f;
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e3m2_ssve_(nk_e3m2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
116
|
+
svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
|
|
117
|
+
svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
|
|
118
|
+
nk_size_t const vector_length = svcnth();
|
|
119
|
+
nk_size_t const half_vector_length = svcntw();
|
|
120
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
121
|
+
nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
|
|
122
|
+
svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
|
|
123
|
+
svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
|
|
124
|
+
svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
|
|
125
|
+
svfloat16_t values_f16x = nk_e3m2x_to_f16x_ssve_(predicate_f16x, raw_u8x);
|
|
126
|
+
|
|
127
|
+
svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
|
|
128
|
+
svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
|
|
129
|
+
accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
|
|
130
|
+
|
|
131
|
+
svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
|
|
132
|
+
svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
|
|
133
|
+
accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
|
|
134
|
+
}
|
|
135
|
+
return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
|
|
139
|
+
svint64_t accumulator_i64x = svdup_s64(0);
|
|
140
|
+
nk_size_t const vector_length = svcntd();
|
|
141
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
142
|
+
svbool_t predicate_i64x = svwhilelt_b64_u64(i, count);
|
|
143
|
+
svint8_t loaded_i8x = svld1_s8(svwhilelt_b8_u64(i, count), data + i);
|
|
144
|
+
svint16_t values_i16x = svunpklo_s16(loaded_i8x);
|
|
145
|
+
svint16_t squares_i16x = svmul_s16_z(svwhilelt_b16_u64(i, count), values_i16x, values_i16x);
|
|
146
|
+
svint64_t squares_i64x = svunpklo_s64(svunpklo_s32(squares_i16x));
|
|
147
|
+
accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, squares_i64x);
|
|
148
|
+
}
|
|
149
|
+
return (nk_u32_t)svaddv_s64(svptrue_b64(), accumulator_i64x);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
|
|
153
|
+
svuint64_t accumulator_u64x = svdup_u64(0);
|
|
154
|
+
nk_size_t const vector_length = svcntd();
|
|
155
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
156
|
+
svbool_t predicate_u64x = svwhilelt_b64_u64(i, count);
|
|
157
|
+
svuint8_t raw_u8x = svld1_u8(svwhilelt_b8_u64(i, count), data + i);
|
|
158
|
+
svuint16_t values_u16x = svunpklo_u16(raw_u8x);
|
|
159
|
+
svuint16_t squares_u16x = svmul_u16_z(svwhilelt_b16_u64(i, count), values_u16x, values_u16x);
|
|
160
|
+
svuint64_t squares_u64x = svunpklo_u64(svunpklo_u32(squares_u16x));
|
|
161
|
+
accumulator_u64x = svadd_u64_m(predicate_u64x, accumulator_u64x, squares_u64x);
|
|
162
|
+
}
|
|
163
|
+
return (nk_u32_t)svaddv_u64(svptrue_b64(), accumulator_u64x);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
|
|
167
|
+
svint64_t accumulator_i64x = svdup_s64(0);
|
|
168
|
+
nk_u8_t const *bytes = (nk_u8_t const *)data;
|
|
169
|
+
nk_size_t const byte_count = (count + 1) / 2;
|
|
170
|
+
nk_size_t const vector_length = svcntd();
|
|
171
|
+
for (nk_size_t i = 0; i < byte_count; i += vector_length) {
|
|
172
|
+
svbool_t predicate_u8x = svwhilelt_b8_u64(i, byte_count);
|
|
173
|
+
svuint8_t packed_u8x = svld1_u8(predicate_u8x, bytes + i);
|
|
174
|
+
svuint8_t low_u8x = svand_n_u8_x(predicate_u8x, packed_u8x, 0x0F);
|
|
175
|
+
svuint8_t high_u8x = svlsr_n_u8_x(predicate_u8x, packed_u8x, 4);
|
|
176
|
+
// Sign-extend 4-bit to 8-bit: shift left 4, arithmetic shift right 4
|
|
177
|
+
svint8_t low_i8x = svasr_n_s8_x(predicate_u8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_u8x, low_u8x, 4)), 4);
|
|
178
|
+
svint8_t high_i8x = svasr_n_s8_x(predicate_u8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_u8x, high_u8x, 4)),
|
|
179
|
+
4);
|
|
180
|
+
// Widen to i16, square, sum per byte
|
|
181
|
+
svbool_t predicate_i16x = svwhilelt_b16_u64(i, byte_count);
|
|
182
|
+
svint16_t low_i16x = svunpklo_s16(low_i8x);
|
|
183
|
+
svint16_t high_i16x = svunpklo_s16(high_i8x);
|
|
184
|
+
svint16_t squares_low_i16x = svmul_s16_z(predicate_i16x, low_i16x, low_i16x);
|
|
185
|
+
svint16_t squares_high_i16x = svmul_s16_z(predicate_i16x, high_i16x, high_i16x);
|
|
186
|
+
svint16_t sum_i16x = svadd_s16_z(predicate_i16x, squares_low_i16x, squares_high_i16x);
|
|
187
|
+
svbool_t predicate_i64x = svwhilelt_b64_u64(i, byte_count);
|
|
188
|
+
svint64_t sum_i64x = svunpklo_s64(svunpklo_s32(sum_i16x));
|
|
189
|
+
accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, sum_i64x);
|
|
190
|
+
}
|
|
191
|
+
return (nk_u32_t)svaddv_s64(svptrue_b64(), accumulator_i64x);
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
|
|
195
|
+
svuint64_t accumulator_u64x = svdup_u64(0);
|
|
196
|
+
nk_u8_t const *bytes = (nk_u8_t const *)data;
|
|
197
|
+
nk_size_t const byte_count = (count + 1) / 2;
|
|
198
|
+
nk_size_t const vector_length = svcntd();
|
|
199
|
+
for (nk_size_t i = 0; i < byte_count; i += vector_length) {
|
|
200
|
+
svbool_t predicate_u8x = svwhilelt_b8_u64(i, byte_count);
|
|
201
|
+
svuint8_t packed_u8x = svld1_u8(predicate_u8x, bytes + i);
|
|
202
|
+
svuint8_t low_u8x = svand_n_u8_x(predicate_u8x, packed_u8x, 0x0F);
|
|
203
|
+
svuint8_t high_u8x = svlsr_n_u8_x(predicate_u8x, packed_u8x, 4);
|
|
204
|
+
// Widen to u16, square, sum per byte
|
|
205
|
+
svbool_t predicate_u16x = svwhilelt_b16_u64(i, byte_count);
|
|
206
|
+
svuint16_t low_u16x = svunpklo_u16(low_u8x);
|
|
207
|
+
svuint16_t high_u16x = svunpklo_u16(high_u8x);
|
|
208
|
+
svuint16_t squares_low_u16x = svmul_u16_z(predicate_u16x, low_u16x, low_u16x);
|
|
209
|
+
svuint16_t squares_high_u16x = svmul_u16_z(predicate_u16x, high_u16x, high_u16x);
|
|
210
|
+
svuint16_t sum_u16x = svadd_u16_z(predicate_u16x, squares_low_u16x, squares_high_u16x);
|
|
211
|
+
svbool_t predicate_u64x = svwhilelt_b64_u64(i, byte_count);
|
|
212
|
+
svuint64_t sum_u64x = svunpklo_u64(svunpklo_u32(sum_u16x));
|
|
213
|
+
accumulator_u64x = svadd_u64_m(predicate_u64x, accumulator_u64x, sum_u64x);
|
|
214
|
+
}
|
|
215
|
+
return (nk_u32_t)svaddv_u64(svptrue_b64(), accumulator_u64x);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
NK_PUBLIC svfloat32_t nk_angulars_from_dot_f32x_ssve_(svbool_t predicate_f32x, svfloat32_t dots_f32x,
|
|
219
|
+
svfloat32_t query_norm_sq_f32x,
|
|
220
|
+
svfloat32_t target_norms_sq_f32x) NK_STREAMING_COMPATIBLE_ {
|
|
221
|
+
svfloat32_t norms_product_f32x = svmul_f32_x(predicate_f32x, query_norm_sq_f32x, target_norms_sq_f32x);
|
|
222
|
+
svfloat32_t rsqrt_f32x = svrsqrte_f32(norms_product_f32x);
|
|
223
|
+
rsqrt_f32x = svmul_f32_x(predicate_f32x, rsqrt_f32x,
|
|
224
|
+
svrsqrts_f32(svmul_f32_x(predicate_f32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
|
|
225
|
+
rsqrt_f32x = svmul_f32_x(predicate_f32x, rsqrt_f32x,
|
|
226
|
+
svrsqrts_f32(svmul_f32_x(predicate_f32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
|
|
227
|
+
svfloat32_t angular_f32x = svsub_f32_x(predicate_f32x, svdup_n_f32(1.0f),
|
|
228
|
+
svmul_f32_x(predicate_f32x, dots_f32x, rsqrt_f32x));
|
|
229
|
+
return svmax_f32_x(predicate_f32x, angular_f32x, svdup_n_f32(0.0f));
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
NK_PUBLIC svfloat32_t nk_euclideans_from_dot_f32x_ssve_(svbool_t predicate_f32x, svfloat32_t dots_f32x,
|
|
233
|
+
svfloat32_t query_norm_sq_f32x,
|
|
234
|
+
svfloat32_t target_norms_sq_f32x) NK_STREAMING_COMPATIBLE_ {
|
|
235
|
+
svfloat32_t sum_sq_f32x = svadd_f32_x(predicate_f32x, query_norm_sq_f32x, target_norms_sq_f32x);
|
|
236
|
+
svfloat32_t dist_sq_f32x = svsub_f32_x(predicate_f32x, sum_sq_f32x,
|
|
237
|
+
svmul_f32_x(predicate_f32x, svdup_n_f32(2.0f), dots_f32x));
|
|
238
|
+
dist_sq_f32x = svmax_f32_x(predicate_f32x, dist_sq_f32x, svdup_n_f32(0.0f));
|
|
239
|
+
return svsqrt_f32_x(predicate_f32x, dist_sq_f32x);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
#pragma region Half Precision Floats
|
|
243
|
+
|
|
244
|
+
__arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streaming_( //
|
|
245
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
246
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
247
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
248
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
249
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
250
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
251
|
+
nk_f16_t const *a_row = a + row_index * a_stride_elements;
|
|
252
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
253
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_ssve_(a_row, depth);
|
|
254
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
255
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
256
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
257
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
258
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
259
|
+
svst1_f32(
|
|
260
|
+
predicate_f32x, result_row + col_index,
|
|
261
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
NK_PUBLIC void nk_angulars_packed_f16_sme( //
|
|
267
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
268
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
269
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
270
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
271
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
272
|
+
nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
273
|
+
nk_angulars_packed_f16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
274
|
+
c_stride_elements);
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
__arm_locally_streaming static void nk_euclideans_packed_f16_sme_finalize_streaming_( //
|
|
278
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
279
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
280
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
281
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
282
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
283
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
284
|
+
nk_f16_t const *a_row = a + row_index * a_stride_elements;
|
|
285
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
286
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_ssve_(a_row, depth);
|
|
287
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
288
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
289
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
290
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
291
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
292
|
+
svst1_f32(
|
|
293
|
+
predicate_f32x, result_row + col_index,
|
|
294
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
NK_PUBLIC void nk_euclideans_packed_f16_sme( //
|
|
300
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
301
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
302
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
303
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
304
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
305
|
+
nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
306
|
+
nk_euclideans_packed_f16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
307
|
+
c_stride_elements);
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
__arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_streaming_( //
|
|
311
|
+
nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
312
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
313
|
+
// Phase 1: cache row norms on diagonal
|
|
314
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
315
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
316
|
+
result_row[row_index] = nk_dots_reduce_sumsq_f16_ssve_(vectors + row_index * stride_elements, depth);
|
|
317
|
+
}
|
|
318
|
+
// Phase 2: column-first post-processing
|
|
319
|
+
nk_f32_t norms_cache[256];
|
|
320
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
321
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
322
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
323
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_ssve_(vectors + col * stride_elements, depth);
|
|
324
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
325
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
326
|
+
if (col_start >= chunk_end) continue;
|
|
327
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
328
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
329
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
330
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
331
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
332
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
333
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
334
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
335
|
+
target_norms_sq_f32x));
|
|
336
|
+
}
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
// Phase 3: zero diagonals
|
|
340
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
341
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
NK_PUBLIC void nk_angulars_symmetric_f16_sme( //
|
|
345
|
+
nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
346
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
347
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
|
|
348
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
349
|
+
nk_dots_symmetric_f16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
350
|
+
row_start, row_count);
|
|
351
|
+
nk_angulars_symmetric_f16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
352
|
+
result_stride_elements, row_start, row_count);
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_streaming_( //
|
|
356
|
+
nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
357
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
358
|
+
// Phase 1: cache row norms on diagonal
|
|
359
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
360
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
361
|
+
result_row[row_index] = nk_dots_reduce_sumsq_f16_ssve_(vectors + row_index * stride_elements, depth);
|
|
362
|
+
}
|
|
363
|
+
// Phase 2: column-first post-processing
|
|
364
|
+
nk_f32_t norms_cache[256];
|
|
365
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
366
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
367
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
368
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_ssve_(vectors + col * stride_elements, depth);
|
|
369
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
370
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
371
|
+
if (col_start >= chunk_end) continue;
|
|
372
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
373
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
374
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
375
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
376
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
377
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
378
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
379
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
380
|
+
target_norms_sq_f32x));
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
// Phase 3: zero diagonals
|
|
385
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
386
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
NK_PUBLIC void nk_euclideans_symmetric_f16_sme( //
|
|
390
|
+
nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
391
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
392
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
|
|
393
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
394
|
+
nk_dots_symmetric_f16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
395
|
+
row_start, row_count);
|
|
396
|
+
nk_euclideans_symmetric_f16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
397
|
+
result_stride_elements, row_start, row_count);
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
#pragma endregion // Half Precision Floats
|
|
401
|
+
|
|
402
|
+
#pragma region Brain Float 16
|
|
403
|
+
|
|
404
|
+
__arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streaming_( //
|
|
405
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
406
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
407
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
408
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
409
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
410
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
411
|
+
nk_bf16_t const *a_row = a + row_index * a_stride_elements;
|
|
412
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
413
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_ssve_(a_row, depth);
|
|
414
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
415
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
416
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
417
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
418
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
419
|
+
svst1_f32(
|
|
420
|
+
predicate_f32x, result_row + col_index,
|
|
421
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
NK_PUBLIC void nk_angulars_packed_bf16_sme( //
|
|
427
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
428
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
429
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
430
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
431
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
432
|
+
nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
433
|
+
nk_angulars_packed_bf16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
434
|
+
c_stride_elements);
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
__arm_locally_streaming static void nk_euclideans_packed_bf16_sme_finalize_streaming_( //
|
|
438
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
439
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
440
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
441
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
442
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
443
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
444
|
+
nk_bf16_t const *a_row = a + row_index * a_stride_elements;
|
|
445
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
446
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_ssve_(a_row, depth);
|
|
447
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
448
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
449
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
450
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
451
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
452
|
+
svst1_f32(
|
|
453
|
+
predicate_f32x, result_row + col_index,
|
|
454
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
NK_PUBLIC void nk_euclideans_packed_bf16_sme( //
|
|
460
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
461
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
462
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
463
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
464
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
465
|
+
nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
466
|
+
nk_euclideans_packed_bf16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
467
|
+
c_stride_elements);
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
__arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_streaming_( //
|
|
471
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
472
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
473
|
+
// Phase 1: cache row norms on diagonal
|
|
474
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
475
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
476
|
+
result_row[row_index] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + row_index * stride_elements, depth);
|
|
477
|
+
}
|
|
478
|
+
// Phase 2: column-first post-processing
|
|
479
|
+
nk_f32_t norms_cache[256];
|
|
480
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
481
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
482
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
483
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + col * stride_elements, depth);
|
|
484
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
485
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
486
|
+
if (col_start >= chunk_end) continue;
|
|
487
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
488
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
489
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
490
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
491
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
492
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
493
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
494
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
495
|
+
target_norms_sq_f32x));
|
|
496
|
+
}
|
|
497
|
+
}
|
|
498
|
+
}
|
|
499
|
+
// Phase 3: zero diagonals
|
|
500
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
501
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
NK_PUBLIC void nk_angulars_symmetric_bf16_sme( //
|
|
505
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
506
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
507
|
+
nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
|
|
508
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
509
|
+
nk_dots_symmetric_bf16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
510
|
+
row_start, row_count);
|
|
511
|
+
nk_angulars_symmetric_bf16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
512
|
+
result_stride_elements, row_start, row_count);
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_streaming_( //
|
|
516
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
517
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
518
|
+
// Phase 1: cache row norms on diagonal
|
|
519
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
520
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
521
|
+
result_row[row_index] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + row_index * stride_elements, depth);
|
|
522
|
+
}
|
|
523
|
+
// Phase 2: column-first post-processing
|
|
524
|
+
nk_f32_t norms_cache[256];
|
|
525
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
526
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
527
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
528
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + col * stride_elements, depth);
|
|
529
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
530
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
531
|
+
if (col_start >= chunk_end) continue;
|
|
532
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
533
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
534
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
535
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
536
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
537
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
538
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
539
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
540
|
+
target_norms_sq_f32x));
|
|
541
|
+
}
|
|
542
|
+
}
|
|
543
|
+
}
|
|
544
|
+
// Phase 3: zero diagonals
|
|
545
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
546
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
NK_PUBLIC void nk_euclideans_symmetric_bf16_sme( //
|
|
550
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
551
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
552
|
+
nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
|
|
553
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
554
|
+
nk_dots_symmetric_bf16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
555
|
+
row_start, row_count);
|
|
556
|
+
nk_euclideans_symmetric_bf16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
557
|
+
result_stride_elements, row_start, row_count);
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
#pragma endregion // Brain Float 16
|
|
561
|
+
|
|
562
|
+
#pragma region Quarter Precision E4M3
|
|
563
|
+
|
|
564
|
+
__arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streaming_( //
|
|
565
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
566
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
567
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
568
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
569
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
570
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
571
|
+
nk_e4m3_t const *a_row = a + row_index * a_stride_elements;
|
|
572
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
573
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_ssve_(a_row, depth);
|
|
574
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
575
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
576
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
577
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
578
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
579
|
+
svst1_f32(
|
|
580
|
+
predicate_f32x, result_row + col_index,
|
|
581
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
582
|
+
}
|
|
583
|
+
}
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
NK_PUBLIC void nk_angulars_packed_e4m3_sme( //
|
|
587
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
588
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
589
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
590
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
591
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
592
|
+
nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
593
|
+
nk_angulars_packed_e4m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
594
|
+
c_stride_elements);
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
__arm_locally_streaming static void nk_euclideans_packed_e4m3_sme_finalize_streaming_( //
|
|
598
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
599
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
600
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
601
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
602
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
603
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
604
|
+
nk_e4m3_t const *a_row = a + row_index * a_stride_elements;
|
|
605
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
606
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_ssve_(a_row, depth);
|
|
607
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
608
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
609
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
610
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
611
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
612
|
+
svst1_f32(
|
|
613
|
+
predicate_f32x, result_row + col_index,
|
|
614
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
615
|
+
}
|
|
616
|
+
}
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
NK_PUBLIC void nk_euclideans_packed_e4m3_sme( //
|
|
620
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
621
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
622
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
623
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
624
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
625
|
+
nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
626
|
+
nk_euclideans_packed_e4m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
627
|
+
c_stride_elements);
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
__arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_streaming_( //
|
|
631
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
632
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
633
|
+
// Phase 1: cache row norms on diagonal
|
|
634
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
635
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
636
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + row_index * stride_elements, depth);
|
|
637
|
+
}
|
|
638
|
+
// Phase 2: column-first post-processing
|
|
639
|
+
nk_f32_t norms_cache[256];
|
|
640
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
641
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
642
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
643
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + col * stride_elements, depth);
|
|
644
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
645
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
646
|
+
if (col_start >= chunk_end) continue;
|
|
647
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
648
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
649
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
650
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
651
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
652
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
653
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
654
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
655
|
+
target_norms_sq_f32x));
|
|
656
|
+
}
|
|
657
|
+
}
|
|
658
|
+
}
|
|
659
|
+
// Phase 3: zero diagonals
|
|
660
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
661
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
NK_PUBLIC void nk_angulars_symmetric_e4m3_sme( //
|
|
665
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
666
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
667
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
|
|
668
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
669
|
+
nk_dots_symmetric_e4m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
670
|
+
row_start, row_count);
|
|
671
|
+
nk_angulars_symmetric_e4m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
672
|
+
result_stride_elements, row_start, row_count);
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_streaming_( //
|
|
676
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
677
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
678
|
+
// Phase 1: cache row norms on diagonal
|
|
679
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
680
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
681
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + row_index * stride_elements, depth);
|
|
682
|
+
}
|
|
683
|
+
// Phase 2: column-first post-processing
|
|
684
|
+
nk_f32_t norms_cache[256];
|
|
685
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
686
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
687
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
688
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + col * stride_elements, depth);
|
|
689
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
690
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
691
|
+
if (col_start >= chunk_end) continue;
|
|
692
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
693
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
694
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
695
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
696
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
697
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
698
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
699
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
700
|
+
target_norms_sq_f32x));
|
|
701
|
+
}
|
|
702
|
+
}
|
|
703
|
+
}
|
|
704
|
+
// Phase 3: zero diagonals
|
|
705
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
706
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme( //
|
|
710
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
711
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
712
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
|
|
713
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
714
|
+
nk_dots_symmetric_e4m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
715
|
+
row_start, row_count);
|
|
716
|
+
nk_euclideans_symmetric_e4m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
717
|
+
result_stride_elements, row_start, row_count);
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
#pragma endregion // Quarter Precision E4M3
|
|
721
|
+
|
|
722
|
+
#pragma region Quarter Precision E5M2
|
|
723
|
+
|
|
724
|
+
__arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streaming_( //
|
|
725
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
726
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
727
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
728
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
729
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
730
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
731
|
+
nk_e5m2_t const *a_row = a + row_index * a_stride_elements;
|
|
732
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
733
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_ssve_(a_row, depth);
|
|
734
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
735
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
736
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
737
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
738
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
739
|
+
svst1_f32(
|
|
740
|
+
predicate_f32x, result_row + col_index,
|
|
741
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
742
|
+
}
|
|
743
|
+
}
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
NK_PUBLIC void nk_angulars_packed_e5m2_sme( //
|
|
747
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
748
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
749
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
750
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
751
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
752
|
+
nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
753
|
+
nk_angulars_packed_e5m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
754
|
+
c_stride_elements);
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
__arm_locally_streaming static void nk_euclideans_packed_e5m2_sme_finalize_streaming_( //
|
|
758
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
759
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
760
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
761
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
762
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
763
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
764
|
+
nk_e5m2_t const *a_row = a + row_index * a_stride_elements;
|
|
765
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
766
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_ssve_(a_row, depth);
|
|
767
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
768
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
769
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
770
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
771
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
772
|
+
svst1_f32(
|
|
773
|
+
predicate_f32x, result_row + col_index,
|
|
774
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
}
|
|
778
|
+
|
|
779
|
+
NK_PUBLIC void nk_euclideans_packed_e5m2_sme( //
|
|
780
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
781
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
782
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
783
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
784
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
785
|
+
nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
786
|
+
nk_euclideans_packed_e5m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
787
|
+
c_stride_elements);
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
__arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_streaming_( //
|
|
791
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
792
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
793
|
+
// Phase 1: cache row norms on diagonal
|
|
794
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
795
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
796
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + row_index * stride_elements, depth);
|
|
797
|
+
}
|
|
798
|
+
// Phase 2: column-first post-processing
|
|
799
|
+
nk_f32_t norms_cache[256];
|
|
800
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
801
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
802
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
803
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + col * stride_elements, depth);
|
|
804
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
805
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
806
|
+
if (col_start >= chunk_end) continue;
|
|
807
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
808
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
809
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
810
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
811
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
812
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
813
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
814
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
815
|
+
target_norms_sq_f32x));
|
|
816
|
+
}
|
|
817
|
+
}
|
|
818
|
+
}
|
|
819
|
+
// Phase 3: zero diagonals
|
|
820
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
821
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_sme( //
|
|
825
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
826
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
827
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
|
|
828
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
829
|
+
nk_dots_symmetric_e5m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
830
|
+
row_start, row_count);
|
|
831
|
+
nk_angulars_symmetric_e5m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
832
|
+
result_stride_elements, row_start, row_count);
|
|
833
|
+
}
|
|
834
|
+
|
|
835
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_streaming_( //
|
|
836
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
837
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
838
|
+
// Phase 1: cache row norms on diagonal
|
|
839
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
840
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
841
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + row_index * stride_elements, depth);
|
|
842
|
+
}
|
|
843
|
+
// Phase 2: column-first post-processing
|
|
844
|
+
nk_f32_t norms_cache[256];
|
|
845
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
846
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
847
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
848
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + col * stride_elements, depth);
|
|
849
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
850
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
851
|
+
if (col_start >= chunk_end) continue;
|
|
852
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
853
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
854
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
855
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
856
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
857
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
858
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
859
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
860
|
+
target_norms_sq_f32x));
|
|
861
|
+
}
|
|
862
|
+
}
|
|
863
|
+
}
|
|
864
|
+
// Phase 3: zero diagonals
|
|
865
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
866
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme( //
|
|
870
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
871
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
872
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
|
|
873
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
874
|
+
nk_dots_symmetric_e5m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
875
|
+
row_start, row_count);
|
|
876
|
+
nk_euclideans_symmetric_e5m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
877
|
+
result_stride_elements, row_start, row_count);
|
|
878
|
+
}
|
|
879
|
+
|
|
880
|
+
#pragma endregion // Quarter Precision E5M2
|
|
881
|
+
|
|
882
|
+
#pragma region Micro Precision E2M3
|
|
883
|
+
|
|
884
|
+
__arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streaming_( //
|
|
885
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
886
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
887
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
888
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
889
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
890
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
891
|
+
nk_e2m3_t const *a_row = a + row_index * a_stride_elements;
|
|
892
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
893
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_ssve_(a_row, depth);
|
|
894
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
895
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
896
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
897
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
898
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
899
|
+
svst1_f32(
|
|
900
|
+
predicate_f32x, result_row + col_index,
|
|
901
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
902
|
+
}
|
|
903
|
+
}
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
NK_PUBLIC void nk_angulars_packed_e2m3_sme( //
|
|
907
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
908
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
909
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
910
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
911
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
912
|
+
nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
913
|
+
nk_angulars_packed_e2m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
914
|
+
c_stride_elements);
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
__arm_locally_streaming static void nk_euclideans_packed_e2m3_sme_finalize_streaming_( //
|
|
918
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
919
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
920
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
921
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
922
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
923
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
924
|
+
nk_e2m3_t const *a_row = a + row_index * a_stride_elements;
|
|
925
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
926
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_ssve_(a_row, depth);
|
|
927
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
928
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
929
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
930
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
931
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
932
|
+
svst1_f32(
|
|
933
|
+
predicate_f32x, result_row + col_index,
|
|
934
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
935
|
+
}
|
|
936
|
+
}
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
NK_PUBLIC void nk_euclideans_packed_e2m3_sme( //
|
|
940
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
941
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
942
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
943
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
944
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
945
|
+
nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
946
|
+
nk_euclideans_packed_e2m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
947
|
+
c_stride_elements);
|
|
948
|
+
}
|
|
949
|
+
|
|
950
|
+
__arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_streaming_( //
|
|
951
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
952
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
953
|
+
// Phase 1: cache row norms on diagonal
|
|
954
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
955
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
956
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + row_index * stride_elements, depth);
|
|
957
|
+
}
|
|
958
|
+
// Phase 2: column-first post-processing
|
|
959
|
+
nk_f32_t norms_cache[256];
|
|
960
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
961
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
962
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
963
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + col * stride_elements, depth);
|
|
964
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
965
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
966
|
+
if (col_start >= chunk_end) continue;
|
|
967
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
968
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
969
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
970
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
971
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
972
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
973
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
974
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
975
|
+
target_norms_sq_f32x));
|
|
976
|
+
}
|
|
977
|
+
}
|
|
978
|
+
}
|
|
979
|
+
// Phase 3: zero diagonals
|
|
980
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
981
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
NK_PUBLIC void nk_angulars_symmetric_e2m3_sme( //
|
|
985
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
986
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
987
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
|
|
988
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
989
|
+
nk_dots_symmetric_e2m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
990
|
+
row_start, row_count);
|
|
991
|
+
nk_angulars_symmetric_e2m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
992
|
+
result_stride_elements, row_start, row_count);
|
|
993
|
+
}
|
|
994
|
+
|
|
995
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_streaming_( //
|
|
996
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
997
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
998
|
+
// Phase 1: cache row norms on diagonal
|
|
999
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1000
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1001
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + row_index * stride_elements, depth);
|
|
1002
|
+
}
|
|
1003
|
+
// Phase 2: column-first post-processing
|
|
1004
|
+
nk_f32_t norms_cache[256];
|
|
1005
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1006
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1007
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1008
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + col * stride_elements, depth);
|
|
1009
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1010
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1011
|
+
if (col_start >= chunk_end) continue;
|
|
1012
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1013
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
1014
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1015
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1016
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
1017
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
1018
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1019
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1020
|
+
target_norms_sq_f32x));
|
|
1021
|
+
}
|
|
1022
|
+
}
|
|
1023
|
+
}
|
|
1024
|
+
// Phase 3: zero diagonals
|
|
1025
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1026
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme( //
|
|
1030
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1031
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1032
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
|
|
1033
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1034
|
+
nk_dots_symmetric_e2m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1035
|
+
row_start, row_count);
|
|
1036
|
+
nk_euclideans_symmetric_e2m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1037
|
+
result_stride_elements, row_start, row_count);
|
|
1038
|
+
}
|
|
1039
|
+
|
|
1040
|
+
#pragma endregion // Micro Precision E2M3
|
|
1041
|
+
|
|
1042
|
+
#pragma region Micro Precision E3M2
|
|
1043
|
+
|
|
1044
|
+
__arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streaming_( //
|
|
1045
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1046
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1047
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1048
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1049
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1050
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1051
|
+
nk_e3m2_t const *a_row = a + row_index * a_stride_elements;
|
|
1052
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1053
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_ssve_(a_row, depth);
|
|
1054
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
1055
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1056
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1057
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
1058
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
1059
|
+
svst1_f32(
|
|
1060
|
+
predicate_f32x, result_row + col_index,
|
|
1061
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
1064
|
+
}
|
|
1065
|
+
|
|
1066
|
+
NK_PUBLIC void nk_angulars_packed_e3m2_sme( //
|
|
1067
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1068
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1069
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1070
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1071
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1072
|
+
nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1073
|
+
nk_angulars_packed_e3m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1074
|
+
c_stride_elements);
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
__arm_locally_streaming static void nk_euclideans_packed_e3m2_sme_finalize_streaming_( //
|
|
1078
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1079
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1080
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1081
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1082
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1083
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1084
|
+
nk_e3m2_t const *a_row = a + row_index * a_stride_elements;
|
|
1085
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1086
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_ssve_(a_row, depth);
|
|
1087
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
|
|
1088
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1089
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1090
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
1091
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
|
|
1092
|
+
svst1_f32(
|
|
1093
|
+
predicate_f32x, result_row + col_index,
|
|
1094
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1095
|
+
}
|
|
1096
|
+
}
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
NK_PUBLIC void nk_euclideans_packed_e3m2_sme( //
|
|
1100
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1101
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1102
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1103
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1104
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1105
|
+
nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1106
|
+
nk_euclideans_packed_e3m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1107
|
+
c_stride_elements);
|
|
1108
|
+
}
|
|
1109
|
+
|
|
1110
|
+
__arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_streaming_( //
|
|
1111
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1112
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1113
|
+
// Phase 1: cache row norms on diagonal
|
|
1114
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1115
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1116
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + row_index * stride_elements, depth);
|
|
1117
|
+
}
|
|
1118
|
+
// Phase 2: column-first post-processing
|
|
1119
|
+
nk_f32_t norms_cache[256];
|
|
1120
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1121
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1122
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1123
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + col * stride_elements, depth);
|
|
1124
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1125
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1126
|
+
if (col_start >= chunk_end) continue;
|
|
1127
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1128
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
1129
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1130
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1131
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
1132
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
1133
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1134
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1135
|
+
target_norms_sq_f32x));
|
|
1136
|
+
}
|
|
1137
|
+
}
|
|
1138
|
+
}
|
|
1139
|
+
// Phase 3: zero diagonals
|
|
1140
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1141
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1142
|
+
}
|
|
1143
|
+
|
|
1144
|
+
NK_PUBLIC void nk_angulars_symmetric_e3m2_sme( //
|
|
1145
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1146
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1147
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
|
|
1148
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1149
|
+
nk_dots_symmetric_e3m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1150
|
+
row_start, row_count);
|
|
1151
|
+
nk_angulars_symmetric_e3m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1152
|
+
result_stride_elements, row_start, row_count);
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_streaming_( //
|
|
1156
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1157
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1158
|
+
// Phase 1: cache row norms on diagonal
|
|
1159
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1160
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1161
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + row_index * stride_elements, depth);
|
|
1162
|
+
}
|
|
1163
|
+
// Phase 2: column-first post-processing
|
|
1164
|
+
nk_f32_t norms_cache[256];
|
|
1165
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1166
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1167
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1168
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + col * stride_elements, depth);
|
|
1169
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1170
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1171
|
+
if (col_start >= chunk_end) continue;
|
|
1172
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1173
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
|
|
1174
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1175
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1176
|
+
svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
|
|
1177
|
+
svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
|
|
1178
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1179
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1180
|
+
target_norms_sq_f32x));
|
|
1181
|
+
}
|
|
1182
|
+
}
|
|
1183
|
+
}
|
|
1184
|
+
// Phase 3: zero diagonals
|
|
1185
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1186
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme( //
|
|
1190
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1191
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1192
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
|
|
1193
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1194
|
+
nk_dots_symmetric_e3m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1195
|
+
row_start, row_count);
|
|
1196
|
+
nk_euclideans_symmetric_e3m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1197
|
+
result_stride_elements, row_start, row_count);
|
|
1198
|
+
}
|
|
1199
|
+
|
|
1200
|
+
#pragma endregion // Micro Precision E3M2
|
|
1201
|
+
#pragma region Signed 8-bit Integers
|
|
1202
|
+
|
|
1203
|
+
__arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming_( //
|
|
1204
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1205
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1206
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1207
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1208
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1209
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1210
|
+
nk_i8_t const *a_row = a + row_index * a_stride_elements;
|
|
1211
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1212
|
+
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i8_ssve_(a_row, depth);
|
|
1213
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1214
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1215
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1216
|
+
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1217
|
+
predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
|
|
1218
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
|
|
1219
|
+
svld1_u32(predicate_f32x, b_norms + col_index));
|
|
1220
|
+
svst1_f32(
|
|
1221
|
+
predicate_f32x, result_row + col_index,
|
|
1222
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1223
|
+
}
|
|
1224
|
+
}
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1227
|
+
NK_PUBLIC void nk_angulars_packed_i8_sme( //
|
|
1228
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1229
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1230
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1231
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
1232
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1233
|
+
nk_dots_packed_i8_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1234
|
+
c_stride_elements);
|
|
1235
|
+
nk_angulars_packed_i8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1236
|
+
c_stride_elements);
|
|
1237
|
+
}
|
|
1238
|
+
|
|
1239
|
+
__arm_locally_streaming static void nk_euclideans_packed_i8_sme_finalize_streaming_( //
|
|
1240
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1241
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1242
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1243
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1244
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1245
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1246
|
+
nk_i8_t const *a_row = a + row_index * a_stride_elements;
|
|
1247
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1248
|
+
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i8_ssve_(a_row, depth);
|
|
1249
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1250
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1251
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1252
|
+
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1253
|
+
predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
|
|
1254
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
|
|
1255
|
+
svld1_u32(predicate_f32x, b_norms + col_index));
|
|
1256
|
+
svst1_f32(
|
|
1257
|
+
predicate_f32x, result_row + col_index,
|
|
1258
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1259
|
+
}
|
|
1260
|
+
}
|
|
1261
|
+
}
|
|
1262
|
+
|
|
1263
|
+
NK_PUBLIC void nk_euclideans_packed_i8_sme( //
|
|
1264
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1265
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1266
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1267
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
1268
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1269
|
+
nk_dots_packed_i8_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1270
|
+
c_stride_elements);
|
|
1271
|
+
nk_euclideans_packed_i8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1272
|
+
c_stride_elements);
|
|
1273
|
+
}
|
|
1274
|
+
|
|
1275
|
+
__arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_streaming_( //
|
|
1276
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1277
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1278
|
+
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1279
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1280
|
+
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i8_ssve_(vectors + row_index * stride_elements, depth);
|
|
1281
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
|
|
1282
|
+
}
|
|
1283
|
+
// Phase 2: column-first post-processing
|
|
1284
|
+
nk_u32_t norms_cache[256];
|
|
1285
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1286
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1287
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1288
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_ssve_(vectors + col * stride_elements, depth);
|
|
1289
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1290
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1291
|
+
if (col_start >= chunk_end) continue;
|
|
1292
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1293
|
+
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1294
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1295
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1296
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1297
|
+
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1298
|
+
predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
|
|
1299
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1300
|
+
predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
|
|
1301
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1302
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1303
|
+
target_norms_sq_f32x));
|
|
1304
|
+
}
|
|
1305
|
+
}
|
|
1306
|
+
}
|
|
1307
|
+
// Phase 3: zero diagonals
|
|
1308
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1309
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1310
|
+
}
|
|
1311
|
+
|
|
1312
|
+
NK_PUBLIC void nk_angulars_symmetric_i8_sme( //
|
|
1313
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1314
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1315
|
+
nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
|
|
1316
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1317
|
+
nk_dots_symmetric_i8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
|
|
1318
|
+
result_stride_elements, row_start, row_count);
|
|
1319
|
+
nk_angulars_symmetric_i8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1320
|
+
result_stride_elements, row_start, row_count);
|
|
1321
|
+
}
|
|
1322
|
+
|
|
1323
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_streaming_( //
|
|
1324
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1325
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1326
|
+
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1327
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1328
|
+
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i8_ssve_(vectors + row_index * stride_elements, depth);
|
|
1329
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
|
|
1330
|
+
}
|
|
1331
|
+
// Phase 2: column-first post-processing
|
|
1332
|
+
nk_u32_t norms_cache[256];
|
|
1333
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1334
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1335
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1336
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_ssve_(vectors + col * stride_elements, depth);
|
|
1337
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1338
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1339
|
+
if (col_start >= chunk_end) continue;
|
|
1340
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1341
|
+
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1342
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1343
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1344
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1345
|
+
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1346
|
+
predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
|
|
1347
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1348
|
+
predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
|
|
1349
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1350
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1351
|
+
target_norms_sq_f32x));
|
|
1352
|
+
}
|
|
1353
|
+
}
|
|
1354
|
+
}
|
|
1355
|
+
// Phase 3: zero diagonals
|
|
1356
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1357
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1358
|
+
}
|
|
1359
|
+
|
|
1360
|
+
NK_PUBLIC void nk_euclideans_symmetric_i8_sme( //
|
|
1361
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1362
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1363
|
+
nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
|
|
1364
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1365
|
+
nk_dots_symmetric_i8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
|
|
1366
|
+
result_stride_elements, row_start, row_count);
|
|
1367
|
+
nk_euclideans_symmetric_i8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1368
|
+
result_stride_elements, row_start, row_count);
|
|
1369
|
+
}
|
|
1370
|
+
|
|
1371
|
+
#pragma endregion // Signed 8-bit Integers
|
|
1372
|
+
|
|
1373
|
+
#pragma region Unsigned 8-bit Integers
|
|
1374
|
+
|
|
1375
|
+
__arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming_( //
|
|
1376
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1377
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1378
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1379
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1380
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1381
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1382
|
+
nk_u8_t const *a_row = a + row_index * a_stride_elements;
|
|
1383
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1384
|
+
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u8_ssve_(a_row, depth);
|
|
1385
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1386
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1387
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1388
|
+
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1389
|
+
predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
|
|
1390
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
|
|
1391
|
+
svld1_u32(predicate_f32x, b_norms + col_index));
|
|
1392
|
+
svst1_f32(
|
|
1393
|
+
predicate_f32x, result_row + col_index,
|
|
1394
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1395
|
+
}
|
|
1396
|
+
}
|
|
1397
|
+
}
|
|
1398
|
+
|
|
1399
|
+
NK_PUBLIC void nk_angulars_packed_u8_sme( //
|
|
1400
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1401
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1402
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1403
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
1404
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1405
|
+
nk_dots_packed_u8_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1406
|
+
c_stride_elements);
|
|
1407
|
+
nk_angulars_packed_u8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1408
|
+
c_stride_elements);
|
|
1409
|
+
}
|
|
1410
|
+
|
|
1411
|
+
__arm_locally_streaming static void nk_euclideans_packed_u8_sme_finalize_streaming_( //
|
|
1412
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1413
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1414
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1415
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1416
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1417
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1418
|
+
nk_u8_t const *a_row = a + row_index * a_stride_elements;
|
|
1419
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1420
|
+
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u8_ssve_(a_row, depth);
|
|
1421
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1422
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1423
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1424
|
+
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1425
|
+
predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
|
|
1426
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
|
|
1427
|
+
svld1_u32(predicate_f32x, b_norms + col_index));
|
|
1428
|
+
svst1_f32(
|
|
1429
|
+
predicate_f32x, result_row + col_index,
|
|
1430
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1431
|
+
}
|
|
1432
|
+
}
|
|
1433
|
+
}
|
|
1434
|
+
|
|
1435
|
+
NK_PUBLIC void nk_euclideans_packed_u8_sme( //
|
|
1436
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1437
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1438
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1439
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
1440
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1441
|
+
nk_dots_packed_u8_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1442
|
+
c_stride_elements);
|
|
1443
|
+
nk_euclideans_packed_u8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1444
|
+
c_stride_elements);
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
__arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_streaming_( //
|
|
1448
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1449
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1450
|
+
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1451
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1452
|
+
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u8_ssve_(vectors + row_index * stride_elements, depth);
|
|
1453
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
|
|
1454
|
+
}
|
|
1455
|
+
// Phase 2: column-first post-processing
|
|
1456
|
+
nk_u32_t norms_cache[256];
|
|
1457
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1458
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1459
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1460
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_ssve_(vectors + col * stride_elements, depth);
|
|
1461
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1462
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1463
|
+
if (col_start >= chunk_end) continue;
|
|
1464
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1465
|
+
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1466
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1467
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1468
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1469
|
+
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1470
|
+
predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
|
|
1471
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1472
|
+
predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
|
|
1473
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1474
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1475
|
+
target_norms_sq_f32x));
|
|
1476
|
+
}
|
|
1477
|
+
}
|
|
1478
|
+
}
|
|
1479
|
+
// Phase 3: zero diagonals
|
|
1480
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1481
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1482
|
+
}
|
|
1483
|
+
|
|
1484
|
+
NK_PUBLIC void nk_angulars_symmetric_u8_sme( //
|
|
1485
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1486
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1487
|
+
nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
|
|
1488
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1489
|
+
nk_dots_symmetric_u8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
|
|
1490
|
+
result_stride_elements, row_start, row_count);
|
|
1491
|
+
nk_angulars_symmetric_u8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1492
|
+
result_stride_elements, row_start, row_count);
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_streaming_( //
|
|
1496
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1497
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1498
|
+
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1499
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1500
|
+
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u8_ssve_(vectors + row_index * stride_elements, depth);
|
|
1501
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
|
|
1502
|
+
}
|
|
1503
|
+
// Phase 2: column-first post-processing
|
|
1504
|
+
nk_u32_t norms_cache[256];
|
|
1505
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1506
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1507
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1508
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_ssve_(vectors + col * stride_elements, depth);
|
|
1509
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1510
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1511
|
+
if (col_start >= chunk_end) continue;
|
|
1512
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1513
|
+
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1514
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1515
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1516
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1517
|
+
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1518
|
+
predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
|
|
1519
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1520
|
+
predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
|
|
1521
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1522
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1523
|
+
target_norms_sq_f32x));
|
|
1524
|
+
}
|
|
1525
|
+
}
|
|
1526
|
+
}
|
|
1527
|
+
// Phase 3: zero diagonals
|
|
1528
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1529
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1530
|
+
}
|
|
1531
|
+
|
|
1532
|
+
NK_PUBLIC void nk_euclideans_symmetric_u8_sme( //
|
|
1533
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1534
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1535
|
+
nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
|
|
1536
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1537
|
+
nk_dots_symmetric_u8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
|
|
1538
|
+
result_stride_elements, row_start, row_count);
|
|
1539
|
+
nk_euclideans_symmetric_u8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1540
|
+
result_stride_elements, row_start, row_count);
|
|
1541
|
+
}
|
|
1542
|
+
|
|
1543
|
+
#pragma endregion // Unsigned 8-bit Integers
|
|
1544
|
+
|
|
1545
|
+
#pragma region Nibble Signed Integers
|
|
1546
|
+
|
|
1547
|
+
__arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming_( //
|
|
1548
|
+
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1549
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1550
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1551
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1552
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1553
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1554
|
+
nk_i4x2_t const *a_row = a + row_index * a_stride_elements;
|
|
1555
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1556
|
+
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i4_ssve_(a_row, depth);
|
|
1557
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1558
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1559
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1560
|
+
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1561
|
+
predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
|
|
1562
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
|
|
1563
|
+
svld1_u32(predicate_f32x, b_norms + col_index));
|
|
1564
|
+
svst1_f32(
|
|
1565
|
+
predicate_f32x, result_row + col_index,
|
|
1566
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1567
|
+
}
|
|
1568
|
+
}
|
|
1569
|
+
}
|
|
1570
|
+
|
|
1571
|
+
NK_PUBLIC void nk_angulars_packed_i4_sme( //
|
|
1572
|
+
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1573
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1574
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1575
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
|
|
1576
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1577
|
+
nk_dots_packed_i4_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1578
|
+
c_stride_elements);
|
|
1579
|
+
nk_angulars_packed_i4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1580
|
+
c_stride_elements);
|
|
1581
|
+
}
|
|
1582
|
+
|
|
1583
|
+
__arm_locally_streaming static void nk_euclideans_packed_i4_sme_finalize_streaming_( //
|
|
1584
|
+
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1585
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1586
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1587
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1588
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1589
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1590
|
+
nk_i4x2_t const *a_row = a + row_index * a_stride_elements;
|
|
1591
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1592
|
+
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i4_ssve_(a_row, depth);
|
|
1593
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1594
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1595
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1596
|
+
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1597
|
+
predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
|
|
1598
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
|
|
1599
|
+
svld1_u32(predicate_f32x, b_norms + col_index));
|
|
1600
|
+
svst1_f32(
|
|
1601
|
+
predicate_f32x, result_row + col_index,
|
|
1602
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1603
|
+
}
|
|
1604
|
+
}
|
|
1605
|
+
}
|
|
1606
|
+
|
|
1607
|
+
NK_PUBLIC void nk_euclideans_packed_i4_sme( //
|
|
1608
|
+
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1609
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1610
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1611
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
|
|
1612
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1613
|
+
nk_dots_packed_i4_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1614
|
+
c_stride_elements);
|
|
1615
|
+
nk_euclideans_packed_i4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1616
|
+
c_stride_elements);
|
|
1617
|
+
}
|
|
1618
|
+
|
|
1619
|
+
__arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_streaming_( //
|
|
1620
|
+
nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1621
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1622
|
+
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1623
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1624
|
+
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i4_ssve_(vectors + row_index * stride_elements, depth);
|
|
1625
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
|
|
1626
|
+
}
|
|
1627
|
+
// Phase 2: column-first post-processing
|
|
1628
|
+
nk_u32_t norms_cache[256];
|
|
1629
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1630
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1631
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1632
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i4_ssve_(vectors + col * stride_elements, depth);
|
|
1633
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1634
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1635
|
+
if (col_start >= chunk_end) continue;
|
|
1636
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1637
|
+
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1638
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1639
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1640
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1641
|
+
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1642
|
+
predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
|
|
1643
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1644
|
+
predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
|
|
1645
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1646
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1647
|
+
target_norms_sq_f32x));
|
|
1648
|
+
}
|
|
1649
|
+
}
|
|
1650
|
+
}
|
|
1651
|
+
// Phase 3: zero diagonals
|
|
1652
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1653
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1654
|
+
}
|
|
1655
|
+
|
|
1656
|
+
NK_PUBLIC void nk_angulars_symmetric_i4_sme( //
|
|
1657
|
+
nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1658
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1659
|
+
nk_size_t const stride_elements = stride / sizeof(nk_i4x2_t);
|
|
1660
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1661
|
+
nk_dots_symmetric_i4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
|
|
1662
|
+
result_stride_elements, row_start, row_count);
|
|
1663
|
+
nk_angulars_symmetric_i4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1664
|
+
result_stride_elements, row_start, row_count);
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1667
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_streaming_( //
|
|
1668
|
+
nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1669
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1670
|
+
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1671
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1672
|
+
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i4_ssve_(vectors + row_index * stride_elements, depth);
|
|
1673
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
|
|
1674
|
+
}
|
|
1675
|
+
// Phase 2: column-first post-processing
|
|
1676
|
+
nk_u32_t norms_cache[256];
|
|
1677
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1678
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1679
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1680
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i4_ssve_(vectors + col * stride_elements, depth);
|
|
1681
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1682
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1683
|
+
if (col_start >= chunk_end) continue;
|
|
1684
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1685
|
+
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1686
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1687
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1688
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1689
|
+
svfloat32_t dots_f32x = svcvt_f32_s32_x(
|
|
1690
|
+
predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
|
|
1691
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1692
|
+
predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
|
|
1693
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1694
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1695
|
+
target_norms_sq_f32x));
|
|
1696
|
+
}
|
|
1697
|
+
}
|
|
1698
|
+
}
|
|
1699
|
+
// Phase 3: zero diagonals
|
|
1700
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1701
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1702
|
+
}
|
|
1703
|
+
|
|
1704
|
+
NK_PUBLIC void nk_euclideans_symmetric_i4_sme( //
|
|
1705
|
+
nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1706
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1707
|
+
nk_size_t const stride_elements = stride / sizeof(nk_i4x2_t);
|
|
1708
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1709
|
+
nk_dots_symmetric_i4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
|
|
1710
|
+
result_stride_elements, row_start, row_count);
|
|
1711
|
+
nk_euclideans_symmetric_i4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1712
|
+
result_stride_elements, row_start, row_count);
|
|
1713
|
+
}
|
|
1714
|
+
|
|
1715
|
+
#pragma endregion // Nibble Signed Integers
|
|
1716
|
+
|
|
1717
|
+
#pragma region Nibble Unsigned Integers
|
|
1718
|
+
|
|
1719
|
+
__arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming_( //
|
|
1720
|
+
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1721
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1722
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1723
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1724
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1725
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1726
|
+
nk_u4x2_t const *a_row = a + row_index * a_stride_elements;
|
|
1727
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1728
|
+
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u4_ssve_(a_row, depth);
|
|
1729
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1730
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1731
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1732
|
+
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1733
|
+
predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
|
|
1734
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
|
|
1735
|
+
svld1_u32(predicate_f32x, b_norms + col_index));
|
|
1736
|
+
svst1_f32(
|
|
1737
|
+
predicate_f32x, result_row + col_index,
|
|
1738
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1739
|
+
}
|
|
1740
|
+
}
|
|
1741
|
+
}
|
|
1742
|
+
|
|
1743
|
+
NK_PUBLIC void nk_angulars_packed_u4_sme( //
|
|
1744
|
+
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1745
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1746
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1747
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
|
|
1748
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1749
|
+
nk_dots_packed_u4_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1750
|
+
c_stride_elements);
|
|
1751
|
+
nk_angulars_packed_u4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1752
|
+
c_stride_elements);
|
|
1753
|
+
}
|
|
1754
|
+
|
|
1755
|
+
__arm_locally_streaming static void nk_euclideans_packed_u4_sme_finalize_streaming_( //
|
|
1756
|
+
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1757
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1758
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1759
|
+
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1760
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1761
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1762
|
+
nk_u4x2_t const *a_row = a + row_index * a_stride_elements;
|
|
1763
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1764
|
+
nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u4_ssve_(a_row, depth);
|
|
1765
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
|
|
1766
|
+
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
|
|
1767
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
|
|
1768
|
+
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1769
|
+
predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
|
|
1770
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
|
|
1771
|
+
svld1_u32(predicate_f32x, b_norms + col_index));
|
|
1772
|
+
svst1_f32(
|
|
1773
|
+
predicate_f32x, result_row + col_index,
|
|
1774
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
|
|
1775
|
+
}
|
|
1776
|
+
}
|
|
1777
|
+
}
|
|
1778
|
+
|
|
1779
|
+
NK_PUBLIC void nk_euclideans_packed_u4_sme( //
|
|
1780
|
+
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1781
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1782
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1783
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
|
|
1784
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1785
|
+
nk_dots_packed_u4_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1786
|
+
c_stride_elements);
|
|
1787
|
+
nk_euclideans_packed_u4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1788
|
+
c_stride_elements);
|
|
1789
|
+
}
|
|
1790
|
+
|
|
1791
|
+
__arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_streaming_( //
|
|
1792
|
+
nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1793
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1794
|
+
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1795
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1796
|
+
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u4_ssve_(vectors + row_index * stride_elements, depth);
|
|
1797
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
|
|
1798
|
+
}
|
|
1799
|
+
// Phase 2: column-first post-processing
|
|
1800
|
+
nk_u32_t norms_cache[256];
|
|
1801
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1802
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1803
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1804
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u4_ssve_(vectors + col * stride_elements, depth);
|
|
1805
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1806
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1807
|
+
if (col_start >= chunk_end) continue;
|
|
1808
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1809
|
+
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1810
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1811
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1812
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1813
|
+
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1814
|
+
predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
|
|
1815
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1816
|
+
predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
|
|
1817
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1818
|
+
nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1819
|
+
target_norms_sq_f32x));
|
|
1820
|
+
}
|
|
1821
|
+
}
|
|
1822
|
+
}
|
|
1823
|
+
// Phase 3: zero diagonals
|
|
1824
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1825
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1826
|
+
}
|
|
1827
|
+
|
|
1828
|
+
NK_PUBLIC void nk_angulars_symmetric_u4_sme( //
|
|
1829
|
+
nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1830
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1831
|
+
nk_size_t const stride_elements = stride / sizeof(nk_u4x2_t);
|
|
1832
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1833
|
+
nk_dots_symmetric_u4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
|
|
1834
|
+
result_stride_elements, row_start, row_count);
|
|
1835
|
+
nk_angulars_symmetric_u4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1836
|
+
result_stride_elements, row_start, row_count);
|
|
1837
|
+
}
|
|
1838
|
+
|
|
1839
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_streaming_( //
|
|
1840
|
+
nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
|
|
1841
|
+
nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1842
|
+
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1843
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1844
|
+
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u4_ssve_(vectors + row_index * stride_elements, depth);
|
|
1845
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
|
|
1846
|
+
}
|
|
1847
|
+
// Phase 2: column-first post-processing
|
|
1848
|
+
nk_u32_t norms_cache[256];
|
|
1849
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1850
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1851
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1852
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u4_ssve_(vectors + col * stride_elements, depth);
|
|
1853
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1854
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1855
|
+
if (col_start >= chunk_end) continue;
|
|
1856
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1857
|
+
nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
|
|
1858
|
+
svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
|
|
1859
|
+
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
|
|
1860
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
|
|
1861
|
+
svfloat32_t dots_f32x = svcvt_f32_u32_x(
|
|
1862
|
+
predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
|
|
1863
|
+
svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
|
|
1864
|
+
predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
|
|
1865
|
+
svst1_f32(predicate_f32x, result_row + col_index,
|
|
1866
|
+
nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
|
|
1867
|
+
target_norms_sq_f32x));
|
|
1868
|
+
}
|
|
1869
|
+
}
|
|
1870
|
+
}
|
|
1871
|
+
// Phase 3: zero diagonals
|
|
1872
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1873
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1874
|
+
}
|
|
1875
|
+
|
|
1876
|
+
NK_PUBLIC void nk_euclideans_symmetric_u4_sme( //
|
|
1877
|
+
nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1878
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1879
|
+
nk_size_t const stride_elements = stride / sizeof(nk_u4x2_t);
|
|
1880
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1881
|
+
nk_dots_symmetric_u4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
|
|
1882
|
+
result_stride_elements, row_start, row_count);
|
|
1883
|
+
nk_euclideans_symmetric_u4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
|
|
1884
|
+
result_stride_elements, row_start, row_count);
|
|
1885
|
+
}
|
|
1886
|
+
|
|
1887
|
+
#pragma endregion // Nibble Unsigned Integers
|
|
1888
|
+
|
|
1889
|
+
#if defined(__clang__)
|
|
1890
|
+
#pragma clang attribute pop
|
|
1891
|
+
#elif defined(__GNUC__)
|
|
1892
|
+
#pragma GCC pop_options
|
|
1893
|
+
#endif
|
|
1894
|
+
|
|
1895
|
+
#if defined(__cplusplus)
|
|
1896
|
+
} // extern "C"
|
|
1897
|
+
#endif
|
|
1898
|
+
|
|
1899
|
+
#endif // NK_TARGET_SME
|
|
1900
|
+
#endif // NK_TARGET_ARM_
|
|
1901
|
+
#endif // NK_SPATIALS_SME_H
|