numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief NEON-accelerated Probability Distribution Similarity Measures.
|
|
3
|
+
* @file include/numkong/probability/neon.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/probability.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_PROBABILITY_NEON_H
|
|
10
|
+
#define NK_PROBABILITY_NEON_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_ARM_
|
|
13
|
+
|
|
14
|
+
#include "numkong/types.h"
|
|
15
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b16x4_serial_`, `nk_partial_load_b32x4_serial_`
|
|
16
|
+
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
17
|
+
|
|
18
|
+
#if defined(__cplusplus)
|
|
19
|
+
extern "C" {
|
|
20
|
+
#endif
|
|
21
|
+
|
|
22
|
+
#if NK_TARGET_NEON
|
|
23
|
+
#if defined(__clang__)
|
|
24
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function)
|
|
25
|
+
#elif defined(__GNUC__)
|
|
26
|
+
#pragma GCC push_options
|
|
27
|
+
#pragma GCC target("arch=armv8.2-a+simd")
|
|
28
|
+
#endif
|
|
29
|
+
|
|
30
|
+
NK_PUBLIC float32x4_t nk_log2_f32x4_neon_(float32x4_t x) {
|
|
31
|
+
// Extracting the exponent
|
|
32
|
+
int32x4_t bits_i32x4 = vreinterpretq_s32_f32(x);
|
|
33
|
+
int32x4_t exponent_i32x4 = vsubq_s32(vshrq_n_s32(vandq_s32(bits_i32x4, vdupq_n_s32(0x7F800000)), 23),
|
|
34
|
+
vdupq_n_s32(127));
|
|
35
|
+
float32x4_t exponent_f32x4 = vcvtq_f32_s32(exponent_i32x4);
|
|
36
|
+
|
|
37
|
+
// Extracting the mantissa
|
|
38
|
+
float32x4_t mantissa_f32x4 = vreinterpretq_f32_s32(
|
|
39
|
+
vorrq_s32(vandq_s32(bits_i32x4, vdupq_n_s32(0x007FFFFF)), vdupq_n_s32(0x3F800000)));
|
|
40
|
+
|
|
41
|
+
// Constants for polynomial
|
|
42
|
+
float32x4_t one_f32x4 = vdupq_n_f32(1.0f);
|
|
43
|
+
float32x4_t poly_f32x4 = vdupq_n_f32(-3.4436006e-2f);
|
|
44
|
+
|
|
45
|
+
// Compute polynomial using Horner's method
|
|
46
|
+
poly_f32x4 = vmlaq_f32(vdupq_n_f32(3.1821337e-1f), mantissa_f32x4, poly_f32x4);
|
|
47
|
+
poly_f32x4 = vmlaq_f32(vdupq_n_f32(-1.2315303f), mantissa_f32x4, poly_f32x4);
|
|
48
|
+
poly_f32x4 = vmlaq_f32(vdupq_n_f32(2.5988452f), mantissa_f32x4, poly_f32x4);
|
|
49
|
+
poly_f32x4 = vmlaq_f32(vdupq_n_f32(-3.3241990f), mantissa_f32x4, poly_f32x4);
|
|
50
|
+
poly_f32x4 = vmlaq_f32(vdupq_n_f32(3.1157899f), mantissa_f32x4, poly_f32x4);
|
|
51
|
+
|
|
52
|
+
// Final computation
|
|
53
|
+
float32x4_t result_f32x4 = vaddq_f32(vmulq_f32(poly_f32x4, vsubq_f32(mantissa_f32x4, one_f32x4)), exponent_f32x4);
|
|
54
|
+
return result_f32x4;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
NK_PUBLIC void nk_kld_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
58
|
+
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
59
|
+
float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
|
|
60
|
+
float64x2_t sum_lower_f64x2 = vdupq_n_f64(0.0);
|
|
61
|
+
float64x2_t sum_upper_f64x2 = vdupq_n_f64(0.0);
|
|
62
|
+
float32x4_t a_f32x4, b_f32x4;
|
|
63
|
+
|
|
64
|
+
nk_kld_f32_neon_cycle:
|
|
65
|
+
if (n < 4) {
|
|
66
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
67
|
+
nk_partial_load_b32x4_serial_(a, &a_vec, n);
|
|
68
|
+
nk_partial_load_b32x4_serial_(b, &b_vec, n);
|
|
69
|
+
a_f32x4 = a_vec.f32x4;
|
|
70
|
+
b_f32x4 = b_vec.f32x4;
|
|
71
|
+
n = 0;
|
|
72
|
+
}
|
|
73
|
+
else {
|
|
74
|
+
a_f32x4 = vld1q_f32(a);
|
|
75
|
+
b_f32x4 = vld1q_f32(b);
|
|
76
|
+
n -= 4, a += 4, b += 4;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
float32x4_t ratio_f32x4 = vdivq_f32(vaddq_f32(a_f32x4, epsilon_f32x4), vaddq_f32(b_f32x4, epsilon_f32x4));
|
|
80
|
+
float32x4_t log_ratio_f32x4 = nk_log2_f32x4_neon_(ratio_f32x4);
|
|
81
|
+
float32x4_t contribution_f32x4 = vmulq_f32(a_f32x4, log_ratio_f32x4);
|
|
82
|
+
sum_lower_f64x2 = vaddq_f64(sum_lower_f64x2, vcvt_f64_f32(vget_low_f32(contribution_f32x4)));
|
|
83
|
+
sum_upper_f64x2 = vaddq_f64(sum_upper_f64x2, vcvt_f64_f32(vget_high_f32(contribution_f32x4)));
|
|
84
|
+
if (n != 0) goto nk_kld_f32_neon_cycle;
|
|
85
|
+
|
|
86
|
+
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
87
|
+
nk_f64_t sum = vaddvq_f64(vaddq_f64(sum_lower_f64x2, sum_upper_f64x2)) * log2_normalizer;
|
|
88
|
+
*result = sum;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
NK_PUBLIC void nk_jsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
92
|
+
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
93
|
+
float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
|
|
94
|
+
float64x2_t sum_lower_f64x2 = vdupq_n_f64(0.0);
|
|
95
|
+
float64x2_t sum_upper_f64x2 = vdupq_n_f64(0.0);
|
|
96
|
+
float32x4_t a_f32x4, b_f32x4;
|
|
97
|
+
|
|
98
|
+
nk_jsd_f32_neon_cycle:
|
|
99
|
+
if (n < 4) {
|
|
100
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
101
|
+
nk_partial_load_b32x4_serial_(a, &a_vec, n);
|
|
102
|
+
nk_partial_load_b32x4_serial_(b, &b_vec, n);
|
|
103
|
+
a_f32x4 = a_vec.f32x4;
|
|
104
|
+
b_f32x4 = b_vec.f32x4;
|
|
105
|
+
n = 0;
|
|
106
|
+
}
|
|
107
|
+
else {
|
|
108
|
+
a_f32x4 = vld1q_f32(a);
|
|
109
|
+
b_f32x4 = vld1q_f32(b);
|
|
110
|
+
n -= 4, a += 4, b += 4;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
float32x4_t mean_f32x4 = vmulq_n_f32(vaddq_f32(a_f32x4, b_f32x4), 0.5f);
|
|
114
|
+
float32x4_t ratio_a_f32x4 = vdivq_f32(vaddq_f32(a_f32x4, epsilon_f32x4), vaddq_f32(mean_f32x4, epsilon_f32x4));
|
|
115
|
+
float32x4_t ratio_b_f32x4 = vdivq_f32(vaddq_f32(b_f32x4, epsilon_f32x4), vaddq_f32(mean_f32x4, epsilon_f32x4));
|
|
116
|
+
float32x4_t log_ratio_a_f32x4 = nk_log2_f32x4_neon_(ratio_a_f32x4);
|
|
117
|
+
float32x4_t log_ratio_b_f32x4 = nk_log2_f32x4_neon_(ratio_b_f32x4);
|
|
118
|
+
float32x4_t contribution_a_f32x4 = vmulq_f32(a_f32x4, log_ratio_a_f32x4);
|
|
119
|
+
float32x4_t contribution_b_f32x4 = vmulq_f32(b_f32x4, log_ratio_b_f32x4);
|
|
120
|
+
float32x4_t contribution_f32x4 = vaddq_f32(contribution_a_f32x4, contribution_b_f32x4);
|
|
121
|
+
sum_lower_f64x2 = vaddq_f64(sum_lower_f64x2, vcvt_f64_f32(vget_low_f32(contribution_f32x4)));
|
|
122
|
+
sum_upper_f64x2 = vaddq_f64(sum_upper_f64x2, vcvt_f64_f32(vget_high_f32(contribution_f32x4)));
|
|
123
|
+
if (n != 0) goto nk_jsd_f32_neon_cycle;
|
|
124
|
+
|
|
125
|
+
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
126
|
+
nk_f64_t sum = vaddvq_f64(vaddq_f64(sum_lower_f64x2, sum_upper_f64x2)) * log2_normalizer / 2.0;
|
|
127
|
+
*result = sum > 0 ? nk_f64_sqrt_neon(sum) : 0;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
#if defined(__clang__)
|
|
131
|
+
#pragma clang attribute pop
|
|
132
|
+
#elif defined(__GNUC__)
|
|
133
|
+
#pragma GCC pop_options
|
|
134
|
+
#endif
|
|
135
|
+
#endif // NK_TARGET_NEON
|
|
136
|
+
|
|
137
|
+
#if NK_TARGET_NEONHALF
|
|
138
|
+
#if defined(__clang__)
|
|
139
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
140
|
+
#elif defined(__GNUC__)
|
|
141
|
+
#pragma GCC push_options
|
|
142
|
+
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
143
|
+
#endif
|
|
144
|
+
|
|
145
|
+
NK_PUBLIC void nk_kld_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
146
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
147
|
+
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
148
|
+
float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
|
|
149
|
+
float32x4_t a_f32x4, b_f32x4;
|
|
150
|
+
|
|
151
|
+
nk_kld_f16_neonhalf_cycle:
|
|
152
|
+
if (n < 4) {
|
|
153
|
+
nk_b64_vec_t a_vec, b_vec;
|
|
154
|
+
nk_partial_load_b16x4_serial_(a, &a_vec, n);
|
|
155
|
+
nk_partial_load_b16x4_serial_(b, &b_vec, n);
|
|
156
|
+
a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
|
|
157
|
+
b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
|
|
158
|
+
n = 0;
|
|
159
|
+
}
|
|
160
|
+
else {
|
|
161
|
+
a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
|
|
162
|
+
b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
|
|
163
|
+
n -= 4, a += 4, b += 4;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
float32x4_t ratio_f32x4 = vdivq_f32(vaddq_f32(a_f32x4, epsilon_f32x4), vaddq_f32(b_f32x4, epsilon_f32x4));
|
|
167
|
+
float32x4_t log_ratio_f32x4 = nk_log2_f32x4_neon_(ratio_f32x4);
|
|
168
|
+
float32x4_t contribution_f32x4 = vmulq_f32(a_f32x4, log_ratio_f32x4);
|
|
169
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, contribution_f32x4);
|
|
170
|
+
if (n) goto nk_kld_f16_neonhalf_cycle;
|
|
171
|
+
|
|
172
|
+
nk_f32_t log2_normalizer = 0.693147181f;
|
|
173
|
+
nk_f32_t sum = vaddvq_f32(sum_f32x4) * log2_normalizer;
|
|
174
|
+
*result = sum;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
NK_PUBLIC void nk_jsd_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
178
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
179
|
+
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
180
|
+
float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
|
|
181
|
+
float32x4_t a_f32x4, b_f32x4;
|
|
182
|
+
|
|
183
|
+
nk_jsd_f16_neonhalf_cycle:
|
|
184
|
+
if (n < 4) {
|
|
185
|
+
nk_b64_vec_t a_vec, b_vec;
|
|
186
|
+
nk_partial_load_b16x4_serial_(a, &a_vec, n);
|
|
187
|
+
nk_partial_load_b16x4_serial_(b, &b_vec, n);
|
|
188
|
+
a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
|
|
189
|
+
b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
|
|
190
|
+
n = 0;
|
|
191
|
+
}
|
|
192
|
+
else {
|
|
193
|
+
a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
|
|
194
|
+
b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
|
|
195
|
+
n -= 4, a += 4, b += 4;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
float32x4_t mean_f32x4 = vmulq_n_f32(vaddq_f32(a_f32x4, b_f32x4), 0.5f);
|
|
199
|
+
float32x4_t ratio_a_f32x4 = vdivq_f32(vaddq_f32(a_f32x4, epsilon_f32x4), vaddq_f32(mean_f32x4, epsilon_f32x4));
|
|
200
|
+
float32x4_t ratio_b_f32x4 = vdivq_f32(vaddq_f32(b_f32x4, epsilon_f32x4), vaddq_f32(mean_f32x4, epsilon_f32x4));
|
|
201
|
+
float32x4_t log_ratio_a_f32x4 = nk_log2_f32x4_neon_(ratio_a_f32x4);
|
|
202
|
+
float32x4_t log_ratio_b_f32x4 = nk_log2_f32x4_neon_(ratio_b_f32x4);
|
|
203
|
+
float32x4_t contribution_a_f32x4 = vmulq_f32(a_f32x4, log_ratio_a_f32x4);
|
|
204
|
+
float32x4_t contribution_b_f32x4 = vmulq_f32(b_f32x4, log_ratio_b_f32x4);
|
|
205
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(contribution_a_f32x4, contribution_b_f32x4));
|
|
206
|
+
if (n) goto nk_jsd_f16_neonhalf_cycle;
|
|
207
|
+
|
|
208
|
+
nk_f32_t log2_normalizer = 0.693147181f;
|
|
209
|
+
nk_f32_t sum = vaddvq_f32(sum_f32x4) * log2_normalizer / 2;
|
|
210
|
+
*result = sum > 0 ? nk_f32_sqrt_neon(sum) : 0;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
#if defined(__clang__)
|
|
214
|
+
#pragma clang attribute pop
|
|
215
|
+
#elif defined(__GNUC__)
|
|
216
|
+
#pragma GCC pop_options
|
|
217
|
+
#endif
|
|
218
|
+
#endif // NK_TARGET_NEONHALF
|
|
219
|
+
|
|
220
|
+
#if defined(__cplusplus)
|
|
221
|
+
} // extern "C"
|
|
222
|
+
#endif
|
|
223
|
+
|
|
224
|
+
#endif // NK_TARGET_ARM_
|
|
225
|
+
#endif // NK_PROBABILITY_NEON_H
|
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Probability Distribution Similarity Measures for RISC-V.
|
|
3
|
+
* @file include/numkong/probability/rvv.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/probability.h
|
|
8
|
+
*
|
|
9
|
+
* Implements KLD and JSD using RVV 1.0 vector intrinsics for f32, f64, f16, and bf16.
|
|
10
|
+
* The log2 approximation uses the same polynomial as the Haswell implementation,
|
|
11
|
+
* ported to RVV's vector fused-multiply-add instructions.
|
|
12
|
+
*
|
|
13
|
+
* For f64, uses the s-series 14-term Horner log2 approximation (matching Skylake).
|
|
14
|
+
* For f16/bf16, converts to f32 using the cast helpers from cast/rvv.h,
|
|
15
|
+
* then uses the f32 algorithm.
|
|
16
|
+
*/
|
|
17
|
+
#ifndef NK_PROBABILITY_RVV_H
|
|
18
|
+
#define NK_PROBABILITY_RVV_H
|
|
19
|
+
|
|
20
|
+
#if NK_TARGET_RISCV_
|
|
21
|
+
#if NK_TARGET_RVV
|
|
22
|
+
|
|
23
|
+
#include "numkong/types.h"
|
|
24
|
+
#include "numkong/probability/serial.h" // `nk_kld_f64_serial`, `nk_jsd_f64_serial`
|
|
25
|
+
#include "numkong/cast/rvv.h" // `nk_f16m1_to_f32m2_rvv_`, `nk_bf16m1_to_f32m2_rvv_`
|
|
26
|
+
#include "numkong/spatial/rvv.h" // `nk_f32_sqrt_rvv`
|
|
27
|
+
|
|
28
|
+
#if defined(__clang__)
|
|
29
|
+
#pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
|
|
30
|
+
#elif defined(__GNUC__)
|
|
31
|
+
#pragma GCC push_options
|
|
32
|
+
#pragma GCC target("arch=+v")
|
|
33
|
+
#endif
|
|
34
|
+
|
|
35
|
+
#if defined(__cplusplus)
|
|
36
|
+
extern "C" {
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
/**
|
|
40
|
+
* @brief Computes `log2(x)` for a vector of f32 values using IEEE 754 bit manipulation
|
|
41
|
+
* and a 5-term Horner polynomial, matching the Haswell log2 approximation.
|
|
42
|
+
*
|
|
43
|
+
* Decomposes each float into exponent and mantissa:
|
|
44
|
+
* - exponent = (bits >> 23) - 127
|
|
45
|
+
* - mantissa = (bits & 0x007FFFFF) | 0x3F800000, yielding m in [1, 2)
|
|
46
|
+
*
|
|
47
|
+
* Then evaluates poly(m) via Horner's method:
|
|
48
|
+
* poly = -3.4436006e-2f
|
|
49
|
+
* poly = poly * m + 3.1821337e-1f
|
|
50
|
+
* poly = poly * m - 1.2315303f
|
|
51
|
+
* poly = poly * m + 2.5988452f
|
|
52
|
+
* poly = poly * m - 3.3241990f
|
|
53
|
+
* poly = poly * m + 3.1157899f
|
|
54
|
+
*
|
|
55
|
+
* Final result: log2(x) = exponent + poly * (m - 1)
|
|
56
|
+
*/
|
|
57
|
+
NK_INTERNAL vfloat32m4_t nk_log2_f32m4_rvv_(vfloat32m4_t x, nk_size_t vector_length) {
|
|
58
|
+
vuint32m4_t bits = __riscv_vreinterpret_v_f32m4_u32m4(x);
|
|
59
|
+
|
|
60
|
+
// Extract exponent: (bits >> 23) - 127
|
|
61
|
+
vuint32m4_t exp_bits = __riscv_vsrl_vx_u32m4(bits, 23, vector_length);
|
|
62
|
+
vint32m4_t exponent = __riscv_vsub_vx_i32m4(__riscv_vreinterpret_v_u32m4_i32m4(exp_bits), 127, vector_length);
|
|
63
|
+
vfloat32m4_t exp_f = __riscv_vfcvt_f_x_v_f32m4(exponent, vector_length);
|
|
64
|
+
|
|
65
|
+
// Extract mantissa: set exponent field to 0 (bias 127), so value is in [1, 2)
|
|
66
|
+
vuint32m4_t mantissa_bits = __riscv_vor_vx_u32m4(__riscv_vand_vx_u32m4(bits, 0x007FFFFF, vector_length), 0x3F800000,
|
|
67
|
+
vector_length);
|
|
68
|
+
vfloat32m4_t m = __riscv_vreinterpret_v_u32m4_f32m4(mantissa_bits);
|
|
69
|
+
|
|
70
|
+
// Horner polynomial evaluation:
|
|
71
|
+
// vfmadd_vv(vd, vs1, vs2) computes vd = vd * vs1 + vs2
|
|
72
|
+
// So poly = vfmadd(poly, m, coeff) means poly = poly * m + coeff
|
|
73
|
+
vfloat32m4_t poly = __riscv_vfmv_v_f_f32m4(-3.4436006e-2f, vector_length);
|
|
74
|
+
poly = __riscv_vfmadd_vv_f32m4(poly, m, __riscv_vfmv_v_f_f32m4(3.1821337e-1f, vector_length), vector_length);
|
|
75
|
+
poly = __riscv_vfmadd_vv_f32m4(poly, m, __riscv_vfmv_v_f_f32m4(-1.2315303f, vector_length), vector_length);
|
|
76
|
+
poly = __riscv_vfmadd_vv_f32m4(poly, m, __riscv_vfmv_v_f_f32m4(2.5988452f, vector_length), vector_length);
|
|
77
|
+
poly = __riscv_vfmadd_vv_f32m4(poly, m, __riscv_vfmv_v_f_f32m4(-3.3241990f, vector_length), vector_length);
|
|
78
|
+
poly = __riscv_vfmadd_vv_f32m4(poly, m, __riscv_vfmv_v_f_f32m4(3.1157899f, vector_length), vector_length);
|
|
79
|
+
|
|
80
|
+
// result = exponent + poly * (m - 1)
|
|
81
|
+
vfloat32m4_t m_minus_1 = __riscv_vfsub_vf_f32m4(m, 1.0f, vector_length);
|
|
82
|
+
return __riscv_vfmacc_vv_f32m4(exp_f, poly, m_minus_1, vector_length);
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
NK_INTERNAL vfloat32m2_t nk_log2_f32m2_rvv_(vfloat32m2_t x, nk_size_t vector_length) {
|
|
86
|
+
vuint32m2_t bits = __riscv_vreinterpret_v_f32m2_u32m2(x);
|
|
87
|
+
vuint32m2_t exp_bits = __riscv_vsrl_vx_u32m2(bits, 23, vector_length);
|
|
88
|
+
vint32m2_t exponent = __riscv_vsub_vx_i32m2(__riscv_vreinterpret_v_u32m2_i32m2(exp_bits), 127, vector_length);
|
|
89
|
+
vfloat32m2_t exp_f = __riscv_vfcvt_f_x_v_f32m2(exponent, vector_length);
|
|
90
|
+
vuint32m2_t mant_bits = __riscv_vor_vx_u32m2(__riscv_vand_vx_u32m2(bits, 0x007FFFFF, vector_length), 0x3F800000,
|
|
91
|
+
vector_length);
|
|
92
|
+
vfloat32m2_t m = __riscv_vreinterpret_v_u32m2_f32m2(mant_bits);
|
|
93
|
+
vfloat32m2_t poly = __riscv_vfmv_v_f_f32m2(-3.4436006e-2f, vector_length);
|
|
94
|
+
poly = __riscv_vfmadd_vv_f32m2(poly, m, __riscv_vfmv_v_f_f32m2(3.1821337e-1f, vector_length), vector_length);
|
|
95
|
+
poly = __riscv_vfmadd_vv_f32m2(poly, m, __riscv_vfmv_v_f_f32m2(-1.2315303f, vector_length), vector_length);
|
|
96
|
+
poly = __riscv_vfmadd_vv_f32m2(poly, m, __riscv_vfmv_v_f_f32m2(2.5988452f, vector_length), vector_length);
|
|
97
|
+
poly = __riscv_vfmadd_vv_f32m2(poly, m, __riscv_vfmv_v_f_f32m2(-3.3241990f, vector_length), vector_length);
|
|
98
|
+
poly = __riscv_vfmadd_vv_f32m2(poly, m, __riscv_vfmv_v_f_f32m2(3.1157899f, vector_length), vector_length);
|
|
99
|
+
vfloat32m2_t m_minus_1 = __riscv_vfsub_vf_f32m2(m, 1.0f, vector_length);
|
|
100
|
+
return __riscv_vfmacc_vv_f32m2(exp_f, poly, m_minus_1, vector_length);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
/**
|
|
104
|
+
* @brief Computes `log2(x)` for a vector of f64 values using the s-series approach.
|
|
105
|
+
*
|
|
106
|
+
* Uses s = (m-1)/(m+1), then evaluates ln(m) = 2 × s × P(s²) with 14-term Horner polynomial.
|
|
107
|
+
* Converts to log2 via multiplication by log2(e). Matches Skylake's f64 log2 algorithm.
|
|
108
|
+
*/
|
|
109
|
+
NK_INTERNAL vfloat64m4_t nk_log2_f64m4_rvv_(vfloat64m4_t x, nk_size_t vector_length) {
|
|
110
|
+
// Step 1-2: Extract exponent and mantissa via bit manipulation
|
|
111
|
+
vuint64m4_t bits = __riscv_vreinterpret_v_f64m4_u64m4(x);
|
|
112
|
+
vuint64m4_t exp_bits = __riscv_vsrl_vx_u64m4(bits, 52, vector_length);
|
|
113
|
+
vint64m4_t exponent = __riscv_vsub_vx_i64m4(__riscv_vreinterpret_v_u64m4_i64m4(exp_bits), 1023, vector_length);
|
|
114
|
+
vfloat64m4_t exp_f = __riscv_vfcvt_f_x_v_f64m4(exponent, vector_length);
|
|
115
|
+
vuint64m4_t mant_bits = __riscv_vor_vx_u64m4(__riscv_vand_vx_u64m4(bits, 0x000FFFFFFFFFFFFFULL, vector_length),
|
|
116
|
+
0x3FF0000000000000ULL, vector_length);
|
|
117
|
+
vfloat64m4_t m = __riscv_vreinterpret_v_u64m4_f64m4(mant_bits);
|
|
118
|
+
|
|
119
|
+
// Step 3: s = (m - 1) / (m + 1)
|
|
120
|
+
vfloat64m4_t one = __riscv_vfmv_v_f_f64m4(1.0, vector_length);
|
|
121
|
+
vfloat64m4_t s = __riscv_vfdiv_vv_f64m4(__riscv_vfsub_vv_f64m4(m, one, vector_length),
|
|
122
|
+
__riscv_vfadd_vv_f64m4(m, one, vector_length), vector_length);
|
|
123
|
+
vfloat64m4_t s2 = __riscv_vfmul_vv_f64m4(s, s, vector_length);
|
|
124
|
+
|
|
125
|
+
// Step 4: P(s²) = 1 + s²/3 + s⁴/5 + ... (14 terms, Horner's method)
|
|
126
|
+
vfloat64m4_t poly = __riscv_vfmv_v_f_f64m4(1.0 / 27.0, vector_length); // 1/(2*13+1)
|
|
127
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 25.0, vector_length), vector_length);
|
|
128
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 23.0, vector_length), vector_length);
|
|
129
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 21.0, vector_length), vector_length);
|
|
130
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 19.0, vector_length), vector_length);
|
|
131
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 17.0, vector_length), vector_length);
|
|
132
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 15.0, vector_length), vector_length);
|
|
133
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 13.0, vector_length), vector_length);
|
|
134
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 11.0, vector_length), vector_length);
|
|
135
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 9.0, vector_length), vector_length);
|
|
136
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 7.0, vector_length), vector_length);
|
|
137
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 5.0, vector_length), vector_length);
|
|
138
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, __riscv_vfmv_v_f_f64m4(1.0 / 3.0, vector_length), vector_length);
|
|
139
|
+
poly = __riscv_vfmadd_vv_f64m4(s2, poly, one, vector_length);
|
|
140
|
+
|
|
141
|
+
// Step 5-6: ln(m) = 2 × s × P(s²), log2(m) = ln(m) × log2(e), log2(x) = exp + log2(m)
|
|
142
|
+
vfloat64m4_t two_s = __riscv_vfmul_vf_f64m4(s, 2.0, vector_length);
|
|
143
|
+
vfloat64m4_t ln_m = __riscv_vfmul_vv_f64m4(two_s, poly, vector_length);
|
|
144
|
+
vfloat64m4_t log2_m = __riscv_vfmul_vf_f64m4(ln_m, 1.4426950408889634, vector_length);
|
|
145
|
+
return __riscv_vfadd_vv_f64m4(exp_f, log2_m, vector_length);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
#pragma region - Kullback-Leibler Divergence
|
|
149
|
+
|
|
150
|
+
NK_PUBLIC void nk_kld_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
151
|
+
nk_size_t vector_length_max = __riscv_vsetvlmax_e64m4();
|
|
152
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vector_length_max);
|
|
153
|
+
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
154
|
+
vector_length = __riscv_vsetvl_e32m2(n);
|
|
155
|
+
vfloat32m2_t a_f32m2 = __riscv_vle32_v_f32m2(a, vector_length);
|
|
156
|
+
vfloat32m2_t b_f32m2 = __riscv_vle32_v_f32m2(b, vector_length);
|
|
157
|
+
// ratio = (a + ε) / (b + ε)
|
|
158
|
+
vfloat32m2_t a_eps_f32m2 = __riscv_vfadd_vf_f32m2(a_f32m2, NK_F32_DIVISION_EPSILON, vector_length);
|
|
159
|
+
vfloat32m2_t b_eps_f32m2 = __riscv_vfadd_vf_f32m2(b_f32m2, NK_F32_DIVISION_EPSILON, vector_length);
|
|
160
|
+
vfloat32m2_t ratio_f32m2 = __riscv_vfmul_vv_f32m2(
|
|
161
|
+
a_eps_f32m2, nk_f32m2_reciprocal_rvv_(b_eps_f32m2, vector_length), vector_length);
|
|
162
|
+
// log2(ratio)
|
|
163
|
+
vfloat32m2_t log_ratio_f32m2 = nk_log2_f32m2_rvv_(ratio_f32m2, vector_length);
|
|
164
|
+
// contribution = a * log2(a / b)
|
|
165
|
+
vfloat32m2_t contribution_f32m2 = __riscv_vfmul_vv_f32m2(a_f32m2, log_ratio_f32m2, vector_length);
|
|
166
|
+
vfloat64m4_t contribution_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(contribution_f32m2, vector_length);
|
|
167
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, contribution_f64m4, vector_length);
|
|
168
|
+
}
|
|
169
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
170
|
+
*result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vector_length_max)) *
|
|
171
|
+
0.6931471805599453;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
NK_PUBLIC void nk_kld_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
175
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
176
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
177
|
+
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
178
|
+
vector_length = __riscv_vsetvl_e64m4(n);
|
|
179
|
+
vfloat64m4_t a_f64m4 = __riscv_vle64_v_f64m4(a, vector_length);
|
|
180
|
+
vfloat64m4_t b_f64m4 = __riscv_vle64_v_f64m4(b, vector_length);
|
|
181
|
+
// ratio = (a + ε) / (b + ε) — full precision division
|
|
182
|
+
vfloat64m4_t a_eps_f64m4 = __riscv_vfadd_vf_f64m4(a_f64m4, NK_F64_DIVISION_EPSILON, vector_length);
|
|
183
|
+
vfloat64m4_t b_eps_f64m4 = __riscv_vfadd_vf_f64m4(b_f64m4, NK_F64_DIVISION_EPSILON, vector_length);
|
|
184
|
+
vfloat64m4_t ratio_f64m4 = __riscv_vfdiv_vv_f64m4(a_eps_f64m4, b_eps_f64m4, vector_length);
|
|
185
|
+
// log2(ratio)
|
|
186
|
+
vfloat64m4_t log_ratio_f64m4 = nk_log2_f64m4_rvv_(ratio_f64m4, vector_length);
|
|
187
|
+
// contribution = a * log2(a / b)
|
|
188
|
+
vfloat64m4_t contribution_f64m4 = __riscv_vfmul_vv_f64m4(a_f64m4, log_ratio_f64m4, vector_length);
|
|
189
|
+
// Per-lane accumulation
|
|
190
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, contribution_f64m4, vector_length);
|
|
191
|
+
}
|
|
192
|
+
// Single horizontal reduction after loop
|
|
193
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
194
|
+
// Convert from log2 to ln by multiplying by ln(2)
|
|
195
|
+
*result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)) *
|
|
196
|
+
0.6931471805599453;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
NK_PUBLIC void nk_kld_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
200
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
201
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
202
|
+
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
203
|
+
vector_length = __riscv_vsetvl_e16m1(n);
|
|
204
|
+
// Load f16 as raw u16 bits
|
|
205
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a, vector_length);
|
|
206
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b, vector_length);
|
|
207
|
+
// Convert f16 to f32 (m1 -> m2)
|
|
208
|
+
vfloat32m2_t a_f32m2 = nk_f16m1_to_f32m2_rvv_(a_u16m1, vector_length);
|
|
209
|
+
vfloat32m2_t b_f32m2 = nk_f16m1_to_f32m2_rvv_(b_u16m1, vector_length);
|
|
210
|
+
// ratio = (a + ε) / (b + ε)
|
|
211
|
+
vfloat32m2_t a_eps_f32m2 = __riscv_vfadd_vf_f32m2(a_f32m2, NK_F32_DIVISION_EPSILON, vector_length);
|
|
212
|
+
vfloat32m2_t b_eps_f32m2 = __riscv_vfadd_vf_f32m2(b_f32m2, NK_F32_DIVISION_EPSILON, vector_length);
|
|
213
|
+
vfloat32m2_t ratio_f32m2 = __riscv_vfmul_vv_f32m2(
|
|
214
|
+
a_eps_f32m2, nk_f32m2_reciprocal_rvv_(b_eps_f32m2, vector_length), vector_length);
|
|
215
|
+
vfloat32m2_t log_ratio_f32m2 = nk_log2_f32m2_rvv_(ratio_f32m2, vector_length);
|
|
216
|
+
// contribution = a * log2(a / b)
|
|
217
|
+
vfloat32m2_t contribution_f32m2 = __riscv_vfmul_vv_f32m2(a_f32m2, log_ratio_f32m2, vector_length);
|
|
218
|
+
// Per-lane accumulation
|
|
219
|
+
sum_f32m2 = __riscv_vfadd_vv_f32m2_tu(sum_f32m2, sum_f32m2, contribution_f32m2, vector_length);
|
|
220
|
+
}
|
|
221
|
+
// Single horizontal reduction after loop
|
|
222
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
223
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax)) * 0.693147181f;
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
NK_PUBLIC void nk_kld_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
227
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
228
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
229
|
+
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
230
|
+
vector_length = __riscv_vsetvl_e16m1(n);
|
|
231
|
+
// Load bf16 as raw u16 bits
|
|
232
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a, vector_length);
|
|
233
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b, vector_length);
|
|
234
|
+
// Convert bf16 to f32 (m1 -> m2)
|
|
235
|
+
vfloat32m2_t a_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_u16m1, vector_length);
|
|
236
|
+
vfloat32m2_t b_f32m2 = nk_bf16m1_to_f32m2_rvv_(b_u16m1, vector_length);
|
|
237
|
+
// ratio = (a + ε) / (b + ε)
|
|
238
|
+
vfloat32m2_t a_eps_f32m2 = __riscv_vfadd_vf_f32m2(a_f32m2, NK_F32_DIVISION_EPSILON, vector_length);
|
|
239
|
+
vfloat32m2_t b_eps_f32m2 = __riscv_vfadd_vf_f32m2(b_f32m2, NK_F32_DIVISION_EPSILON, vector_length);
|
|
240
|
+
vfloat32m2_t ratio_f32m2 = __riscv_vfmul_vv_f32m2(
|
|
241
|
+
a_eps_f32m2, nk_f32m2_reciprocal_rvv_(b_eps_f32m2, vector_length), vector_length);
|
|
242
|
+
vfloat32m2_t log_ratio_f32m2 = nk_log2_f32m2_rvv_(ratio_f32m2, vector_length);
|
|
243
|
+
// contribution = a * log2(a / b)
|
|
244
|
+
vfloat32m2_t contribution_f32m2 = __riscv_vfmul_vv_f32m2(a_f32m2, log_ratio_f32m2, vector_length);
|
|
245
|
+
// Per-lane accumulation
|
|
246
|
+
sum_f32m2 = __riscv_vfadd_vv_f32m2_tu(sum_f32m2, sum_f32m2, contribution_f32m2, vector_length);
|
|
247
|
+
}
|
|
248
|
+
// Single horizontal reduction after loop
|
|
249
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
250
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax)) * 0.693147181f;
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
#pragma endregion - Kullback - Leibler Divergence
|
|
254
|
+
|
|
255
|
+
#pragma region - Jensen-Shannon Divergence
|
|
256
|
+
|
|
257
|
+
NK_PUBLIC void nk_jsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
258
|
+
nk_size_t vector_length_max = __riscv_vsetvlmax_e64m4();
|
|
259
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vector_length_max);
|
|
260
|
+
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
261
|
+
vector_length = __riscv_vsetvl_e32m2(n);
|
|
262
|
+
vfloat32m2_t va = __riscv_vle32_v_f32m2(a, vector_length);
|
|
263
|
+
vfloat32m2_t vb = __riscv_vle32_v_f32m2(b, vector_length);
|
|
264
|
+
// M = (a + b) / 2
|
|
265
|
+
vfloat32m2_t mean = __riscv_vfmul_vf_f32m2(__riscv_vfadd_vv_f32m2(va, vb, vector_length), 0.5f, vector_length);
|
|
266
|
+
// ratio_a = (a + eps) / (M + eps)
|
|
267
|
+
vfloat32m2_t va_eps = __riscv_vfadd_vf_f32m2(va, NK_F32_DIVISION_EPSILON, vector_length);
|
|
268
|
+
vfloat32m2_t vb_eps = __riscv_vfadd_vf_f32m2(vb, NK_F32_DIVISION_EPSILON, vector_length);
|
|
269
|
+
vfloat32m2_t mean_eps_f32m2 = __riscv_vfadd_vf_f32m2(mean, NK_F32_DIVISION_EPSILON, vector_length);
|
|
270
|
+
vfloat32m2_t mean_rcp_f32m2 = nk_f32m2_reciprocal_rvv_(mean_eps_f32m2, vector_length);
|
|
271
|
+
vfloat32m2_t ratio_a = __riscv_vfmul_vv_f32m2(va_eps, mean_rcp_f32m2, vector_length);
|
|
272
|
+
vfloat32m2_t ratio_b = __riscv_vfmul_vv_f32m2(vb_eps, mean_rcp_f32m2, vector_length);
|
|
273
|
+
// log2(ratio_a), log2(ratio_b)
|
|
274
|
+
vfloat32m2_t log_ratio_a = nk_log2_f32m2_rvv_(ratio_a, vector_length);
|
|
275
|
+
vfloat32m2_t log_ratio_b = nk_log2_f32m2_rvv_(ratio_b, vector_length);
|
|
276
|
+
// contribution_a = a * log2(a / M), contribution_b = b * log2(b / M)
|
|
277
|
+
vfloat32m2_t contrib_a = __riscv_vfmul_vv_f32m2(va, log_ratio_a, vector_length);
|
|
278
|
+
vfloat32m2_t contrib_b = __riscv_vfmul_vv_f32m2(vb, log_ratio_b, vector_length);
|
|
279
|
+
vfloat32m2_t contrib_f32m2 = __riscv_vfadd_vv_f32m2(contrib_a, contrib_b, vector_length);
|
|
280
|
+
vfloat64m4_t contrib_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(contrib_f32m2, vector_length);
|
|
281
|
+
sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(sum_f64m4, sum_f64m4, contrib_f64m4, vector_length);
|
|
282
|
+
}
|
|
283
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
284
|
+
nk_f64_t sum = __riscv_vfmv_f_s_f64m1_f64(
|
|
285
|
+
__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vector_length_max)) *
|
|
286
|
+
0.6931471805599453 / 2.0;
|
|
287
|
+
*result = sum > 0 ? nk_f64_sqrt_rvv(sum) : 0;
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
NK_PUBLIC void nk_jsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
291
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
292
|
+
vfloat64m4_t sum_a_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
293
|
+
vfloat64m4_t sum_b_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
294
|
+
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
295
|
+
vector_length = __riscv_vsetvl_e64m4(n);
|
|
296
|
+
vfloat64m4_t va = __riscv_vle64_v_f64m4(a, vector_length);
|
|
297
|
+
vfloat64m4_t vb = __riscv_vle64_v_f64m4(b, vector_length);
|
|
298
|
+
// M = (a + b) / 2
|
|
299
|
+
vfloat64m4_t mean = __riscv_vfmul_vf_f64m4(__riscv_vfadd_vv_f64m4(va, vb, vector_length), 0.5, vector_length);
|
|
300
|
+
// ratio_a = (a + eps) / (M + eps), ratio_b = (b + eps) / (M + eps)
|
|
301
|
+
vfloat64m4_t va_eps = __riscv_vfadd_vf_f64m4(va, NK_F64_DIVISION_EPSILON, vector_length);
|
|
302
|
+
vfloat64m4_t vb_eps = __riscv_vfadd_vf_f64m4(vb, NK_F64_DIVISION_EPSILON, vector_length);
|
|
303
|
+
vfloat64m4_t mean_eps = __riscv_vfadd_vf_f64m4(mean, NK_F64_DIVISION_EPSILON, vector_length);
|
|
304
|
+
// Full precision division (not reciprocal approximation)
|
|
305
|
+
vfloat64m4_t ratio_a = __riscv_vfdiv_vv_f64m4(va_eps, mean_eps, vector_length);
|
|
306
|
+
vfloat64m4_t ratio_b = __riscv_vfdiv_vv_f64m4(vb_eps, mean_eps, vector_length);
|
|
307
|
+
// log2(ratio_a), log2(ratio_b)
|
|
308
|
+
vfloat64m4_t log_ratio_a = nk_log2_f64m4_rvv_(ratio_a, vector_length);
|
|
309
|
+
vfloat64m4_t log_ratio_b = nk_log2_f64m4_rvv_(ratio_b, vector_length);
|
|
310
|
+
// contribution_a = a * log2(a / M), contribution_b = b * log2(b / M)
|
|
311
|
+
sum_a_f64m4 = __riscv_vfmacc_vv_f64m4_tu(sum_a_f64m4, va, log_ratio_a, vector_length);
|
|
312
|
+
sum_b_f64m4 = __riscv_vfmacc_vv_f64m4_tu(sum_b_f64m4, vb, log_ratio_b, vector_length);
|
|
313
|
+
}
|
|
314
|
+
// Single horizontal reduction after loop
|
|
315
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
316
|
+
// JSD = sqrt((sum_a + sum_b) * ln(2) / 2)
|
|
317
|
+
nk_f64_t sum = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(
|
|
318
|
+
__riscv_vfadd_vv_f64m4(sum_a_f64m4, sum_b_f64m4, vlmax), zero_f64m1, vlmax)) *
|
|
319
|
+
0.6931471805599453 / 2;
|
|
320
|
+
*result = sum > 0 ? nk_f64_sqrt_rvv(sum) : 0;
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
NK_PUBLIC void nk_jsd_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
324
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
325
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
326
|
+
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
327
|
+
vector_length = __riscv_vsetvl_e16m1(n);
|
|
328
|
+
// Load f16 as raw u16 bits
|
|
329
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a, vector_length);
|
|
330
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b, vector_length);
|
|
331
|
+
// Convert f16 to f32 (m1 -> m2)
|
|
332
|
+
vfloat32m2_t va = nk_f16m1_to_f32m2_rvv_(a_u16m1, vector_length);
|
|
333
|
+
vfloat32m2_t vb = nk_f16m1_to_f32m2_rvv_(b_u16m1, vector_length);
|
|
334
|
+
// M = (a + b) / 2
|
|
335
|
+
vfloat32m2_t mean = __riscv_vfmul_vf_f32m2(__riscv_vfadd_vv_f32m2(va, vb, vector_length), 0.5f, vector_length);
|
|
336
|
+
// ratio_a = (a + eps) / (M + eps), ratio_b = (b + eps) / (M + eps)
|
|
337
|
+
vfloat32m2_t va_eps = __riscv_vfadd_vf_f32m2(va, NK_F32_DIVISION_EPSILON, vector_length);
|
|
338
|
+
vfloat32m2_t vb_eps = __riscv_vfadd_vf_f32m2(vb, NK_F32_DIVISION_EPSILON, vector_length);
|
|
339
|
+
vfloat32m2_t mean_eps_f32m2 = __riscv_vfadd_vf_f32m2(mean, NK_F32_DIVISION_EPSILON, vector_length);
|
|
340
|
+
vfloat32m2_t mean_rcp_f32m2 = nk_f32m2_reciprocal_rvv_(mean_eps_f32m2, vector_length);
|
|
341
|
+
vfloat32m2_t ratio_a = __riscv_vfmul_vv_f32m2(va_eps, mean_rcp_f32m2, vector_length);
|
|
342
|
+
vfloat32m2_t ratio_b = __riscv_vfmul_vv_f32m2(vb_eps, mean_rcp_f32m2, vector_length);
|
|
343
|
+
vfloat32m2_t log_ratio_a = nk_log2_f32m2_rvv_(ratio_a, vector_length);
|
|
344
|
+
vfloat32m2_t log_ratio_b = nk_log2_f32m2_rvv_(ratio_b, vector_length);
|
|
345
|
+
// contribution_a = a * log2(a / M), contribution_b = b * log2(b / M)
|
|
346
|
+
vfloat32m2_t contrib_a = __riscv_vfmul_vv_f32m2(va, log_ratio_a, vector_length);
|
|
347
|
+
vfloat32m2_t contrib_b = __riscv_vfmul_vv_f32m2(vb, log_ratio_b, vector_length);
|
|
348
|
+
vfloat32m2_t contrib = __riscv_vfadd_vv_f32m2(contrib_a, contrib_b, vector_length);
|
|
349
|
+
// Per-lane accumulation
|
|
350
|
+
sum_f32m2 = __riscv_vfadd_vv_f32m2_tu(sum_f32m2, sum_f32m2, contrib, vector_length);
|
|
351
|
+
}
|
|
352
|
+
// Single horizontal reduction after loop
|
|
353
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
354
|
+
nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax)) *
|
|
355
|
+
0.693147181f / 2;
|
|
356
|
+
*result = sum > 0 ? nk_f32_sqrt_rvv(sum) : 0;
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
NK_PUBLIC void nk_jsd_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
360
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
361
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
362
|
+
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
363
|
+
vector_length = __riscv_vsetvl_e16m1(n);
|
|
364
|
+
// Load bf16 as raw u16 bits
|
|
365
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a, vector_length);
|
|
366
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b, vector_length);
|
|
367
|
+
// Convert bf16 to f32 (m1 -> m2)
|
|
368
|
+
vfloat32m2_t va = nk_bf16m1_to_f32m2_rvv_(a_u16m1, vector_length);
|
|
369
|
+
vfloat32m2_t vb = nk_bf16m1_to_f32m2_rvv_(b_u16m1, vector_length);
|
|
370
|
+
// M = (a + b) / 2
|
|
371
|
+
vfloat32m2_t mean = __riscv_vfmul_vf_f32m2(__riscv_vfadd_vv_f32m2(va, vb, vector_length), 0.5f, vector_length);
|
|
372
|
+
// ratio_a = (a + eps) / (M + eps), ratio_b = (b + eps) / (M + eps)
|
|
373
|
+
vfloat32m2_t va_eps = __riscv_vfadd_vf_f32m2(va, NK_F32_DIVISION_EPSILON, vector_length);
|
|
374
|
+
vfloat32m2_t vb_eps = __riscv_vfadd_vf_f32m2(vb, NK_F32_DIVISION_EPSILON, vector_length);
|
|
375
|
+
vfloat32m2_t mean_eps_f32m2 = __riscv_vfadd_vf_f32m2(mean, NK_F32_DIVISION_EPSILON, vector_length);
|
|
376
|
+
vfloat32m2_t mean_rcp_f32m2 = nk_f32m2_reciprocal_rvv_(mean_eps_f32m2, vector_length);
|
|
377
|
+
vfloat32m2_t ratio_a = __riscv_vfmul_vv_f32m2(va_eps, mean_rcp_f32m2, vector_length);
|
|
378
|
+
vfloat32m2_t ratio_b = __riscv_vfmul_vv_f32m2(vb_eps, mean_rcp_f32m2, vector_length);
|
|
379
|
+
vfloat32m2_t log_ratio_a = nk_log2_f32m2_rvv_(ratio_a, vector_length);
|
|
380
|
+
vfloat32m2_t log_ratio_b = nk_log2_f32m2_rvv_(ratio_b, vector_length);
|
|
381
|
+
// contribution_a = a * log2(a / M), contribution_b = b * log2(b / M)
|
|
382
|
+
vfloat32m2_t contrib_a = __riscv_vfmul_vv_f32m2(va, log_ratio_a, vector_length);
|
|
383
|
+
vfloat32m2_t contrib_b = __riscv_vfmul_vv_f32m2(vb, log_ratio_b, vector_length);
|
|
384
|
+
vfloat32m2_t contrib = __riscv_vfadd_vv_f32m2(contrib_a, contrib_b, vector_length);
|
|
385
|
+
// Per-lane accumulation
|
|
386
|
+
sum_f32m2 = __riscv_vfadd_vv_f32m2_tu(sum_f32m2, sum_f32m2, contrib, vector_length);
|
|
387
|
+
}
|
|
388
|
+
// Single horizontal reduction after loop
|
|
389
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
390
|
+
nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax)) *
|
|
391
|
+
0.693147181f / 2;
|
|
392
|
+
*result = sum > 0 ? nk_f32_sqrt_rvv(sum) : 0;
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
#pragma endregion - Jensen - Shannon Divergence
|
|
396
|
+
|
|
397
|
+
#if defined(__cplusplus)
|
|
398
|
+
} // extern "C"
|
|
399
|
+
#endif
|
|
400
|
+
|
|
401
|
+
#if defined(__clang__)
|
|
402
|
+
#pragma clang attribute pop
|
|
403
|
+
#elif defined(__GNUC__)
|
|
404
|
+
#pragma GCC pop_options
|
|
405
|
+
#endif
|
|
406
|
+
|
|
407
|
+
#endif // NK_TARGET_RVV
|
|
408
|
+
#endif // NK_TARGET_RISCV_
|
|
409
|
+
#endif // NK_PROBABILITY_RVV_H
|