numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Similarity Measures for Probability Distributions.
|
|
3
|
+
* @file include/numkong/probability.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date October 20, 2023
|
|
6
|
+
*
|
|
7
|
+
* Contains following similarity measures:
|
|
8
|
+
*
|
|
9
|
+
* - Kullback-Leibler Divergence (KLD)
|
|
10
|
+
* - Jensen-Shannon Distance (JSD)
|
|
11
|
+
*
|
|
12
|
+
* For dtypes:
|
|
13
|
+
*
|
|
14
|
+
* - 64-bit floating point numbers → 64-bit
|
|
15
|
+
* - 32-bit floating point numbers → 64-bit
|
|
16
|
+
* - 16-bit floating point numbers → 32-bit
|
|
17
|
+
* - 16-bit brain-floating point numbers → 32-bit
|
|
18
|
+
*
|
|
19
|
+
* Precision policy:
|
|
20
|
+
*
|
|
21
|
+
* - For `f32` inputs, the per-element vertical path stays in `f32` to preserve the fast ratio/log
|
|
22
|
+
* approximations and SIMD throughput.
|
|
23
|
+
* - The horizontal reduction over those per-element contributions widens to `f64`, and public
|
|
24
|
+
* `f32` results are exposed as `f64`.
|
|
25
|
+
* - For `f64` inputs, both the vertical path and the horizontal reduction stay in `f64`, with
|
|
26
|
+
* stable summation in the serial kernels.
|
|
27
|
+
* - For `f16` and `bf16` inputs, the kernels still widen to `f32`.
|
|
28
|
+
*
|
|
29
|
+
* For hardware architectures:
|
|
30
|
+
*
|
|
31
|
+
* - Arm: NEON
|
|
32
|
+
* - x86: Haswell, Skylake, Sapphire
|
|
33
|
+
*
|
|
34
|
+
* @section x86_instructions Relevant x86 Instructions
|
|
35
|
+
*
|
|
36
|
+
* KL/JS divergence requires log2(x) which decomposes into exponent extraction (VGETEXP) plus
|
|
37
|
+
* mantissa polynomial (using VGETMANT + FMA chain). This approach is faster than scalar log()
|
|
38
|
+
* calls. Division (for p/q ratio) uses either VDIVPS directly or VRCP14PS with Newton-Raphson
|
|
39
|
+
* refinement when ~14-bit precision suffices. Genoa's VGETEXP/VGETMANT are 25% faster than Ice.
|
|
40
|
+
*
|
|
41
|
+
* Intrinsic Instruction Ice Genoa
|
|
42
|
+
* _mm512_getexp_ps VGETEXPPS (ZMM, ZMM) 4c @ p0 3c @ p23
|
|
43
|
+
* _mm512_getexp_pd VGETEXPPD (ZMM, ZMM) 4c @ p0 3c @ p23
|
|
44
|
+
* _mm512_getmant_ps VGETMANTPS (ZMM, ZMM, I8) 4c @ p0 3c @ p23
|
|
45
|
+
* _mm512_getmant_pd VGETMANTPD (ZMM, ZMM, I8) 4c @ p0 3c @ p23
|
|
46
|
+
* _mm512_rcp14_ps VRCP14PS (ZMM, ZMM) 7c @ p05 5c @ p01
|
|
47
|
+
* _mm512_div_ps VDIVPS (ZMM, ZMM, ZMM) 17c @ p05 11c @ p01
|
|
48
|
+
* _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM) 4c @ p0 4c @ p01
|
|
49
|
+
*
|
|
50
|
+
* @section arm_instructions Relevant ARM NEON/SVE Instructions
|
|
51
|
+
*
|
|
52
|
+
* ARM lacks direct exponent/mantissa extraction, so log2 uses integer reinterpretation of the
|
|
53
|
+
* float bits followed by polynomial refinement. FRECPE provides ~8-bit reciprocal approximation
|
|
54
|
+
* for division, refined with FRECPS Newton-Raphson steps to ~22-bit precision.
|
|
55
|
+
*
|
|
56
|
+
* Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
|
|
57
|
+
* vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
|
|
58
|
+
* vrecpeq_f32 FRECPE.S 3c @ V02 3c @ V02 3c @ V02
|
|
59
|
+
* vrecpsq_f32 FRECPS.S 4c @ V0123 4c @ V0123 4c @ V0123
|
|
60
|
+
*
|
|
61
|
+
* @section references References
|
|
62
|
+
*
|
|
63
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
64
|
+
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
65
|
+
*
|
|
66
|
+
*/
|
|
67
|
+
#ifndef NK_PROBABILITY_H
|
|
68
|
+
#define NK_PROBABILITY_H
|
|
69
|
+
|
|
70
|
+
#include "numkong/types.h"
|
|
71
|
+
#include "numkong/reduce.h" // For horizontal reduction helpers
|
|
72
|
+
|
|
73
|
+
#if defined(__cplusplus)
|
|
74
|
+
extern "C" {
|
|
75
|
+
#endif
|
|
76
|
+
|
|
77
|
+
/**
|
|
78
|
+
* @brief Kullback-Leibler divergence between two discrete probability distributions.
|
|
79
|
+
*
|
|
80
|
+
* @param[in] a The first discrete probability distribution.
|
|
81
|
+
* @param[in] b The second discrete probability distribution.
|
|
82
|
+
* @param[in] n The number of elements in the distributions.
|
|
83
|
+
* @param[out] result The output divergence value.
|
|
84
|
+
*
|
|
85
|
+
* @note The distributions are assumed to be normalized.
|
|
86
|
+
* @note The output divergence value is non-negative.
|
|
87
|
+
* @note The output divergence value is zero if and only if the two distributions are identical.
|
|
88
|
+
*/
|
|
89
|
+
NK_DYNAMIC void nk_kld_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
90
|
+
/**
|
|
91
|
+
* @brief Kullback-Leibler divergence between two discrete probability distributions.
|
|
92
|
+
*
|
|
93
|
+
* @param[in] a The first discrete probability distribution.
|
|
94
|
+
* @param[in] b The second discrete probability distribution.
|
|
95
|
+
* @param[in] n The number of elements in the distributions.
|
|
96
|
+
* @param[out] result The output divergence value.
|
|
97
|
+
*
|
|
98
|
+
* @note The distributions are assumed to be normalized.
|
|
99
|
+
* @note The output divergence value is non-negative.
|
|
100
|
+
* @note The output divergence value is zero if and only if the two distributions are identical.
|
|
101
|
+
*/
|
|
102
|
+
NK_DYNAMIC void nk_kld_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
103
|
+
/**
|
|
104
|
+
* @brief Kullback-Leibler divergence between two discrete probability distributions.
|
|
105
|
+
*
|
|
106
|
+
* @param[in] a The first discrete probability distribution.
|
|
107
|
+
* @param[in] b The second discrete probability distribution.
|
|
108
|
+
* @param[in] n The number of elements in the distributions.
|
|
109
|
+
* @param[out] result The output divergence value.
|
|
110
|
+
*
|
|
111
|
+
* @note The distributions are assumed to be normalized.
|
|
112
|
+
* @note The output divergence value is non-negative.
|
|
113
|
+
* @note The output divergence value is zero if and only if the two distributions are identical.
|
|
114
|
+
*/
|
|
115
|
+
NK_DYNAMIC void nk_kld_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
116
|
+
/**
|
|
117
|
+
* @brief Kullback-Leibler divergence between two discrete probability distributions.
|
|
118
|
+
*
|
|
119
|
+
* @param[in] a The first discrete probability distribution.
|
|
120
|
+
* @param[in] b The second discrete probability distribution.
|
|
121
|
+
* @param[in] n The number of elements in the distributions.
|
|
122
|
+
* @param[out] result The output divergence value.
|
|
123
|
+
*
|
|
124
|
+
* @note The distributions are assumed to be normalized.
|
|
125
|
+
* @note The output divergence value is non-negative.
|
|
126
|
+
* @note The output divergence value is zero if and only if the two distributions are identical.
|
|
127
|
+
*/
|
|
128
|
+
NK_DYNAMIC void nk_kld_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
129
|
+
/**
|
|
130
|
+
* @brief Jensen-Shannon distance between two discrete probability distributions.
|
|
131
|
+
*
|
|
132
|
+
* @param[in] a The first discrete probability distribution.
|
|
133
|
+
* @param[in] b The second discrete probability distribution.
|
|
134
|
+
* @param[in] n The number of elements in the distributions.
|
|
135
|
+
* @param[out] result The output distance value.
|
|
136
|
+
*
|
|
137
|
+
* @note The distributions are assumed to be normalized.
|
|
138
|
+
* @note The output distance value is non-negative.
|
|
139
|
+
* @note The output distance value is zero if and only if the two distributions are identical.
|
|
140
|
+
*/
|
|
141
|
+
NK_DYNAMIC void nk_jsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
142
|
+
/**
|
|
143
|
+
* @brief Jensen-Shannon distance between two discrete probability distributions.
|
|
144
|
+
*
|
|
145
|
+
* @param[in] a The first discrete probability distribution.
|
|
146
|
+
* @param[in] b The second discrete probability distribution.
|
|
147
|
+
* @param[in] n The number of elements in the distributions.
|
|
148
|
+
* @param[out] result The output distance value.
|
|
149
|
+
*
|
|
150
|
+
* @note The distributions are assumed to be normalized.
|
|
151
|
+
* @note The output distance value is non-negative.
|
|
152
|
+
* @note The output distance value is zero if and only if the two distributions are identical.
|
|
153
|
+
*/
|
|
154
|
+
NK_DYNAMIC void nk_jsd_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
155
|
+
/**
|
|
156
|
+
* @brief Jensen-Shannon distance between two discrete probability distributions.
|
|
157
|
+
*
|
|
158
|
+
* @param[in] a The first discrete probability distribution.
|
|
159
|
+
* @param[in] b The second discrete probability distribution.
|
|
160
|
+
* @param[in] n The number of elements in the distributions.
|
|
161
|
+
* @param[out] result The output distance value.
|
|
162
|
+
*
|
|
163
|
+
* @note The distributions are assumed to be normalized.
|
|
164
|
+
* @note The output distance value is non-negative.
|
|
165
|
+
* @note The output distance value is zero if and only if the two distributions are identical.
|
|
166
|
+
*/
|
|
167
|
+
NK_DYNAMIC void nk_jsd_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
168
|
+
/**
|
|
169
|
+
* @brief Jensen-Shannon distance between two discrete probability distributions.
|
|
170
|
+
*
|
|
171
|
+
* @param[in] a The first discrete probability distribution.
|
|
172
|
+
* @param[in] b The second discrete probability distribution.
|
|
173
|
+
* @param[in] n The number of elements in the distributions.
|
|
174
|
+
* @param[out] result The output distance value.
|
|
175
|
+
*
|
|
176
|
+
* @note The distributions are assumed to be normalized.
|
|
177
|
+
* @note The output distance value is non-negative.
|
|
178
|
+
* @note The output distance value is zero if and only if the two distributions are identical.
|
|
179
|
+
*/
|
|
180
|
+
NK_DYNAMIC void nk_jsd_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
181
|
+
|
|
182
|
+
/** @copydoc nk_kld_f64 */
|
|
183
|
+
NK_PUBLIC void nk_kld_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
184
|
+
/** @copydoc nk_jsd_f64 */
|
|
185
|
+
NK_PUBLIC void nk_jsd_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
186
|
+
/** @copydoc nk_kld_f32 */
|
|
187
|
+
NK_PUBLIC void nk_kld_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
188
|
+
/** @copydoc nk_jsd_f32 */
|
|
189
|
+
NK_PUBLIC void nk_jsd_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
190
|
+
/** @copydoc nk_kld_f16 */
|
|
191
|
+
NK_PUBLIC void nk_kld_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
192
|
+
/** @copydoc nk_jsd_f16 */
|
|
193
|
+
NK_PUBLIC void nk_jsd_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
194
|
+
/** @copydoc nk_kld_bf16 */
|
|
195
|
+
NK_PUBLIC void nk_kld_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
196
|
+
/** @copydoc nk_jsd_bf16 */
|
|
197
|
+
NK_PUBLIC void nk_jsd_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
198
|
+
|
|
199
|
+
#if NK_TARGET_NEON
|
|
200
|
+
/** @copydoc nk_kld_f32 */
|
|
201
|
+
NK_PUBLIC void nk_kld_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
202
|
+
/** @copydoc nk_jsd_f32 */
|
|
203
|
+
NK_PUBLIC void nk_jsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
204
|
+
#endif // NK_TARGET_NEON
|
|
205
|
+
|
|
206
|
+
#if NK_TARGET_NEONHALF
|
|
207
|
+
/** @copydoc nk_kld_f16 */
|
|
208
|
+
NK_PUBLIC void nk_kld_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
209
|
+
/** @copydoc nk_jsd_f16 */
|
|
210
|
+
NK_PUBLIC void nk_jsd_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
211
|
+
#endif // NK_TARGET_NEONHALF
|
|
212
|
+
|
|
213
|
+
#if NK_TARGET_HASWELL
|
|
214
|
+
/** @copydoc nk_kld_f64 */
|
|
215
|
+
NK_PUBLIC void nk_kld_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
216
|
+
/** @copydoc nk_jsd_f64 */
|
|
217
|
+
NK_PUBLIC void nk_jsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
218
|
+
/** @copydoc nk_kld_f16 */
|
|
219
|
+
NK_PUBLIC void nk_kld_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
220
|
+
/** @copydoc nk_jsd_f16 */
|
|
221
|
+
NK_PUBLIC void nk_jsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
222
|
+
#endif // NK_TARGET_HASWELL
|
|
223
|
+
|
|
224
|
+
#if NK_TARGET_SKYLAKE
|
|
225
|
+
/** @copydoc nk_kld_f64 */
|
|
226
|
+
NK_PUBLIC void nk_kld_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
227
|
+
/** @copydoc nk_jsd_f64 */
|
|
228
|
+
NK_PUBLIC void nk_jsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
229
|
+
/** @copydoc nk_kld_f32 */
|
|
230
|
+
NK_PUBLIC void nk_kld_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
231
|
+
/** @copydoc nk_jsd_f32 */
|
|
232
|
+
NK_PUBLIC void nk_jsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
233
|
+
/** @copydoc nk_kld_f16 */
|
|
234
|
+
NK_PUBLIC void nk_kld_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
235
|
+
/** @copydoc nk_jsd_f16 */
|
|
236
|
+
NK_PUBLIC void nk_jsd_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
237
|
+
#endif // NK_TARGET_SKYLAKE
|
|
238
|
+
|
|
239
|
+
#if NK_TARGET_RVV
|
|
240
|
+
/** @copydoc nk_kld_f32 */
|
|
241
|
+
NK_PUBLIC void nk_kld_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
242
|
+
/** @copydoc nk_jsd_f32 */
|
|
243
|
+
NK_PUBLIC void nk_jsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
244
|
+
/** @copydoc nk_kld_f64 */
|
|
245
|
+
NK_PUBLIC void nk_kld_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
246
|
+
/** @copydoc nk_jsd_f64 */
|
|
247
|
+
NK_PUBLIC void nk_jsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
248
|
+
/** @copydoc nk_kld_f16 */
|
|
249
|
+
NK_PUBLIC void nk_kld_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
250
|
+
/** @copydoc nk_jsd_f16 */
|
|
251
|
+
NK_PUBLIC void nk_jsd_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
252
|
+
/** @copydoc nk_kld_bf16 */
|
|
253
|
+
NK_PUBLIC void nk_kld_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
254
|
+
/** @copydoc nk_jsd_bf16 */
|
|
255
|
+
NK_PUBLIC void nk_jsd_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
256
|
+
#endif // NK_TARGET_RVV
|
|
257
|
+
|
|
258
|
+
/** @brief Returns the output dtype for probability measures (KLD, JSD). */
|
|
259
|
+
NK_INTERNAL nk_dtype_t nk_probability_output_dtype(nk_dtype_t dtype) {
|
|
260
|
+
switch (dtype) {
|
|
261
|
+
case nk_f64_k: return nk_f64_k;
|
|
262
|
+
case nk_f32_k: return nk_f64_k;
|
|
263
|
+
case nk_f16_k: return nk_f32_k;
|
|
264
|
+
case nk_bf16_k: return nk_f32_k;
|
|
265
|
+
default: return nk_dtype_unknown_k;
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
#if defined(__cplusplus)
|
|
270
|
+
} // extern "C"
|
|
271
|
+
#endif
|
|
272
|
+
|
|
273
|
+
#include "numkong/probability/serial.h"
|
|
274
|
+
#include "numkong/probability/neon.h"
|
|
275
|
+
#include "numkong/probability/haswell.h"
|
|
276
|
+
#include "numkong/probability/skylake.h"
|
|
277
|
+
#include "numkong/probability/rvv.h"
|
|
278
|
+
|
|
279
|
+
#if defined(__cplusplus)
|
|
280
|
+
extern "C" {
|
|
281
|
+
#endif
|
|
282
|
+
|
|
283
|
+
#if !NK_DYNAMIC_DISPATCH
|
|
284
|
+
|
|
285
|
+
NK_PUBLIC void nk_kld_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
286
|
+
#if NK_TARGET_NEONHALF
|
|
287
|
+
nk_kld_f16_neonhalf(a, b, n, result);
|
|
288
|
+
#elif NK_TARGET_SKYLAKE
|
|
289
|
+
nk_kld_f16_skylake(a, b, n, result);
|
|
290
|
+
#elif NK_TARGET_HASWELL
|
|
291
|
+
nk_kld_f16_haswell(a, b, n, result);
|
|
292
|
+
#elif NK_TARGET_RVV
|
|
293
|
+
nk_kld_f16_rvv(a, b, n, result);
|
|
294
|
+
#else
|
|
295
|
+
nk_kld_f16_serial(a, b, n, result);
|
|
296
|
+
#endif
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
NK_PUBLIC void nk_kld_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
300
|
+
#if NK_TARGET_RVV
|
|
301
|
+
nk_kld_bf16_rvv(a, b, n, result);
|
|
302
|
+
#else
|
|
303
|
+
nk_kld_bf16_serial(a, b, n, result);
|
|
304
|
+
#endif
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
NK_PUBLIC void nk_kld_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
308
|
+
#if NK_TARGET_NEON
|
|
309
|
+
nk_kld_f32_neon(a, b, n, result);
|
|
310
|
+
#elif NK_TARGET_SKYLAKE
|
|
311
|
+
nk_kld_f32_skylake(a, b, n, result);
|
|
312
|
+
#elif NK_TARGET_RVV
|
|
313
|
+
nk_kld_f32_rvv(a, b, n, result);
|
|
314
|
+
#else
|
|
315
|
+
nk_kld_f32_serial(a, b, n, result);
|
|
316
|
+
#endif
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
NK_PUBLIC void nk_kld_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
320
|
+
#if NK_TARGET_SKYLAKE
|
|
321
|
+
nk_kld_f64_skylake(a, b, n, result);
|
|
322
|
+
#elif NK_TARGET_HASWELL
|
|
323
|
+
nk_kld_f64_haswell(a, b, n, result);
|
|
324
|
+
#elif NK_TARGET_RVV
|
|
325
|
+
nk_kld_f64_rvv(a, b, n, result);
|
|
326
|
+
#else
|
|
327
|
+
nk_kld_f64_serial(a, b, n, result);
|
|
328
|
+
#endif
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
NK_PUBLIC void nk_jsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
332
|
+
#if NK_TARGET_NEONHALF
|
|
333
|
+
nk_jsd_f16_neonhalf(a, b, n, result);
|
|
334
|
+
#elif NK_TARGET_SKYLAKE
|
|
335
|
+
nk_jsd_f16_skylake(a, b, n, result);
|
|
336
|
+
#elif NK_TARGET_HASWELL
|
|
337
|
+
nk_jsd_f16_haswell(a, b, n, result);
|
|
338
|
+
#elif NK_TARGET_RVV
|
|
339
|
+
nk_jsd_f16_rvv(a, b, n, result);
|
|
340
|
+
#else
|
|
341
|
+
nk_jsd_f16_serial(a, b, n, result);
|
|
342
|
+
#endif
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
NK_PUBLIC void nk_jsd_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
346
|
+
#if NK_TARGET_RVV
|
|
347
|
+
nk_jsd_bf16_rvv(a, b, n, result);
|
|
348
|
+
#else
|
|
349
|
+
nk_jsd_bf16_serial(a, b, n, result);
|
|
350
|
+
#endif
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
NK_PUBLIC void nk_jsd_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
354
|
+
#if NK_TARGET_NEON
|
|
355
|
+
nk_jsd_f32_neon(a, b, n, result);
|
|
356
|
+
#elif NK_TARGET_SKYLAKE
|
|
357
|
+
nk_jsd_f32_skylake(a, b, n, result);
|
|
358
|
+
#elif NK_TARGET_RVV
|
|
359
|
+
nk_jsd_f32_rvv(a, b, n, result);
|
|
360
|
+
#else
|
|
361
|
+
nk_jsd_f32_serial(a, b, n, result);
|
|
362
|
+
#endif
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
NK_PUBLIC void nk_jsd_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
366
|
+
#if NK_TARGET_SKYLAKE
|
|
367
|
+
nk_jsd_f64_skylake(a, b, n, result);
|
|
368
|
+
#elif NK_TARGET_HASWELL
|
|
369
|
+
nk_jsd_f64_haswell(a, b, n, result);
|
|
370
|
+
#elif NK_TARGET_RVV
|
|
371
|
+
nk_jsd_f64_rvv(a, b, n, result);
|
|
372
|
+
#else
|
|
373
|
+
nk_jsd_f64_serial(a, b, n, result);
|
|
374
|
+
#endif
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
#endif // !NK_DYNAMIC_DISPATCH
|
|
378
|
+
|
|
379
|
+
#if defined(__cplusplus)
|
|
380
|
+
} // extern "C"
|
|
381
|
+
#endif
|
|
382
|
+
|
|
383
|
+
#endif
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief C++ wrappers for SIMD-accelerated Similarity Measures for Probability Distributions.
|
|
3
|
+
* @file include/numkong/probability.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 5, 2026
|
|
6
|
+
*/
|
|
7
|
+
#ifndef NK_PROBABILITY_HPP
|
|
8
|
+
#define NK_PROBABILITY_HPP
|
|
9
|
+
|
|
10
|
+
#include <cstdint>
|
|
11
|
+
#include <type_traits>
|
|
12
|
+
|
|
13
|
+
#include "numkong/probability.h"
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.hpp"
|
|
16
|
+
|
|
17
|
+
namespace ashvardanian::numkong {
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* @brief Kullback-Leibler divergence: Σ pᵢ × log(pᵢ / qᵢ)
|
|
21
|
+
* @param[in] p,q First and second probability distributions
|
|
22
|
+
* @param[in] d Number of dimensions in input vectors
|
|
23
|
+
* @param[out] r Pointer to output divergence value
|
|
24
|
+
*
|
|
25
|
+
* @tparam in_type_ Input distribution type (probability vectors)
|
|
26
|
+
* @tparam result_type_ Result type, defaults to `in_type_::probability_result_t`
|
|
27
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
28
|
+
*/
|
|
29
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::probability_result_t,
|
|
30
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
31
|
+
void kld(in_type_ const *p, in_type_ const *q, std::size_t d, result_type_ *r) noexcept {
|
|
32
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
33
|
+
std::is_same_v<result_type_, typename in_type_::probability_result_t>;
|
|
34
|
+
|
|
35
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_kld_f64(&p->raw_, &q->raw_, d, &r->raw_);
|
|
36
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_kld_f32(&p->raw_, &q->raw_, d, &r->raw_);
|
|
37
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_kld_f16(&p->raw_, &q->raw_, d, &r->raw_);
|
|
38
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) nk_kld_bf16(&p->raw_, &q->raw_, d, &r->raw_);
|
|
39
|
+
// Scalar fallback
|
|
40
|
+
else {
|
|
41
|
+
result_type_ sum {};
|
|
42
|
+
for (std::size_t i = 0; i < d; i++) {
|
|
43
|
+
result_type_ pi(p[i]), qi(q[i]);
|
|
44
|
+
if (pi > result_type_(0)) sum = sum + pi * (pi / qi).log();
|
|
45
|
+
}
|
|
46
|
+
*r = sum;
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* @brief Jensen-Shannon distance: √(½ × (KL(p‖m) + KL(q‖m))), where m = (p + q) / 2
|
|
52
|
+
* @param[in] p,q First and second probability distributions
|
|
53
|
+
* @param[in] d Number of dimensions in input vectors
|
|
54
|
+
* @param[out] r Pointer to output distance value
|
|
55
|
+
*
|
|
56
|
+
* @tparam in_type_ Input distribution type (probability vectors)
|
|
57
|
+
* @tparam result_type_ Result type, defaults to `in_type_::probability_result_t`
|
|
58
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
59
|
+
*/
|
|
60
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::probability_result_t,
|
|
61
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
62
|
+
void jsd(in_type_ const *p, in_type_ const *q, std::size_t d, result_type_ *r) noexcept {
|
|
63
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
64
|
+
std::is_same_v<result_type_, typename in_type_::probability_result_t>;
|
|
65
|
+
|
|
66
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_jsd_f64(&p->raw_, &q->raw_, d, &r->raw_);
|
|
67
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_jsd_f32(&p->raw_, &q->raw_, d, &r->raw_);
|
|
68
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_jsd_f16(&p->raw_, &q->raw_, d, &r->raw_);
|
|
69
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) nk_jsd_bf16(&p->raw_, &q->raw_, d, &r->raw_);
|
|
70
|
+
// Scalar fallback
|
|
71
|
+
else {
|
|
72
|
+
result_type_ sum {};
|
|
73
|
+
result_type_ half(0.5);
|
|
74
|
+
for (std::size_t i = 0; i < d; i++) {
|
|
75
|
+
result_type_ pi(p[i]), qi(q[i]);
|
|
76
|
+
result_type_ mi = half * (pi + qi);
|
|
77
|
+
if (pi > result_type_(0)) sum = sum + pi * (pi / mi).log();
|
|
78
|
+
if (qi > result_type_(0)) sum = sum + qi * (qi / mi).log();
|
|
79
|
+
}
|
|
80
|
+
// JSD distance = sqrt(divergence / 2), clamped to non-negative
|
|
81
|
+
result_type_ divergence = half * sum;
|
|
82
|
+
*r = divergence > result_type_(0) ? divergence.sqrt() : result_type_(0);
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
} // namespace ashvardanian::numkong
|
|
87
|
+
|
|
88
|
+
#include "numkong/tensor.hpp"
|
|
89
|
+
|
|
90
|
+
namespace ashvardanian::numkong {
|
|
91
|
+
|
|
92
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::probability_result_t,
|
|
93
|
+
allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_>
|
|
94
|
+
void kld(tensor_view<in_type_, max_rank_a_> p, tensor_view<in_type_, max_rank_b_> q, std::size_t d,
|
|
95
|
+
result_type_ *r) noexcept {
|
|
96
|
+
kld<in_type_, result_type_, allow_simd_>(p.data(), q.data(), d, r);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::probability_result_t,
|
|
100
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
101
|
+
void kld(vector_view<in_type_> p, vector_view<in_type_> q, std::size_t d, result_type_ *r) noexcept {
|
|
102
|
+
kld<in_type_, result_type_, allow_simd_>(p.data(), q.data(), d, r);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::probability_result_t,
|
|
106
|
+
allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_>
|
|
107
|
+
void jsd(tensor_view<in_type_, max_rank_a_> p, tensor_view<in_type_, max_rank_b_> q, std::size_t d,
|
|
108
|
+
result_type_ *r) noexcept {
|
|
109
|
+
jsd<in_type_, result_type_, allow_simd_>(p.data(), q.data(), d, r);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::probability_result_t,
|
|
113
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
114
|
+
void jsd(vector_view<in_type_> p, vector_view<in_type_> q, std::size_t d, result_type_ *r) noexcept {
|
|
115
|
+
jsd<in_type_, result_type_, allow_simd_>(p.data(), q.data(), d, r);
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
} // namespace ashvardanian::numkong
|
|
119
|
+
|
|
120
|
+
#endif // NK_PROBABILITY_HPP
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Pseudo-Random Number Generators.
|
|
3
|
+
* @file include/numkong/random.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 11, 2026
|
|
6
|
+
*
|
|
7
|
+
* Implements following statistical distributions
|
|
8
|
+
*
|
|
9
|
+
* - Uniform Distribution
|
|
10
|
+
* - Gaussian (Normal) Distribution
|
|
11
|
+
*
|
|
12
|
+
* For dtypes:
|
|
13
|
+
*
|
|
14
|
+
* - 64-bit floating point numbers
|
|
15
|
+
* - 32-bit floating point numbers
|
|
16
|
+
* - 16-bit floating point numbers
|
|
17
|
+
* - 16-bit brain-floating point numbers
|
|
18
|
+
* - 8-bit floating point numbers
|
|
19
|
+
* - 8-bit integers
|
|
20
|
+
*
|
|
21
|
+
* For hardware architectures:
|
|
22
|
+
*
|
|
23
|
+
* - Arm: NEON, SSVE
|
|
24
|
+
* - x86: Haswell, Ice Lake, Skylake, Genoa
|
|
25
|
+
*
|
|
26
|
+
* @section usage Usage and Benefits
|
|
27
|
+
*
|
|
28
|
+
*
|
|
29
|
+
*
|
|
30
|
+
* @section references References
|
|
31
|
+
*
|
|
32
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
33
|
+
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
34
|
+
*
|
|
35
|
+
*/
|
|
36
|
+
#ifndef NK_RANDOM_H
|
|
37
|
+
#define NK_RANDOM_H
|
|
38
|
+
|
|
39
|
+
#include "numkong/types.h"
|
|
40
|
+
#include "numkong/cast.h"
|
|
41
|
+
|
|
42
|
+
#if defined(__cplusplus)
|
|
43
|
+
extern "C" {
|
|
44
|
+
#endif // defined(__cplusplus)
|
|
45
|
+
|
|
46
|
+
#if defined(__cplusplus)
|
|
47
|
+
} // extern "C"
|
|
48
|
+
#endif // defined(__cplusplus)
|
|
49
|
+
|
|
50
|
+
#endif // NK_RANDOM_H
|