numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1960 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Batched Spatial Distances for RISC-V Vector (RVV).
|
|
3
|
+
* @file include/numkong/spatials/rvv.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatials.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPATIALS_RVV_H
|
|
10
|
+
#define NK_SPATIALS_RVV_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_RISCV_
|
|
13
|
+
#if NK_TARGET_RVV
|
|
14
|
+
|
|
15
|
+
#include "numkong/dots/serial.h"
|
|
16
|
+
#include "numkong/dots/rvv.h"
|
|
17
|
+
#include "numkong/spatial/rvv.h"
|
|
18
|
+
|
|
19
|
+
#if defined(__cplusplus)
|
|
20
|
+
extern "C" {
|
|
21
|
+
#endif
|
|
22
|
+
|
|
23
|
+
#if defined(__clang__)
|
|
24
|
+
#pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
|
|
25
|
+
#elif defined(__GNUC__)
|
|
26
|
+
#pragma GCC push_options
|
|
27
|
+
#pragma GCC target("arch=+v")
|
|
28
|
+
#endif
|
|
29
|
+
|
|
30
|
+
#pragma region Single Precision Floats
|
|
31
|
+
|
|
32
|
+
NK_INTERNAL void nk_angulars_packed_f32_rvv_finalize_(nk_f32_t const *a, void const *b_packed, nk_f64_t *c,
|
|
33
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
34
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
35
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
36
|
+
nk_f64_t const *target_norms = (nk_f64_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
37
|
+
header->column_count * header->depth_padded_values *
|
|
38
|
+
sizeof(nk_f32_t));
|
|
39
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
40
|
+
nk_f32_t const *a_row = a + row_index * a_stride_elements;
|
|
41
|
+
nk_f64_t *result_row = c + row_index * c_stride_elements;
|
|
42
|
+
nk_f64_t query_norm_sq_f64 = nk_dots_reduce_sumsq_f32_(a_row, depth);
|
|
43
|
+
nk_size_t count_columns = columns;
|
|
44
|
+
nk_f64_t *result_ptr = result_row;
|
|
45
|
+
nk_f64_t const *norms_ptr = target_norms;
|
|
46
|
+
while (count_columns > 0) {
|
|
47
|
+
size_t vector_length = __riscv_vsetvl_e64m1(count_columns);
|
|
48
|
+
vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
|
|
49
|
+
vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
|
|
50
|
+
vfloat64m1_t norms_product_f64m1 = __riscv_vfmul_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
|
|
51
|
+
vector_length);
|
|
52
|
+
vfloat64m1_t rsqrt_f64m1 = nk_rsqrt_f64m1_rvv_(norms_product_f64m1, vector_length);
|
|
53
|
+
vfloat64m1_t normalized_dots_f64m1 = __riscv_vfmul_vv_f64m1(dots_f64m1, rsqrt_f64m1, vector_length);
|
|
54
|
+
vfloat64m1_t angular_f64m1 = __riscv_vfrsub_vf_f64m1(normalized_dots_f64m1, 1.0, vector_length);
|
|
55
|
+
angular_f64m1 = __riscv_vfmax_vf_f64m1(angular_f64m1, 0.0, vector_length);
|
|
56
|
+
__riscv_vse64_v_f64m1(result_ptr, angular_f64m1, vector_length);
|
|
57
|
+
result_ptr += vector_length;
|
|
58
|
+
norms_ptr += vector_length;
|
|
59
|
+
count_columns -= vector_length;
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
NK_PUBLIC void nk_angulars_packed_f32_rvv( //
|
|
65
|
+
nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
|
|
66
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
67
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
68
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
|
|
69
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
70
|
+
nk_dots_packed_f32_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
71
|
+
nk_angulars_packed_f32_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
NK_INTERNAL void nk_euclideans_packed_f32_rvv_finalize_(nk_f32_t const *a, void const *b_packed, nk_f64_t *c,
|
|
75
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
76
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
77
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
78
|
+
nk_f64_t const *target_norms = (nk_f64_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
79
|
+
header->column_count * header->depth_padded_values *
|
|
80
|
+
sizeof(nk_f32_t));
|
|
81
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
82
|
+
nk_f32_t const *a_row = a + row_index * a_stride_elements;
|
|
83
|
+
nk_f64_t *result_row = c + row_index * c_stride_elements;
|
|
84
|
+
nk_f64_t query_norm_sq_f64 = nk_dots_reduce_sumsq_f32_(a_row, depth);
|
|
85
|
+
nk_size_t count_columns = columns;
|
|
86
|
+
nk_f64_t *result_ptr = result_row;
|
|
87
|
+
nk_f64_t const *norms_ptr = target_norms;
|
|
88
|
+
while (count_columns > 0) {
|
|
89
|
+
size_t vector_length = __riscv_vsetvl_e64m1(count_columns);
|
|
90
|
+
vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
|
|
91
|
+
vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
|
|
92
|
+
vfloat64m1_t sum_sq_f64m1 = __riscv_vfadd_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64, vector_length);
|
|
93
|
+
vfloat64m1_t dist_sq_f64m1 = __riscv_vfsub_vv_f64m1(
|
|
94
|
+
sum_sq_f64m1, __riscv_vfmul_vf_f64m1(dots_f64m1, 2.0, vector_length), vector_length);
|
|
95
|
+
dist_sq_f64m1 = __riscv_vfmax_vf_f64m1(dist_sq_f64m1, 0.0, vector_length);
|
|
96
|
+
__riscv_vse64_v_f64m1(result_ptr, __riscv_vfsqrt_v_f64m1(dist_sq_f64m1, vector_length), vector_length);
|
|
97
|
+
result_ptr += vector_length;
|
|
98
|
+
norms_ptr += vector_length;
|
|
99
|
+
count_columns -= vector_length;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
NK_PUBLIC void nk_euclideans_packed_f32_rvv( //
|
|
105
|
+
nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
|
|
106
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
107
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
108
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
|
|
109
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
110
|
+
nk_dots_packed_f32_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
111
|
+
nk_euclideans_packed_f32_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
NK_INTERNAL void nk_angulars_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
115
|
+
nk_size_t stride_elements, nk_f64_t *result,
|
|
116
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
117
|
+
nk_size_t row_count) {
|
|
118
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
119
|
+
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
120
|
+
result_row[row_index] = nk_dots_reduce_sumsq_f32_(vectors + row_index * stride_elements, depth);
|
|
121
|
+
}
|
|
122
|
+
nk_f64_t norms_cache[256];
|
|
123
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
124
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
125
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
126
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f32_(vectors + col * stride_elements, depth);
|
|
127
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
128
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
129
|
+
if (col_start >= chunk_end) continue;
|
|
130
|
+
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
131
|
+
nk_f64_t query_norm_sq_f64 = result_row[row_index];
|
|
132
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
133
|
+
nk_f64_t *result_ptr = result_row + col_start;
|
|
134
|
+
nk_f64_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
135
|
+
while (count_remaining > 0) {
|
|
136
|
+
size_t vector_length = __riscv_vsetvl_e64m1(count_remaining);
|
|
137
|
+
vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
|
|
138
|
+
vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
|
|
139
|
+
vfloat64m1_t norms_product_f64m1 = __riscv_vfmul_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
|
|
140
|
+
vector_length);
|
|
141
|
+
vfloat64m1_t rsqrt_f64m1 = nk_rsqrt_f64m1_rvv_(norms_product_f64m1, vector_length);
|
|
142
|
+
vfloat64m1_t normalized_dots_f64m1 = __riscv_vfmul_vv_f64m1(dots_f64m1, rsqrt_f64m1, vector_length);
|
|
143
|
+
vfloat64m1_t angular_f64m1 = __riscv_vfrsub_vf_f64m1(normalized_dots_f64m1, 1.0, vector_length);
|
|
144
|
+
angular_f64m1 = __riscv_vfmax_vf_f64m1(angular_f64m1, 0.0, vector_length);
|
|
145
|
+
__riscv_vse64_v_f64m1(result_ptr, angular_f64m1, vector_length);
|
|
146
|
+
result_ptr += vector_length;
|
|
147
|
+
norms_ptr += vector_length;
|
|
148
|
+
count_remaining -= vector_length;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
153
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
NK_PUBLIC void nk_angulars_symmetric_f32_rvv( //
|
|
157
|
+
nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
158
|
+
nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
159
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
|
|
160
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
|
|
161
|
+
nk_dots_symmetric_f32_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
162
|
+
nk_angulars_symmetric_f32_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
163
|
+
row_start, row_count);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
NK_INTERNAL void nk_euclideans_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors, nk_size_t n_vectors,
|
|
167
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
168
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
169
|
+
nk_size_t row_count) {
|
|
170
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
171
|
+
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
172
|
+
result_row[row_index] = nk_dots_reduce_sumsq_f32_(vectors + row_index * stride_elements, depth);
|
|
173
|
+
}
|
|
174
|
+
nk_f64_t norms_cache[256];
|
|
175
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
176
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
177
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
178
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f32_(vectors + col * stride_elements, depth);
|
|
179
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
180
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
181
|
+
if (col_start >= chunk_end) continue;
|
|
182
|
+
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
183
|
+
nk_f64_t query_norm_sq_f64 = result_row[row_index];
|
|
184
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
185
|
+
nk_f64_t *result_ptr = result_row + col_start;
|
|
186
|
+
nk_f64_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
187
|
+
while (count_remaining > 0) {
|
|
188
|
+
size_t vector_length = __riscv_vsetvl_e64m1(count_remaining);
|
|
189
|
+
vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
|
|
190
|
+
vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
|
|
191
|
+
vfloat64m1_t sum_sq_f64m1 = __riscv_vfadd_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
|
|
192
|
+
vector_length);
|
|
193
|
+
vfloat64m1_t dist_sq_f64m1 = __riscv_vfsub_vv_f64m1(
|
|
194
|
+
sum_sq_f64m1, __riscv_vfmul_vf_f64m1(dots_f64m1, 2.0, vector_length), vector_length);
|
|
195
|
+
dist_sq_f64m1 = __riscv_vfmax_vf_f64m1(dist_sq_f64m1, 0.0, vector_length);
|
|
196
|
+
__riscv_vse64_v_f64m1(result_ptr, __riscv_vfsqrt_v_f64m1(dist_sq_f64m1, vector_length), vector_length);
|
|
197
|
+
result_ptr += vector_length;
|
|
198
|
+
norms_ptr += vector_length;
|
|
199
|
+
count_remaining -= vector_length;
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
204
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
NK_PUBLIC void nk_euclideans_symmetric_f32_rvv( //
|
|
208
|
+
nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
209
|
+
nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
210
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
|
|
211
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
|
|
212
|
+
nk_dots_symmetric_f32_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
213
|
+
nk_euclideans_symmetric_f32_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
214
|
+
result_stride_elements, row_start, row_count);
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
#pragma endregion // Single Precision Floats
|
|
218
|
+
|
|
219
|
+
#pragma region Double Precision Floats
|
|
220
|
+
|
|
221
|
+
NK_INTERNAL void nk_angulars_packed_f64_rvv_finalize_(nk_f64_t const *a, void const *b_packed, nk_f64_t *c,
|
|
222
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
223
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
224
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
225
|
+
nk_f64_t const *target_norms = (nk_f64_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
226
|
+
header->column_count * header->depth_padded_values *
|
|
227
|
+
sizeof(nk_f64_t));
|
|
228
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
229
|
+
nk_f64_t const *a_row = a + row_index * a_stride_elements;
|
|
230
|
+
nk_f64_t *result_row = c + row_index * c_stride_elements;
|
|
231
|
+
nk_f64_t query_norm_sq_f64 = nk_dots_reduce_sumsq_f64_(a_row, depth);
|
|
232
|
+
nk_size_t count_columns = columns;
|
|
233
|
+
nk_f64_t *result_ptr = result_row;
|
|
234
|
+
nk_f64_t const *norms_ptr = target_norms;
|
|
235
|
+
while (count_columns > 0) {
|
|
236
|
+
size_t vector_length = __riscv_vsetvl_e64m1(count_columns);
|
|
237
|
+
vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
|
|
238
|
+
vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
|
|
239
|
+
vfloat64m1_t norms_product_f64m1 = __riscv_vfmul_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
|
|
240
|
+
vector_length);
|
|
241
|
+
vfloat64m1_t rsqrt_f64m1 = nk_rsqrt_f64m1_rvv_(norms_product_f64m1, vector_length);
|
|
242
|
+
vfloat64m1_t normalized_dots_f64m1 = __riscv_vfmul_vv_f64m1(dots_f64m1, rsqrt_f64m1, vector_length);
|
|
243
|
+
vfloat64m1_t angular_f64m1 = __riscv_vfrsub_vf_f64m1(normalized_dots_f64m1, 1.0, vector_length);
|
|
244
|
+
angular_f64m1 = __riscv_vfmax_vf_f64m1(angular_f64m1, 0.0, vector_length);
|
|
245
|
+
__riscv_vse64_v_f64m1(result_ptr, angular_f64m1, vector_length);
|
|
246
|
+
result_ptr += vector_length;
|
|
247
|
+
norms_ptr += vector_length;
|
|
248
|
+
count_columns -= vector_length;
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
NK_PUBLIC void nk_angulars_packed_f64_rvv( //
|
|
254
|
+
nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
|
|
255
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
256
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
257
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
|
|
258
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
259
|
+
nk_dots_packed_f64_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
260
|
+
nk_angulars_packed_f64_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
NK_INTERNAL void nk_euclideans_packed_f64_rvv_finalize_(nk_f64_t const *a, void const *b_packed, nk_f64_t *c,
|
|
264
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
265
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
266
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
267
|
+
nk_f64_t const *target_norms = (nk_f64_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
268
|
+
header->column_count * header->depth_padded_values *
|
|
269
|
+
sizeof(nk_f64_t));
|
|
270
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
271
|
+
nk_f64_t const *a_row = a + row_index * a_stride_elements;
|
|
272
|
+
nk_f64_t *result_row = c + row_index * c_stride_elements;
|
|
273
|
+
nk_f64_t query_norm_sq_f64 = nk_dots_reduce_sumsq_f64_(a_row, depth);
|
|
274
|
+
nk_size_t count_columns = columns;
|
|
275
|
+
nk_f64_t *result_ptr = result_row;
|
|
276
|
+
nk_f64_t const *norms_ptr = target_norms;
|
|
277
|
+
while (count_columns > 0) {
|
|
278
|
+
size_t vector_length = __riscv_vsetvl_e64m1(count_columns);
|
|
279
|
+
vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
|
|
280
|
+
vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
|
|
281
|
+
vfloat64m1_t sum_sq_f64m1 = __riscv_vfadd_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64, vector_length);
|
|
282
|
+
vfloat64m1_t dist_sq_f64m1 = __riscv_vfsub_vv_f64m1(
|
|
283
|
+
sum_sq_f64m1, __riscv_vfmul_vf_f64m1(dots_f64m1, 2.0, vector_length), vector_length);
|
|
284
|
+
dist_sq_f64m1 = __riscv_vfmax_vf_f64m1(dist_sq_f64m1, 0.0, vector_length);
|
|
285
|
+
__riscv_vse64_v_f64m1(result_ptr, __riscv_vfsqrt_v_f64m1(dist_sq_f64m1, vector_length), vector_length);
|
|
286
|
+
result_ptr += vector_length;
|
|
287
|
+
norms_ptr += vector_length;
|
|
288
|
+
count_columns -= vector_length;
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
NK_PUBLIC void nk_euclideans_packed_f64_rvv( //
|
|
294
|
+
nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
|
|
295
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
296
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
297
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
|
|
298
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
299
|
+
nk_dots_packed_f64_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
300
|
+
nk_euclideans_packed_f64_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
NK_INTERNAL void nk_angulars_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
304
|
+
nk_size_t stride_elements, nk_f64_t *result,
|
|
305
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
306
|
+
nk_size_t row_count) {
|
|
307
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
308
|
+
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
309
|
+
result_row[row_index] = nk_dots_reduce_sumsq_f64_(vectors + row_index * stride_elements, depth);
|
|
310
|
+
}
|
|
311
|
+
nk_f64_t norms_cache[256];
|
|
312
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
313
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
314
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
315
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f64_(vectors + col * stride_elements, depth);
|
|
316
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
317
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
318
|
+
if (col_start >= chunk_end) continue;
|
|
319
|
+
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
320
|
+
nk_f64_t query_norm_sq_f64 = result_row[row_index];
|
|
321
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
322
|
+
nk_f64_t *result_ptr = result_row + col_start;
|
|
323
|
+
nk_f64_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
324
|
+
while (count_remaining > 0) {
|
|
325
|
+
size_t vector_length = __riscv_vsetvl_e64m1(count_remaining);
|
|
326
|
+
vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
|
|
327
|
+
vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
|
|
328
|
+
vfloat64m1_t norms_product_f64m1 = __riscv_vfmul_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
|
|
329
|
+
vector_length);
|
|
330
|
+
vfloat64m1_t rsqrt_f64m1 = nk_rsqrt_f64m1_rvv_(norms_product_f64m1, vector_length);
|
|
331
|
+
vfloat64m1_t normalized_dots_f64m1 = __riscv_vfmul_vv_f64m1(dots_f64m1, rsqrt_f64m1, vector_length);
|
|
332
|
+
vfloat64m1_t angular_f64m1 = __riscv_vfrsub_vf_f64m1(normalized_dots_f64m1, 1.0, vector_length);
|
|
333
|
+
angular_f64m1 = __riscv_vfmax_vf_f64m1(angular_f64m1, 0.0, vector_length);
|
|
334
|
+
__riscv_vse64_v_f64m1(result_ptr, angular_f64m1, vector_length);
|
|
335
|
+
result_ptr += vector_length;
|
|
336
|
+
norms_ptr += vector_length;
|
|
337
|
+
count_remaining -= vector_length;
|
|
338
|
+
}
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
342
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
NK_PUBLIC void nk_angulars_symmetric_f64_rvv( //
|
|
346
|
+
nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
347
|
+
nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
348
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
|
|
349
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
|
|
350
|
+
nk_dots_symmetric_f64_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
351
|
+
nk_angulars_symmetric_f64_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
352
|
+
row_start, row_count);
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
NK_INTERNAL void nk_euclideans_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors, nk_size_t n_vectors,
|
|
356
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
357
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
358
|
+
nk_size_t row_count) {
|
|
359
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
360
|
+
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
361
|
+
result_row[row_index] = nk_dots_reduce_sumsq_f64_(vectors + row_index * stride_elements, depth);
|
|
362
|
+
}
|
|
363
|
+
nk_f64_t norms_cache[256];
|
|
364
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
365
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
366
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
367
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f64_(vectors + col * stride_elements, depth);
|
|
368
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
369
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
370
|
+
if (col_start >= chunk_end) continue;
|
|
371
|
+
nk_f64_t *result_row = result + row_index * result_stride_elements;
|
|
372
|
+
nk_f64_t query_norm_sq_f64 = result_row[row_index];
|
|
373
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
374
|
+
nk_f64_t *result_ptr = result_row + col_start;
|
|
375
|
+
nk_f64_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
376
|
+
while (count_remaining > 0) {
|
|
377
|
+
size_t vector_length = __riscv_vsetvl_e64m1(count_remaining);
|
|
378
|
+
vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
|
|
379
|
+
vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
|
|
380
|
+
vfloat64m1_t sum_sq_f64m1 = __riscv_vfadd_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
|
|
381
|
+
vector_length);
|
|
382
|
+
vfloat64m1_t dist_sq_f64m1 = __riscv_vfsub_vv_f64m1(
|
|
383
|
+
sum_sq_f64m1, __riscv_vfmul_vf_f64m1(dots_f64m1, 2.0, vector_length), vector_length);
|
|
384
|
+
dist_sq_f64m1 = __riscv_vfmax_vf_f64m1(dist_sq_f64m1, 0.0, vector_length);
|
|
385
|
+
__riscv_vse64_v_f64m1(result_ptr, __riscv_vfsqrt_v_f64m1(dist_sq_f64m1, vector_length), vector_length);
|
|
386
|
+
result_ptr += vector_length;
|
|
387
|
+
norms_ptr += vector_length;
|
|
388
|
+
count_remaining -= vector_length;
|
|
389
|
+
}
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
393
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
NK_PUBLIC void nk_euclideans_symmetric_f64_rvv( //
|
|
397
|
+
nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
398
|
+
nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
399
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
|
|
400
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
|
|
401
|
+
nk_dots_symmetric_f64_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
402
|
+
nk_euclideans_symmetric_f64_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
403
|
+
result_stride_elements, row_start, row_count);
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
#pragma endregion // Double Precision Floats
|
|
407
|
+
|
|
408
|
+
#pragma region Half Precision Floats
|
|
409
|
+
|
|
410
|
+
NK_INTERNAL void nk_angulars_packed_f16_rvv_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
411
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
412
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
413
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
414
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
415
|
+
header->column_count * header->depth_padded_values *
|
|
416
|
+
sizeof(nk_f32_t));
|
|
417
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
418
|
+
nk_f16_t const *a_row = a + row_index * a_stride_elements;
|
|
419
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
420
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_(a_row, depth);
|
|
421
|
+
nk_size_t count_columns = columns;
|
|
422
|
+
nk_f32_t *result_ptr = result_row;
|
|
423
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
424
|
+
while (count_columns > 0) {
|
|
425
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
426
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
427
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
428
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
429
|
+
vector_length);
|
|
430
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
431
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
432
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
433
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
434
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
435
|
+
result_ptr += vector_length;
|
|
436
|
+
norms_ptr += vector_length;
|
|
437
|
+
count_columns -= vector_length;
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
NK_PUBLIC void nk_angulars_packed_f16_rvv( //
|
|
443
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
444
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
445
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
446
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
447
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
448
|
+
nk_dots_packed_f16_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
449
|
+
nk_angulars_packed_f16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
NK_INTERNAL void nk_euclideans_packed_f16_rvv_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
453
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
454
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
455
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
456
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
457
|
+
header->column_count * header->depth_padded_values *
|
|
458
|
+
sizeof(nk_f32_t));
|
|
459
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
460
|
+
nk_f16_t const *a_row = a + row_index * a_stride_elements;
|
|
461
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
462
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_(a_row, depth);
|
|
463
|
+
nk_size_t count_columns = columns;
|
|
464
|
+
nk_f32_t *result_ptr = result_row;
|
|
465
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
466
|
+
while (count_columns > 0) {
|
|
467
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
468
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
469
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
470
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
|
|
471
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
472
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
473
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
474
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
475
|
+
result_ptr += vector_length;
|
|
476
|
+
norms_ptr += vector_length;
|
|
477
|
+
count_columns -= vector_length;
|
|
478
|
+
}
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
NK_PUBLIC void nk_euclideans_packed_f16_rvv( //
|
|
483
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
484
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
485
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
486
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
487
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
488
|
+
nk_dots_packed_f16_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
489
|
+
nk_euclideans_packed_f16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
NK_INTERNAL void nk_angulars_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
493
|
+
nk_size_t stride_elements, nk_f32_t *result,
|
|
494
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
495
|
+
nk_size_t row_count) {
|
|
496
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
497
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
498
|
+
result_row[row_index] = nk_dots_reduce_sumsq_f16_(vectors + row_index * stride_elements, depth);
|
|
499
|
+
}
|
|
500
|
+
nk_f32_t norms_cache[256];
|
|
501
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
502
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
503
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
504
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
|
|
505
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
506
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
507
|
+
if (col_start >= chunk_end) continue;
|
|
508
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
509
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
510
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
511
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
512
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
513
|
+
while (count_remaining > 0) {
|
|
514
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
515
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
516
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
517
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
518
|
+
vector_length);
|
|
519
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
520
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
521
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
522
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
523
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
524
|
+
result_ptr += vector_length;
|
|
525
|
+
norms_ptr += vector_length;
|
|
526
|
+
count_remaining -= vector_length;
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
}
|
|
530
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
531
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
NK_PUBLIC void nk_angulars_symmetric_f16_rvv( //
|
|
535
|
+
nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
536
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
537
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
|
|
538
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
539
|
+
nk_dots_symmetric_f16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
540
|
+
nk_angulars_symmetric_f16_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
541
|
+
row_start, row_count);
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
NK_INTERNAL void nk_euclideans_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors, nk_size_t n_vectors,
|
|
545
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
546
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
547
|
+
nk_size_t row_count) {
|
|
548
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
549
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
550
|
+
result_row[row_index] = nk_dots_reduce_sumsq_f16_(vectors + row_index * stride_elements, depth);
|
|
551
|
+
}
|
|
552
|
+
nk_f32_t norms_cache[256];
|
|
553
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
554
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
555
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
556
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
|
|
557
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
558
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
559
|
+
if (col_start >= chunk_end) continue;
|
|
560
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
561
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
562
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
563
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
564
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
565
|
+
while (count_remaining > 0) {
|
|
566
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
567
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
568
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
569
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
570
|
+
vector_length);
|
|
571
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
572
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
573
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
574
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
575
|
+
result_ptr += vector_length;
|
|
576
|
+
norms_ptr += vector_length;
|
|
577
|
+
count_remaining -= vector_length;
|
|
578
|
+
}
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
582
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
NK_PUBLIC void nk_euclideans_symmetric_f16_rvv( //
|
|
586
|
+
nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
587
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
588
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
|
|
589
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
590
|
+
nk_dots_symmetric_f16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
591
|
+
nk_euclideans_symmetric_f16_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
592
|
+
result_stride_elements, row_start, row_count);
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
#pragma endregion // Half Precision Floats
|
|
596
|
+
|
|
597
|
+
#pragma region Brain Float 16
|
|
598
|
+
|
|
599
|
+
NK_INTERNAL void nk_angulars_packed_bf16_rvv_finalize_(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
600
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
601
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
602
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
603
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
604
|
+
header->column_count * header->depth_padded_values *
|
|
605
|
+
sizeof(nk_f32_t));
|
|
606
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
607
|
+
nk_bf16_t const *a_row = a + row_index * a_stride_elements;
|
|
608
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
609
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_(a_row, depth);
|
|
610
|
+
nk_size_t count_columns = columns;
|
|
611
|
+
nk_f32_t *result_ptr = result_row;
|
|
612
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
613
|
+
while (count_columns > 0) {
|
|
614
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
615
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
616
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
617
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
618
|
+
vector_length);
|
|
619
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
620
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
621
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
622
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
623
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
624
|
+
result_ptr += vector_length;
|
|
625
|
+
norms_ptr += vector_length;
|
|
626
|
+
count_columns -= vector_length;
|
|
627
|
+
}
|
|
628
|
+
}
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
NK_PUBLIC void nk_angulars_packed_bf16_rvv( //
|
|
632
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
633
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
634
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
635
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
636
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
637
|
+
nk_dots_packed_bf16_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
638
|
+
nk_angulars_packed_bf16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
NK_INTERNAL void nk_euclideans_packed_bf16_rvv_finalize_(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
642
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
643
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
644
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
645
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
646
|
+
header->column_count * header->depth_padded_values *
|
|
647
|
+
sizeof(nk_f32_t));
|
|
648
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
649
|
+
nk_bf16_t const *a_row = a + row_index * a_stride_elements;
|
|
650
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
651
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_(a_row, depth);
|
|
652
|
+
nk_size_t count_columns = columns;
|
|
653
|
+
nk_f32_t *result_ptr = result_row;
|
|
654
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
655
|
+
while (count_columns > 0) {
|
|
656
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
657
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
658
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
659
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
|
|
660
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
661
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
662
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
663
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
664
|
+
result_ptr += vector_length;
|
|
665
|
+
norms_ptr += vector_length;
|
|
666
|
+
count_columns -= vector_length;
|
|
667
|
+
}
|
|
668
|
+
}
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
NK_PUBLIC void nk_euclideans_packed_bf16_rvv( //
|
|
672
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
673
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
674
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
675
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
676
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
677
|
+
nk_dots_packed_bf16_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
678
|
+
nk_euclideans_packed_bf16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
NK_INTERNAL void nk_angulars_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vectors, nk_size_t n_vectors,
|
|
682
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
683
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
684
|
+
nk_size_t row_count) {
|
|
685
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
686
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
687
|
+
result_row[row_index] = nk_dots_reduce_sumsq_bf16_(vectors + row_index * stride_elements, depth);
|
|
688
|
+
}
|
|
689
|
+
nk_f32_t norms_cache[256];
|
|
690
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
691
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
692
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
693
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
|
|
694
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
695
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
696
|
+
if (col_start >= chunk_end) continue;
|
|
697
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
698
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
699
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
700
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
701
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
702
|
+
while (count_remaining > 0) {
|
|
703
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
704
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
705
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
706
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
707
|
+
vector_length);
|
|
708
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
709
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
710
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
711
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
712
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
713
|
+
result_ptr += vector_length;
|
|
714
|
+
norms_ptr += vector_length;
|
|
715
|
+
count_remaining -= vector_length;
|
|
716
|
+
}
|
|
717
|
+
}
|
|
718
|
+
}
|
|
719
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
720
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
NK_PUBLIC void nk_angulars_symmetric_bf16_rvv( //
|
|
724
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
725
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
726
|
+
nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
|
|
727
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
728
|
+
nk_dots_symmetric_bf16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
729
|
+
nk_angulars_symmetric_bf16_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
730
|
+
row_start, row_count);
|
|
731
|
+
}
|
|
732
|
+
|
|
733
|
+
NK_INTERNAL void nk_euclideans_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vectors, nk_size_t n_vectors,
|
|
734
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
735
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
736
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
737
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
738
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
739
|
+
result_row[row_index] = nk_dots_reduce_sumsq_bf16_(vectors + row_index * stride_elements, depth);
|
|
740
|
+
}
|
|
741
|
+
nk_f32_t norms_cache[256];
|
|
742
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
743
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
744
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
745
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
|
|
746
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
747
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
748
|
+
if (col_start >= chunk_end) continue;
|
|
749
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
750
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
751
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
752
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
753
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
754
|
+
while (count_remaining > 0) {
|
|
755
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
756
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
757
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
758
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
759
|
+
vector_length);
|
|
760
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
761
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
762
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
763
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
764
|
+
result_ptr += vector_length;
|
|
765
|
+
norms_ptr += vector_length;
|
|
766
|
+
count_remaining -= vector_length;
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
}
|
|
770
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
771
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
NK_PUBLIC void nk_euclideans_symmetric_bf16_rvv( //
|
|
775
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
776
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
777
|
+
nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
|
|
778
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
779
|
+
nk_dots_symmetric_bf16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
780
|
+
nk_euclideans_symmetric_bf16_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
781
|
+
result_stride_elements, row_start, row_count);
|
|
782
|
+
}
|
|
783
|
+
|
|
784
|
+
#pragma endregion // Brain Float 16
|
|
785
|
+
|
|
786
|
+
#pragma region Micro Precision E2M3
|
|
787
|
+
|
|
788
|
+
NK_INTERNAL void nk_angulars_packed_e2m3_rvv_finalize_(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
789
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
790
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
791
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
792
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
793
|
+
header->column_count * header->depth_padded_values *
|
|
794
|
+
sizeof(nk_e2m3_t));
|
|
795
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
796
|
+
nk_e2m3_t const *a_row = a + row_index * a_stride_elements;
|
|
797
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
798
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_(a_row, depth);
|
|
799
|
+
nk_size_t count_columns = columns;
|
|
800
|
+
nk_f32_t *result_ptr = result_row;
|
|
801
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
802
|
+
while (count_columns > 0) {
|
|
803
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
804
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
805
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
806
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
807
|
+
vector_length);
|
|
808
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
809
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
810
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
811
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
812
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
813
|
+
result_ptr += vector_length;
|
|
814
|
+
norms_ptr += vector_length;
|
|
815
|
+
count_columns -= vector_length;
|
|
816
|
+
}
|
|
817
|
+
}
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
NK_PUBLIC void nk_angulars_packed_e2m3_rvv( //
|
|
821
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
822
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
823
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
824
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
825
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
826
|
+
nk_dots_packed_e2m3_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
827
|
+
nk_angulars_packed_e2m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
828
|
+
}
|
|
829
|
+
|
|
830
|
+
NK_INTERNAL void nk_euclideans_packed_e2m3_rvv_finalize_(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
831
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
832
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
833
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
834
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
835
|
+
header->column_count * header->depth_padded_values *
|
|
836
|
+
sizeof(nk_e2m3_t));
|
|
837
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
838
|
+
nk_e2m3_t const *a_row = a + row_index * a_stride_elements;
|
|
839
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
840
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_(a_row, depth);
|
|
841
|
+
nk_size_t count_columns = columns;
|
|
842
|
+
nk_f32_t *result_ptr = result_row;
|
|
843
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
844
|
+
while (count_columns > 0) {
|
|
845
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
846
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
847
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
848
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
|
|
849
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
850
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
851
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
852
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
853
|
+
result_ptr += vector_length;
|
|
854
|
+
norms_ptr += vector_length;
|
|
855
|
+
count_columns -= vector_length;
|
|
856
|
+
}
|
|
857
|
+
}
|
|
858
|
+
}
|
|
859
|
+
|
|
860
|
+
NK_PUBLIC void nk_euclideans_packed_e2m3_rvv( //
|
|
861
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
862
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
863
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
864
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
865
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
866
|
+
nk_dots_packed_e2m3_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
867
|
+
nk_euclideans_packed_e2m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
NK_INTERNAL void nk_angulars_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vectors, nk_size_t n_vectors,
|
|
871
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
872
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
873
|
+
nk_size_t row_count) {
|
|
874
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
875
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
876
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e2m3_(vectors + row_index * stride_elements, depth);
|
|
877
|
+
}
|
|
878
|
+
nk_f32_t norms_cache[256];
|
|
879
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
880
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
881
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
882
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
|
|
883
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
884
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
885
|
+
if (col_start >= chunk_end) continue;
|
|
886
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
887
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
888
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
889
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
890
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
891
|
+
while (count_remaining > 0) {
|
|
892
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
893
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
894
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
895
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
896
|
+
vector_length);
|
|
897
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
898
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
899
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
900
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
901
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
902
|
+
result_ptr += vector_length;
|
|
903
|
+
norms_ptr += vector_length;
|
|
904
|
+
count_remaining -= vector_length;
|
|
905
|
+
}
|
|
906
|
+
}
|
|
907
|
+
}
|
|
908
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
909
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
NK_PUBLIC void nk_angulars_symmetric_e2m3_rvv( //
|
|
913
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
914
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
915
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
|
|
916
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
917
|
+
nk_dots_symmetric_e2m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
918
|
+
nk_angulars_symmetric_e2m3_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
919
|
+
row_start, row_count);
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
NK_INTERNAL void nk_euclideans_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vectors, nk_size_t n_vectors,
|
|
923
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
924
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
925
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
926
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
927
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
928
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e2m3_(vectors + row_index * stride_elements, depth);
|
|
929
|
+
}
|
|
930
|
+
nk_f32_t norms_cache[256];
|
|
931
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
932
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
933
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
934
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
|
|
935
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
936
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
937
|
+
if (col_start >= chunk_end) continue;
|
|
938
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
939
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
940
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
941
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
942
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
943
|
+
while (count_remaining > 0) {
|
|
944
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
945
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
946
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
947
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
948
|
+
vector_length);
|
|
949
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
950
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
951
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
952
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
953
|
+
result_ptr += vector_length;
|
|
954
|
+
norms_ptr += vector_length;
|
|
955
|
+
count_remaining -= vector_length;
|
|
956
|
+
}
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
960
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
961
|
+
}
|
|
962
|
+
|
|
963
|
+
NK_PUBLIC void nk_euclideans_symmetric_e2m3_rvv( //
|
|
964
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
965
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
966
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
|
|
967
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
968
|
+
nk_dots_symmetric_e2m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
969
|
+
nk_euclideans_symmetric_e2m3_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
970
|
+
result_stride_elements, row_start, row_count);
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
#pragma endregion // Micro Precision E2M3
|
|
974
|
+
|
|
975
|
+
#pragma region Micro Precision E3M2
|
|
976
|
+
|
|
977
|
+
NK_INTERNAL void nk_angulars_packed_e3m2_rvv_finalize_(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
978
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
979
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
980
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
981
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
982
|
+
header->column_count * header->depth_padded_values *
|
|
983
|
+
sizeof(nk_i16_t));
|
|
984
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
985
|
+
nk_e3m2_t const *a_row = a + row_index * a_stride_elements;
|
|
986
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
987
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_(a_row, depth);
|
|
988
|
+
nk_size_t count_columns = columns;
|
|
989
|
+
nk_f32_t *result_ptr = result_row;
|
|
990
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
991
|
+
while (count_columns > 0) {
|
|
992
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
993
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
994
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
995
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
996
|
+
vector_length);
|
|
997
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
998
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
999
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1000
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1001
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1002
|
+
result_ptr += vector_length;
|
|
1003
|
+
norms_ptr += vector_length;
|
|
1004
|
+
count_columns -= vector_length;
|
|
1005
|
+
}
|
|
1006
|
+
}
|
|
1007
|
+
}
|
|
1008
|
+
|
|
1009
|
+
NK_PUBLIC void nk_angulars_packed_e3m2_rvv( //
|
|
1010
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1011
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1012
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1013
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1014
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1015
|
+
nk_dots_packed_e3m2_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1016
|
+
nk_angulars_packed_e3m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1017
|
+
}
|
|
1018
|
+
|
|
1019
|
+
NK_INTERNAL void nk_euclideans_packed_e3m2_rvv_finalize_(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1020
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1021
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1022
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1023
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1024
|
+
header->column_count * header->depth_padded_values *
|
|
1025
|
+
sizeof(nk_i16_t));
|
|
1026
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1027
|
+
nk_e3m2_t const *a_row = a + row_index * a_stride_elements;
|
|
1028
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1029
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_(a_row, depth);
|
|
1030
|
+
nk_size_t count_columns = columns;
|
|
1031
|
+
nk_f32_t *result_ptr = result_row;
|
|
1032
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
1033
|
+
while (count_columns > 0) {
|
|
1034
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1035
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1036
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1037
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
|
|
1038
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1039
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1040
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1041
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1042
|
+
result_ptr += vector_length;
|
|
1043
|
+
norms_ptr += vector_length;
|
|
1044
|
+
count_columns -= vector_length;
|
|
1045
|
+
}
|
|
1046
|
+
}
|
|
1047
|
+
}
|
|
1048
|
+
|
|
1049
|
+
NK_PUBLIC void nk_euclideans_packed_e3m2_rvv( //
|
|
1050
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1051
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1052
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1053
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1054
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1055
|
+
nk_dots_packed_e3m2_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1056
|
+
nk_euclideans_packed_e3m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
NK_INTERNAL void nk_angulars_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vectors, nk_size_t n_vectors,
|
|
1060
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1061
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1062
|
+
nk_size_t row_count) {
|
|
1063
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1064
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1065
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e3m2_(vectors + row_index * stride_elements, depth);
|
|
1066
|
+
}
|
|
1067
|
+
nk_f32_t norms_cache[256];
|
|
1068
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1069
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1070
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1071
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
|
|
1072
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1073
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1074
|
+
if (col_start >= chunk_end) continue;
|
|
1075
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1076
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
1077
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1078
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1079
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1080
|
+
while (count_remaining > 0) {
|
|
1081
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1082
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1083
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1084
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1085
|
+
vector_length);
|
|
1086
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1087
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1088
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1089
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1090
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1091
|
+
result_ptr += vector_length;
|
|
1092
|
+
norms_ptr += vector_length;
|
|
1093
|
+
count_remaining -= vector_length;
|
|
1094
|
+
}
|
|
1095
|
+
}
|
|
1096
|
+
}
|
|
1097
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1098
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
NK_PUBLIC void nk_angulars_symmetric_e3m2_rvv( //
|
|
1102
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1103
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1104
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
|
|
1105
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1106
|
+
nk_dots_symmetric_e3m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
1107
|
+
nk_angulars_symmetric_e3m2_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1108
|
+
row_start, row_count);
|
|
1109
|
+
}
|
|
1110
|
+
|
|
1111
|
+
NK_INTERNAL void nk_euclideans_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vectors, nk_size_t n_vectors,
|
|
1112
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
1113
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1114
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1115
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1116
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1117
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e3m2_(vectors + row_index * stride_elements, depth);
|
|
1118
|
+
}
|
|
1119
|
+
nk_f32_t norms_cache[256];
|
|
1120
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1121
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1122
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1123
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
|
|
1124
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1125
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1126
|
+
if (col_start >= chunk_end) continue;
|
|
1127
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1128
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
1129
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1130
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1131
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1132
|
+
while (count_remaining > 0) {
|
|
1133
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1134
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1135
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1136
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1137
|
+
vector_length);
|
|
1138
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1139
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1140
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1141
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1142
|
+
result_ptr += vector_length;
|
|
1143
|
+
norms_ptr += vector_length;
|
|
1144
|
+
count_remaining -= vector_length;
|
|
1145
|
+
}
|
|
1146
|
+
}
|
|
1147
|
+
}
|
|
1148
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1149
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1150
|
+
}
|
|
1151
|
+
|
|
1152
|
+
NK_PUBLIC void nk_euclideans_symmetric_e3m2_rvv( //
|
|
1153
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1154
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1155
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
|
|
1156
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1157
|
+
nk_dots_symmetric_e3m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
1158
|
+
nk_euclideans_symmetric_e3m2_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
1159
|
+
result_stride_elements, row_start, row_count);
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
#pragma endregion // Micro Precision E3M2
|
|
1163
|
+
|
|
1164
|
+
#pragma region Quarter Precision E4M3
|
|
1165
|
+
|
|
1166
|
+
NK_INTERNAL void nk_angulars_packed_e4m3_rvv_finalize_(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1167
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1168
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1169
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1170
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1171
|
+
header->column_count * header->depth_padded_values *
|
|
1172
|
+
sizeof(nk_f32_t));
|
|
1173
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1174
|
+
nk_e4m3_t const *a_row = a + row_index * a_stride_elements;
|
|
1175
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1176
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_(a_row, depth);
|
|
1177
|
+
nk_size_t count_columns = columns;
|
|
1178
|
+
nk_f32_t *result_ptr = result_row;
|
|
1179
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
1180
|
+
while (count_columns > 0) {
|
|
1181
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1182
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1183
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1184
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1185
|
+
vector_length);
|
|
1186
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1187
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1188
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1189
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1190
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1191
|
+
result_ptr += vector_length;
|
|
1192
|
+
norms_ptr += vector_length;
|
|
1193
|
+
count_columns -= vector_length;
|
|
1194
|
+
}
|
|
1195
|
+
}
|
|
1196
|
+
}
|
|
1197
|
+
|
|
1198
|
+
NK_PUBLIC void nk_angulars_packed_e4m3_rvv( //
|
|
1199
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1200
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1201
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1202
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
1203
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1204
|
+
nk_dots_packed_e4m3_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1205
|
+
nk_angulars_packed_e4m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
NK_INTERNAL void nk_euclideans_packed_e4m3_rvv_finalize_(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1209
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1210
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1211
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1212
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1213
|
+
header->column_count * header->depth_padded_values *
|
|
1214
|
+
sizeof(nk_f32_t));
|
|
1215
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1216
|
+
nk_e4m3_t const *a_row = a + row_index * a_stride_elements;
|
|
1217
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1218
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_(a_row, depth);
|
|
1219
|
+
nk_size_t count_columns = columns;
|
|
1220
|
+
nk_f32_t *result_ptr = result_row;
|
|
1221
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
1222
|
+
while (count_columns > 0) {
|
|
1223
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1224
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1225
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1226
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
|
|
1227
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1228
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1229
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1230
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1231
|
+
result_ptr += vector_length;
|
|
1232
|
+
norms_ptr += vector_length;
|
|
1233
|
+
count_columns -= vector_length;
|
|
1234
|
+
}
|
|
1235
|
+
}
|
|
1236
|
+
}
|
|
1237
|
+
|
|
1238
|
+
NK_PUBLIC void nk_euclideans_packed_e4m3_rvv( //
|
|
1239
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1240
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1241
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1242
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
1243
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1244
|
+
nk_dots_packed_e4m3_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1245
|
+
nk_euclideans_packed_e4m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1246
|
+
}
|
|
1247
|
+
|
|
1248
|
+
NK_INTERNAL void nk_angulars_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vectors, nk_size_t n_vectors,
|
|
1249
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1250
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1251
|
+
nk_size_t row_count) {
|
|
1252
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1253
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1254
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e4m3_(vectors + row_index * stride_elements, depth);
|
|
1255
|
+
}
|
|
1256
|
+
nk_f32_t norms_cache[256];
|
|
1257
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1258
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1259
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1260
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
|
|
1261
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1262
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1263
|
+
if (col_start >= chunk_end) continue;
|
|
1264
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1265
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
1266
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1267
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1268
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1269
|
+
while (count_remaining > 0) {
|
|
1270
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1271
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1272
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1273
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1274
|
+
vector_length);
|
|
1275
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1276
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1277
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1278
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1279
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1280
|
+
result_ptr += vector_length;
|
|
1281
|
+
norms_ptr += vector_length;
|
|
1282
|
+
count_remaining -= vector_length;
|
|
1283
|
+
}
|
|
1284
|
+
}
|
|
1285
|
+
}
|
|
1286
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1287
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1288
|
+
}
|
|
1289
|
+
|
|
1290
|
+
NK_PUBLIC void nk_angulars_symmetric_e4m3_rvv( //
|
|
1291
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1292
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1293
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
|
|
1294
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1295
|
+
nk_dots_symmetric_e4m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
1296
|
+
nk_angulars_symmetric_e4m3_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1297
|
+
row_start, row_count);
|
|
1298
|
+
}
|
|
1299
|
+
|
|
1300
|
+
NK_INTERNAL void nk_euclideans_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vectors, nk_size_t n_vectors,
|
|
1301
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
1302
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1303
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1304
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1305
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1306
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e4m3_(vectors + row_index * stride_elements, depth);
|
|
1307
|
+
}
|
|
1308
|
+
nk_f32_t norms_cache[256];
|
|
1309
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1310
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1311
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1312
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
|
|
1313
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1314
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1315
|
+
if (col_start >= chunk_end) continue;
|
|
1316
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1317
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
1318
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1319
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1320
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1321
|
+
while (count_remaining > 0) {
|
|
1322
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1323
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1324
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1325
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1326
|
+
vector_length);
|
|
1327
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1328
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1329
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1330
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1331
|
+
result_ptr += vector_length;
|
|
1332
|
+
norms_ptr += vector_length;
|
|
1333
|
+
count_remaining -= vector_length;
|
|
1334
|
+
}
|
|
1335
|
+
}
|
|
1336
|
+
}
|
|
1337
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1338
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
NK_PUBLIC void nk_euclideans_symmetric_e4m3_rvv( //
|
|
1342
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1343
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1344
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
|
|
1345
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1346
|
+
nk_dots_symmetric_e4m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
1347
|
+
nk_euclideans_symmetric_e4m3_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
1348
|
+
result_stride_elements, row_start, row_count);
|
|
1349
|
+
}
|
|
1350
|
+
|
|
1351
|
+
#pragma endregion // Quarter Precision E4M3
|
|
1352
|
+
|
|
1353
|
+
#pragma region Quarter Precision E5M2
|
|
1354
|
+
|
|
1355
|
+
NK_INTERNAL void nk_angulars_packed_e5m2_rvv_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1356
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1357
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1358
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1359
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1360
|
+
header->column_count * header->depth_padded_values *
|
|
1361
|
+
sizeof(nk_f32_t));
|
|
1362
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1363
|
+
nk_e5m2_t const *a_row = a + row_index * a_stride_elements;
|
|
1364
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1365
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_(a_row, depth);
|
|
1366
|
+
nk_size_t count_columns = columns;
|
|
1367
|
+
nk_f32_t *result_ptr = result_row;
|
|
1368
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
1369
|
+
while (count_columns > 0) {
|
|
1370
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1371
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1372
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1373
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1374
|
+
vector_length);
|
|
1375
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1376
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1377
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1378
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1379
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1380
|
+
result_ptr += vector_length;
|
|
1381
|
+
norms_ptr += vector_length;
|
|
1382
|
+
count_columns -= vector_length;
|
|
1383
|
+
}
|
|
1384
|
+
}
|
|
1385
|
+
}
|
|
1386
|
+
|
|
1387
|
+
NK_PUBLIC void nk_angulars_packed_e5m2_rvv( //
|
|
1388
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1389
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1390
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1391
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
1392
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1393
|
+
nk_dots_packed_e5m2_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1394
|
+
nk_angulars_packed_e5m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1395
|
+
}
|
|
1396
|
+
|
|
1397
|
+
NK_INTERNAL void nk_euclideans_packed_e5m2_rvv_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1398
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1399
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1400
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1401
|
+
nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1402
|
+
header->column_count * header->depth_padded_values *
|
|
1403
|
+
sizeof(nk_f32_t));
|
|
1404
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1405
|
+
nk_e5m2_t const *a_row = a + row_index * a_stride_elements;
|
|
1406
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1407
|
+
nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_(a_row, depth);
|
|
1408
|
+
nk_size_t count_columns = columns;
|
|
1409
|
+
nk_f32_t *result_ptr = result_row;
|
|
1410
|
+
nk_f32_t const *norms_ptr = target_norms;
|
|
1411
|
+
while (count_columns > 0) {
|
|
1412
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1413
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1414
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1415
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
|
|
1416
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1417
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1418
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1419
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1420
|
+
result_ptr += vector_length;
|
|
1421
|
+
norms_ptr += vector_length;
|
|
1422
|
+
count_columns -= vector_length;
|
|
1423
|
+
}
|
|
1424
|
+
}
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1427
|
+
NK_PUBLIC void nk_euclideans_packed_e5m2_rvv( //
|
|
1428
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1429
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1430
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1431
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
1432
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1433
|
+
nk_dots_packed_e5m2_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1434
|
+
nk_euclideans_packed_e5m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1435
|
+
}
|
|
1436
|
+
|
|
1437
|
+
NK_INTERNAL void nk_angulars_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vectors, nk_size_t n_vectors,
|
|
1438
|
+
nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1439
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1440
|
+
nk_size_t row_count) {
|
|
1441
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1442
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1443
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e5m2_(vectors + row_index * stride_elements, depth);
|
|
1444
|
+
}
|
|
1445
|
+
nk_f32_t norms_cache[256];
|
|
1446
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1447
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1448
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1449
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
1450
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1451
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1452
|
+
if (col_start >= chunk_end) continue;
|
|
1453
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1454
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
1455
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1456
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1457
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1458
|
+
while (count_remaining > 0) {
|
|
1459
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1460
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1461
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1462
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1463
|
+
vector_length);
|
|
1464
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1465
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1466
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1467
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1468
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1469
|
+
result_ptr += vector_length;
|
|
1470
|
+
norms_ptr += vector_length;
|
|
1471
|
+
count_remaining -= vector_length;
|
|
1472
|
+
}
|
|
1473
|
+
}
|
|
1474
|
+
}
|
|
1475
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1476
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1477
|
+
}
|
|
1478
|
+
|
|
1479
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_rvv( //
|
|
1480
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1481
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1482
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
|
|
1483
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1484
|
+
nk_dots_symmetric_e5m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
1485
|
+
nk_angulars_symmetric_e5m2_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1486
|
+
row_start, row_count);
|
|
1487
|
+
}
|
|
1488
|
+
|
|
1489
|
+
NK_INTERNAL void nk_euclideans_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vectors, nk_size_t n_vectors,
|
|
1490
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
1491
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1492
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1493
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1494
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1495
|
+
result_row[row_index] = nk_dots_reduce_sumsq_e5m2_(vectors + row_index * stride_elements, depth);
|
|
1496
|
+
}
|
|
1497
|
+
nk_f32_t norms_cache[256];
|
|
1498
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1499
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1500
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1501
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
1502
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1503
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1504
|
+
if (col_start >= chunk_end) continue;
|
|
1505
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1506
|
+
nk_f32_t query_norm_sq_f32 = result_row[row_index];
|
|
1507
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1508
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1509
|
+
nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1510
|
+
while (count_remaining > 0) {
|
|
1511
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1512
|
+
vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
|
|
1513
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
|
|
1514
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1515
|
+
vector_length);
|
|
1516
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1517
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1518
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1519
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1520
|
+
result_ptr += vector_length;
|
|
1521
|
+
norms_ptr += vector_length;
|
|
1522
|
+
count_remaining -= vector_length;
|
|
1523
|
+
}
|
|
1524
|
+
}
|
|
1525
|
+
}
|
|
1526
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1527
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1528
|
+
}
|
|
1529
|
+
|
|
1530
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_rvv( //
|
|
1531
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1532
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1533
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
|
|
1534
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1535
|
+
nk_dots_symmetric_e5m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
1536
|
+
nk_euclideans_symmetric_e5m2_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
1537
|
+
result_stride_elements, row_start, row_count);
|
|
1538
|
+
}
|
|
1539
|
+
|
|
1540
|
+
#pragma endregion // Quarter Precision E5M2
|
|
1541
|
+
|
|
1542
|
+
#pragma region Signed 8-bit Integers
|
|
1543
|
+
|
|
1544
|
+
NK_INTERNAL void nk_angulars_packed_i8_rvv_finalize_(nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1545
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1546
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1547
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1548
|
+
nk_u32_t const *target_norms = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1549
|
+
header->column_count * header->depth_padded_values *
|
|
1550
|
+
sizeof(nk_i8_t));
|
|
1551
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1552
|
+
nk_i8_t const *a_row = a + row_index * a_stride_elements;
|
|
1553
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1554
|
+
nk_u32_t query_norm_sq = nk_dots_reduce_sumsq_i8_(a_row, depth);
|
|
1555
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
|
|
1556
|
+
nk_size_t count_columns = columns;
|
|
1557
|
+
nk_f32_t *result_ptr = result_row;
|
|
1558
|
+
nk_u32_t const *norms_ptr = target_norms;
|
|
1559
|
+
while (count_columns > 0) {
|
|
1560
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1561
|
+
vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_x_v_f32m1(
|
|
1562
|
+
__riscv_vle32_v_i32m1((nk_i32_t const *)result_ptr, vector_length), vector_length);
|
|
1563
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1564
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
|
|
1565
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1566
|
+
vector_length);
|
|
1567
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1568
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1569
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1570
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1571
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1572
|
+
result_ptr += vector_length;
|
|
1573
|
+
norms_ptr += vector_length;
|
|
1574
|
+
count_columns -= vector_length;
|
|
1575
|
+
}
|
|
1576
|
+
}
|
|
1577
|
+
}
|
|
1578
|
+
|
|
1579
|
+
NK_PUBLIC void nk_angulars_packed_i8_rvv( //
|
|
1580
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1581
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1582
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1583
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
1584
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1585
|
+
nk_dots_packed_i8_rvv(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1586
|
+
nk_angulars_packed_i8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1587
|
+
}
|
|
1588
|
+
|
|
1589
|
+
NK_INTERNAL void nk_euclideans_packed_i8_rvv_finalize_(nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1590
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1591
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1592
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1593
|
+
nk_u32_t const *target_norms = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1594
|
+
header->column_count * header->depth_padded_values *
|
|
1595
|
+
sizeof(nk_i8_t));
|
|
1596
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1597
|
+
nk_i8_t const *a_row = a + row_index * a_stride_elements;
|
|
1598
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1599
|
+
nk_u32_t query_norm_sq = nk_dots_reduce_sumsq_i8_(a_row, depth);
|
|
1600
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
|
|
1601
|
+
nk_size_t count_columns = columns;
|
|
1602
|
+
nk_f32_t *result_ptr = result_row;
|
|
1603
|
+
nk_u32_t const *norms_ptr = target_norms;
|
|
1604
|
+
while (count_columns > 0) {
|
|
1605
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1606
|
+
vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_x_v_f32m1(
|
|
1607
|
+
__riscv_vle32_v_i32m1((nk_i32_t const *)result_ptr, vector_length), vector_length);
|
|
1608
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1609
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
|
|
1610
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
|
|
1611
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1612
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1613
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1614
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1615
|
+
result_ptr += vector_length;
|
|
1616
|
+
norms_ptr += vector_length;
|
|
1617
|
+
count_columns -= vector_length;
|
|
1618
|
+
}
|
|
1619
|
+
}
|
|
1620
|
+
}
|
|
1621
|
+
|
|
1622
|
+
NK_PUBLIC void nk_euclideans_packed_i8_rvv( //
|
|
1623
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1624
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1625
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1626
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
1627
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1628
|
+
nk_dots_packed_i8_rvv(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1629
|
+
nk_euclideans_packed_i8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1630
|
+
}
|
|
1631
|
+
|
|
1632
|
+
NK_INTERNAL void nk_angulars_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1633
|
+
nk_size_t stride_elements, nk_f32_t *result,
|
|
1634
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1635
|
+
nk_size_t row_count) {
|
|
1636
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1637
|
+
nk_u32_t norm = nk_dots_reduce_sumsq_i8_(vectors + row_index * stride_elements, depth);
|
|
1638
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
|
|
1639
|
+
}
|
|
1640
|
+
nk_u32_t norms_cache[256];
|
|
1641
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1642
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1643
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1644
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
|
|
1645
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1646
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1647
|
+
if (col_start >= chunk_end) continue;
|
|
1648
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1649
|
+
nk_u32_t query_norm_sq = ((nk_u32_t *)result_row)[row_index];
|
|
1650
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
|
|
1651
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1652
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1653
|
+
nk_u32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1654
|
+
while (count_remaining > 0) {
|
|
1655
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1656
|
+
vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_x_v_f32m1(
|
|
1657
|
+
__riscv_vle32_v_i32m1((nk_i32_t const *)result_ptr, vector_length), vector_length);
|
|
1658
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1659
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
|
|
1660
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1661
|
+
vector_length);
|
|
1662
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1663
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1664
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1665
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1666
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1667
|
+
result_ptr += vector_length;
|
|
1668
|
+
norms_ptr += vector_length;
|
|
1669
|
+
count_remaining -= vector_length;
|
|
1670
|
+
}
|
|
1671
|
+
}
|
|
1672
|
+
}
|
|
1673
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1674
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1675
|
+
}
|
|
1676
|
+
|
|
1677
|
+
NK_PUBLIC void nk_angulars_symmetric_i8_rvv( //
|
|
1678
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1679
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1680
|
+
nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
|
|
1681
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1682
|
+
nk_dots_symmetric_i8_rvv(vectors, n_vectors, depth, stride, (nk_i32_t *)result, result_stride, row_start,
|
|
1683
|
+
row_count);
|
|
1684
|
+
nk_angulars_symmetric_i8_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1685
|
+
row_start, row_count);
|
|
1686
|
+
}
|
|
1687
|
+
|
|
1688
|
+
NK_INTERNAL void nk_euclideans_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1689
|
+
nk_size_t stride_elements, nk_f32_t *result,
|
|
1690
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1691
|
+
nk_size_t row_count) {
|
|
1692
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1693
|
+
nk_u32_t norm = nk_dots_reduce_sumsq_i8_(vectors + row_index * stride_elements, depth);
|
|
1694
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
|
|
1695
|
+
}
|
|
1696
|
+
nk_u32_t norms_cache[256];
|
|
1697
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1698
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1699
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1700
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
|
|
1701
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1702
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1703
|
+
if (col_start >= chunk_end) continue;
|
|
1704
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1705
|
+
nk_u32_t query_norm_sq = ((nk_u32_t *)result_row)[row_index];
|
|
1706
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
|
|
1707
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1708
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1709
|
+
nk_u32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1710
|
+
while (count_remaining > 0) {
|
|
1711
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1712
|
+
vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_x_v_f32m1(
|
|
1713
|
+
__riscv_vle32_v_i32m1((nk_i32_t const *)result_ptr, vector_length), vector_length);
|
|
1714
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1715
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
|
|
1716
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1717
|
+
vector_length);
|
|
1718
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1719
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1720
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1721
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1722
|
+
result_ptr += vector_length;
|
|
1723
|
+
norms_ptr += vector_length;
|
|
1724
|
+
count_remaining -= vector_length;
|
|
1725
|
+
}
|
|
1726
|
+
}
|
|
1727
|
+
}
|
|
1728
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1729
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1730
|
+
}
|
|
1731
|
+
|
|
1732
|
+
NK_PUBLIC void nk_euclideans_symmetric_i8_rvv( //
|
|
1733
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1734
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1735
|
+
nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
|
|
1736
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1737
|
+
nk_dots_symmetric_i8_rvv(vectors, n_vectors, depth, stride, (nk_i32_t *)result, result_stride, row_start,
|
|
1738
|
+
row_count);
|
|
1739
|
+
nk_euclideans_symmetric_i8_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1740
|
+
row_start, row_count);
|
|
1741
|
+
}
|
|
1742
|
+
|
|
1743
|
+
#pragma endregion // Signed 8-bit Integers
|
|
1744
|
+
|
|
1745
|
+
#pragma region Unsigned 8-bit Integers
|
|
1746
|
+
|
|
1747
|
+
NK_INTERNAL void nk_angulars_packed_u8_rvv_finalize_(nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1748
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1749
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1750
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1751
|
+
nk_u32_t const *target_norms = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1752
|
+
header->column_count * header->depth_padded_values *
|
|
1753
|
+
sizeof(nk_u8_t));
|
|
1754
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1755
|
+
nk_u8_t const *a_row = a + row_index * a_stride_elements;
|
|
1756
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1757
|
+
nk_u32_t query_norm_sq = nk_dots_reduce_sumsq_u8_(a_row, depth);
|
|
1758
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
|
|
1759
|
+
nk_size_t count_columns = columns;
|
|
1760
|
+
nk_f32_t *result_ptr = result_row;
|
|
1761
|
+
nk_u32_t const *norms_ptr = target_norms;
|
|
1762
|
+
while (count_columns > 0) {
|
|
1763
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1764
|
+
vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1765
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)result_ptr, vector_length), vector_length);
|
|
1766
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1767
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
|
|
1768
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1769
|
+
vector_length);
|
|
1770
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1771
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1772
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1773
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1774
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1775
|
+
result_ptr += vector_length;
|
|
1776
|
+
norms_ptr += vector_length;
|
|
1777
|
+
count_columns -= vector_length;
|
|
1778
|
+
}
|
|
1779
|
+
}
|
|
1780
|
+
}
|
|
1781
|
+
|
|
1782
|
+
NK_PUBLIC void nk_angulars_packed_u8_rvv( //
|
|
1783
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1784
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1785
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1786
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
1787
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1788
|
+
nk_dots_packed_u8_rvv(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1789
|
+
nk_angulars_packed_u8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1790
|
+
}
|
|
1791
|
+
|
|
1792
|
+
NK_INTERNAL void nk_euclideans_packed_u8_rvv_finalize_(nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1793
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1794
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1795
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
|
|
1796
|
+
nk_u32_t const *target_norms = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
|
|
1797
|
+
header->column_count * header->depth_padded_values *
|
|
1798
|
+
sizeof(nk_u8_t));
|
|
1799
|
+
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
1800
|
+
nk_u8_t const *a_row = a + row_index * a_stride_elements;
|
|
1801
|
+
nk_f32_t *result_row = c + row_index * c_stride_elements;
|
|
1802
|
+
nk_u32_t query_norm_sq = nk_dots_reduce_sumsq_u8_(a_row, depth);
|
|
1803
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
|
|
1804
|
+
nk_size_t count_columns = columns;
|
|
1805
|
+
nk_f32_t *result_ptr = result_row;
|
|
1806
|
+
nk_u32_t const *norms_ptr = target_norms;
|
|
1807
|
+
while (count_columns > 0) {
|
|
1808
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
|
|
1809
|
+
vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1810
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)result_ptr, vector_length), vector_length);
|
|
1811
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1812
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
|
|
1813
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
|
|
1814
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1815
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1816
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1817
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1818
|
+
result_ptr += vector_length;
|
|
1819
|
+
norms_ptr += vector_length;
|
|
1820
|
+
count_columns -= vector_length;
|
|
1821
|
+
}
|
|
1822
|
+
}
|
|
1823
|
+
}
|
|
1824
|
+
|
|
1825
|
+
NK_PUBLIC void nk_euclideans_packed_u8_rvv( //
|
|
1826
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1827
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1828
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1829
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
1830
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1831
|
+
nk_dots_packed_u8_rvv(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1832
|
+
nk_euclideans_packed_u8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1833
|
+
}
|
|
1834
|
+
|
|
1835
|
+
NK_INTERNAL void nk_angulars_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1836
|
+
nk_size_t stride_elements, nk_f32_t *result,
|
|
1837
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1838
|
+
nk_size_t row_count) {
|
|
1839
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1840
|
+
nk_u32_t norm = nk_dots_reduce_sumsq_u8_(vectors + row_index * stride_elements, depth);
|
|
1841
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
|
|
1842
|
+
}
|
|
1843
|
+
nk_u32_t norms_cache[256];
|
|
1844
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1845
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1846
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1847
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
|
|
1848
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1849
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1850
|
+
if (col_start >= chunk_end) continue;
|
|
1851
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1852
|
+
nk_u32_t query_norm_sq = ((nk_u32_t *)result_row)[row_index];
|
|
1853
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
|
|
1854
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1855
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1856
|
+
nk_u32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1857
|
+
while (count_remaining > 0) {
|
|
1858
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1859
|
+
vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1860
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)result_ptr, vector_length), vector_length);
|
|
1861
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1862
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
|
|
1863
|
+
vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1864
|
+
vector_length);
|
|
1865
|
+
vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
|
|
1866
|
+
vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
|
|
1867
|
+
vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
|
|
1868
|
+
angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
|
|
1869
|
+
__riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
|
|
1870
|
+
result_ptr += vector_length;
|
|
1871
|
+
norms_ptr += vector_length;
|
|
1872
|
+
count_remaining -= vector_length;
|
|
1873
|
+
}
|
|
1874
|
+
}
|
|
1875
|
+
}
|
|
1876
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1877
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1878
|
+
}
|
|
1879
|
+
|
|
1880
|
+
NK_PUBLIC void nk_angulars_symmetric_u8_rvv( //
|
|
1881
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1882
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1883
|
+
nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
|
|
1884
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1885
|
+
nk_dots_symmetric_u8_rvv(vectors, n_vectors, depth, stride, (nk_u32_t *)result, result_stride, row_start,
|
|
1886
|
+
row_count);
|
|
1887
|
+
nk_angulars_symmetric_u8_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1888
|
+
row_start, row_count);
|
|
1889
|
+
}
|
|
1890
|
+
|
|
1891
|
+
NK_INTERNAL void nk_euclideans_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1892
|
+
nk_size_t stride_elements, nk_f32_t *result,
|
|
1893
|
+
nk_size_t result_stride_elements, nk_size_t row_start,
|
|
1894
|
+
nk_size_t row_count) {
|
|
1895
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1896
|
+
nk_u32_t norm = nk_dots_reduce_sumsq_u8_(vectors + row_index * stride_elements, depth);
|
|
1897
|
+
((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
|
|
1898
|
+
}
|
|
1899
|
+
nk_u32_t norms_cache[256];
|
|
1900
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1901
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1902
|
+
for (nk_size_t col = chunk_start; col < chunk_end; ++col)
|
|
1903
|
+
norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
|
|
1904
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1905
|
+
nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
|
|
1906
|
+
if (col_start >= chunk_end) continue;
|
|
1907
|
+
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
1908
|
+
nk_u32_t query_norm_sq = ((nk_u32_t *)result_row)[row_index];
|
|
1909
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
|
|
1910
|
+
nk_size_t count_remaining = chunk_end - col_start;
|
|
1911
|
+
nk_f32_t *result_ptr = result_row + col_start;
|
|
1912
|
+
nk_u32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
|
|
1913
|
+
while (count_remaining > 0) {
|
|
1914
|
+
size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
|
|
1915
|
+
vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1916
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)result_ptr, vector_length), vector_length);
|
|
1917
|
+
vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
|
|
1918
|
+
__riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
|
|
1919
|
+
vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
|
|
1920
|
+
vector_length);
|
|
1921
|
+
vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
|
|
1922
|
+
sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
|
|
1923
|
+
dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
|
|
1924
|
+
__riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
|
|
1925
|
+
result_ptr += vector_length;
|
|
1926
|
+
norms_ptr += vector_length;
|
|
1927
|
+
count_remaining -= vector_length;
|
|
1928
|
+
}
|
|
1929
|
+
}
|
|
1930
|
+
}
|
|
1931
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
|
|
1932
|
+
result[row_index * result_stride_elements + row_index] = 0;
|
|
1933
|
+
}
|
|
1934
|
+
|
|
1935
|
+
NK_PUBLIC void nk_euclideans_symmetric_u8_rvv( //
|
|
1936
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1937
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1938
|
+
nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
|
|
1939
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1940
|
+
nk_dots_symmetric_u8_rvv(vectors, n_vectors, depth, stride, (nk_u32_t *)result, result_stride, row_start,
|
|
1941
|
+
row_count);
|
|
1942
|
+
nk_euclideans_symmetric_u8_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
|
|
1943
|
+
row_start, row_count);
|
|
1944
|
+
}
|
|
1945
|
+
|
|
1946
|
+
#pragma endregion // Unsigned 8-bit Integers
|
|
1947
|
+
|
|
1948
|
+
#if defined(__clang__)
|
|
1949
|
+
#pragma clang attribute pop
|
|
1950
|
+
#elif defined(__GNUC__)
|
|
1951
|
+
#pragma GCC pop_options
|
|
1952
|
+
#endif
|
|
1953
|
+
|
|
1954
|
+
#if defined(__cplusplus)
|
|
1955
|
+
} // extern "C"
|
|
1956
|
+
#endif
|
|
1957
|
+
|
|
1958
|
+
#endif // NK_TARGET_RVV
|
|
1959
|
+
#endif // NK_TARGET_RISCV_
|
|
1960
|
+
#endif // NK_SPATIALS_RVV_H
|