numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -145,7 +145,7 @@ NK_INTERNAL vfloat64m4_t nk_log2_f64m4_rvv_(vfloat64m4_t x, nk_size_t vector_len
|
|
|
145
145
|
return __riscv_vfadd_vv_f64m4(exp_f, log2_m, vector_length);
|
|
146
146
|
}
|
|
147
147
|
|
|
148
|
-
#pragma region
|
|
148
|
+
#pragma region Kullback Leibler Divergence
|
|
149
149
|
|
|
150
150
|
NK_PUBLIC void nk_kld_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
151
151
|
nk_size_t vector_length_max = __riscv_vsetvlmax_e64m4();
|
|
@@ -172,8 +172,8 @@ NK_PUBLIC void nk_kld_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
|
|
|
172
172
|
}
|
|
173
173
|
|
|
174
174
|
NK_PUBLIC void nk_kld_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
175
|
-
nk_size_t
|
|
176
|
-
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
175
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
176
|
+
vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
177
177
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
178
178
|
vector_length = __riscv_vsetvl_e64m4(n);
|
|
179
179
|
vfloat64m4_t a_f64m4 = __riscv_vle64_v_f64m4(a, vector_length);
|
|
@@ -192,13 +192,13 @@ NK_PUBLIC void nk_kld_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
|
|
|
192
192
|
// Single horizontal reduction after loop
|
|
193
193
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
194
194
|
// Convert from log2 to ln by multiplying by ln(2)
|
|
195
|
-
*result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1,
|
|
195
|
+
*result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, max_vector_length)) *
|
|
196
196
|
0.6931471805599453;
|
|
197
197
|
}
|
|
198
198
|
|
|
199
199
|
NK_PUBLIC void nk_kld_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
200
|
-
nk_size_t
|
|
201
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
200
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
201
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
202
202
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
203
203
|
vector_length = __riscv_vsetvl_e16m1(n);
|
|
204
204
|
// Load f16 as raw u16 bits
|
|
@@ -220,12 +220,13 @@ NK_PUBLIC void nk_kld_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
|
220
220
|
}
|
|
221
221
|
// Single horizontal reduction after loop
|
|
222
222
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
223
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
223
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length)) *
|
|
224
|
+
0.693147181f;
|
|
224
225
|
}
|
|
225
226
|
|
|
226
227
|
NK_PUBLIC void nk_kld_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
227
|
-
nk_size_t
|
|
228
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
228
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
229
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
229
230
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
230
231
|
vector_length = __riscv_vsetvl_e16m1(n);
|
|
231
232
|
// Load bf16 as raw u16 bits
|
|
@@ -247,12 +248,13 @@ NK_PUBLIC void nk_kld_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t
|
|
|
247
248
|
}
|
|
248
249
|
// Single horizontal reduction after loop
|
|
249
250
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
250
|
-
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1,
|
|
251
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length)) *
|
|
252
|
+
0.693147181f;
|
|
251
253
|
}
|
|
252
254
|
|
|
253
|
-
#pragma endregion
|
|
255
|
+
#pragma endregion Kullback Leibler Divergence
|
|
254
256
|
|
|
255
|
-
#pragma region
|
|
257
|
+
#pragma region Jensen Shannon Divergence
|
|
256
258
|
|
|
257
259
|
NK_PUBLIC void nk_jsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
258
260
|
nk_size_t vector_length_max = __riscv_vsetvlmax_e64m4();
|
|
@@ -288,9 +290,9 @@ NK_PUBLIC void nk_jsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
|
|
|
288
290
|
}
|
|
289
291
|
|
|
290
292
|
NK_PUBLIC void nk_jsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
291
|
-
nk_size_t
|
|
292
|
-
vfloat64m4_t sum_a_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
293
|
-
vfloat64m4_t sum_b_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
293
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
294
|
+
vfloat64m4_t sum_a_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
295
|
+
vfloat64m4_t sum_b_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
294
296
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
295
297
|
vector_length = __riscv_vsetvl_e64m4(n);
|
|
296
298
|
vfloat64m4_t va = __riscv_vle64_v_f64m4(a, vector_length);
|
|
@@ -315,14 +317,15 @@ NK_PUBLIC void nk_jsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
|
|
|
315
317
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
316
318
|
// JSD = sqrt((sum_a + sum_b) * ln(2) / 2)
|
|
317
319
|
nk_f64_t sum = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(
|
|
318
|
-
__riscv_vfadd_vv_f64m4(sum_a_f64m4, sum_b_f64m4,
|
|
320
|
+
__riscv_vfadd_vv_f64m4(sum_a_f64m4, sum_b_f64m4, max_vector_length), zero_f64m1,
|
|
321
|
+
max_vector_length)) *
|
|
319
322
|
0.6931471805599453 / 2;
|
|
320
323
|
*result = sum > 0 ? nk_f64_sqrt_rvv(sum) : 0;
|
|
321
324
|
}
|
|
322
325
|
|
|
323
326
|
NK_PUBLIC void nk_jsd_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
324
|
-
nk_size_t
|
|
325
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
327
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
328
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
326
329
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
327
330
|
vector_length = __riscv_vsetvl_e16m1(n);
|
|
328
331
|
// Load f16 as raw u16 bits
|
|
@@ -351,14 +354,15 @@ NK_PUBLIC void nk_jsd_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
|
351
354
|
}
|
|
352
355
|
// Single horizontal reduction after loop
|
|
353
356
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
354
|
-
nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(
|
|
357
|
+
nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(
|
|
358
|
+
__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length)) *
|
|
355
359
|
0.693147181f / 2;
|
|
356
360
|
*result = sum > 0 ? nk_f32_sqrt_rvv(sum) : 0;
|
|
357
361
|
}
|
|
358
362
|
|
|
359
363
|
NK_PUBLIC void nk_jsd_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
360
|
-
nk_size_t
|
|
361
|
-
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
364
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
365
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
362
366
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
|
|
363
367
|
vector_length = __riscv_vsetvl_e16m1(n);
|
|
364
368
|
// Load bf16 as raw u16 bits
|
|
@@ -387,12 +391,13 @@ NK_PUBLIC void nk_jsd_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t
|
|
|
387
391
|
}
|
|
388
392
|
// Single horizontal reduction after loop
|
|
389
393
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
390
|
-
nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(
|
|
394
|
+
nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(
|
|
395
|
+
__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length)) *
|
|
391
396
|
0.693147181f / 2;
|
|
392
397
|
*result = sum > 0 ? nk_f32_sqrt_rvv(sum) : 0;
|
|
393
398
|
}
|
|
394
399
|
|
|
395
|
-
#pragma endregion
|
|
400
|
+
#pragma endregion Jensen Shannon Divergence
|
|
396
401
|
|
|
397
402
|
#if defined(__cplusplus)
|
|
398
403
|
} // extern "C"
|
|
@@ -17,32 +17,35 @@
|
|
|
17
17
|
extern "C" {
|
|
18
18
|
#endif
|
|
19
19
|
|
|
20
|
-
#define nk_define_kld_(input_type, accumulator_type, output_type, load_and_convert, epsilon,
|
|
20
|
+
#define nk_define_kld_(input_type, unpacked_type, accumulator_type, output_type, load_and_convert, epsilon, \
|
|
21
|
+
compute_log) \
|
|
21
22
|
NK_PUBLIC void nk_kld_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
22
23
|
nk_size_t n, output_type *result) { \
|
|
23
|
-
nk_##accumulator_type##_t
|
|
24
|
+
nk_##accumulator_type##_t sum = 0; \
|
|
25
|
+
nk_##unpacked_type##_t a_value, b_value; \
|
|
24
26
|
for (nk_size_t i = 0; i != n; ++i) { \
|
|
25
|
-
load_and_convert(a + i, &
|
|
26
|
-
load_and_convert(b + i, &
|
|
27
|
-
|
|
27
|
+
load_and_convert(a + i, &a_value); \
|
|
28
|
+
load_and_convert(b + i, &b_value); \
|
|
29
|
+
sum += a_value * compute_log((a_value + epsilon) / (b_value + epsilon)); \
|
|
28
30
|
} \
|
|
29
|
-
*result = (output_type)
|
|
31
|
+
*result = (output_type)sum; \
|
|
30
32
|
}
|
|
31
33
|
|
|
32
|
-
#define nk_define_jsd_(input_type, accumulator_type, output_type, load_and_convert, epsilon,
|
|
33
|
-
compute_sqrt)
|
|
34
|
+
#define nk_define_jsd_(input_type, unpacked_type, accumulator_type, output_type, load_and_convert, epsilon, \
|
|
35
|
+
compute_log, compute_sqrt) \
|
|
34
36
|
NK_PUBLIC void nk_jsd_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
35
37
|
nk_size_t n, output_type *result) { \
|
|
36
|
-
nk_##accumulator_type##_t
|
|
38
|
+
nk_##accumulator_type##_t sum = 0; \
|
|
39
|
+
nk_##unpacked_type##_t a_value, b_value; \
|
|
37
40
|
for (nk_size_t i = 0; i != n; ++i) { \
|
|
38
|
-
load_and_convert(a + i, &
|
|
39
|
-
load_and_convert(b + i, &
|
|
40
|
-
nk_##
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
load_and_convert(a + i, &a_value); \
|
|
42
|
+
load_and_convert(b + i, &b_value); \
|
|
43
|
+
nk_##unpacked_type##_t midpoint_value = (a_value + b_value) / 2; \
|
|
44
|
+
sum += a_value * compute_log((a_value + epsilon) / (midpoint_value + epsilon)); \
|
|
45
|
+
sum += b_value * compute_log((b_value + epsilon) / (midpoint_value + epsilon)); \
|
|
43
46
|
} \
|
|
44
|
-
output_type
|
|
45
|
-
*result =
|
|
47
|
+
output_type sum_half = ((output_type)sum / 2); \
|
|
48
|
+
*result = sum_half > 0 ? compute_sqrt(sum_half) : 0; \
|
|
46
49
|
}
|
|
47
50
|
|
|
48
51
|
/**
|
|
@@ -121,45 +124,54 @@ NK_INTERNAL nk_f64_t nk_f64_log_serial_(nk_f64_t x) {
|
|
|
121
124
|
return (nk_f64_t)exp * 0.6931471805599453 + 2.0 * u * poly;
|
|
122
125
|
}
|
|
123
126
|
|
|
124
|
-
nk_define_kld_(f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
125
|
-
nk_define_jsd_(f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
|
|
127
|
+
nk_define_kld_(f32, f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
128
|
+
nk_define_jsd_(f32, f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
|
|
129
|
+
nk_f64_sqrt_serial)
|
|
126
130
|
|
|
127
|
-
nk_define_kld_(f16, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
128
|
-
nk_define_jsd_(f16, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
|
|
131
|
+
nk_define_kld_(f16, f32, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
132
|
+
nk_define_jsd_(f16, f32, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
|
|
129
133
|
nk_f32_sqrt_serial)
|
|
130
134
|
|
|
131
|
-
nk_define_kld_(bf16, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
132
|
-
nk_define_jsd_(bf16, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
|
|
135
|
+
nk_define_kld_(bf16, f32, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
136
|
+
nk_define_jsd_(bf16, f32, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
|
|
133
137
|
nk_f32_sqrt_serial)
|
|
134
138
|
|
|
135
139
|
NK_PUBLIC void nk_kld_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
140
|
+
// Use Kahan summation for higher numerical stability in long distributions
|
|
136
141
|
nk_f64_t sum = 0, compensation = 0;
|
|
137
142
|
for (nk_size_t i = 0; i != n; ++i) {
|
|
138
|
-
nk_f64_t
|
|
139
|
-
nk_f64_t term =
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
sum
|
|
143
|
+
nk_f64_t a_value = a[i], b_value = b[i];
|
|
144
|
+
nk_f64_t term = a_value *
|
|
145
|
+
nk_f64_log_serial_((a_value + NK_F64_DIVISION_EPSILON) / (b_value + NK_F64_DIVISION_EPSILON));
|
|
146
|
+
nk_f64_t provisional_sum = sum + term;
|
|
147
|
+
compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term)) ? ((sum - provisional_sum) + term)
|
|
148
|
+
: ((term - provisional_sum) + sum);
|
|
149
|
+
sum = provisional_sum;
|
|
143
150
|
}
|
|
144
151
|
*result = sum + compensation;
|
|
145
152
|
}
|
|
146
153
|
|
|
147
154
|
NK_PUBLIC void nk_jsd_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
155
|
+
// Use Kahan summation for higher numerical stability in long distributions
|
|
148
156
|
nk_f64_t sum = 0, compensation = 0;
|
|
149
157
|
for (nk_size_t i = 0; i != n; ++i) {
|
|
150
|
-
nk_f64_t
|
|
151
|
-
nk_f64_t mi = (
|
|
152
|
-
nk_f64_t term_a =
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
sum
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
158
|
+
nk_f64_t a_value = a[i], b_value = b[i];
|
|
159
|
+
nk_f64_t mi = (a_value + b_value) / 2;
|
|
160
|
+
nk_f64_t term_a = a_value *
|
|
161
|
+
nk_f64_log_serial_((a_value + NK_F64_DIVISION_EPSILON) / (mi + NK_F64_DIVISION_EPSILON));
|
|
162
|
+
nk_f64_t provisional_sum = sum + term_a;
|
|
163
|
+
compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term_a)) ? ((sum - provisional_sum) + term_a)
|
|
164
|
+
: ((term_a - provisional_sum) + sum);
|
|
165
|
+
sum = provisional_sum;
|
|
166
|
+
nk_f64_t term_b = b_value *
|
|
167
|
+
nk_f64_log_serial_((b_value + NK_F64_DIVISION_EPSILON) / (mi + NK_F64_DIVISION_EPSILON));
|
|
168
|
+
provisional_sum = sum + term_b;
|
|
169
|
+
compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term_b)) ? ((sum - provisional_sum) + term_b)
|
|
170
|
+
: ((term_b - provisional_sum) + sum);
|
|
171
|
+
sum = provisional_sum;
|
|
160
172
|
}
|
|
161
|
-
nk_f64_t
|
|
162
|
-
*result =
|
|
173
|
+
nk_f64_t sum_half = (sum + compensation) / 2;
|
|
174
|
+
*result = sum_half > 0 ? nk_f64_sqrt_serial(sum_half) : 0;
|
|
163
175
|
}
|
|
164
176
|
|
|
165
177
|
#if defined(__cplusplus)
|
|
@@ -38,14 +38,14 @@
|
|
|
38
38
|
* calls. Division (for p/q ratio) uses either VDIVPS directly or VRCP14PS with Newton-Raphson
|
|
39
39
|
* refinement when ~14-bit precision suffices. Genoa's VGETEXP/VGETMANT are 25% faster than Ice.
|
|
40
40
|
*
|
|
41
|
-
* Intrinsic
|
|
42
|
-
* _mm512_getexp_ps
|
|
43
|
-
* _mm512_getexp_pd
|
|
44
|
-
* _mm512_getmant_ps
|
|
45
|
-
* _mm512_getmant_pd
|
|
46
|
-
* _mm512_rcp14_ps
|
|
47
|
-
* _mm512_div_ps
|
|
48
|
-
* _mm512_fmadd_ps
|
|
41
|
+
* Intrinsic Instruction Icelake Genoa
|
|
42
|
+
* _mm512_getexp_ps VGETEXPPS (ZMM, ZMM) 4cy @ p0 3cy @ p23
|
|
43
|
+
* _mm512_getexp_pd VGETEXPPD (ZMM, ZMM) 4cy @ p0 3cy @ p23
|
|
44
|
+
* _mm512_getmant_ps VGETMANTPS (ZMM, ZMM, I8) 4cy @ p0 3cy @ p23
|
|
45
|
+
* _mm512_getmant_pd VGETMANTPD (ZMM, ZMM, I8) 4cy @ p0 3cy @ p23
|
|
46
|
+
* _mm512_rcp14_ps VRCP14PS (ZMM, ZMM) 7cy @ p0+p0+p05 5cy @ p01
|
|
47
|
+
* _mm512_div_ps VDIVPS (ZMM, ZMM, ZMM) 17cy @ p0+p0+p05 11cy @ p01
|
|
48
|
+
* _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM) 4cy @ p0 4cy @ p01
|
|
49
49
|
*
|
|
50
50
|
* @section arm_instructions Relevant ARM NEON/SVE Instructions
|
|
51
51
|
*
|
|
@@ -53,14 +53,14 @@
|
|
|
53
53
|
* float bits followed by polynomial refinement. FRECPE provides ~8-bit reciprocal approximation
|
|
54
54
|
* for division, refined with FRECPS Newton-Raphson steps to ~22-bit precision.
|
|
55
55
|
*
|
|
56
|
-
* Intrinsic
|
|
57
|
-
* vfmaq_f32
|
|
58
|
-
* vrecpeq_f32
|
|
59
|
-
* vrecpsq_f32
|
|
56
|
+
* Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
|
|
57
|
+
* vfmaq_f32 FMLA.S (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
|
|
58
|
+
* vrecpeq_f32 FRECPE.S 3cy @ V02 3cy @ V02 3cy @ V02
|
|
59
|
+
* vrecpsq_f32 FRECPS.S 4cy @ V0123 4cy @ V0123 4cy @ V0123
|
|
60
60
|
*
|
|
61
61
|
* @section references References
|
|
62
62
|
*
|
|
63
|
-
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
63
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
|
|
64
64
|
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
65
65
|
*
|
|
66
66
|
*/
|
|
@@ -201,14 +201,11 @@ NK_PUBLIC void nk_jsd_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_siz
|
|
|
201
201
|
NK_PUBLIC void nk_kld_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
202
202
|
/** @copydoc nk_jsd_f32 */
|
|
203
203
|
NK_PUBLIC void nk_jsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
204
|
-
#endif // NK_TARGET_NEON
|
|
205
|
-
|
|
206
|
-
#if NK_TARGET_NEONHALF
|
|
207
204
|
/** @copydoc nk_kld_f16 */
|
|
208
|
-
NK_PUBLIC void
|
|
205
|
+
NK_PUBLIC void nk_kld_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
209
206
|
/** @copydoc nk_jsd_f16 */
|
|
210
|
-
NK_PUBLIC void
|
|
211
|
-
#endif //
|
|
207
|
+
NK_PUBLIC void nk_jsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
208
|
+
#endif // NK_TARGET_NEON
|
|
212
209
|
|
|
213
210
|
#if NK_TARGET_HASWELL
|
|
214
211
|
/** @copydoc nk_kld_f64 */
|
|
@@ -283,8 +280,8 @@ extern "C" {
|
|
|
283
280
|
#if !NK_DYNAMIC_DISPATCH
|
|
284
281
|
|
|
285
282
|
NK_PUBLIC void nk_kld_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
286
|
-
#if
|
|
287
|
-
|
|
283
|
+
#if NK_TARGET_NEON
|
|
284
|
+
nk_kld_f16_neon(a, b, n, result);
|
|
288
285
|
#elif NK_TARGET_SKYLAKE
|
|
289
286
|
nk_kld_f16_skylake(a, b, n, result);
|
|
290
287
|
#elif NK_TARGET_HASWELL
|
|
@@ -329,8 +326,8 @@ NK_PUBLIC void nk_kld_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_
|
|
|
329
326
|
}
|
|
330
327
|
|
|
331
328
|
NK_PUBLIC void nk_jsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
332
|
-
#if
|
|
333
|
-
|
|
329
|
+
#if NK_TARGET_NEON
|
|
330
|
+
nk_jsd_f16_neon(a, b, n, result);
|
|
334
331
|
#elif NK_TARGET_SKYLAKE
|
|
335
332
|
nk_jsd_f16_skylake(a, b, n, result);
|
|
336
333
|
#elif NK_TARGET_HASWELL
|
package/include/numkong/random.h
CHANGED
|
@@ -29,7 +29,7 @@
|
|
|
29
29
|
*
|
|
30
30
|
* @section references References
|
|
31
31
|
*
|
|
32
|
-
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
32
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
|
|
33
33
|
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
34
34
|
*
|
|
35
35
|
*/
|