numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -133,7 +133,7 @@ NK_INTERNAL vfloat64m4_t nk_f64m4_reciprocal_rvv_(vfloat64m4_t x_f64m4, nk_size_
|
|
|
133
133
|
return est_f64m4;
|
|
134
134
|
}
|
|
135
135
|
|
|
136
|
-
#pragma region
|
|
136
|
+
#pragma region I8 and U8 Integers
|
|
137
137
|
|
|
138
138
|
NK_PUBLIC void nk_sqeuclidean_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
139
139
|
nk_u32_t *result) {
|
|
@@ -187,13 +187,13 @@ NK_PUBLIC void nk_euclidean_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_sc
|
|
|
187
187
|
*result = nk_f32_sqrt_rvv((nk_f32_t)d2);
|
|
188
188
|
}
|
|
189
189
|
|
|
190
|
-
#pragma endregion
|
|
191
|
-
#pragma region
|
|
190
|
+
#pragma endregion I8 and U8 Integers
|
|
191
|
+
#pragma region F32 and F64 Floats
|
|
192
192
|
|
|
193
193
|
NK_PUBLIC void nk_sqeuclidean_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
194
194
|
nk_f64_t *result) {
|
|
195
|
-
nk_size_t
|
|
196
|
-
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
195
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
196
|
+
vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
197
197
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
198
198
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
199
199
|
vector_length = __riscv_vsetvl_e32m1(count_scalars);
|
|
@@ -206,7 +206,7 @@ NK_PUBLIC void nk_sqeuclidean_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const
|
|
|
206
206
|
}
|
|
207
207
|
// Single horizontal reduction at the end
|
|
208
208
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
209
|
-
*result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1,
|
|
209
|
+
*result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1, max_vector_length));
|
|
210
210
|
}
|
|
211
211
|
|
|
212
212
|
NK_PUBLIC void nk_euclidean_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -239,8 +239,8 @@ NK_PUBLIC void nk_euclidean_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b
|
|
|
239
239
|
*result = nk_f64_sqrt_rvv(*result);
|
|
240
240
|
}
|
|
241
241
|
|
|
242
|
-
#pragma endregion
|
|
243
|
-
#pragma region
|
|
242
|
+
#pragma endregion F32 and F64 Floats
|
|
243
|
+
#pragma region I8 and U8 Integers
|
|
244
244
|
|
|
245
245
|
NK_PUBLIC void nk_angular_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
246
246
|
nk_f32_t *result) {
|
|
@@ -320,15 +320,15 @@ NK_PUBLIC void nk_angular_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scal
|
|
|
320
320
|
}
|
|
321
321
|
}
|
|
322
322
|
|
|
323
|
-
#pragma endregion
|
|
324
|
-
#pragma region
|
|
323
|
+
#pragma endregion I8 and U8 Integers
|
|
324
|
+
#pragma region F32 and F64 Floats
|
|
325
325
|
|
|
326
326
|
NK_PUBLIC void nk_angular_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
327
327
|
nk_f64_t *result) {
|
|
328
|
-
nk_size_t
|
|
329
|
-
vfloat64m2_t dot_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
330
|
-
vfloat64m2_t a_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
331
|
-
vfloat64m2_t b_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0,
|
|
328
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
329
|
+
vfloat64m2_t dot_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
330
|
+
vfloat64m2_t a_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
331
|
+
vfloat64m2_t b_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
|
|
332
332
|
|
|
333
333
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
334
334
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -344,11 +344,12 @@ NK_PUBLIC void nk_angular_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_s
|
|
|
344
344
|
|
|
345
345
|
// Single horizontal reduction at the end for all three accumulators
|
|
346
346
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
347
|
-
nk_f64_t dot_f64 = __riscv_vfmv_f_s_f64m1_f64(
|
|
347
|
+
nk_f64_t dot_f64 = __riscv_vfmv_f_s_f64m1_f64(
|
|
348
|
+
__riscv_vfredusum_vs_f64m2_f64m1(dot_f64m2, zero_f64m1, max_vector_length));
|
|
348
349
|
nk_f64_t a_norm_sq_f64 = __riscv_vfmv_f_s_f64m1_f64(
|
|
349
|
-
__riscv_vfredusum_vs_f64m2_f64m1(a_norm_sq_f64m2, zero_f64m1,
|
|
350
|
+
__riscv_vfredusum_vs_f64m2_f64m1(a_norm_sq_f64m2, zero_f64m1, max_vector_length));
|
|
350
351
|
nk_f64_t b_norm_sq_f64 = __riscv_vfmv_f_s_f64m1_f64(
|
|
351
|
-
__riscv_vfredusum_vs_f64m2_f64m1(b_norm_sq_f64m2, zero_f64m1,
|
|
352
|
+
__riscv_vfredusum_vs_f64m2_f64m1(b_norm_sq_f64m2, zero_f64m1, max_vector_length));
|
|
352
353
|
|
|
353
354
|
// Normalize: 1 − dot / √(‖a‖² × ‖b‖²)
|
|
354
355
|
if (a_norm_sq_f64 == 0.0 && b_norm_sq_f64 == 0.0) { *result = 0.0; }
|
|
@@ -413,13 +414,13 @@ NK_PUBLIC void nk_angular_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_s
|
|
|
413
414
|
}
|
|
414
415
|
}
|
|
415
416
|
|
|
416
|
-
#pragma endregion
|
|
417
|
-
#pragma region
|
|
417
|
+
#pragma endregion F32 and F64 Floats
|
|
418
|
+
#pragma region F16 and BF16 Floats
|
|
418
419
|
|
|
419
420
|
NK_PUBLIC void nk_sqeuclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
420
421
|
nk_f32_t *result) {
|
|
421
|
-
nk_size_t
|
|
422
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
422
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
423
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
423
424
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
424
425
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
425
426
|
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
@@ -436,7 +437,7 @@ NK_PUBLIC void nk_sqeuclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const
|
|
|
436
437
|
}
|
|
437
438
|
// Single horizontal reduction at the end
|
|
438
439
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
439
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
440
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
|
|
440
441
|
}
|
|
441
442
|
|
|
442
443
|
NK_PUBLIC void nk_euclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -447,10 +448,10 @@ NK_PUBLIC void nk_euclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b
|
|
|
447
448
|
|
|
448
449
|
NK_PUBLIC void nk_angular_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
449
450
|
nk_f32_t *result) {
|
|
450
|
-
nk_size_t
|
|
451
|
-
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
452
|
-
vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
453
|
-
vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
451
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
452
|
+
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
453
|
+
vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
454
|
+
vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
454
455
|
|
|
455
456
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
456
457
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -470,11 +471,12 @@ NK_PUBLIC void nk_angular_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_s
|
|
|
470
471
|
|
|
471
472
|
// Single horizontal reduction at the end for all three accumulators
|
|
472
473
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
473
|
-
nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
474
|
+
nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
475
|
+
__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, max_vector_length));
|
|
474
476
|
nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
475
|
-
__riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1,
|
|
477
|
+
__riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1, max_vector_length));
|
|
476
478
|
nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
477
|
-
__riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1,
|
|
479
|
+
__riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1, max_vector_length));
|
|
478
480
|
|
|
479
481
|
if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
|
|
480
482
|
else if (dot_f32 == 0.0f) { *result = 1.0f; }
|
|
@@ -486,8 +488,8 @@ NK_PUBLIC void nk_angular_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_s
|
|
|
486
488
|
|
|
487
489
|
NK_PUBLIC void nk_sqeuclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
488
490
|
nk_f32_t *result) {
|
|
489
|
-
nk_size_t
|
|
490
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
491
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
492
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
491
493
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
492
494
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
493
495
|
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
@@ -504,7 +506,7 @@ NK_PUBLIC void nk_sqeuclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t con
|
|
|
504
506
|
}
|
|
505
507
|
// Single horizontal reduction at the end
|
|
506
508
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
507
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
509
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
|
|
508
510
|
}
|
|
509
511
|
|
|
510
512
|
NK_PUBLIC void nk_euclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -515,10 +517,10 @@ NK_PUBLIC void nk_euclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const
|
|
|
515
517
|
|
|
516
518
|
NK_PUBLIC void nk_angular_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
517
519
|
nk_f32_t *result) {
|
|
518
|
-
nk_size_t
|
|
519
|
-
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
520
|
-
vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
521
|
-
vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
520
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
521
|
+
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
522
|
+
vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
523
|
+
vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
522
524
|
|
|
523
525
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
524
526
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -538,11 +540,12 @@ NK_PUBLIC void nk_angular_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *
|
|
|
538
540
|
|
|
539
541
|
// Single horizontal reduction at the end for all three accumulators
|
|
540
542
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
541
|
-
nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
543
|
+
nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
544
|
+
__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, max_vector_length));
|
|
542
545
|
nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
543
|
-
__riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1,
|
|
546
|
+
__riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1, max_vector_length));
|
|
544
547
|
nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
545
|
-
__riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1,
|
|
548
|
+
__riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1, max_vector_length));
|
|
546
549
|
|
|
547
550
|
if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
|
|
548
551
|
else if (dot_f32 == 0.0f) { *result = 1.0f; }
|
|
@@ -554,8 +557,8 @@ NK_PUBLIC void nk_angular_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *
|
|
|
554
557
|
|
|
555
558
|
NK_PUBLIC void nk_sqeuclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
556
559
|
nk_f32_t *result) {
|
|
557
|
-
nk_size_t
|
|
558
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
560
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
561
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
559
562
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
560
563
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
561
564
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -572,7 +575,7 @@ NK_PUBLIC void nk_sqeuclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t con
|
|
|
572
575
|
}
|
|
573
576
|
// Single horizontal reduction at the end
|
|
574
577
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
575
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
578
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
|
|
576
579
|
}
|
|
577
580
|
|
|
578
581
|
NK_PUBLIC void nk_euclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -583,10 +586,10 @@ NK_PUBLIC void nk_euclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const
|
|
|
583
586
|
|
|
584
587
|
NK_PUBLIC void nk_angular_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
585
588
|
nk_f32_t *result) {
|
|
586
|
-
nk_size_t
|
|
587
|
-
vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
588
|
-
vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
589
|
-
vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
589
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
590
|
+
vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
591
|
+
vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
592
|
+
vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
590
593
|
|
|
591
594
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
592
595
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -606,11 +609,12 @@ NK_PUBLIC void nk_angular_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
|
|
|
606
609
|
|
|
607
610
|
// Single horizontal reduction at the end for all three accumulators
|
|
608
611
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
609
|
-
nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
612
|
+
nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
613
|
+
__riscv_vfredusum_vs_f32m4_f32m1(dot_f32m4, zero_f32m1, max_vector_length));
|
|
610
614
|
nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
611
|
-
__riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1,
|
|
615
|
+
__riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1, max_vector_length));
|
|
612
616
|
nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
613
|
-
__riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1,
|
|
617
|
+
__riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1, max_vector_length));
|
|
614
618
|
|
|
615
619
|
if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
|
|
616
620
|
else if (dot_f32 == 0.0f) { *result = 1.0f; }
|
|
@@ -622,8 +626,8 @@ NK_PUBLIC void nk_angular_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
|
|
|
622
626
|
|
|
623
627
|
NK_PUBLIC void nk_sqeuclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
624
628
|
nk_f32_t *result) {
|
|
625
|
-
nk_size_t
|
|
626
|
-
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
629
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
630
|
+
vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
627
631
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
628
632
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
629
633
|
vector_length = __riscv_vsetvl_e8m1(count_scalars);
|
|
@@ -640,7 +644,7 @@ NK_PUBLIC void nk_sqeuclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t con
|
|
|
640
644
|
}
|
|
641
645
|
// Single horizontal reduction at the end
|
|
642
646
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
643
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1,
|
|
647
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
|
|
644
648
|
}
|
|
645
649
|
|
|
646
650
|
NK_PUBLIC void nk_euclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -651,10 +655,10 @@ NK_PUBLIC void nk_euclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const
|
|
|
651
655
|
|
|
652
656
|
NK_PUBLIC void nk_angular_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
653
657
|
nk_f32_t *result) {
|
|
654
|
-
nk_size_t
|
|
655
|
-
vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
656
|
-
vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
657
|
-
vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f,
|
|
658
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
659
|
+
vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
660
|
+
vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
661
|
+
vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
|
|
658
662
|
|
|
659
663
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
660
664
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -674,11 +678,12 @@ NK_PUBLIC void nk_angular_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *
|
|
|
674
678
|
|
|
675
679
|
// Single horizontal reduction at the end for all three accumulators
|
|
676
680
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
677
|
-
nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
681
|
+
nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
682
|
+
__riscv_vfredusum_vs_f32m4_f32m1(dot_f32m4, zero_f32m1, max_vector_length));
|
|
678
683
|
nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
679
|
-
__riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1,
|
|
684
|
+
__riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1, max_vector_length));
|
|
680
685
|
nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
|
|
681
|
-
__riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1,
|
|
686
|
+
__riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1, max_vector_length));
|
|
682
687
|
|
|
683
688
|
if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
|
|
684
689
|
else if (dot_f32 == 0.0f) { *result = 1.0f; }
|
|
@@ -688,8 +693,8 @@ NK_PUBLIC void nk_angular_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *
|
|
|
688
693
|
}
|
|
689
694
|
}
|
|
690
695
|
|
|
691
|
-
#pragma endregion
|
|
692
|
-
#pragma region
|
|
696
|
+
#pragma endregion F16 and BF16 Floats
|
|
697
|
+
#pragma region I8 and U8 Integers
|
|
693
698
|
|
|
694
699
|
NK_PUBLIC void nk_sqeuclidean_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scalars, nk_size_t count_scalars,
|
|
695
700
|
nk_u32_t *result) {
|
|
@@ -713,31 +718,31 @@ NK_PUBLIC void nk_sqeuclidean_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const
|
|
|
713
718
|
};
|
|
714
719
|
count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
|
|
715
720
|
nk_size_t n_bytes = count_scalars / 2;
|
|
716
|
-
nk_size_t
|
|
717
|
-
vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
721
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
722
|
+
vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
718
723
|
for (nk_size_t vector_length; n_bytes > 0;
|
|
719
724
|
n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
720
725
|
vector_length = __riscv_vsetvl_e8m1(n_bytes);
|
|
721
726
|
vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
|
|
722
727
|
vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
|
|
723
|
-
// Build LUT indices: high nibble pair = (
|
|
724
|
-
vuint8m1_t
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
// Low nibble pair = (
|
|
728
|
-
vuint8m1_t
|
|
728
|
+
// Build LUT indices: high nibble pair = (a_high << 4) | b_hi
|
|
729
|
+
vuint8m1_t high_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
|
|
730
|
+
__riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length),
|
|
731
|
+
vector_length);
|
|
732
|
+
// Low nibble pair = (a_low << 4) | b_lo
|
|
733
|
+
vuint8m1_t low_idx_u8m1 = __riscv_vor_vv_u8m1(
|
|
729
734
|
__riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length), 4, vector_length),
|
|
730
735
|
__riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length), vector_length);
|
|
731
736
|
// Gather squared differences from LUT (0-225, fits u8)
|
|
732
|
-
vuint8m1_t
|
|
733
|
-
vuint8m1_t
|
|
737
|
+
vuint8m1_t sq_high_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sqd_lut_, high_idx_u8m1, vector_length);
|
|
738
|
+
vuint8m1_t sq_low_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sqd_lut_, low_idx_u8m1, vector_length);
|
|
734
739
|
// Combine and per-lane accumulate: u8+u8→u16, then u32+=u16
|
|
735
|
-
vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(
|
|
740
|
+
vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(sq_high_u8m1, sq_low_u8m1, vector_length);
|
|
736
741
|
sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, combined_u16m2, vector_length);
|
|
737
742
|
}
|
|
738
743
|
// Single horizontal reduction after loop
|
|
739
|
-
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0,
|
|
740
|
-
*result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1,
|
|
744
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
|
|
745
|
+
*result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, max_vector_length));
|
|
741
746
|
}
|
|
742
747
|
|
|
743
748
|
NK_PUBLIC void nk_euclidean_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -770,10 +775,10 @@ NK_PUBLIC void nk_angular_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_
|
|
|
770
775
|
static nk_u8_t const nk_i4_sq_lut_[16] = {0, 1, 4, 9, 16, 25, 36, 49, 64, 49, 36, 25, 16, 9, 4, 1};
|
|
771
776
|
count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
|
|
772
777
|
nk_size_t n_bytes = count_scalars / 2;
|
|
773
|
-
nk_size_t
|
|
774
|
-
vint32m4_t dot_i32m4 = __riscv_vmv_v_x_i32m4(0,
|
|
775
|
-
vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
776
|
-
vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
778
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
779
|
+
vint32m4_t dot_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
|
|
780
|
+
vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
781
|
+
vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
777
782
|
|
|
778
783
|
for (nk_size_t vector_length; n_bytes > 0;
|
|
779
784
|
n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -782,44 +787,45 @@ NK_PUBLIC void nk_angular_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_
|
|
|
782
787
|
vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
|
|
783
788
|
|
|
784
789
|
// Extract nibbles for index building
|
|
785
|
-
vuint8m1_t
|
|
786
|
-
vuint8m1_t
|
|
787
|
-
vuint8m1_t
|
|
788
|
-
vuint8m1_t
|
|
790
|
+
vuint8m1_t a_high_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
|
|
791
|
+
vuint8m1_t b_high_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
|
|
792
|
+
vuint8m1_t a_low_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
|
|
793
|
+
vuint8m1_t b_low_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
|
|
789
794
|
|
|
790
795
|
// Dot product via 256-entry LUT: dot_lut[(a<<4)|b] = a_signed * b_signed (i8)
|
|
791
|
-
vuint8m1_t
|
|
792
|
-
|
|
793
|
-
vuint8m1_t
|
|
794
|
-
|
|
795
|
-
vint8m1_t
|
|
796
|
-
vint8m1_t
|
|
796
|
+
vuint8m1_t high_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
|
|
797
|
+
b_high_u8m1, vector_length);
|
|
798
|
+
vuint8m1_t low_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(a_low_u8m1, 4, vector_length), b_low_u8m1,
|
|
799
|
+
vector_length);
|
|
800
|
+
vint8m1_t dot_high_i8m1 = __riscv_vluxei8_v_i8m1(nk_i4_dot_lut_, high_idx_u8m1, vector_length);
|
|
801
|
+
vint8m1_t dot_low_i8m1 = __riscv_vluxei8_v_i8m1(nk_i4_dot_lut_, low_idx_u8m1, vector_length);
|
|
797
802
|
// Widen i8→i16, add hi+lo, then per-lane accumulate i32+=i16
|
|
798
|
-
vint16m2_t dot_combined_i16m2 = __riscv_vwadd_vv_i16m2(
|
|
803
|
+
vint16m2_t dot_combined_i16m2 = __riscv_vwadd_vv_i16m2(dot_high_i8m1, dot_low_i8m1, vector_length);
|
|
799
804
|
dot_i32m4 = __riscv_vwadd_wv_i32m4_tu(dot_i32m4, dot_i32m4, dot_combined_i16m2, vector_length);
|
|
800
805
|
|
|
801
806
|
// Norms via 16-entry squaring LUT + vluxei8
|
|
802
|
-
vuint8m1_t
|
|
803
|
-
vuint8m1_t
|
|
804
|
-
vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(
|
|
807
|
+
vuint8m1_t a_high_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, a_high_u8m1, vector_length);
|
|
808
|
+
vuint8m1_t a_low_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, a_low_u8m1, vector_length);
|
|
809
|
+
vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(a_high_sq_u8m1, a_low_sq_u8m1, vector_length);
|
|
805
810
|
a_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(a_norm_sq_u32m4, a_norm_sq_u32m4, a_sq_combined_u16m2,
|
|
806
811
|
vector_length);
|
|
807
812
|
|
|
808
|
-
vuint8m1_t
|
|
809
|
-
vuint8m1_t
|
|
810
|
-
vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(
|
|
813
|
+
vuint8m1_t b_high_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, b_high_u8m1, vector_length);
|
|
814
|
+
vuint8m1_t b_low_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, b_low_u8m1, vector_length);
|
|
815
|
+
vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(b_high_sq_u8m1, b_low_sq_u8m1, vector_length);
|
|
811
816
|
b_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(b_norm_sq_u32m4, b_norm_sq_u32m4, b_sq_combined_u16m2,
|
|
812
817
|
vector_length);
|
|
813
818
|
}
|
|
814
819
|
|
|
815
820
|
// Single horizontal reductions after loop
|
|
816
|
-
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0,
|
|
817
|
-
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0,
|
|
818
|
-
nk_i32_t dot_i32 = __riscv_vmv_x_s_i32m1_i32(
|
|
821
|
+
vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
|
|
822
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
|
|
823
|
+
nk_i32_t dot_i32 = __riscv_vmv_x_s_i32m1_i32(
|
|
824
|
+
__riscv_vredsum_vs_i32m4_i32m1(dot_i32m4, zero_i32m1, max_vector_length));
|
|
819
825
|
nk_u32_t a_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
|
|
820
|
-
__riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1,
|
|
826
|
+
__riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1, max_vector_length));
|
|
821
827
|
nk_u32_t b_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
|
|
822
|
-
__riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1,
|
|
828
|
+
__riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1, max_vector_length));
|
|
823
829
|
|
|
824
830
|
if (a_norm_sq_u32 == 0 && b_norm_sq_u32 == 0) { *result = 0.0f; }
|
|
825
831
|
else if (dot_i32 == 0) { *result = 1.0f; }
|
|
@@ -852,31 +858,31 @@ NK_PUBLIC void nk_sqeuclidean_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const
|
|
|
852
858
|
};
|
|
853
859
|
count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
|
|
854
860
|
nk_size_t n_bytes = count_scalars / 2;
|
|
855
|
-
nk_size_t
|
|
856
|
-
vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
861
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
862
|
+
vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
857
863
|
for (nk_size_t vector_length; n_bytes > 0;
|
|
858
864
|
n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
859
865
|
vector_length = __riscv_vsetvl_e8m1(n_bytes);
|
|
860
866
|
vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
|
|
861
867
|
vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
|
|
862
|
-
// Build LUT indices: high nibble pair = (
|
|
863
|
-
vuint8m1_t
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
// Low nibble pair = (
|
|
867
|
-
vuint8m1_t
|
|
868
|
+
// Build LUT indices: high nibble pair = (a_high & 0xF0) | (b_high >> 4)
|
|
869
|
+
vuint8m1_t high_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
|
|
870
|
+
__riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length),
|
|
871
|
+
vector_length);
|
|
872
|
+
// Low nibble pair = (a_low << 4) | b_lo
|
|
873
|
+
vuint8m1_t low_idx_u8m1 = __riscv_vor_vv_u8m1(
|
|
868
874
|
__riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length), 4, vector_length),
|
|
869
875
|
__riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length), vector_length);
|
|
870
876
|
// Gather squared differences from LUT (0-225, fits u8)
|
|
871
|
-
vuint8m1_t
|
|
872
|
-
vuint8m1_t
|
|
877
|
+
vuint8m1_t sq_high_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sqd_lut_, high_idx_u8m1, vector_length);
|
|
878
|
+
vuint8m1_t sq_low_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sqd_lut_, low_idx_u8m1, vector_length);
|
|
873
879
|
// Combine and per-lane accumulate: u8+u8→u16, then u32+=u16
|
|
874
|
-
vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(
|
|
880
|
+
vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(sq_high_u8m1, sq_low_u8m1, vector_length);
|
|
875
881
|
sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, combined_u16m2, vector_length);
|
|
876
882
|
}
|
|
877
883
|
// Single horizontal reduction after loop
|
|
878
|
-
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0,
|
|
879
|
-
*result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1,
|
|
884
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
|
|
885
|
+
*result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, max_vector_length));
|
|
880
886
|
}
|
|
881
887
|
|
|
882
888
|
NK_PUBLIC void nk_euclidean_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -909,10 +915,10 @@ NK_PUBLIC void nk_angular_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_
|
|
|
909
915
|
static nk_u8_t const nk_u4_sq_lut_[16] = {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225};
|
|
910
916
|
count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
|
|
911
917
|
nk_size_t n_bytes = count_scalars / 2;
|
|
912
|
-
nk_size_t
|
|
913
|
-
vuint32m4_t dot_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
914
|
-
vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
915
|
-
vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0,
|
|
918
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
919
|
+
vuint32m4_t dot_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
920
|
+
vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
921
|
+
vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
|
|
916
922
|
|
|
917
923
|
for (nk_size_t vector_length; n_bytes > 0;
|
|
918
924
|
n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -921,43 +927,44 @@ NK_PUBLIC void nk_angular_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_
|
|
|
921
927
|
vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
|
|
922
928
|
|
|
923
929
|
// Extract nibbles
|
|
924
|
-
vuint8m1_t
|
|
925
|
-
vuint8m1_t
|
|
926
|
-
vuint8m1_t
|
|
927
|
-
vuint8m1_t
|
|
930
|
+
vuint8m1_t a_high_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
|
|
931
|
+
vuint8m1_t b_high_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
|
|
932
|
+
vuint8m1_t a_low_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
|
|
933
|
+
vuint8m1_t b_low_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
|
|
928
934
|
|
|
929
935
|
// Dot product via 256-entry LUT: dot_lut[(a<<4)|b] = a * b (u8)
|
|
930
|
-
vuint8m1_t
|
|
931
|
-
|
|
932
|
-
vuint8m1_t
|
|
933
|
-
|
|
934
|
-
vuint8m1_t
|
|
935
|
-
vuint8m1_t
|
|
936
|
+
vuint8m1_t high_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
|
|
937
|
+
b_high_u8m1, vector_length);
|
|
938
|
+
vuint8m1_t low_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(a_low_u8m1, 4, vector_length), b_low_u8m1,
|
|
939
|
+
vector_length);
|
|
940
|
+
vuint8m1_t dot_high_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_dot_lut_, high_idx_u8m1, vector_length);
|
|
941
|
+
vuint8m1_t dot_low_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_dot_lut_, low_idx_u8m1, vector_length);
|
|
936
942
|
// Widen u8→u16, add hi+lo, then per-lane accumulate u32+=u16
|
|
937
|
-
vuint16m2_t dot_combined_u16m2 = __riscv_vwaddu_vv_u16m2(
|
|
943
|
+
vuint16m2_t dot_combined_u16m2 = __riscv_vwaddu_vv_u16m2(dot_high_u8m1, dot_low_u8m1, vector_length);
|
|
938
944
|
dot_u32m4 = __riscv_vwaddu_wv_u32m4_tu(dot_u32m4, dot_u32m4, dot_combined_u16m2, vector_length);
|
|
939
945
|
|
|
940
946
|
// Norms via 16-entry squaring LUT + vluxei8
|
|
941
|
-
vuint8m1_t
|
|
942
|
-
vuint8m1_t
|
|
943
|
-
vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(
|
|
947
|
+
vuint8m1_t a_high_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, a_high_u8m1, vector_length);
|
|
948
|
+
vuint8m1_t a_low_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, a_low_u8m1, vector_length);
|
|
949
|
+
vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(a_high_sq_u8m1, a_low_sq_u8m1, vector_length);
|
|
944
950
|
a_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(a_norm_sq_u32m4, a_norm_sq_u32m4, a_sq_combined_u16m2,
|
|
945
951
|
vector_length);
|
|
946
952
|
|
|
947
|
-
vuint8m1_t
|
|
948
|
-
vuint8m1_t
|
|
949
|
-
vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(
|
|
953
|
+
vuint8m1_t b_high_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, b_high_u8m1, vector_length);
|
|
954
|
+
vuint8m1_t b_low_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, b_low_u8m1, vector_length);
|
|
955
|
+
vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(b_high_sq_u8m1, b_low_sq_u8m1, vector_length);
|
|
950
956
|
b_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(b_norm_sq_u32m4, b_norm_sq_u32m4, b_sq_combined_u16m2,
|
|
951
957
|
vector_length);
|
|
952
958
|
}
|
|
953
959
|
|
|
954
960
|
// Single horizontal reductions after loop
|
|
955
|
-
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0,
|
|
956
|
-
nk_u32_t dot_u32 = __riscv_vmv_x_s_u32m1_u32(
|
|
961
|
+
vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
|
|
962
|
+
nk_u32_t dot_u32 = __riscv_vmv_x_s_u32m1_u32(
|
|
963
|
+
__riscv_vredsum_vs_u32m4_u32m1(dot_u32m4, zero_u32m1, max_vector_length));
|
|
957
964
|
nk_u32_t a_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
|
|
958
|
-
__riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1,
|
|
965
|
+
__riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1, max_vector_length));
|
|
959
966
|
nk_u32_t b_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
|
|
960
|
-
__riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1,
|
|
967
|
+
__riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1, max_vector_length));
|
|
961
968
|
|
|
962
969
|
if (a_norm_sq_u32 == 0 && b_norm_sq_u32 == 0) { *result = 0.0f; }
|
|
963
970
|
else if (dot_u32 == 0) { *result = 1.0f; }
|
|
@@ -978,7 +985,7 @@ NK_PUBLIC void nk_angular_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_
|
|
|
978
985
|
#pragma GCC pop_options
|
|
979
986
|
#endif
|
|
980
987
|
|
|
981
|
-
#pragma endregion
|
|
988
|
+
#pragma endregion I8 and U8 Integers
|
|
982
989
|
#endif // NK_TARGET_RVV
|
|
983
990
|
#endif // NK_TARGET_RISCV_
|
|
984
991
|
#endif // NK_SPATIAL_RVV_H
|
|
@@ -37,9 +37,9 @@ extern "C" {
|
|
|
37
37
|
NK_PUBLIC void nk_sqeuclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars,
|
|
38
38
|
nk_size_t count_scalars, nk_f32_t *result) {
|
|
39
39
|
// Per-lane accumulators — deferred horizontal reduction
|
|
40
|
-
nk_size_t
|
|
41
|
-
vfloat32m2_t sq_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
42
|
-
vfloat32m2_t ab_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
40
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
41
|
+
vfloat32m2_t sq_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length); // a² + b²
|
|
42
|
+
vfloat32m2_t ab_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length); // a × b
|
|
43
43
|
|
|
44
44
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
45
45
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -57,8 +57,10 @@ NK_PUBLIC void nk_sqeuclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t
|
|
|
57
57
|
|
|
58
58
|
// Single horizontal reduction after the loop
|
|
59
59
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
60
|
-
nk_f32_t sq_sum = __riscv_vfmv_f_s_f32m1_f32(
|
|
61
|
-
|
|
60
|
+
nk_f32_t sq_sum = __riscv_vfmv_f_s_f32m1_f32(
|
|
61
|
+
__riscv_vfredusum_vs_f32m2_f32m1(sq_sum_f32m2, zero_f32m1, max_vector_length));
|
|
62
|
+
nk_f32_t ab_sum = __riscv_vfmv_f_s_f32m1_f32(
|
|
63
|
+
__riscv_vfredusum_vs_f32m2_f32m1(ab_sum_f32m2, zero_f32m1, max_vector_length));
|
|
62
64
|
*result = sq_sum - 2.0f * ab_sum;
|
|
63
65
|
}
|
|
64
66
|
|
|
@@ -72,10 +74,10 @@ NK_PUBLIC void nk_euclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t c
|
|
|
72
74
|
NK_PUBLIC void nk_angular_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
73
75
|
nk_f32_t *result) {
|
|
74
76
|
// Per-lane accumulators — deferred horizontal reduction
|
|
75
|
-
nk_size_t
|
|
76
|
-
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
77
|
-
vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
78
|
-
vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
77
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
78
|
+
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
79
|
+
vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
80
|
+
vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
79
81
|
|
|
80
82
|
for (nk_size_t vector_length; count_scalars > 0;
|
|
81
83
|
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
@@ -95,9 +97,12 @@ NK_PUBLIC void nk_angular_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t con
|
|
|
95
97
|
|
|
96
98
|
// Single horizontal reduction after the loop
|
|
97
99
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
98
|
-
nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(
|
|
99
|
-
|
|
100
|
-
nk_f32_t
|
|
100
|
+
nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(
|
|
101
|
+
__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, max_vector_length));
|
|
102
|
+
nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(
|
|
103
|
+
__riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, max_vector_length));
|
|
104
|
+
nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(
|
|
105
|
+
__riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, max_vector_length));
|
|
101
106
|
|
|
102
107
|
// Normalize: 1 − dot / sqrt(‖a‖² × ‖b‖²)
|
|
103
108
|
if (a_sq == 0.0f && b_sq == 0.0f) { *result = 0.0f; }
|