numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -20,69 +20,70 @@ extern "C" {
|
|
|
20
20
|
#endif
|
|
21
21
|
|
|
22
22
|
#if defined(__clang__)
|
|
23
|
-
#pragma clang attribute push(__attribute__((target("sme,
|
|
23
|
+
#pragma clang attribute push(__attribute__((target("sme,sme-f64f64"))), apply_to = function)
|
|
24
24
|
#elif defined(__GNUC__)
|
|
25
25
|
#pragma GCC push_options
|
|
26
26
|
#pragma GCC target("+sme+sme-f64f64")
|
|
27
27
|
#endif
|
|
28
28
|
|
|
29
29
|
NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f32_ssve_(nk_f32_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
30
|
-
svfloat64_t
|
|
31
|
-
svfloat64_t
|
|
30
|
+
svfloat64_t accumulator_even_f64x = svdup_f64(0.0);
|
|
31
|
+
svfloat64_t accumulator_odd_f64x = svdup_f64(0.0);
|
|
32
32
|
nk_size_t const vector_length = svcntw();
|
|
33
33
|
nk_size_t const half_vector_length = svcntd();
|
|
34
34
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
35
|
-
svbool_t
|
|
36
|
-
svfloat32_t values_f32x = svld1_f32(
|
|
35
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(i, count);
|
|
36
|
+
svfloat32_t values_f32x = svld1_f32(predicate_b32x, data + i);
|
|
37
37
|
|
|
38
|
-
svbool_t
|
|
39
|
-
svfloat64_t
|
|
40
|
-
|
|
38
|
+
svbool_t predicate_even_b64x = svwhilelt_b64_u64(i, count);
|
|
39
|
+
svfloat64_t values_even_f64x = svcvt_f64_f32_x(predicate_even_b64x, values_f32x);
|
|
40
|
+
accumulator_even_f64x = svmla_f64_m(predicate_even_b64x, accumulator_even_f64x, values_even_f64x,
|
|
41
|
+
values_even_f64x);
|
|
41
42
|
|
|
42
|
-
svbool_t
|
|
43
|
-
svfloat64_t
|
|
44
|
-
|
|
43
|
+
svbool_t predicate_odd_b64x = svwhilelt_b64_u64(i + half_vector_length, count);
|
|
44
|
+
svfloat64_t values_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, values_f32x);
|
|
45
|
+
accumulator_odd_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_odd_f64x, values_odd_f64x, values_odd_f64x);
|
|
45
46
|
}
|
|
46
|
-
return svaddv_f64(svptrue_b64(),
|
|
47
|
+
return svaddv_f64(svptrue_b64(), accumulator_even_f64x) + svaddv_f64(svptrue_b64(), accumulator_odd_f64x);
|
|
47
48
|
}
|
|
48
49
|
|
|
49
|
-
NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f64_ssve_(nk_f64_t const *data, nk_size_t count)
|
|
50
|
+
NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f64_ssve_(nk_f64_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
50
51
|
svfloat64_t accumulator_f64x = svdup_f64(0.0);
|
|
51
52
|
nk_size_t const vector_length = svcntd();
|
|
52
53
|
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
53
|
-
svbool_t
|
|
54
|
-
svfloat64_t values_f64x = svld1_f64(
|
|
55
|
-
accumulator_f64x =
|
|
54
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(i, count);
|
|
55
|
+
svfloat64_t values_f64x = svld1_f64(predicate_b64x, data + i);
|
|
56
|
+
accumulator_f64x = svmla_f64_m(predicate_b64x, accumulator_f64x, values_f64x, values_f64x);
|
|
56
57
|
}
|
|
57
58
|
return svaddv_f64(svptrue_b64(), accumulator_f64x);
|
|
58
59
|
}
|
|
59
60
|
|
|
60
|
-
NK_PUBLIC svfloat64_t nk_angulars_from_dot_f64x_ssvef64_(svbool_t
|
|
61
|
+
NK_PUBLIC svfloat64_t nk_angulars_from_dot_f64x_ssvef64_(svbool_t predicate_b64x, svfloat64_t dots_f64x,
|
|
61
62
|
svfloat64_t query_norm_sq_f64x,
|
|
62
|
-
svfloat64_t target_norms_sq_f64x)
|
|
63
|
-
svfloat64_t norms_product_f64x = svmul_f64_x(
|
|
64
|
-
svbool_t
|
|
65
|
-
svfloat64_t denom_f64x = svsqrt_f64_x(
|
|
66
|
-
svfloat64_t safe_denom_f64x = svsel_f64(
|
|
67
|
-
svfloat64_t normalized_f64x = svdiv_f64_x(
|
|
68
|
-
svfloat64_t angular_f64x = svsub_f64_x(
|
|
63
|
+
svfloat64_t target_norms_sq_f64x) NK_STREAMING_ {
|
|
64
|
+
svfloat64_t norms_product_f64x = svmul_f64_x(predicate_b64x, query_norm_sq_f64x, target_norms_sq_f64x);
|
|
65
|
+
svbool_t positive_norms_b64x = svcmpgt_n_f64(predicate_b64x, norms_product_f64x, 0.0);
|
|
66
|
+
svfloat64_t denom_f64x = svsqrt_f64_x(positive_norms_b64x, norms_product_f64x);
|
|
67
|
+
svfloat64_t safe_denom_f64x = svsel_f64(positive_norms_b64x, denom_f64x, svdup_n_f64(1.0));
|
|
68
|
+
svfloat64_t normalized_f64x = svdiv_f64_x(predicate_b64x, dots_f64x, safe_denom_f64x);
|
|
69
|
+
svfloat64_t angular_f64x = svsub_f64_x(predicate_b64x, svdup_n_f64(1.0), normalized_f64x);
|
|
69
70
|
angular_f64x = svsel_f64(
|
|
70
|
-
|
|
71
|
-
svsel_f64(svcmpeq_n_f64(
|
|
72
|
-
return svmax_f64_x(
|
|
71
|
+
positive_norms_b64x, angular_f64x,
|
|
72
|
+
svsel_f64(svcmpeq_n_f64(predicate_b64x, dots_f64x, 0.0), svdup_n_f64(0.0), svdup_n_f64(1.0)));
|
|
73
|
+
return svmax_f64_x(predicate_b64x, angular_f64x, svdup_n_f64(0.0));
|
|
73
74
|
}
|
|
74
75
|
|
|
75
|
-
NK_PUBLIC svfloat64_t nk_euclideans_from_dot_f64x_ssvef64_(svbool_t
|
|
76
|
+
NK_PUBLIC svfloat64_t nk_euclideans_from_dot_f64x_ssvef64_(svbool_t predicate_b64x, svfloat64_t dots_f64x,
|
|
76
77
|
svfloat64_t query_norm_sq_f64x,
|
|
77
|
-
svfloat64_t target_norms_sq_f64x)
|
|
78
|
-
svfloat64_t sum_sq_f64x = svadd_f64_x(
|
|
79
|
-
svfloat64_t dist_sq_f64x = svsub_f64_x(
|
|
80
|
-
svmul_f64_x(
|
|
81
|
-
dist_sq_f64x = svmax_f64_x(
|
|
82
|
-
return svsqrt_f64_x(
|
|
78
|
+
svfloat64_t target_norms_sq_f64x) NK_STREAMING_ {
|
|
79
|
+
svfloat64_t sum_sq_f64x = svadd_f64_x(predicate_b64x, query_norm_sq_f64x, target_norms_sq_f64x);
|
|
80
|
+
svfloat64_t dist_sq_f64x = svsub_f64_x(predicate_b64x, sum_sq_f64x,
|
|
81
|
+
svmul_f64_x(predicate_b64x, svdup_n_f64(2.0), dots_f64x));
|
|
82
|
+
dist_sq_f64x = svmax_f64_x(predicate_b64x, dist_sq_f64x, svdup_n_f64(0.0));
|
|
83
|
+
return svsqrt_f64_x(predicate_b64x, dist_sq_f64x);
|
|
83
84
|
}
|
|
84
85
|
|
|
85
|
-
#pragma region
|
|
86
|
+
#pragma region F32 Packed Angular
|
|
86
87
|
|
|
87
88
|
__arm_locally_streaming static void nk_angulars_packed_f32_smef64_finalize_streaming_( //
|
|
88
89
|
nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
|
|
@@ -99,11 +100,11 @@ __arm_locally_streaming static void nk_angulars_packed_f32_smef64_finalize_strea
|
|
|
99
100
|
svfloat64_t query_norm_sq_f64x = svdup_n_f64(query_norm_sq_f64);
|
|
100
101
|
|
|
101
102
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntd()) {
|
|
102
|
-
svbool_t
|
|
103
|
-
svfloat64_t dots_f64x = svld1_f64(
|
|
104
|
-
svfloat64_t target_norms_sq_f64x = svld1_f64(
|
|
105
|
-
svst1_f64(
|
|
106
|
-
nk_angulars_from_dot_f64x_ssvef64_(
|
|
103
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, columns);
|
|
104
|
+
svfloat64_t dots_f64x = svld1_f64(predicate_b64x, c_row + col_index);
|
|
105
|
+
svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, b_norms + col_index);
|
|
106
|
+
svst1_f64(predicate_b64x, c_row + col_index,
|
|
107
|
+
nk_angulars_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
|
|
107
108
|
target_norms_sq_f64x));
|
|
108
109
|
}
|
|
109
110
|
}
|
|
@@ -122,7 +123,8 @@ NK_PUBLIC void nk_angulars_packed_f32_smef64( //
|
|
|
122
123
|
c_stride_elements);
|
|
123
124
|
}
|
|
124
125
|
|
|
125
|
-
#pragma
|
|
126
|
+
#pragma endregion F32 Packed Angular
|
|
127
|
+
#pragma region F32 Packed Euclidean
|
|
126
128
|
|
|
127
129
|
__arm_locally_streaming static void nk_euclideans_packed_f32_smef64_finalize_streaming_( //
|
|
128
130
|
nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
|
|
@@ -139,11 +141,11 @@ __arm_locally_streaming static void nk_euclideans_packed_f32_smef64_finalize_str
|
|
|
139
141
|
svfloat64_t query_norm_sq_f64x = svdup_n_f64(query_norm_sq_f64);
|
|
140
142
|
|
|
141
143
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntd()) {
|
|
142
|
-
svbool_t
|
|
143
|
-
svfloat64_t dots_f64x = svld1_f64(
|
|
144
|
-
svfloat64_t target_norms_sq_f64x = svld1_f64(
|
|
145
|
-
svst1_f64(
|
|
146
|
-
nk_euclideans_from_dot_f64x_ssvef64_(
|
|
144
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, columns);
|
|
145
|
+
svfloat64_t dots_f64x = svld1_f64(predicate_b64x, c_row + col_index);
|
|
146
|
+
svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, b_norms + col_index);
|
|
147
|
+
svst1_f64(predicate_b64x, c_row + col_index,
|
|
148
|
+
nk_euclideans_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
|
|
147
149
|
target_norms_sq_f64x));
|
|
148
150
|
}
|
|
149
151
|
}
|
|
@@ -162,10 +164,11 @@ NK_PUBLIC void nk_euclideans_packed_f32_smef64( //
|
|
|
162
164
|
c_stride_elements);
|
|
163
165
|
}
|
|
164
166
|
|
|
165
|
-
#pragma
|
|
167
|
+
#pragma endregion F32 Packed Euclidean
|
|
168
|
+
#pragma region F32 Symmetric Angular
|
|
166
169
|
|
|
167
|
-
__arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_streaming_(
|
|
168
|
-
nk_f32_t const *vectors, nk_size_t
|
|
170
|
+
__arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_streaming_( //
|
|
171
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
169
172
|
nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
170
173
|
// Phase 1: cache row norms on diagonal
|
|
171
174
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -175,8 +178,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_st
|
|
|
175
178
|
}
|
|
176
179
|
// Phase 2: column-chunked post-processing
|
|
177
180
|
nk_f64_t column_norms[256];
|
|
178
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
179
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
181
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
182
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
180
183
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col) {
|
|
181
184
|
nk_f32_t const *col_vector = vectors + col * stride_elements;
|
|
182
185
|
column_norms[col - chunk_start] = nk_dots_reduce_sumsq_f32_ssve_(col_vector, depth);
|
|
@@ -187,11 +190,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_st
|
|
|
187
190
|
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
188
191
|
svfloat64_t query_norm_sq_f64x = svdup_n_f64(result_row[row_index]);
|
|
189
192
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntd()) {
|
|
190
|
-
svbool_t
|
|
191
|
-
svfloat64_t dots_f64x = svld1_f64(
|
|
192
|
-
svfloat64_t target_norms_sq_f64x = svld1_f64(
|
|
193
|
-
svst1_f64(
|
|
194
|
-
nk_angulars_from_dot_f64x_ssvef64_(
|
|
193
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, chunk_end);
|
|
194
|
+
svfloat64_t dots_f64x = svld1_f64(predicate_b64x, result_row + col_index);
|
|
195
|
+
svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, column_norms + (col_index - chunk_start));
|
|
196
|
+
svst1_f64(predicate_b64x, result_row + col_index,
|
|
197
|
+
nk_angulars_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
|
|
195
198
|
target_norms_sq_f64x));
|
|
196
199
|
}
|
|
197
200
|
}
|
|
@@ -201,23 +204,24 @@ __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_st
|
|
|
201
204
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
202
205
|
}
|
|
203
206
|
|
|
204
|
-
NK_PUBLIC void nk_angulars_symmetric_f32_smef64(
|
|
205
|
-
nk_f32_t const *vectors, nk_size_t
|
|
206
|
-
nk_f64_t *result, nk_size_t
|
|
207
|
+
NK_PUBLIC void nk_angulars_symmetric_f32_smef64( //
|
|
208
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
209
|
+
nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
207
210
|
|
|
208
|
-
nk_size_t const stride_elements =
|
|
209
|
-
nk_size_t const result_stride_elements =
|
|
211
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
212
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
210
213
|
|
|
211
|
-
nk_dots_symmetric_f32_smef64_streaming_(vectors,
|
|
212
|
-
row_start, row_count);
|
|
213
|
-
nk_angulars_symmetric_f32_smef64_finalize_streaming_(vectors,
|
|
214
|
+
nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
215
|
+
result_stride_elements, row_start, row_count);
|
|
216
|
+
nk_angulars_symmetric_f32_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
214
217
|
result_stride_elements, row_start, row_count);
|
|
215
218
|
}
|
|
216
219
|
|
|
217
|
-
#pragma
|
|
220
|
+
#pragma endregion F32 Symmetric Angular
|
|
221
|
+
#pragma region F32 Symmetric Euclidean
|
|
218
222
|
|
|
219
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_streaming_(
|
|
220
|
-
nk_f32_t const *vectors, nk_size_t
|
|
223
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_streaming_( //
|
|
224
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
221
225
|
nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
222
226
|
// Phase 1: cache row norms on diagonal
|
|
223
227
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -227,8 +231,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_
|
|
|
227
231
|
}
|
|
228
232
|
// Phase 2: column-chunked post-processing
|
|
229
233
|
nk_f64_t column_norms[256];
|
|
230
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
231
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
234
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
235
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
232
236
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col) {
|
|
233
237
|
nk_f32_t const *col_vector = vectors + col * stride_elements;
|
|
234
238
|
column_norms[col - chunk_start] = nk_dots_reduce_sumsq_f32_ssve_(col_vector, depth);
|
|
@@ -239,11 +243,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_
|
|
|
239
243
|
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
240
244
|
svfloat64_t query_norm_sq_f64x = svdup_n_f64(result_row[row_index]);
|
|
241
245
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntd()) {
|
|
242
|
-
svbool_t
|
|
243
|
-
svfloat64_t dots_f64x = svld1_f64(
|
|
244
|
-
svfloat64_t target_norms_sq_f64x = svld1_f64(
|
|
245
|
-
svst1_f64(
|
|
246
|
-
nk_euclideans_from_dot_f64x_ssvef64_(
|
|
246
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, chunk_end);
|
|
247
|
+
svfloat64_t dots_f64x = svld1_f64(predicate_b64x, result_row + col_index);
|
|
248
|
+
svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, column_norms + (col_index - chunk_start));
|
|
249
|
+
svst1_f64(predicate_b64x, result_row + col_index,
|
|
250
|
+
nk_euclideans_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
|
|
247
251
|
target_norms_sq_f64x));
|
|
248
252
|
}
|
|
249
253
|
}
|
|
@@ -253,20 +257,21 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_
|
|
|
253
257
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
254
258
|
}
|
|
255
259
|
|
|
256
|
-
NK_PUBLIC void nk_euclideans_symmetric_f32_smef64(
|
|
257
|
-
nk_f32_t const *vectors, nk_size_t
|
|
258
|
-
nk_f64_t *result, nk_size_t
|
|
260
|
+
NK_PUBLIC void nk_euclideans_symmetric_f32_smef64( //
|
|
261
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
262
|
+
nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
259
263
|
|
|
260
|
-
nk_size_t const stride_elements =
|
|
261
|
-
nk_size_t const result_stride_elements =
|
|
264
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
265
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
262
266
|
|
|
263
|
-
nk_dots_symmetric_f32_smef64_streaming_(vectors,
|
|
264
|
-
row_start, row_count);
|
|
265
|
-
nk_euclideans_symmetric_f32_smef64_finalize_streaming_(vectors,
|
|
267
|
+
nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
268
|
+
result_stride_elements, row_start, row_count);
|
|
269
|
+
nk_euclideans_symmetric_f32_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
266
270
|
result_stride_elements, row_start, row_count);
|
|
267
271
|
}
|
|
268
272
|
|
|
269
|
-
#pragma
|
|
273
|
+
#pragma endregion F32 Symmetric Euclidean
|
|
274
|
+
#pragma region F64 Packed Angular
|
|
270
275
|
|
|
271
276
|
__arm_locally_streaming static void nk_angulars_packed_f64_smef64_finalize_streaming_( //
|
|
272
277
|
nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
|
|
@@ -283,11 +288,11 @@ __arm_locally_streaming static void nk_angulars_packed_f64_smef64_finalize_strea
|
|
|
283
288
|
svfloat64_t query_norm_sq_f64x = svdup_n_f64(query_norm_sq_f64);
|
|
284
289
|
|
|
285
290
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntd()) {
|
|
286
|
-
svbool_t
|
|
287
|
-
svfloat64_t dots_f64x = svld1_f64(
|
|
288
|
-
svfloat64_t target_norms_sq_f64x = svld1_f64(
|
|
289
|
-
svst1_f64(
|
|
290
|
-
nk_angulars_from_dot_f64x_ssvef64_(
|
|
291
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, columns);
|
|
292
|
+
svfloat64_t dots_f64x = svld1_f64(predicate_b64x, c_row + col_index);
|
|
293
|
+
svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, b_norms + col_index);
|
|
294
|
+
svst1_f64(predicate_b64x, c_row + col_index,
|
|
295
|
+
nk_angulars_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
|
|
291
296
|
target_norms_sq_f64x));
|
|
292
297
|
}
|
|
293
298
|
}
|
|
@@ -306,7 +311,8 @@ NK_PUBLIC void nk_angulars_packed_f64_smef64( //
|
|
|
306
311
|
c_stride_elements);
|
|
307
312
|
}
|
|
308
313
|
|
|
309
|
-
#pragma
|
|
314
|
+
#pragma endregion F64 Packed Angular
|
|
315
|
+
#pragma region F64 Packed Euclidean
|
|
310
316
|
|
|
311
317
|
__arm_locally_streaming static void nk_euclideans_packed_f64_smef64_finalize_streaming_( //
|
|
312
318
|
nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
|
|
@@ -323,11 +329,11 @@ __arm_locally_streaming static void nk_euclideans_packed_f64_smef64_finalize_str
|
|
|
323
329
|
svfloat64_t query_norm_sq_f64x = svdup_n_f64(query_norm_sq_f64);
|
|
324
330
|
|
|
325
331
|
for (nk_size_t col_index = 0; col_index < columns; col_index += svcntd()) {
|
|
326
|
-
svbool_t
|
|
327
|
-
svfloat64_t dots_f64x = svld1_f64(
|
|
328
|
-
svfloat64_t target_norms_sq_f64x = svld1_f64(
|
|
329
|
-
svst1_f64(
|
|
330
|
-
nk_euclideans_from_dot_f64x_ssvef64_(
|
|
332
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, columns);
|
|
333
|
+
svfloat64_t dots_f64x = svld1_f64(predicate_b64x, c_row + col_index);
|
|
334
|
+
svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, b_norms + col_index);
|
|
335
|
+
svst1_f64(predicate_b64x, c_row + col_index,
|
|
336
|
+
nk_euclideans_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
|
|
331
337
|
target_norms_sq_f64x));
|
|
332
338
|
}
|
|
333
339
|
}
|
|
@@ -346,10 +352,11 @@ NK_PUBLIC void nk_euclideans_packed_f64_smef64( //
|
|
|
346
352
|
c_stride_elements);
|
|
347
353
|
}
|
|
348
354
|
|
|
349
|
-
#pragma
|
|
355
|
+
#pragma endregion F64 Packed Euclidean
|
|
356
|
+
#pragma region F64 Symmetric Angular
|
|
350
357
|
|
|
351
|
-
__arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_streaming_(
|
|
352
|
-
nk_f64_t const *vectors, nk_size_t
|
|
358
|
+
__arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_streaming_( //
|
|
359
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
353
360
|
nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
354
361
|
// Phase 1: cache row norms on diagonal
|
|
355
362
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -359,8 +366,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_st
|
|
|
359
366
|
}
|
|
360
367
|
// Phase 2: column-chunked post-processing
|
|
361
368
|
nk_f64_t column_norms[256];
|
|
362
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
363
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
369
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
370
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
364
371
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col) {
|
|
365
372
|
nk_f64_t const *col_vector = vectors + col * stride_elements;
|
|
366
373
|
column_norms[col - chunk_start] = nk_dots_reduce_sumsq_f64_ssve_(col_vector, depth);
|
|
@@ -371,11 +378,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_st
|
|
|
371
378
|
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
372
379
|
svfloat64_t query_norm_sq_f64x = svdup_n_f64(result_row[row_index]);
|
|
373
380
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntd()) {
|
|
374
|
-
svbool_t
|
|
375
|
-
svfloat64_t dots_f64x = svld1_f64(
|
|
376
|
-
svfloat64_t target_norms_sq_f64x = svld1_f64(
|
|
377
|
-
svst1_f64(
|
|
378
|
-
nk_angulars_from_dot_f64x_ssvef64_(
|
|
381
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, chunk_end);
|
|
382
|
+
svfloat64_t dots_f64x = svld1_f64(predicate_b64x, result_row + col_index);
|
|
383
|
+
svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, column_norms + (col_index - chunk_start));
|
|
384
|
+
svst1_f64(predicate_b64x, result_row + col_index,
|
|
385
|
+
nk_angulars_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
|
|
379
386
|
target_norms_sq_f64x));
|
|
380
387
|
}
|
|
381
388
|
}
|
|
@@ -385,23 +392,24 @@ __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_st
|
|
|
385
392
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
386
393
|
}
|
|
387
394
|
|
|
388
|
-
NK_PUBLIC void nk_angulars_symmetric_f64_smef64(
|
|
389
|
-
nk_f64_t const *vectors, nk_size_t
|
|
390
|
-
nk_f64_t *result, nk_size_t
|
|
395
|
+
NK_PUBLIC void nk_angulars_symmetric_f64_smef64( //
|
|
396
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
397
|
+
nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
391
398
|
|
|
392
|
-
nk_size_t const stride_elements =
|
|
393
|
-
nk_size_t const result_stride_elements =
|
|
399
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
|
|
400
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
394
401
|
|
|
395
|
-
nk_dots_symmetric_f64_smef64_streaming_(vectors,
|
|
396
|
-
row_start, row_count);
|
|
397
|
-
nk_angulars_symmetric_f64_smef64_finalize_streaming_(vectors,
|
|
402
|
+
nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
403
|
+
result_stride_elements, row_start, row_count);
|
|
404
|
+
nk_angulars_symmetric_f64_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
398
405
|
result_stride_elements, row_start, row_count);
|
|
399
406
|
}
|
|
400
407
|
|
|
401
|
-
#pragma
|
|
408
|
+
#pragma endregion F64 Symmetric Angular
|
|
409
|
+
#pragma region F64 Symmetric Euclidean
|
|
402
410
|
|
|
403
|
-
__arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_streaming_(
|
|
404
|
-
nk_f64_t const *vectors, nk_size_t
|
|
411
|
+
__arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_streaming_( //
|
|
412
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
|
|
405
413
|
nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
406
414
|
// Phase 1: cache row norms on diagonal
|
|
407
415
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -411,8 +419,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_
|
|
|
411
419
|
}
|
|
412
420
|
// Phase 2: column-chunked post-processing
|
|
413
421
|
nk_f64_t column_norms[256];
|
|
414
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
415
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
422
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
423
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
416
424
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col) {
|
|
417
425
|
nk_f64_t const *col_vector = vectors + col * stride_elements;
|
|
418
426
|
column_norms[col - chunk_start] = nk_dots_reduce_sumsq_f64_ssve_(col_vector, depth);
|
|
@@ -423,11 +431,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_
|
|
|
423
431
|
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
424
432
|
svfloat64_t query_norm_sq_f64x = svdup_n_f64(result_row[row_index]);
|
|
425
433
|
for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntd()) {
|
|
426
|
-
svbool_t
|
|
427
|
-
svfloat64_t dots_f64x = svld1_f64(
|
|
428
|
-
svfloat64_t target_norms_sq_f64x = svld1_f64(
|
|
429
|
-
svst1_f64(
|
|
430
|
-
nk_euclideans_from_dot_f64x_ssvef64_(
|
|
434
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, chunk_end);
|
|
435
|
+
svfloat64_t dots_f64x = svld1_f64(predicate_b64x, result_row + col_index);
|
|
436
|
+
svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, column_norms + (col_index - chunk_start));
|
|
437
|
+
svst1_f64(predicate_b64x, result_row + col_index,
|
|
438
|
+
nk_euclideans_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
|
|
431
439
|
target_norms_sq_f64x));
|
|
432
440
|
}
|
|
433
441
|
}
|
|
@@ -437,19 +445,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_
|
|
|
437
445
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
438
446
|
}
|
|
439
447
|
|
|
440
|
-
NK_PUBLIC void nk_euclideans_symmetric_f64_smef64(
|
|
441
|
-
nk_f64_t const *vectors, nk_size_t
|
|
442
|
-
nk_f64_t *result, nk_size_t
|
|
448
|
+
NK_PUBLIC void nk_euclideans_symmetric_f64_smef64( //
|
|
449
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
450
|
+
nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
443
451
|
|
|
444
|
-
nk_size_t const stride_elements =
|
|
445
|
-
nk_size_t const result_stride_elements =
|
|
452
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
|
|
453
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
446
454
|
|
|
447
|
-
nk_dots_symmetric_f64_smef64_streaming_(vectors,
|
|
448
|
-
row_start, row_count);
|
|
449
|
-
nk_euclideans_symmetric_f64_smef64_finalize_streaming_(vectors,
|
|
455
|
+
nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
456
|
+
result_stride_elements, row_start, row_count);
|
|
457
|
+
nk_euclideans_symmetric_f64_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
450
458
|
result_stride_elements, row_start, row_count);
|
|
451
459
|
}
|
|
452
460
|
|
|
461
|
+
#pragma endregion F64 Symmetric Euclidean
|
|
453
462
|
#if defined(__clang__)
|
|
454
463
|
#pragma clang attribute pop
|
|
455
464
|
#elif defined(__GNUC__)
|
|
@@ -144,7 +144,7 @@ nk_define_cross_normalized_symmetric_(euclidean, e5m2, v128relaxed, e5m2, f32, /
|
|
|
144
144
|
nk_load_b128_v128relaxed_, nk_partial_load_b32x4_serial_,
|
|
145
145
|
nk_store_b128_v128relaxed_, nk_partial_store_b32x4_serial_, 1)
|
|
146
146
|
|
|
147
|
-
nk_define_cross_normalized_packed_(angular, bf16, v128relaxed, bf16,
|
|
147
|
+
nk_define_cross_normalized_packed_(angular, bf16, v128relaxed, bf16, bf16, f32, /*norm_value_type=*/f32, f32,
|
|
148
148
|
nk_b128_vec_t, nk_dots_packed_bf16_v128relaxed,
|
|
149
149
|
nk_angular_through_f32_from_dot_v128relaxed_, nk_dots_reduce_sumsq_bf16_,
|
|
150
150
|
nk_load_b128_v128relaxed_, nk_partial_load_b32x4_serial_, nk_store_b128_v128relaxed_,
|
|
@@ -154,7 +154,7 @@ nk_define_cross_normalized_symmetric_(angular, bf16, v128relaxed, bf16, f32, /*n
|
|
|
154
154
|
nk_angular_through_f32_from_dot_v128relaxed_, nk_dots_reduce_sumsq_bf16_,
|
|
155
155
|
nk_load_b128_v128relaxed_, nk_partial_load_b32x4_serial_,
|
|
156
156
|
nk_store_b128_v128relaxed_, nk_partial_store_b32x4_serial_, 1)
|
|
157
|
-
nk_define_cross_normalized_packed_(euclidean, bf16, v128relaxed, bf16,
|
|
157
|
+
nk_define_cross_normalized_packed_(euclidean, bf16, v128relaxed, bf16, bf16, f32, /*norm_value_type=*/f32, f32,
|
|
158
158
|
nk_b128_vec_t, nk_dots_packed_bf16_v128relaxed,
|
|
159
159
|
nk_euclidean_through_f32_from_dot_v128relaxed_, nk_dots_reduce_sumsq_bf16_,
|
|
160
160
|
nk_load_b128_v128relaxed_, nk_partial_load_b32x4_serial_, nk_store_b128_v128relaxed_,
|