numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Similarity Measures and Distance Functions.
|
|
3
|
+
* @file include/numkong.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 14, 2023
|
|
6
|
+
*
|
|
7
|
+
* Umbrella header that includes all domain-specific kernel headers
|
|
8
|
+
* and the runtime capability detection infrastructure.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
#ifndef NK_NUMKONG_H
|
|
12
|
+
#define NK_NUMKONG_H
|
|
13
|
+
|
|
14
|
+
#include "numkong/capabilities.h" // Runtime detection, like `nk_capabilities_x86_`
|
|
15
|
+
#include "numkong/scalar.h" // Scalar math: sqrt, rsqrt, fma, saturating, order, like `nk_f32_sqrt`
|
|
16
|
+
#include "numkong/cast.h" // Type conversions, like `nk_cast`
|
|
17
|
+
#include "numkong/set.h" // Hamming, Jaccard, like `nk_hamming_u1`
|
|
18
|
+
#include "numkong/curved.h" // Mahalanobis, Bilinear Forms, like `nk_bilinear_f64`
|
|
19
|
+
#include "numkong/dot.h" // Inner (dot) product and its conjugate, like `nk_dot_f32`
|
|
20
|
+
#include "numkong/dots.h" // GEMM-style MxN batched dot-products, like `nk_dots_packed_size_bf16`
|
|
21
|
+
#include "numkong/each.h" // Weighted Sum, Fused-Multiply-Add, like `nk_each_scale_f64`
|
|
22
|
+
#include "numkong/geospatial.h" // Haversine and Vincenty, like `nk_haversine_f64`
|
|
23
|
+
#include "numkong/mesh.h" // RMSD, Kabsch, Umeyama, like `nk_rmsd_f64`
|
|
24
|
+
#include "numkong/probability.h" // Kullback-Leibler, Jensen-Shannon, like `nk_kld_f16`
|
|
25
|
+
#include "numkong/reduce.h" // Horizontal MinMax & Moments reductions, like `nk_reduce_moments_f64`
|
|
26
|
+
#include "numkong/sets.h" // Hamming & Jaccard for binary sets, like `nk_hammings_packed_u1`
|
|
27
|
+
#include "numkong/sparse.h" // Set Intersections and Sparse Dot Products, like `nk_sparse_intersect_u16`
|
|
28
|
+
#include "numkong/spatial.h" // Euclidean, Angular, like `nk_euclidean_f64`
|
|
29
|
+
#include "numkong/spatials.h" // Batched Angular & Euclidean distances, like `nk_angulars_packed_f32`
|
|
30
|
+
#include "numkong/maxsim.h" // MaxSim: Multi-Vector Maximum Similarity, like `nk_maxsim_packed_f32`
|
|
31
|
+
#include "numkong/trigonometry.h" // Sin, Cos, Atan, like `nk_each_sin_f64`
|
|
32
|
+
|
|
33
|
+
#if defined(__cplusplus)
|
|
34
|
+
extern "C" {
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
/**
|
|
38
|
+
* @brief Returns the output dtype for a given metric kind and input dtype.
|
|
39
|
+
*/
|
|
40
|
+
NK_PUBLIC nk_dtype_t nk_kernel_output_dtype(nk_kernel_kind_t kind, nk_dtype_t input) {
|
|
41
|
+
switch (kind) {
|
|
42
|
+
case nk_kernel_dot_k:
|
|
43
|
+
case nk_kernel_vdot_k:
|
|
44
|
+
case nk_kernel_dots_packed_k:
|
|
45
|
+
case nk_kernel_dots_symmetric_k: return nk_dot_output_dtype(input);
|
|
46
|
+
case nk_kernel_angular_k:
|
|
47
|
+
case nk_kernel_angulars_packed_k:
|
|
48
|
+
case nk_kernel_angulars_symmetric_k: return nk_angular_output_dtype(input);
|
|
49
|
+
case nk_kernel_euclidean_k:
|
|
50
|
+
case nk_kernel_euclideans_packed_k:
|
|
51
|
+
case nk_kernel_euclideans_symmetric_k: return nk_euclidean_output_dtype(input);
|
|
52
|
+
case nk_kernel_sqeuclidean_k: return nk_sqeuclidean_output_dtype(input);
|
|
53
|
+
case nk_kernel_bilinear_k: return nk_bilinear_output_dtype(input);
|
|
54
|
+
case nk_kernel_mahalanobis_k: return nk_mahalanobis_output_dtype(input);
|
|
55
|
+
case nk_kernel_hamming_k:
|
|
56
|
+
case nk_kernel_hammings_packed_k:
|
|
57
|
+
case nk_kernel_hammings_symmetric_k: return nk_hamming_output_dtype(input);
|
|
58
|
+
case nk_kernel_jaccard_k:
|
|
59
|
+
case nk_kernel_jaccards_packed_k:
|
|
60
|
+
case nk_kernel_jaccards_symmetric_k: return nk_jaccard_output_dtype(input);
|
|
61
|
+
case nk_kernel_haversine_k: return nk_haversine_output_dtype(input);
|
|
62
|
+
case nk_kernel_vincenty_k: return nk_vincenty_output_dtype(input);
|
|
63
|
+
case nk_kernel_kld_k:
|
|
64
|
+
case nk_kernel_jsd_k: return nk_probability_output_dtype(input);
|
|
65
|
+
case nk_kernel_rmsd_k: return nk_rmsd_output_dtype(input);
|
|
66
|
+
case nk_kernel_kabsch_k: return nk_kabsch_output_dtype(input);
|
|
67
|
+
case nk_kernel_umeyama_k: return nk_umeyama_output_dtype(input);
|
|
68
|
+
case nk_kernel_sparse_dot_k: return nk_sparse_dot_output_dtype(input);
|
|
69
|
+
case nk_kernel_maxsim_packed_k: return nk_maxsim_output_dtype(input);
|
|
70
|
+
default: return nk_dtype_unknown_k;
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
#if defined(__cplusplus)
|
|
75
|
+
} // extern "C"
|
|
76
|
+
#endif
|
|
77
|
+
|
|
78
|
+
#endif // NK_NUMKONG_H
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief NumKong SDK for C++23 and newer.
|
|
3
|
+
* @file include/numkong.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 7, 2026
|
|
6
|
+
*
|
|
7
|
+
* C doesn't have a strong type system or composable infrastructure for complex kernels
|
|
8
|
+
* and datastructures like the C++ templates and Rust traits. Unlike C++, C also lacks
|
|
9
|
+
* function overloading, namespaces and templates, thus requiring verbose signatures and
|
|
10
|
+
* naming conventions, like:
|
|
11
|
+
*
|
|
12
|
+
* @code{c}
|
|
13
|
+
* void nk_dot_f64(nk_f64_t const*, nk_f64_t const*, nk_size_t, nk_f64_t *);
|
|
14
|
+
* void nk_dot_f32(nk_f32_t const*, nk_f32_t const*, nk_size_t, nk_f64_t *);
|
|
15
|
+
* void nk_dot_f16(nk_f16_t const*, nk_f16_t const*, nk_size_t, nk_f32_t *);
|
|
16
|
+
* void nk_dot_bf16(nk_bf16_t const*, nk_bf16_t const*, nk_size_t, nk_f32_t *);
|
|
17
|
+
* void nk_dot_e4m3(nk_e4m3_t const*, nk_e4m3_t const*, nk_size_t, nk_f32_t *);
|
|
18
|
+
* void nk_dot_e5m2(nk_e5m2_t const*, nk_e5m2_t const*, nk_size_t, nk_f32_t *);
|
|
19
|
+
* @endcode
|
|
20
|
+
*
|
|
21
|
+
* As opposed to C++:
|
|
22
|
+
*
|
|
23
|
+
* @code{cpp}
|
|
24
|
+
* namespace ashvardanian::numkong {
|
|
25
|
+
* template <typename input_type_, typename result_type_>
|
|
26
|
+
* void dot(input_type_ const*, input_type_ const*, size_t, result_type_ *);
|
|
27
|
+
* }
|
|
28
|
+
*
|
|
29
|
+
* In HPC implementations, where pretty much every kernel and every datatype uses different
|
|
30
|
+
* Assembly instructions on different CPU generations/models, those higher-level abstractions
|
|
31
|
+
* aren't always productive for the primary implementation, but they can still be handy as
|
|
32
|
+
* a higher-level API for NumKong. They are also used for algorithm verification in no-SIMD
|
|
33
|
+
* mode, upcasting to much larger number types like `f118_t`.
|
|
34
|
+
*/
|
|
35
|
+
|
|
36
|
+
#ifndef NK_NUMKONG_HPP
|
|
37
|
+
#define NK_NUMKONG_HPP
|
|
38
|
+
|
|
39
|
+
#include "numkong/random.hpp"
|
|
40
|
+
#include "numkong/dot.hpp"
|
|
41
|
+
#include "numkong/spatial.hpp"
|
|
42
|
+
#include "numkong/spatials.hpp"
|
|
43
|
+
#include "numkong/probability.hpp"
|
|
44
|
+
#include "numkong/each.hpp"
|
|
45
|
+
#include "numkong/reduce.hpp"
|
|
46
|
+
#include "numkong/curved.hpp"
|
|
47
|
+
#include "numkong/geospatial.hpp"
|
|
48
|
+
#include "numkong/sparse.hpp"
|
|
49
|
+
#include "numkong/set.hpp"
|
|
50
|
+
#include "numkong/mesh.hpp"
|
|
51
|
+
#include "numkong/trigonometry.hpp"
|
|
52
|
+
#include "numkong/dots.hpp"
|
|
53
|
+
#include "numkong/matrix.hpp"
|
|
54
|
+
#include "numkong/maxsim.hpp"
|
|
55
|
+
#include "numkong/tensor.hpp"
|
|
56
|
+
|
|
57
|
+
#endif // NK_NUMKONG_HPP
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# Divergence Measures for Probability Distributions in NumKong
|
|
2
|
+
|
|
3
|
+
NumKong implements divergence functions between discrete probability distributions: Kullback-Leibler divergence measures the information lost when one distribution approximates another, while Jensen-Shannon distance provides a symmetric and bounded alternative.
|
|
4
|
+
These are used in variational inference, topic modeling, and distribution comparison tasks.
|
|
5
|
+
|
|
6
|
+
Kullback-Leibler divergence from $P$ to $Q$:
|
|
7
|
+
|
|
8
|
+
```math
|
|
9
|
+
\text{KLD}(P \| Q) = \sum_{i=0}^{n-1} P(i) \log_2 \frac{P(i)}{Q(i)}
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
Jensen-Shannon distance is the square root of the symmetrized KLD through a mixture:
|
|
13
|
+
|
|
14
|
+
$$\text{JSD}(P, Q) = \frac{1}{2} \text{KLD}(P \| M) + \frac{1}{2} \text{KLD}(Q \| M)$$
|
|
15
|
+
|
|
16
|
+
where $M = \frac{P + Q}{2}$, yielding the distance:
|
|
17
|
+
|
|
18
|
+
$$d_{JS}(P, Q) = \sqrt{\text{JSD}(P, Q)}$$
|
|
19
|
+
|
|
20
|
+
Unlike the raw divergence, $d_{JS}$ is a true metric satisfying the triangle inequality.
|
|
21
|
+
|
|
22
|
+
Reformulating as Python pseudocode:
|
|
23
|
+
|
|
24
|
+
```python
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
def kld(p: np.ndarray, q: np.ndarray) -> float:
|
|
28
|
+
mask = p > 0
|
|
29
|
+
return np.sum(p[mask] * np.log2(p[mask] / q[mask]))
|
|
30
|
+
|
|
31
|
+
def jsd(p: np.ndarray, q: np.ndarray) -> float:
|
|
32
|
+
m = (p + q) / 2
|
|
33
|
+
return np.sqrt((kld(p, m) + kld(q, m)) / 2)
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Use Cases
|
|
37
|
+
|
|
38
|
+
__Kullback-Leibler divergence__ is the workhorse of variational inference (ELBO objective), knowledge distillation between neural networks, information gain in decision trees, and measuring fit between a model and observed data.
|
|
39
|
+
|
|
40
|
+
__Jensen-Shannon distance__ sees primary use in microbiome community comparison (enterotyping), where its metric property enables clustering with standard algorithms. It also appears in distribution drift detection, topic model evaluation, and as the theoretical foundation of the original GAN objective — though in practice GAN training uses proxy losses rather than computing JSD directly.
|
|
41
|
+
|
|
42
|
+
## Input & Output Types
|
|
43
|
+
|
|
44
|
+
| Input Type | Output Type | Description |
|
|
45
|
+
| ---------- | ----------- | ---------------------------------------------- |
|
|
46
|
+
| `f64` | `f64` | 64-bit IEEE 754 double precision |
|
|
47
|
+
| `f32` | `f32` | 32-bit IEEE 754 single precision |
|
|
48
|
+
| `f16` | `f32` | 16-bit IEEE 754 half precision, widened output |
|
|
49
|
+
| `bf16` | `f32` | 16-bit brain float, widened output |
|
|
50
|
+
|
|
51
|
+
## Optimizations
|
|
52
|
+
|
|
53
|
+
### SIMD Log2 Approximation
|
|
54
|
+
|
|
55
|
+
`nk_kld_f32_skylake`, `nk_jsd_f32_skylake` use `VGETEXP` and `VGETMANT` to decompose floating-point values into exponent and mantissa components, then apply a polynomial approximation to the mantissa to compute $\log_2$.
|
|
56
|
+
The pipeline on Skylake is:
|
|
57
|
+
|
|
58
|
+
```
|
|
59
|
+
exponent = VGETEXPPS(x)
|
|
60
|
+
mantissa = VGETMANTPS(x, normalize_to_[1,2)) - 1
|
|
61
|
+
log2(x) ≈ exponent + polynomial(mantissa)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
`VGETEXP` extracts the unbiased exponent as a float, while `VGETMANT` normalizes the mantissa to $[1, 2)$.
|
|
65
|
+
A degree-4 minimax polynomial over the normalized mantissa completes the approximation.
|
|
66
|
+
These instructions handle subnormals correctly without extra integer bit manipulation.
|
|
67
|
+
|
|
68
|
+
`nk_kld_f32_neon`, `nk_jsd_f32_neon`, `nk_kld_f16_haswell`, `nk_jsd_f16_haswell` use integer bit extraction instead:
|
|
69
|
+
|
|
70
|
+
```
|
|
71
|
+
exponent = (reinterpret_as_int(x) >> 23) - 127
|
|
72
|
+
mantissa = reinterpret_as_float((reinterpret_as_int(x) & 0x7FFFFF) | 0x3F800000) - 1
|
|
73
|
+
log2(x) ≈ exponent + c₁·m + c₂·m² + c₃·m³ + c₄·m⁴ + c₅·m⁵
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
This approach reinterprets the float as an integer, shifts out the mantissa bits to obtain the exponent, then masks and recombines to produce a normalized mantissa in $[1, 2)$.
|
|
77
|
+
It works on any ISA with integer-float reinterpretation and avoids the need for specialized exponent/mantissa instructions.
|
|
78
|
+
|
|
79
|
+
### Kahan Compensated Summation for Float64
|
|
80
|
+
|
|
81
|
+
`nk_kld_f64_haswell`, `nk_jsd_f64_haswell` use Kahan compensated summation to maintain a running correction term alongside the accumulator.
|
|
82
|
+
The Kahan update for each divergence term is:
|
|
83
|
+
|
|
84
|
+
```
|
|
85
|
+
compensated_term = divergence_term - correction
|
|
86
|
+
tentative_sum = accumulator + compensated_term
|
|
87
|
+
correction = (tentative_sum - accumulator) - compensated_term
|
|
88
|
+
accumulator = tentative_sum
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
After each $P(i) \log_2(P(i) / Q(i))$ term is computed, `correction` captures the low-order bits lost in the addition, and the next iteration subtracts this correction from the new term before adding it to the accumulator.
|
|
92
|
+
This keeps the accumulated error bounded by $O(1)$ ULP regardless of vector length, rather than the $O(n)$ ULP growth of naive summation.
|
|
93
|
+
|
|
94
|
+
## Performance
|
|
95
|
+
|
|
96
|
+
The following performance tables are produced by manually re-running `nk_test` and `nk_bench` included internal tools to measure both accuracy and throughput at different input shapes.
|
|
97
|
+
The input size is controlled by the `NK_DENSE_DIMENSIONS` environment variable and set to 256, 1024, and 4096 elements.
|
|
98
|
+
The throughput is measured in GB/s as the number of input bytes per second.
|
|
99
|
+
The published tables below summarize mean ULP (units in last place) across all test pairs — the average number of representable floating-point values between the computed result and the exact answer. The current `nk_test` family also reports max/mean absolute and relative divergence error for detailed inspection.
|
|
100
|
+
Each kernel runs for at least 20 seconds per configuration.
|
|
101
|
+
Benchmark threads are pinned to specific cores; on machines with heterogeneous core types (e.g., Apple P/E cores), only the fastest cores are used.
|
|
102
|
+
Workloads that significantly degrade CPU frequencies (Intel AMX, Apple SME) run in separate passes to avoid affecting throughput measurements of other kernels.
|
|
103
|
+
|
|
104
|
+
### Intel Sapphire Rapids
|
|
105
|
+
|
|
106
|
+
#### Native
|
|
107
|
+
|
|
108
|
+
| Kernel | 256 | 1024 | 4096 |
|
|
109
|
+
| :------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
110
|
+
| __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
111
|
+
| `nk_kld_f64_serial` | 0.693 gb/s, 5.65K ulp | 0.699 gb/s, 24.5K ulp | 0.753 gb/s, 98.9K ulp |
|
|
112
|
+
| `nk_jsd_f64_serial` | 0.324 gb/s, 0.5 ulp | 0.349 gb/s, 0.3 ulp | 0.391 gb/s, 0.6 ulp |
|
|
113
|
+
| `nk_kld_f64_haswell` | 5.34 gb/s, 5.64K ulp | 5.59 gb/s, 24.6K ulp | 5.76 gb/s, 99.1K ulp |
|
|
114
|
+
| `nk_jsd_f64_haswell` | 3.03 gb/s, 1.7 ulp | 3.05 gb/s, 1.4 ulp | 3.25 gb/s, 1.2 ulp |
|
|
115
|
+
| `nk_kld_f64_skylake` | 7.01 gb/s, 5.64K ulp | 6.85 gb/s, 24.4K ulp | 6.86 gb/s, 98.9K ulp |
|
|
116
|
+
| `nk_jsd_f64_skylake` | 3.66 gb/s, 1.6 ulp | 3.85 gb/s, 1.4 ulp | 4.30 gb/s, 1.2 ulp |
|
|
117
|
+
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
118
|
+
| `nk_kld_f32_serial` | 0.528 gb/s, 1.04K ulp | 0.516 gb/s, 4.54K ulp | 0.527 gb/s, 18.2K ulp |
|
|
119
|
+
| `nk_jsd_f32_serial` | 0.273 gb/s, 0.4 ulp | 0.272 gb/s, 0.4 ulp | 0.268 gb/s, 4.5 ulp |
|
|
120
|
+
| `nk_kld_f32_skylake` | 11.8 gb/s, 1.04K ulp | 10.4 gb/s, 4.55K ulp | 8.73 gb/s, 18.3K ulp |
|
|
121
|
+
| `nk_jsd_f32_skylake` | 6.25 gb/s, 6.6 ulp | 5.96 gb/s, 7.0 ulp | 6.05 gb/s, 11.1 ulp |
|
|
122
|
+
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
123
|
+
| `nk_kld_bf16_serial` | 0.138 gb/s, 1.04K ulp | 0.142 gb/s, 4.53K ulp | 0.136 gb/s, 18.3K ulp |
|
|
124
|
+
| `nk_jsd_bf16_serial` | 0.0857 gb/s, 1.5 ulp | 0.0842 gb/s, 3.4 ulp | 0.0841 gb/s, 10.7 ulp |
|
|
125
|
+
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
126
|
+
| `nk_kld_f16_serial` | 0.166 gb/s, 1.05K ulp | 0.163 gb/s, 4.53K ulp | 0.163 gb/s, 18.2K ulp |
|
|
127
|
+
| `nk_jsd_f16_serial` | 0.151 gb/s, 1.5 ulp | 0.148 gb/s, 2.3 ulp | 0.152 gb/s, 9.4 ulp |
|
|
128
|
+
| `nk_kld_f16_haswell` | 6.99 gb/s, 1.05K ulp | 6.09 gb/s, 4.54K ulp | 6.97 gb/s, 18.2K ulp |
|
|
129
|
+
| `nk_jsd_f16_haswell` | 2.81 gb/s, 6.4 ulp | 2.79 gb/s, 6.8 ulp | 2.72 gb/s, 11.5 ulp |
|
|
130
|
+
| `nk_kld_f16_skylake` | 6.16 gb/s, 1.05K ulp | 5.65 gb/s, 4.54K ulp | 5.78 gb/s, 18.3K ulp |
|
|
131
|
+
| `nk_jsd_f16_skylake` | 3.51 gb/s, 6.5 ulp | 3.22 gb/s, 6.9 ulp | 3.35 gb/s, 11.4 ulp |
|
|
132
|
+
|
|
133
|
+
#### WASM
|
|
134
|
+
|
|
135
|
+
Measured with Wasmtime v42 (Cranelift backend).
|
|
136
|
+
|
|
137
|
+
| Kernel | 256 | 1024 | 4096 |
|
|
138
|
+
| :------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
139
|
+
| __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
140
|
+
| `nk_kld_f64_serial` | 0.239 gb/s, 5.64K ulp | 0.223 gb/s, 24.6K ulp | 0.13 gb/s, 99.6K ulp |
|
|
141
|
+
| `nk_jsd_f64_serial` | 0.315 gb/s, 0.5 ulp | 0.402 gb/s, 0.3 ulp | 0.29 gb/s, 0.5 ulp |
|
|
142
|
+
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
143
|
+
| `nk_kld_f32_serial` | 0.302 gb/s, 1.04K ulp | 0.342 gb/s, 4.52K ulp | 0.277 gb/s, 18.3K ulp |
|
|
144
|
+
| `nk_jsd_f32_serial` | 0.152 gb/s, 0.4 ulp | 0.164 gb/s, 0.4 ulp | 0.160 gb/s, 4.7 ulp |
|
|
145
|
+
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
146
|
+
| `nk_kld_bf16_serial` | 0.139 gb/s, 1.05K ulp | 0.143 gb/s, 4.53K ulp | 0.150 gb/s, 18.3K ulp |
|
|
147
|
+
| `nk_jsd_bf16_serial` | 0.0867 gb/s, 1.5 ulp | 0.0775 gb/s, 3.1 ulp | 0.0679 gb/s, 9.8 ulp |
|
|
148
|
+
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
149
|
+
| `nk_kld_f16_serial` | 0.118 gb/s, 1.04K ulp | 0.127 gb/s, 4.53K ulp | 0.111 gb/s, 18.3K ulp |
|
|
150
|
+
| `nk_jsd_f16_serial` | 0.0748 gb/s, 1.4 ulp | 0.0681 gb/s, 2.6 ulp | 0.0857 gb/s, 9.7 ulp |
|
|
151
|
+
|
|
152
|
+
### Apple M4
|
|
153
|
+
|
|
154
|
+
#### Native
|
|
155
|
+
|
|
156
|
+
| Kernel | 256 | 1024 | 4096 |
|
|
157
|
+
| :-------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
158
|
+
| __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
159
|
+
| `nk_kld_f64_serial` | 2.21 gb/s, 5.6K ulp | 2.22 gb/s, 25K ulp | 2.18 gb/s, 99K ulp |
|
|
160
|
+
| `nk_jsd_f64_serial` | 1.40 gb/s, 0.4 ulp | 1.45 gb/s, 0.4 ulp | 1.45 gb/s, 0.5 ulp |
|
|
161
|
+
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
162
|
+
| `nk_kld_f32_serial` | 6.29 gb/s, 1.0K ulp | 6.35 gb/s, 4.5K ulp | 6.22 gb/s, 18K ulp |
|
|
163
|
+
| `nk_jsd_f32_serial` | 1.21 gb/s, 0.4 ulp | 1.20 gb/s, 0.4 ulp | 1.20 gb/s, 4.6 ulp |
|
|
164
|
+
| `nk_kld_f32_neon` | 14.5 gb/s, 1.0K ulp | 14.4 gb/s, 4.5K ulp | 12.8 gb/s, 18K ulp |
|
|
165
|
+
| `nk_jsd_f32_neon` | 6.81 gb/s, 15 ulp | 7.04 gb/s, 14 ulp | 6.78 gb/s, 9.9 ulp |
|
|
166
|
+
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
167
|
+
| `nk_kld_bf16_serial` | 3.16 gb/s, 1.0K ulp | 2.96 gb/s, 4.5K ulp | 3.16 gb/s, 18K ulp |
|
|
168
|
+
| `nk_jsd_bf16_serial` | 0.611 gb/s, 1.4 ulp | 0.595 gb/s, 2.9 ulp | 0.613 gb/s, 9.7 ulp |
|
|
169
|
+
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
170
|
+
| `nk_kld_f16_serial` | 3.15 gb/s, 1.0K ulp | 3.14 gb/s, 4.5K ulp | 2.81 gb/s, 18K ulp |
|
|
171
|
+
| `nk_jsd_f16_serial` | 0.610 gb/s, 1.4 ulp | 0.611 gb/s, 2.7 ulp | 0.602 gb/s, 8.7 ulp |
|
|
172
|
+
| `nk_kld_f16_neonhalf` | 6.78 gb/s, 1.0K ulp | 6.72 gb/s, 4.5K ulp | 6.09 gb/s, 18K ulp |
|
|
173
|
+
| `nk_jsd_f16_neonhalf` | 3.42 gb/s, 15 ulp | 3.40 gb/s, 14 ulp | 3.14 gb/s, 9.9 ulp |
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Haswell-accelerated Probability Distribution Similarity Measures.
|
|
3
|
+
* @file include/numkong/probability/haswell.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/probability.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_PROBABILITY_HASWELL_H
|
|
10
|
+
#define NK_PROBABILITY_HASWELL_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_X86_
|
|
13
|
+
#if NK_TARGET_HASWELL
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.h"
|
|
16
|
+
#include "numkong/reduce/haswell.h" // `nk_reduce_add_f32x8_haswell_`, `nk_reduce_add_f64x4_haswell_`
|
|
17
|
+
#include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`, `nk_f64_sqrt_haswell`
|
|
18
|
+
#include "numkong/cast/haswell.h" // `nk_partial_load_f16x8_to_f32x8_haswell_`, `nk_partial_load_b64x4_haswell_`
|
|
19
|
+
|
|
20
|
+
#if defined(__cplusplus)
|
|
21
|
+
extern "C" {
|
|
22
|
+
#endif
|
|
23
|
+
|
|
24
|
+
#if defined(__clang__)
|
|
25
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
|
|
26
|
+
#elif defined(__GNUC__)
|
|
27
|
+
#pragma GCC push_options
|
|
28
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
|
|
29
|
+
#endif
|
|
30
|
+
|
|
31
|
+
NK_INTERNAL __m256 nk_log2_f32x8_haswell_(__m256 x) {
|
|
32
|
+
// Extracting the exponent
|
|
33
|
+
__m256i bits_i32x8 = _mm256_castps_si256(x);
|
|
34
|
+
__m256i exponent_i32x8 = _mm256_srli_epi32(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x7F800000)), 23);
|
|
35
|
+
exponent_i32x8 = _mm256_sub_epi32(exponent_i32x8, _mm256_set1_epi32(127)); // removing the bias
|
|
36
|
+
__m256 exponent_f32x8 = _mm256_cvtepi32_ps(exponent_i32x8);
|
|
37
|
+
|
|
38
|
+
// Extracting the mantissa ∈ [1, 2)
|
|
39
|
+
__m256 mantissa_f32x8 = _mm256_castsi256_ps(
|
|
40
|
+
_mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)), _mm256_set1_epi32(0x3F800000)));
|
|
41
|
+
|
|
42
|
+
// Compute log2(m) using the s-series: s = (m-1)/(m+1), s ∈ [0, 1/3] for m ∈ [1, 2)
|
|
43
|
+
// log2(m) = (2/ln2) × s × (1 + s²/3 + s⁴/5 + s⁶/7 + s⁸/9)
|
|
44
|
+
__m256 one_f32x8 = _mm256_set1_ps(1.0f);
|
|
45
|
+
__m256 s_f32x8 = _mm256_div_ps(_mm256_sub_ps(mantissa_f32x8, one_f32x8), _mm256_add_ps(mantissa_f32x8, one_f32x8));
|
|
46
|
+
__m256 s2_f32x8 = _mm256_mul_ps(s_f32x8, s_f32x8);
|
|
47
|
+
__m256 series_f32x8 = _mm256_set1_ps(0.111111111f); // 1/9
|
|
48
|
+
series_f32x8 = _mm256_fmadd_ps(series_f32x8, s2_f32x8, _mm256_set1_ps(0.142857143f)); // 1/7
|
|
49
|
+
series_f32x8 = _mm256_fmadd_ps(series_f32x8, s2_f32x8, _mm256_set1_ps(0.2f)); // 1/5
|
|
50
|
+
series_f32x8 = _mm256_fmadd_ps(series_f32x8, s2_f32x8, _mm256_set1_ps(0.333333333f)); // 1/3
|
|
51
|
+
series_f32x8 = _mm256_fmadd_ps(series_f32x8, s2_f32x8, one_f32x8); // 1
|
|
52
|
+
__m256 log2m_f32x8 = _mm256_mul_ps(_mm256_set1_ps(2.885390081777927f), _mm256_mul_ps(s_f32x8, series_f32x8));
|
|
53
|
+
return _mm256_add_ps(log2m_f32x8, exponent_f32x8);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
NK_INTERNAL __m256d nk_log2_f64x4_haswell_(__m256d x) {
|
|
57
|
+
// Extract exponent via integer shift: (bits >> 52) - 1023
|
|
58
|
+
__m256i bits_i64x4 = _mm256_castpd_si256(x);
|
|
59
|
+
__m256i exponent_i64x4 = _mm256_srli_epi64(bits_i64x4, 52);
|
|
60
|
+
// AVX2 has no _mm256_cvtepi64_pd, so extract lanes and convert
|
|
61
|
+
nk_f64_t exp0 = (nk_f64_t)((nk_i64_t)_mm256_extract_epi64(exponent_i64x4, 0) - 1023);
|
|
62
|
+
nk_f64_t exp1 = (nk_f64_t)((nk_i64_t)_mm256_extract_epi64(exponent_i64x4, 1) - 1023);
|
|
63
|
+
nk_f64_t exp2 = (nk_f64_t)((nk_i64_t)_mm256_extract_epi64(exponent_i64x4, 2) - 1023);
|
|
64
|
+
nk_f64_t exp3 = (nk_f64_t)((nk_i64_t)_mm256_extract_epi64(exponent_i64x4, 3) - 1023);
|
|
65
|
+
__m256d exponent_f64x4 = _mm256_set_pd(exp3, exp2, exp1, exp0);
|
|
66
|
+
|
|
67
|
+
// Extract mantissa: clear exponent bits, set exponent to 1023 (= 1.0 bias)
|
|
68
|
+
__m256i mantissa_mask = _mm256_set1_epi64x(0x000FFFFFFFFFFFFFLL);
|
|
69
|
+
__m256i bias = _mm256_set1_epi64x(0x3FF0000000000000LL);
|
|
70
|
+
__m256d mantissa_f64x4 = _mm256_castsi256_pd(_mm256_or_si256(_mm256_and_si256(bits_i64x4, mantissa_mask), bias));
|
|
71
|
+
|
|
72
|
+
// s-series: s = (m-1)/(m+1), log2(m) = 2*s*P(s²) * log2(e)
|
|
73
|
+
__m256d one_f64x4 = _mm256_set1_pd(1.0);
|
|
74
|
+
__m256d s_f64x4 = _mm256_div_pd(_mm256_sub_pd(mantissa_f64x4, one_f64x4), _mm256_add_pd(mantissa_f64x4, one_f64x4));
|
|
75
|
+
__m256d s2_f64x4 = _mm256_mul_pd(s_f64x4, s_f64x4);
|
|
76
|
+
|
|
77
|
+
// 14-term Horner: P(s²) = 1 + s²/3 + s⁴/5 + ... + s²⁶/27
|
|
78
|
+
__m256d poly_f64x4 = _mm256_set1_pd(1.0 / 27.0);
|
|
79
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 25.0));
|
|
80
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 23.0));
|
|
81
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 21.0));
|
|
82
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 19.0));
|
|
83
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 17.0));
|
|
84
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 15.0));
|
|
85
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 13.0));
|
|
86
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 11.0));
|
|
87
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 9.0));
|
|
88
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 7.0));
|
|
89
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 5.0));
|
|
90
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0 / 3.0));
|
|
91
|
+
poly_f64x4 = _mm256_fmadd_pd(s2_f64x4, poly_f64x4, _mm256_set1_pd(1.0));
|
|
92
|
+
|
|
93
|
+
__m256d two_f64x4 = _mm256_set1_pd(2.0);
|
|
94
|
+
__m256d ln_m_f64x4 = _mm256_mul_pd(_mm256_mul_pd(two_f64x4, s_f64x4), poly_f64x4);
|
|
95
|
+
__m256d log2e_f64x4 = _mm256_set1_pd(1.4426950408889634);
|
|
96
|
+
__m256d log2_m_f64x4 = _mm256_mul_pd(ln_m_f64x4, log2e_f64x4);
|
|
97
|
+
|
|
98
|
+
return _mm256_add_pd(exponent_f64x4, log2_m_f64x4);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
NK_PUBLIC void nk_kld_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
102
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
103
|
+
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
104
|
+
__m256 epsilon_f32x8 = _mm256_set1_ps(epsilon);
|
|
105
|
+
__m256 a_f32x8, b_f32x8;
|
|
106
|
+
|
|
107
|
+
nk_kld_f16_haswell_cycle:
|
|
108
|
+
if (n < 8) {
|
|
109
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
110
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(a, &a_vec, n);
|
|
111
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(b, &b_vec, n);
|
|
112
|
+
a_f32x8 = a_vec.ymm_ps;
|
|
113
|
+
b_f32x8 = b_vec.ymm_ps;
|
|
114
|
+
n = 0;
|
|
115
|
+
}
|
|
116
|
+
else {
|
|
117
|
+
a_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)a));
|
|
118
|
+
b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)b));
|
|
119
|
+
n -= 8, a += 8, b += 8;
|
|
120
|
+
}
|
|
121
|
+
__m256 ratio_f32x8 = _mm256_div_ps(_mm256_add_ps(a_f32x8, epsilon_f32x8), _mm256_add_ps(b_f32x8, epsilon_f32x8));
|
|
122
|
+
__m256 log_ratio_f32x8 = nk_log2_f32x8_haswell_(ratio_f32x8);
|
|
123
|
+
__m256 contribution_f32x8 = _mm256_mul_ps(a_f32x8, log_ratio_f32x8);
|
|
124
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, contribution_f32x8);
|
|
125
|
+
if (n) goto nk_kld_f16_haswell_cycle;
|
|
126
|
+
|
|
127
|
+
nk_f32_t log2_normalizer = 0.6931471805599453f;
|
|
128
|
+
nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
|
|
129
|
+
sum *= log2_normalizer;
|
|
130
|
+
*result = sum;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
NK_PUBLIC void nk_jsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
134
|
+
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
135
|
+
__m256 epsilon_f32x8 = _mm256_set1_ps(epsilon);
|
|
136
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
137
|
+
__m256 a_f32x8, b_f32x8;
|
|
138
|
+
|
|
139
|
+
nk_jsd_f16_haswell_cycle:
|
|
140
|
+
if (n < 8) {
|
|
141
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
142
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(a, &a_vec, n);
|
|
143
|
+
nk_partial_load_f16x8_to_f32x8_haswell_(b, &b_vec, n);
|
|
144
|
+
a_f32x8 = a_vec.ymm_ps;
|
|
145
|
+
b_f32x8 = b_vec.ymm_ps;
|
|
146
|
+
n = 0;
|
|
147
|
+
}
|
|
148
|
+
else {
|
|
149
|
+
a_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)a));
|
|
150
|
+
b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)b));
|
|
151
|
+
n -= 8, a += 8, b += 8;
|
|
152
|
+
}
|
|
153
|
+
__m256 mean_f32x8 = _mm256_mul_ps(_mm256_add_ps(a_f32x8, b_f32x8), _mm256_set1_ps(0.5f)); // M = (P + Q) / 2
|
|
154
|
+
__m256 ratio_a_f32x8 = _mm256_div_ps(_mm256_add_ps(a_f32x8, epsilon_f32x8),
|
|
155
|
+
_mm256_add_ps(mean_f32x8, epsilon_f32x8));
|
|
156
|
+
__m256 ratio_b_f32x8 = _mm256_div_ps(_mm256_add_ps(b_f32x8, epsilon_f32x8),
|
|
157
|
+
_mm256_add_ps(mean_f32x8, epsilon_f32x8));
|
|
158
|
+
__m256 log_ratio_a_f32x8 = nk_log2_f32x8_haswell_(ratio_a_f32x8);
|
|
159
|
+
__m256 log_ratio_b_f32x8 = nk_log2_f32x8_haswell_(ratio_b_f32x8);
|
|
160
|
+
__m256 contribution_a_f32x8 = _mm256_mul_ps(a_f32x8, log_ratio_a_f32x8);
|
|
161
|
+
__m256 contribution_b_f32x8 = _mm256_mul_ps(b_f32x8, log_ratio_b_f32x8);
|
|
162
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, contribution_a_f32x8);
|
|
163
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, contribution_b_f32x8);
|
|
164
|
+
if (n) goto nk_jsd_f16_haswell_cycle;
|
|
165
|
+
|
|
166
|
+
nk_f32_t log2_normalizer = 0.6931471805599453f;
|
|
167
|
+
nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
|
|
168
|
+
sum *= log2_normalizer / 2;
|
|
169
|
+
*result = sum > 0 ? nk_f32_sqrt_haswell(sum) : 0;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
NK_PUBLIC void nk_kld_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
173
|
+
nk_f64_t epsilon = NK_F64_DIVISION_EPSILON;
|
|
174
|
+
__m256d epsilon_f64x4 = _mm256_set1_pd(epsilon);
|
|
175
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
176
|
+
__m256d compensation_f64x4 = _mm256_setzero_pd();
|
|
177
|
+
__m256d a_f64x4, b_f64x4;
|
|
178
|
+
|
|
179
|
+
nk_kld_f64_haswell_cycle:
|
|
180
|
+
if (n < 4) {
|
|
181
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
182
|
+
nk_partial_load_b64x4_haswell_(a, &a_vec, n);
|
|
183
|
+
nk_partial_load_b64x4_haswell_(b, &b_vec, n);
|
|
184
|
+
a_f64x4 = a_vec.ymm_pd;
|
|
185
|
+
b_f64x4 = b_vec.ymm_pd;
|
|
186
|
+
n = 0;
|
|
187
|
+
}
|
|
188
|
+
else {
|
|
189
|
+
a_f64x4 = _mm256_loadu_pd(a);
|
|
190
|
+
b_f64x4 = _mm256_loadu_pd(b);
|
|
191
|
+
n -= 4, a += 4, b += 4;
|
|
192
|
+
}
|
|
193
|
+
__m256d ratio_f64x4 = _mm256_div_pd(_mm256_add_pd(a_f64x4, epsilon_f64x4), _mm256_add_pd(b_f64x4, epsilon_f64x4));
|
|
194
|
+
__m256d log_ratio_f64x4 = nk_log2_f64x4_haswell_(ratio_f64x4);
|
|
195
|
+
__m256d contribution_f64x4 = _mm256_mul_pd(a_f64x4, log_ratio_f64x4);
|
|
196
|
+
// Kahan compensated summation
|
|
197
|
+
__m256d compensated_f64x4 = _mm256_sub_pd(contribution_f64x4, compensation_f64x4);
|
|
198
|
+
__m256d tentative_f64x4 = _mm256_add_pd(sum_f64x4, compensated_f64x4);
|
|
199
|
+
compensation_f64x4 = _mm256_sub_pd(_mm256_sub_pd(tentative_f64x4, sum_f64x4), compensated_f64x4);
|
|
200
|
+
sum_f64x4 = tentative_f64x4;
|
|
201
|
+
if (n) goto nk_kld_f64_haswell_cycle;
|
|
202
|
+
|
|
203
|
+
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
204
|
+
*result = nk_reduce_add_f64x4_haswell_(sum_f64x4) * log2_normalizer;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
NK_PUBLIC void nk_jsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
208
|
+
nk_f64_t epsilon = NK_F64_DIVISION_EPSILON;
|
|
209
|
+
__m256d epsilon_f64x4 = _mm256_set1_pd(epsilon);
|
|
210
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
211
|
+
__m256d compensation_f64x4 = _mm256_setzero_pd();
|
|
212
|
+
__m256d a_f64x4, b_f64x4;
|
|
213
|
+
|
|
214
|
+
nk_jsd_f64_haswell_cycle:
|
|
215
|
+
if (n < 4) {
|
|
216
|
+
nk_b256_vec_t a_vec, b_vec;
|
|
217
|
+
nk_partial_load_b64x4_haswell_(a, &a_vec, n);
|
|
218
|
+
nk_partial_load_b64x4_haswell_(b, &b_vec, n);
|
|
219
|
+
a_f64x4 = a_vec.ymm_pd;
|
|
220
|
+
b_f64x4 = b_vec.ymm_pd;
|
|
221
|
+
n = 0;
|
|
222
|
+
}
|
|
223
|
+
else {
|
|
224
|
+
a_f64x4 = _mm256_loadu_pd(a);
|
|
225
|
+
b_f64x4 = _mm256_loadu_pd(b);
|
|
226
|
+
n -= 4, a += 4, b += 4;
|
|
227
|
+
}
|
|
228
|
+
__m256d mean_f64x4 = _mm256_mul_pd(_mm256_add_pd(a_f64x4, b_f64x4), _mm256_set1_pd(0.5));
|
|
229
|
+
__m256d ratio_a_f64x4 = _mm256_div_pd(_mm256_add_pd(a_f64x4, epsilon_f64x4),
|
|
230
|
+
_mm256_add_pd(mean_f64x4, epsilon_f64x4));
|
|
231
|
+
__m256d ratio_b_f64x4 = _mm256_div_pd(_mm256_add_pd(b_f64x4, epsilon_f64x4),
|
|
232
|
+
_mm256_add_pd(mean_f64x4, epsilon_f64x4));
|
|
233
|
+
__m256d log_ratio_a_f64x4 = nk_log2_f64x4_haswell_(ratio_a_f64x4);
|
|
234
|
+
__m256d log_ratio_b_f64x4 = nk_log2_f64x4_haswell_(ratio_b_f64x4);
|
|
235
|
+
__m256d contribution_a_f64x4 = _mm256_mul_pd(a_f64x4, log_ratio_a_f64x4);
|
|
236
|
+
__m256d contribution_b_f64x4 = _mm256_mul_pd(b_f64x4, log_ratio_b_f64x4);
|
|
237
|
+
// Kahan compensated summation for contribution a
|
|
238
|
+
__m256d compensated_a_f64x4 = _mm256_sub_pd(contribution_a_f64x4, compensation_f64x4);
|
|
239
|
+
__m256d tentative_a_f64x4 = _mm256_add_pd(sum_f64x4, compensated_a_f64x4);
|
|
240
|
+
compensation_f64x4 = _mm256_sub_pd(_mm256_sub_pd(tentative_a_f64x4, sum_f64x4), compensated_a_f64x4);
|
|
241
|
+
sum_f64x4 = tentative_a_f64x4;
|
|
242
|
+
// Kahan compensated summation for contribution b
|
|
243
|
+
__m256d compensated_b_f64x4 = _mm256_sub_pd(contribution_b_f64x4, compensation_f64x4);
|
|
244
|
+
__m256d tentative_b_f64x4 = _mm256_add_pd(sum_f64x4, compensated_b_f64x4);
|
|
245
|
+
compensation_f64x4 = _mm256_sub_pd(_mm256_sub_pd(tentative_b_f64x4, sum_f64x4), compensated_b_f64x4);
|
|
246
|
+
sum_f64x4 = tentative_b_f64x4;
|
|
247
|
+
if (n) goto nk_jsd_f64_haswell_cycle;
|
|
248
|
+
|
|
249
|
+
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
250
|
+
nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
251
|
+
sum *= log2_normalizer / 2;
|
|
252
|
+
*result = sum > 0 ? nk_f64_sqrt_haswell(sum) : 0;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
#if defined(__clang__)
|
|
256
|
+
#pragma clang attribute pop
|
|
257
|
+
#elif defined(__GNUC__)
|
|
258
|
+
#pragma GCC pop_options
|
|
259
|
+
#endif
|
|
260
|
+
|
|
261
|
+
#if defined(__cplusplus)
|
|
262
|
+
} // extern "C"
|
|
263
|
+
#endif
|
|
264
|
+
|
|
265
|
+
#endif // NK_TARGET_HASWELL
|
|
266
|
+
#endif // NK_TARGET_X86_
|
|
267
|
+
#endif // NK_PROBABILITY_HASWELL_H
|