numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -27,7 +27,7 @@ extern "C" {
|
|
|
27
27
|
#pragma GCC target("arch=+v")
|
|
28
28
|
#endif
|
|
29
29
|
|
|
30
|
-
#pragma region
|
|
30
|
+
#pragma region F32 Floats
|
|
31
31
|
|
|
32
32
|
NK_INTERNAL void nk_angulars_packed_f32_rvv_finalize_(nk_f32_t const *a, void const *b_packed, nk_f64_t *c,
|
|
33
33
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -111,8 +111,8 @@ NK_PUBLIC void nk_euclideans_packed_f32_rvv( //
|
|
|
111
111
|
nk_euclideans_packed_f32_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
112
112
|
}
|
|
113
113
|
|
|
114
|
-
NK_INTERNAL void nk_angulars_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors, nk_size_t
|
|
115
|
-
nk_size_t stride_elements, nk_f64_t *result,
|
|
114
|
+
NK_INTERNAL void nk_angulars_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors, nk_size_t vectors_count,
|
|
115
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
116
116
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
117
117
|
nk_size_t row_count) {
|
|
118
118
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -120,8 +120,8 @@ NK_INTERNAL void nk_angulars_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors
|
|
|
120
120
|
result_row[row_index] = nk_dots_reduce_sumsq_f32_(vectors + row_index * stride_elements, depth);
|
|
121
121
|
}
|
|
122
122
|
nk_f64_t norms_cache[256];
|
|
123
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
124
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
123
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
124
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
125
125
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
126
126
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f32_(vectors + col * stride_elements, depth);
|
|
127
127
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -153,17 +153,18 @@ NK_INTERNAL void nk_angulars_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors
|
|
|
153
153
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
154
154
|
}
|
|
155
155
|
|
|
156
|
-
NK_PUBLIC void nk_angulars_symmetric_f32_rvv(
|
|
157
|
-
nk_f32_t const *vectors, nk_size_t
|
|
158
|
-
nk_f64_t *result, nk_size_t
|
|
159
|
-
nk_size_t const stride_elements =
|
|
160
|
-
nk_size_t const result_stride_elements =
|
|
161
|
-
nk_dots_symmetric_f32_rvv(vectors,
|
|
162
|
-
|
|
163
|
-
|
|
156
|
+
NK_PUBLIC void nk_angulars_symmetric_f32_rvv( //
|
|
157
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
158
|
+
nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
159
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
160
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
161
|
+
nk_dots_symmetric_f32_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes, row_start,
|
|
162
|
+
row_count);
|
|
163
|
+
nk_angulars_symmetric_f32_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
164
|
+
result_stride_elements, row_start, row_count);
|
|
164
165
|
}
|
|
165
166
|
|
|
166
|
-
NK_INTERNAL void nk_euclideans_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors, nk_size_t
|
|
167
|
+
NK_INTERNAL void nk_euclideans_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors, nk_size_t vectors_count,
|
|
167
168
|
nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
168
169
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
169
170
|
nk_size_t row_count) {
|
|
@@ -172,8 +173,8 @@ NK_INTERNAL void nk_euclideans_symmetric_f32_rvv_finalize_(nk_f32_t const *vecto
|
|
|
172
173
|
result_row[row_index] = nk_dots_reduce_sumsq_f32_(vectors + row_index * stride_elements, depth);
|
|
173
174
|
}
|
|
174
175
|
nk_f64_t norms_cache[256];
|
|
175
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
176
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
176
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
177
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
177
178
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
178
179
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f32_(vectors + col * stride_elements, depth);
|
|
179
180
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -204,19 +205,20 @@ NK_INTERNAL void nk_euclideans_symmetric_f32_rvv_finalize_(nk_f32_t const *vecto
|
|
|
204
205
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
205
206
|
}
|
|
206
207
|
|
|
207
|
-
NK_PUBLIC void nk_euclideans_symmetric_f32_rvv(
|
|
208
|
-
nk_f32_t const *vectors, nk_size_t
|
|
209
|
-
nk_f64_t *result, nk_size_t
|
|
210
|
-
nk_size_t const stride_elements =
|
|
211
|
-
nk_size_t const result_stride_elements =
|
|
212
|
-
nk_dots_symmetric_f32_rvv(vectors,
|
|
213
|
-
|
|
208
|
+
NK_PUBLIC void nk_euclideans_symmetric_f32_rvv( //
|
|
209
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
210
|
+
nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
211
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
212
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
213
|
+
nk_dots_symmetric_f32_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes, row_start,
|
|
214
|
+
row_count);
|
|
215
|
+
nk_euclideans_symmetric_f32_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
214
216
|
result_stride_elements, row_start, row_count);
|
|
215
217
|
}
|
|
216
218
|
|
|
217
|
-
#pragma endregion
|
|
219
|
+
#pragma endregion F32 Floats
|
|
218
220
|
|
|
219
|
-
#pragma region
|
|
221
|
+
#pragma region F64 Floats
|
|
220
222
|
|
|
221
223
|
NK_INTERNAL void nk_angulars_packed_f64_rvv_finalize_(nk_f64_t const *a, void const *b_packed, nk_f64_t *c,
|
|
222
224
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -300,8 +302,8 @@ NK_PUBLIC void nk_euclideans_packed_f64_rvv( //
|
|
|
300
302
|
nk_euclideans_packed_f64_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
301
303
|
}
|
|
302
304
|
|
|
303
|
-
NK_INTERNAL void nk_angulars_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors, nk_size_t
|
|
304
|
-
nk_size_t stride_elements, nk_f64_t *result,
|
|
305
|
+
NK_INTERNAL void nk_angulars_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors, nk_size_t vectors_count,
|
|
306
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
305
307
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
306
308
|
nk_size_t row_count) {
|
|
307
309
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -309,8 +311,8 @@ NK_INTERNAL void nk_angulars_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors
|
|
|
309
311
|
result_row[row_index] = nk_dots_reduce_sumsq_f64_(vectors + row_index * stride_elements, depth);
|
|
310
312
|
}
|
|
311
313
|
nk_f64_t norms_cache[256];
|
|
312
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
313
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
314
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
315
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
314
316
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
315
317
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f64_(vectors + col * stride_elements, depth);
|
|
316
318
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -342,17 +344,18 @@ NK_INTERNAL void nk_angulars_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors
|
|
|
342
344
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
343
345
|
}
|
|
344
346
|
|
|
345
|
-
NK_PUBLIC void nk_angulars_symmetric_f64_rvv(
|
|
346
|
-
nk_f64_t const *vectors, nk_size_t
|
|
347
|
-
nk_f64_t *result, nk_size_t
|
|
348
|
-
nk_size_t const stride_elements =
|
|
349
|
-
nk_size_t const result_stride_elements =
|
|
350
|
-
nk_dots_symmetric_f64_rvv(vectors,
|
|
351
|
-
|
|
352
|
-
|
|
347
|
+
NK_PUBLIC void nk_angulars_symmetric_f64_rvv( //
|
|
348
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
349
|
+
nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
350
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
|
|
351
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
352
|
+
nk_dots_symmetric_f64_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes, row_start,
|
|
353
|
+
row_count);
|
|
354
|
+
nk_angulars_symmetric_f64_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
355
|
+
result_stride_elements, row_start, row_count);
|
|
353
356
|
}
|
|
354
357
|
|
|
355
|
-
NK_INTERNAL void nk_euclideans_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors, nk_size_t
|
|
358
|
+
NK_INTERNAL void nk_euclideans_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors, nk_size_t vectors_count,
|
|
356
359
|
nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
357
360
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
358
361
|
nk_size_t row_count) {
|
|
@@ -361,8 +364,8 @@ NK_INTERNAL void nk_euclideans_symmetric_f64_rvv_finalize_(nk_f64_t const *vecto
|
|
|
361
364
|
result_row[row_index] = nk_dots_reduce_sumsq_f64_(vectors + row_index * stride_elements, depth);
|
|
362
365
|
}
|
|
363
366
|
nk_f64_t norms_cache[256];
|
|
364
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
365
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
367
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
368
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
366
369
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
367
370
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f64_(vectors + col * stride_elements, depth);
|
|
368
371
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -393,19 +396,20 @@ NK_INTERNAL void nk_euclideans_symmetric_f64_rvv_finalize_(nk_f64_t const *vecto
|
|
|
393
396
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
394
397
|
}
|
|
395
398
|
|
|
396
|
-
NK_PUBLIC void nk_euclideans_symmetric_f64_rvv(
|
|
397
|
-
nk_f64_t const *vectors, nk_size_t
|
|
398
|
-
nk_f64_t *result, nk_size_t
|
|
399
|
-
nk_size_t const stride_elements =
|
|
400
|
-
nk_size_t const result_stride_elements =
|
|
401
|
-
nk_dots_symmetric_f64_rvv(vectors,
|
|
402
|
-
|
|
399
|
+
NK_PUBLIC void nk_euclideans_symmetric_f64_rvv( //
|
|
400
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
401
|
+
nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
402
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
|
|
403
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
404
|
+
nk_dots_symmetric_f64_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes, row_start,
|
|
405
|
+
row_count);
|
|
406
|
+
nk_euclideans_symmetric_f64_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
403
407
|
result_stride_elements, row_start, row_count);
|
|
404
408
|
}
|
|
405
409
|
|
|
406
|
-
#pragma endregion
|
|
410
|
+
#pragma endregion F64 Floats
|
|
407
411
|
|
|
408
|
-
#pragma region
|
|
412
|
+
#pragma region F16 Floats
|
|
409
413
|
|
|
410
414
|
NK_INTERNAL void nk_angulars_packed_f16_rvv_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
411
415
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -489,8 +493,8 @@ NK_PUBLIC void nk_euclideans_packed_f16_rvv( //
|
|
|
489
493
|
nk_euclideans_packed_f16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
490
494
|
}
|
|
491
495
|
|
|
492
|
-
NK_INTERNAL void nk_angulars_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors, nk_size_t
|
|
493
|
-
nk_size_t stride_elements, nk_f32_t *result,
|
|
496
|
+
NK_INTERNAL void nk_angulars_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
|
|
497
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
494
498
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
495
499
|
nk_size_t row_count) {
|
|
496
500
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -498,8 +502,8 @@ NK_INTERNAL void nk_angulars_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors
|
|
|
498
502
|
result_row[row_index] = nk_dots_reduce_sumsq_f16_(vectors + row_index * stride_elements, depth);
|
|
499
503
|
}
|
|
500
504
|
nk_f32_t norms_cache[256];
|
|
501
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
502
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
505
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
506
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
503
507
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
504
508
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
|
|
505
509
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -531,17 +535,18 @@ NK_INTERNAL void nk_angulars_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors
|
|
|
531
535
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
532
536
|
}
|
|
533
537
|
|
|
534
|
-
NK_PUBLIC void nk_angulars_symmetric_f16_rvv(
|
|
535
|
-
nk_f16_t const *vectors, nk_size_t
|
|
536
|
-
nk_f32_t *result, nk_size_t
|
|
537
|
-
nk_size_t const stride_elements =
|
|
538
|
-
nk_size_t const result_stride_elements =
|
|
539
|
-
nk_dots_symmetric_f16_rvv(vectors,
|
|
540
|
-
|
|
541
|
-
|
|
538
|
+
NK_PUBLIC void nk_angulars_symmetric_f16_rvv( //
|
|
539
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
540
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
541
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
542
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
543
|
+
nk_dots_symmetric_f16_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes, row_start,
|
|
544
|
+
row_count);
|
|
545
|
+
nk_angulars_symmetric_f16_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
546
|
+
result_stride_elements, row_start, row_count);
|
|
542
547
|
}
|
|
543
548
|
|
|
544
|
-
NK_INTERNAL void nk_euclideans_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors, nk_size_t
|
|
549
|
+
NK_INTERNAL void nk_euclideans_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
|
|
545
550
|
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
546
551
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
547
552
|
nk_size_t row_count) {
|
|
@@ -550,8 +555,8 @@ NK_INTERNAL void nk_euclideans_symmetric_f16_rvv_finalize_(nk_f16_t const *vecto
|
|
|
550
555
|
result_row[row_index] = nk_dots_reduce_sumsq_f16_(vectors + row_index * stride_elements, depth);
|
|
551
556
|
}
|
|
552
557
|
nk_f32_t norms_cache[256];
|
|
553
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
554
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
558
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
559
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
555
560
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
556
561
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
|
|
557
562
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -582,19 +587,20 @@ NK_INTERNAL void nk_euclideans_symmetric_f16_rvv_finalize_(nk_f16_t const *vecto
|
|
|
582
587
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
583
588
|
}
|
|
584
589
|
|
|
585
|
-
NK_PUBLIC void nk_euclideans_symmetric_f16_rvv(
|
|
586
|
-
nk_f16_t const *vectors, nk_size_t
|
|
587
|
-
nk_f32_t *result, nk_size_t
|
|
588
|
-
nk_size_t const stride_elements =
|
|
589
|
-
nk_size_t const result_stride_elements =
|
|
590
|
-
nk_dots_symmetric_f16_rvv(vectors,
|
|
591
|
-
|
|
590
|
+
NK_PUBLIC void nk_euclideans_symmetric_f16_rvv( //
|
|
591
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
592
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
593
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
594
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
595
|
+
nk_dots_symmetric_f16_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes, row_start,
|
|
596
|
+
row_count);
|
|
597
|
+
nk_euclideans_symmetric_f16_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
592
598
|
result_stride_elements, row_start, row_count);
|
|
593
599
|
}
|
|
594
600
|
|
|
595
|
-
#pragma endregion
|
|
601
|
+
#pragma endregion F16 Floats
|
|
596
602
|
|
|
597
|
-
#pragma region
|
|
603
|
+
#pragma region BF16 Floats
|
|
598
604
|
|
|
599
605
|
NK_INTERNAL void nk_angulars_packed_bf16_rvv_finalize_(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
600
606
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -678,7 +684,7 @@ NK_PUBLIC void nk_euclideans_packed_bf16_rvv( //
|
|
|
678
684
|
nk_euclideans_packed_bf16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
679
685
|
}
|
|
680
686
|
|
|
681
|
-
NK_INTERNAL void nk_angulars_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vectors, nk_size_t
|
|
687
|
+
NK_INTERNAL void nk_angulars_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vectors, nk_size_t vectors_count,
|
|
682
688
|
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
683
689
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
684
690
|
nk_size_t row_count) {
|
|
@@ -687,8 +693,8 @@ NK_INTERNAL void nk_angulars_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vecto
|
|
|
687
693
|
result_row[row_index] = nk_dots_reduce_sumsq_bf16_(vectors + row_index * stride_elements, depth);
|
|
688
694
|
}
|
|
689
695
|
nk_f32_t norms_cache[256];
|
|
690
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
691
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
696
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
697
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
692
698
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
693
699
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
|
|
694
700
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -720,17 +726,18 @@ NK_INTERNAL void nk_angulars_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vecto
|
|
|
720
726
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
721
727
|
}
|
|
722
728
|
|
|
723
|
-
NK_PUBLIC void nk_angulars_symmetric_bf16_rvv(
|
|
724
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
725
|
-
nk_f32_t *result, nk_size_t
|
|
726
|
-
nk_size_t const stride_elements =
|
|
727
|
-
nk_size_t const result_stride_elements =
|
|
728
|
-
nk_dots_symmetric_bf16_rvv(vectors,
|
|
729
|
-
|
|
730
|
-
|
|
729
|
+
NK_PUBLIC void nk_angulars_symmetric_bf16_rvv( //
|
|
730
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
731
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
732
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
733
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
734
|
+
nk_dots_symmetric_bf16_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
735
|
+
row_start, row_count);
|
|
736
|
+
nk_angulars_symmetric_bf16_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
737
|
+
result_stride_elements, row_start, row_count);
|
|
731
738
|
}
|
|
732
739
|
|
|
733
|
-
NK_INTERNAL void nk_euclideans_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vectors, nk_size_t
|
|
740
|
+
NK_INTERNAL void nk_euclideans_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vectors, nk_size_t vectors_count,
|
|
734
741
|
nk_size_t depth, nk_size_t stride_elements,
|
|
735
742
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
736
743
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -739,8 +746,8 @@ NK_INTERNAL void nk_euclideans_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vec
|
|
|
739
746
|
result_row[row_index] = nk_dots_reduce_sumsq_bf16_(vectors + row_index * stride_elements, depth);
|
|
740
747
|
}
|
|
741
748
|
nk_f32_t norms_cache[256];
|
|
742
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
743
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
749
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
750
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
744
751
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
745
752
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
|
|
746
753
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -771,19 +778,20 @@ NK_INTERNAL void nk_euclideans_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vec
|
|
|
771
778
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
772
779
|
}
|
|
773
780
|
|
|
774
|
-
NK_PUBLIC void nk_euclideans_symmetric_bf16_rvv(
|
|
775
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
776
|
-
nk_f32_t *result, nk_size_t
|
|
777
|
-
nk_size_t const stride_elements =
|
|
778
|
-
nk_size_t const result_stride_elements =
|
|
779
|
-
nk_dots_symmetric_bf16_rvv(vectors,
|
|
780
|
-
|
|
781
|
+
NK_PUBLIC void nk_euclideans_symmetric_bf16_rvv( //
|
|
782
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
783
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
784
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
785
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
786
|
+
nk_dots_symmetric_bf16_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
787
|
+
row_start, row_count);
|
|
788
|
+
nk_euclideans_symmetric_bf16_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
781
789
|
result_stride_elements, row_start, row_count);
|
|
782
790
|
}
|
|
783
791
|
|
|
784
|
-
#pragma endregion
|
|
792
|
+
#pragma endregion BF16 Floats
|
|
785
793
|
|
|
786
|
-
#pragma region
|
|
794
|
+
#pragma region E2M3 Floats
|
|
787
795
|
|
|
788
796
|
NK_INTERNAL void nk_angulars_packed_e2m3_rvv_finalize_(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
789
797
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -867,7 +875,7 @@ NK_PUBLIC void nk_euclideans_packed_e2m3_rvv( //
|
|
|
867
875
|
nk_euclideans_packed_e2m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
868
876
|
}
|
|
869
877
|
|
|
870
|
-
NK_INTERNAL void nk_angulars_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vectors, nk_size_t
|
|
878
|
+
NK_INTERNAL void nk_angulars_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vectors, nk_size_t vectors_count,
|
|
871
879
|
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
872
880
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
873
881
|
nk_size_t row_count) {
|
|
@@ -876,8 +884,8 @@ NK_INTERNAL void nk_angulars_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vecto
|
|
|
876
884
|
result_row[row_index] = nk_dots_reduce_sumsq_e2m3_(vectors + row_index * stride_elements, depth);
|
|
877
885
|
}
|
|
878
886
|
nk_f32_t norms_cache[256];
|
|
879
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
880
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
887
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
888
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
881
889
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
882
890
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
|
|
883
891
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -909,17 +917,18 @@ NK_INTERNAL void nk_angulars_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vecto
|
|
|
909
917
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
910
918
|
}
|
|
911
919
|
|
|
912
|
-
NK_PUBLIC void nk_angulars_symmetric_e2m3_rvv(
|
|
913
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
914
|
-
nk_f32_t *result, nk_size_t
|
|
915
|
-
nk_size_t const stride_elements =
|
|
916
|
-
nk_size_t const result_stride_elements =
|
|
917
|
-
nk_dots_symmetric_e2m3_rvv(vectors,
|
|
918
|
-
|
|
919
|
-
|
|
920
|
+
NK_PUBLIC void nk_angulars_symmetric_e2m3_rvv( //
|
|
921
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
922
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
923
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
924
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
925
|
+
nk_dots_symmetric_e2m3_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
926
|
+
row_start, row_count);
|
|
927
|
+
nk_angulars_symmetric_e2m3_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
928
|
+
result_stride_elements, row_start, row_count);
|
|
920
929
|
}
|
|
921
930
|
|
|
922
|
-
NK_INTERNAL void nk_euclideans_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vectors, nk_size_t
|
|
931
|
+
NK_INTERNAL void nk_euclideans_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vectors, nk_size_t vectors_count,
|
|
923
932
|
nk_size_t depth, nk_size_t stride_elements,
|
|
924
933
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
925
934
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -928,8 +937,8 @@ NK_INTERNAL void nk_euclideans_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vec
|
|
|
928
937
|
result_row[row_index] = nk_dots_reduce_sumsq_e2m3_(vectors + row_index * stride_elements, depth);
|
|
929
938
|
}
|
|
930
939
|
nk_f32_t norms_cache[256];
|
|
931
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
932
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
940
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
941
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
933
942
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
934
943
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
|
|
935
944
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -960,19 +969,20 @@ NK_INTERNAL void nk_euclideans_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vec
|
|
|
960
969
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
961
970
|
}
|
|
962
971
|
|
|
963
|
-
NK_PUBLIC void nk_euclideans_symmetric_e2m3_rvv(
|
|
964
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
965
|
-
nk_f32_t *result, nk_size_t
|
|
966
|
-
nk_size_t const stride_elements =
|
|
967
|
-
nk_size_t const result_stride_elements =
|
|
968
|
-
nk_dots_symmetric_e2m3_rvv(vectors,
|
|
969
|
-
|
|
972
|
+
NK_PUBLIC void nk_euclideans_symmetric_e2m3_rvv( //
|
|
973
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
974
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
975
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
976
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
977
|
+
nk_dots_symmetric_e2m3_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
978
|
+
row_start, row_count);
|
|
979
|
+
nk_euclideans_symmetric_e2m3_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
970
980
|
result_stride_elements, row_start, row_count);
|
|
971
981
|
}
|
|
972
982
|
|
|
973
|
-
#pragma endregion
|
|
983
|
+
#pragma endregion E2M3 Floats
|
|
974
984
|
|
|
975
|
-
#pragma region
|
|
985
|
+
#pragma region E3M2 Floats
|
|
976
986
|
|
|
977
987
|
NK_INTERNAL void nk_angulars_packed_e3m2_rvv_finalize_(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
978
988
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -1056,7 +1066,7 @@ NK_PUBLIC void nk_euclideans_packed_e3m2_rvv( //
|
|
|
1056
1066
|
nk_euclideans_packed_e3m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1057
1067
|
}
|
|
1058
1068
|
|
|
1059
|
-
NK_INTERNAL void nk_angulars_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vectors, nk_size_t
|
|
1069
|
+
NK_INTERNAL void nk_angulars_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vectors, nk_size_t vectors_count,
|
|
1060
1070
|
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1061
1071
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1062
1072
|
nk_size_t row_count) {
|
|
@@ -1065,8 +1075,8 @@ NK_INTERNAL void nk_angulars_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vecto
|
|
|
1065
1075
|
result_row[row_index] = nk_dots_reduce_sumsq_e3m2_(vectors + row_index * stride_elements, depth);
|
|
1066
1076
|
}
|
|
1067
1077
|
nk_f32_t norms_cache[256];
|
|
1068
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1069
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1078
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1079
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1070
1080
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1071
1081
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
|
|
1072
1082
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1098,17 +1108,18 @@ NK_INTERNAL void nk_angulars_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vecto
|
|
|
1098
1108
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1099
1109
|
}
|
|
1100
1110
|
|
|
1101
|
-
NK_PUBLIC void nk_angulars_symmetric_e3m2_rvv(
|
|
1102
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
1103
|
-
nk_f32_t *result, nk_size_t
|
|
1104
|
-
nk_size_t const stride_elements =
|
|
1105
|
-
nk_size_t const result_stride_elements =
|
|
1106
|
-
nk_dots_symmetric_e3m2_rvv(vectors,
|
|
1107
|
-
|
|
1108
|
-
|
|
1111
|
+
NK_PUBLIC void nk_angulars_symmetric_e3m2_rvv( //
|
|
1112
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1113
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1114
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1115
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1116
|
+
nk_dots_symmetric_e3m2_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
1117
|
+
row_start, row_count);
|
|
1118
|
+
nk_angulars_symmetric_e3m2_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1119
|
+
result_stride_elements, row_start, row_count);
|
|
1109
1120
|
}
|
|
1110
1121
|
|
|
1111
|
-
NK_INTERNAL void nk_euclideans_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vectors, nk_size_t
|
|
1122
|
+
NK_INTERNAL void nk_euclideans_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vectors, nk_size_t vectors_count,
|
|
1112
1123
|
nk_size_t depth, nk_size_t stride_elements,
|
|
1113
1124
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1114
1125
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -1117,8 +1128,8 @@ NK_INTERNAL void nk_euclideans_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vec
|
|
|
1117
1128
|
result_row[row_index] = nk_dots_reduce_sumsq_e3m2_(vectors + row_index * stride_elements, depth);
|
|
1118
1129
|
}
|
|
1119
1130
|
nk_f32_t norms_cache[256];
|
|
1120
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1121
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1131
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1132
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1122
1133
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1123
1134
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
|
|
1124
1135
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1149,19 +1160,20 @@ NK_INTERNAL void nk_euclideans_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vec
|
|
|
1149
1160
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1150
1161
|
}
|
|
1151
1162
|
|
|
1152
|
-
NK_PUBLIC void nk_euclideans_symmetric_e3m2_rvv(
|
|
1153
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
1154
|
-
nk_f32_t *result, nk_size_t
|
|
1155
|
-
nk_size_t const stride_elements =
|
|
1156
|
-
nk_size_t const result_stride_elements =
|
|
1157
|
-
nk_dots_symmetric_e3m2_rvv(vectors,
|
|
1158
|
-
|
|
1163
|
+
NK_PUBLIC void nk_euclideans_symmetric_e3m2_rvv( //
|
|
1164
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1165
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1166
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1167
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1168
|
+
nk_dots_symmetric_e3m2_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
1169
|
+
row_start, row_count);
|
|
1170
|
+
nk_euclideans_symmetric_e3m2_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1159
1171
|
result_stride_elements, row_start, row_count);
|
|
1160
1172
|
}
|
|
1161
1173
|
|
|
1162
|
-
#pragma endregion
|
|
1174
|
+
#pragma endregion E3M2 Floats
|
|
1163
1175
|
|
|
1164
|
-
#pragma region
|
|
1176
|
+
#pragma region E4M3 Floats
|
|
1165
1177
|
|
|
1166
1178
|
NK_INTERNAL void nk_angulars_packed_e4m3_rvv_finalize_(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1167
1179
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -1245,7 +1257,7 @@ NK_PUBLIC void nk_euclideans_packed_e4m3_rvv( //
|
|
|
1245
1257
|
nk_euclideans_packed_e4m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1246
1258
|
}
|
|
1247
1259
|
|
|
1248
|
-
NK_INTERNAL void nk_angulars_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vectors, nk_size_t
|
|
1260
|
+
NK_INTERNAL void nk_angulars_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vectors, nk_size_t vectors_count,
|
|
1249
1261
|
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1250
1262
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1251
1263
|
nk_size_t row_count) {
|
|
@@ -1254,8 +1266,8 @@ NK_INTERNAL void nk_angulars_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vecto
|
|
|
1254
1266
|
result_row[row_index] = nk_dots_reduce_sumsq_e4m3_(vectors + row_index * stride_elements, depth);
|
|
1255
1267
|
}
|
|
1256
1268
|
nk_f32_t norms_cache[256];
|
|
1257
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1258
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1269
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1270
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1259
1271
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1260
1272
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
|
|
1261
1273
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1287,17 +1299,18 @@ NK_INTERNAL void nk_angulars_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vecto
|
|
|
1287
1299
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1288
1300
|
}
|
|
1289
1301
|
|
|
1290
|
-
NK_PUBLIC void nk_angulars_symmetric_e4m3_rvv(
|
|
1291
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
1292
|
-
nk_f32_t *result, nk_size_t
|
|
1293
|
-
nk_size_t const stride_elements =
|
|
1294
|
-
nk_size_t const result_stride_elements =
|
|
1295
|
-
nk_dots_symmetric_e4m3_rvv(vectors,
|
|
1296
|
-
|
|
1297
|
-
|
|
1302
|
+
NK_PUBLIC void nk_angulars_symmetric_e4m3_rvv( //
|
|
1303
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1304
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1305
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
1306
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1307
|
+
nk_dots_symmetric_e4m3_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
1308
|
+
row_start, row_count);
|
|
1309
|
+
nk_angulars_symmetric_e4m3_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1310
|
+
result_stride_elements, row_start, row_count);
|
|
1298
1311
|
}
|
|
1299
1312
|
|
|
1300
|
-
NK_INTERNAL void nk_euclideans_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vectors, nk_size_t
|
|
1313
|
+
NK_INTERNAL void nk_euclideans_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vectors, nk_size_t vectors_count,
|
|
1301
1314
|
nk_size_t depth, nk_size_t stride_elements,
|
|
1302
1315
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1303
1316
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -1306,8 +1319,8 @@ NK_INTERNAL void nk_euclideans_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vec
|
|
|
1306
1319
|
result_row[row_index] = nk_dots_reduce_sumsq_e4m3_(vectors + row_index * stride_elements, depth);
|
|
1307
1320
|
}
|
|
1308
1321
|
nk_f32_t norms_cache[256];
|
|
1309
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1310
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1322
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1323
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1311
1324
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1312
1325
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
|
|
1313
1326
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1338,19 +1351,20 @@ NK_INTERNAL void nk_euclideans_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vec
|
|
|
1338
1351
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1339
1352
|
}
|
|
1340
1353
|
|
|
1341
|
-
NK_PUBLIC void nk_euclideans_symmetric_e4m3_rvv(
|
|
1342
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
1343
|
-
nk_f32_t *result, nk_size_t
|
|
1344
|
-
nk_size_t const stride_elements =
|
|
1345
|
-
nk_size_t const result_stride_elements =
|
|
1346
|
-
nk_dots_symmetric_e4m3_rvv(vectors,
|
|
1347
|
-
|
|
1354
|
+
NK_PUBLIC void nk_euclideans_symmetric_e4m3_rvv( //
|
|
1355
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1356
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1357
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
1358
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1359
|
+
nk_dots_symmetric_e4m3_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
1360
|
+
row_start, row_count);
|
|
1361
|
+
nk_euclideans_symmetric_e4m3_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1348
1362
|
result_stride_elements, row_start, row_count);
|
|
1349
1363
|
}
|
|
1350
1364
|
|
|
1351
|
-
#pragma endregion
|
|
1365
|
+
#pragma endregion E4M3 Floats
|
|
1352
1366
|
|
|
1353
|
-
#pragma region
|
|
1367
|
+
#pragma region E5M2 Floats
|
|
1354
1368
|
|
|
1355
1369
|
NK_INTERNAL void nk_angulars_packed_e5m2_rvv_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1356
1370
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -1434,7 +1448,7 @@ NK_PUBLIC void nk_euclideans_packed_e5m2_rvv( //
|
|
|
1434
1448
|
nk_euclideans_packed_e5m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1435
1449
|
}
|
|
1436
1450
|
|
|
1437
|
-
NK_INTERNAL void nk_angulars_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vectors, nk_size_t
|
|
1451
|
+
NK_INTERNAL void nk_angulars_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vectors, nk_size_t vectors_count,
|
|
1438
1452
|
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1439
1453
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1440
1454
|
nk_size_t row_count) {
|
|
@@ -1443,8 +1457,8 @@ NK_INTERNAL void nk_angulars_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vecto
|
|
|
1443
1457
|
result_row[row_index] = nk_dots_reduce_sumsq_e5m2_(vectors + row_index * stride_elements, depth);
|
|
1444
1458
|
}
|
|
1445
1459
|
nk_f32_t norms_cache[256];
|
|
1446
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1447
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1460
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1461
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1448
1462
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1449
1463
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
1450
1464
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1476,17 +1490,18 @@ NK_INTERNAL void nk_angulars_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vecto
|
|
|
1476
1490
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1477
1491
|
}
|
|
1478
1492
|
|
|
1479
|
-
NK_PUBLIC void nk_angulars_symmetric_e5m2_rvv(
|
|
1480
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
1481
|
-
nk_f32_t *result, nk_size_t
|
|
1482
|
-
nk_size_t const stride_elements =
|
|
1483
|
-
nk_size_t const result_stride_elements =
|
|
1484
|
-
nk_dots_symmetric_e5m2_rvv(vectors,
|
|
1485
|
-
|
|
1486
|
-
|
|
1493
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_rvv( //
|
|
1494
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1495
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1496
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
1497
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1498
|
+
nk_dots_symmetric_e5m2_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
1499
|
+
row_start, row_count);
|
|
1500
|
+
nk_angulars_symmetric_e5m2_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1501
|
+
result_stride_elements, row_start, row_count);
|
|
1487
1502
|
}
|
|
1488
1503
|
|
|
1489
|
-
NK_INTERNAL void nk_euclideans_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vectors, nk_size_t
|
|
1504
|
+
NK_INTERNAL void nk_euclideans_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vectors, nk_size_t vectors_count,
|
|
1490
1505
|
nk_size_t depth, nk_size_t stride_elements,
|
|
1491
1506
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1492
1507
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -1495,8 +1510,8 @@ NK_INTERNAL void nk_euclideans_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vec
|
|
|
1495
1510
|
result_row[row_index] = nk_dots_reduce_sumsq_e5m2_(vectors + row_index * stride_elements, depth);
|
|
1496
1511
|
}
|
|
1497
1512
|
nk_f32_t norms_cache[256];
|
|
1498
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1499
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1513
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1514
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1500
1515
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1501
1516
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
1502
1517
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1527,19 +1542,20 @@ NK_INTERNAL void nk_euclideans_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vec
|
|
|
1527
1542
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1528
1543
|
}
|
|
1529
1544
|
|
|
1530
|
-
NK_PUBLIC void nk_euclideans_symmetric_e5m2_rvv(
|
|
1531
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
1532
|
-
nk_f32_t *result, nk_size_t
|
|
1533
|
-
nk_size_t const stride_elements =
|
|
1534
|
-
nk_size_t const result_stride_elements =
|
|
1535
|
-
nk_dots_symmetric_e5m2_rvv(vectors,
|
|
1536
|
-
|
|
1545
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_rvv( //
|
|
1546
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1547
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1548
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
1549
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1550
|
+
nk_dots_symmetric_e5m2_rvv(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
1551
|
+
row_start, row_count);
|
|
1552
|
+
nk_euclideans_symmetric_e5m2_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1537
1553
|
result_stride_elements, row_start, row_count);
|
|
1538
1554
|
}
|
|
1539
1555
|
|
|
1540
|
-
#pragma endregion
|
|
1556
|
+
#pragma endregion E5M2 Floats
|
|
1541
1557
|
|
|
1542
|
-
#pragma region
|
|
1558
|
+
#pragma region I8 Integers
|
|
1543
1559
|
|
|
1544
1560
|
NK_INTERNAL void nk_angulars_packed_i8_rvv_finalize_(nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1545
1561
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -1629,8 +1645,8 @@ NK_PUBLIC void nk_euclideans_packed_i8_rvv( //
|
|
|
1629
1645
|
nk_euclideans_packed_i8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1630
1646
|
}
|
|
1631
1647
|
|
|
1632
|
-
NK_INTERNAL void nk_angulars_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors, nk_size_t
|
|
1633
|
-
nk_size_t stride_elements, nk_f32_t *result,
|
|
1648
|
+
NK_INTERNAL void nk_angulars_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors, nk_size_t vectors_count,
|
|
1649
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1634
1650
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1635
1651
|
nk_size_t row_count) {
|
|
1636
1652
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1638,8 +1654,8 @@ NK_INTERNAL void nk_angulars_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors,
|
|
|
1638
1654
|
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
|
|
1639
1655
|
}
|
|
1640
1656
|
nk_u32_t norms_cache[256];
|
|
1641
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1642
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1657
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1658
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1643
1659
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1644
1660
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
|
|
1645
1661
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1674,19 +1690,19 @@ NK_INTERNAL void nk_angulars_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors,
|
|
|
1674
1690
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1675
1691
|
}
|
|
1676
1692
|
|
|
1677
|
-
NK_PUBLIC void nk_angulars_symmetric_i8_rvv(
|
|
1678
|
-
nk_i8_t const *vectors, nk_size_t
|
|
1679
|
-
nk_f32_t *result, nk_size_t
|
|
1680
|
-
nk_size_t const stride_elements =
|
|
1681
|
-
nk_size_t const result_stride_elements =
|
|
1682
|
-
nk_dots_symmetric_i8_rvv(vectors,
|
|
1683
|
-
row_count);
|
|
1684
|
-
nk_angulars_symmetric_i8_rvv_finalize_(vectors,
|
|
1685
|
-
row_start, row_count);
|
|
1693
|
+
NK_PUBLIC void nk_angulars_symmetric_i8_rvv( //
|
|
1694
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1695
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1696
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
1697
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1698
|
+
nk_dots_symmetric_i8_rvv(vectors, vectors_count, depth, stride_in_bytes, (nk_i32_t *)result, result_stride_in_bytes,
|
|
1699
|
+
row_start, row_count);
|
|
1700
|
+
nk_angulars_symmetric_i8_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1701
|
+
result_stride_elements, row_start, row_count);
|
|
1686
1702
|
}
|
|
1687
1703
|
|
|
1688
|
-
NK_INTERNAL void nk_euclideans_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors, nk_size_t
|
|
1689
|
-
nk_size_t stride_elements, nk_f32_t *result,
|
|
1704
|
+
NK_INTERNAL void nk_euclideans_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors, nk_size_t vectors_count,
|
|
1705
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1690
1706
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1691
1707
|
nk_size_t row_count) {
|
|
1692
1708
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1694,8 +1710,8 @@ NK_INTERNAL void nk_euclideans_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors
|
|
|
1694
1710
|
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
|
|
1695
1711
|
}
|
|
1696
1712
|
nk_u32_t norms_cache[256];
|
|
1697
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1698
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1713
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1714
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1699
1715
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1700
1716
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
|
|
1701
1717
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1729,20 +1745,20 @@ NK_INTERNAL void nk_euclideans_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors
|
|
|
1729
1745
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1730
1746
|
}
|
|
1731
1747
|
|
|
1732
|
-
NK_PUBLIC void nk_euclideans_symmetric_i8_rvv(
|
|
1733
|
-
nk_i8_t const *vectors, nk_size_t
|
|
1734
|
-
nk_f32_t *result, nk_size_t
|
|
1735
|
-
nk_size_t const stride_elements =
|
|
1736
|
-
nk_size_t const result_stride_elements =
|
|
1737
|
-
nk_dots_symmetric_i8_rvv(vectors,
|
|
1738
|
-
row_count);
|
|
1739
|
-
nk_euclideans_symmetric_i8_rvv_finalize_(vectors,
|
|
1740
|
-
row_start, row_count);
|
|
1748
|
+
NK_PUBLIC void nk_euclideans_symmetric_i8_rvv( //
|
|
1749
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1750
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1751
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
1752
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1753
|
+
nk_dots_symmetric_i8_rvv(vectors, vectors_count, depth, stride_in_bytes, (nk_i32_t *)result, result_stride_in_bytes,
|
|
1754
|
+
row_start, row_count);
|
|
1755
|
+
nk_euclideans_symmetric_i8_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1756
|
+
result_stride_elements, row_start, row_count);
|
|
1741
1757
|
}
|
|
1742
1758
|
|
|
1743
|
-
#pragma endregion
|
|
1759
|
+
#pragma endregion I8 Integers
|
|
1744
1760
|
|
|
1745
|
-
#pragma region
|
|
1761
|
+
#pragma region U8 Integers
|
|
1746
1762
|
|
|
1747
1763
|
NK_INTERNAL void nk_angulars_packed_u8_rvv_finalize_(nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1748
1764
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -1832,8 +1848,8 @@ NK_PUBLIC void nk_euclideans_packed_u8_rvv( //
|
|
|
1832
1848
|
nk_euclideans_packed_u8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1833
1849
|
}
|
|
1834
1850
|
|
|
1835
|
-
NK_INTERNAL void nk_angulars_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors, nk_size_t
|
|
1836
|
-
nk_size_t stride_elements, nk_f32_t *result,
|
|
1851
|
+
NK_INTERNAL void nk_angulars_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors, nk_size_t vectors_count,
|
|
1852
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1837
1853
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1838
1854
|
nk_size_t row_count) {
|
|
1839
1855
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1841,8 +1857,8 @@ NK_INTERNAL void nk_angulars_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors,
|
|
|
1841
1857
|
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
|
|
1842
1858
|
}
|
|
1843
1859
|
nk_u32_t norms_cache[256];
|
|
1844
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1845
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1860
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1861
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1846
1862
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1847
1863
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
|
|
1848
1864
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1877,19 +1893,19 @@ NK_INTERNAL void nk_angulars_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors,
|
|
|
1877
1893
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1878
1894
|
}
|
|
1879
1895
|
|
|
1880
|
-
NK_PUBLIC void nk_angulars_symmetric_u8_rvv(
|
|
1881
|
-
nk_u8_t const *vectors, nk_size_t
|
|
1882
|
-
nk_f32_t *result, nk_size_t
|
|
1883
|
-
nk_size_t const stride_elements =
|
|
1884
|
-
nk_size_t const result_stride_elements =
|
|
1885
|
-
nk_dots_symmetric_u8_rvv(vectors,
|
|
1886
|
-
row_count);
|
|
1887
|
-
nk_angulars_symmetric_u8_rvv_finalize_(vectors,
|
|
1888
|
-
row_start, row_count);
|
|
1896
|
+
NK_PUBLIC void nk_angulars_symmetric_u8_rvv( //
|
|
1897
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1898
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1899
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
1900
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1901
|
+
nk_dots_symmetric_u8_rvv(vectors, vectors_count, depth, stride_in_bytes, (nk_u32_t *)result, result_stride_in_bytes,
|
|
1902
|
+
row_start, row_count);
|
|
1903
|
+
nk_angulars_symmetric_u8_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1904
|
+
result_stride_elements, row_start, row_count);
|
|
1889
1905
|
}
|
|
1890
1906
|
|
|
1891
|
-
NK_INTERNAL void nk_euclideans_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors, nk_size_t
|
|
1892
|
-
nk_size_t stride_elements, nk_f32_t *result,
|
|
1907
|
+
NK_INTERNAL void nk_euclideans_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors, nk_size_t vectors_count,
|
|
1908
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1893
1909
|
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1894
1910
|
nk_size_t row_count) {
|
|
1895
1911
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1897,8 +1913,8 @@ NK_INTERNAL void nk_euclideans_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors
|
|
|
1897
1913
|
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
|
|
1898
1914
|
}
|
|
1899
1915
|
nk_u32_t norms_cache[256];
|
|
1900
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1901
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1916
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1917
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1902
1918
|
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1903
1919
|
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
|
|
1904
1920
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
@@ -1932,18 +1948,18 @@ NK_INTERNAL void nk_euclideans_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors
|
|
|
1932
1948
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1933
1949
|
}
|
|
1934
1950
|
|
|
1935
|
-
NK_PUBLIC void nk_euclideans_symmetric_u8_rvv(
|
|
1936
|
-
nk_u8_t const *vectors, nk_size_t
|
|
1937
|
-
nk_f32_t *result, nk_size_t
|
|
1938
|
-
nk_size_t const stride_elements =
|
|
1939
|
-
nk_size_t const result_stride_elements =
|
|
1940
|
-
nk_dots_symmetric_u8_rvv(vectors,
|
|
1941
|
-
row_count);
|
|
1942
|
-
nk_euclideans_symmetric_u8_rvv_finalize_(vectors,
|
|
1943
|
-
row_start, row_count);
|
|
1951
|
+
NK_PUBLIC void nk_euclideans_symmetric_u8_rvv( //
|
|
1952
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1953
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1954
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
1955
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1956
|
+
nk_dots_symmetric_u8_rvv(vectors, vectors_count, depth, stride_in_bytes, (nk_u32_t *)result, result_stride_in_bytes,
|
|
1957
|
+
row_start, row_count);
|
|
1958
|
+
nk_euclideans_symmetric_u8_rvv_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1959
|
+
result_stride_elements, row_start, row_count);
|
|
1944
1960
|
}
|
|
1945
1961
|
|
|
1946
|
-
#pragma endregion
|
|
1962
|
+
#pragma endregion U8 Integers
|
|
1947
1963
|
|
|
1948
1964
|
#if defined(__clang__)
|
|
1949
1965
|
#pragma clang attribute pop
|