numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Serial Probability Distribution Similarity Measures.
|
|
3
|
+
* @file include/numkong/probability/serial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/probability.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_PROBABILITY_SERIAL_H
|
|
10
|
+
#define NK_PROBABILITY_SERIAL_H
|
|
11
|
+
|
|
12
|
+
#include "numkong/types.h"
|
|
13
|
+
#include "numkong/cast/serial.h" // `nk_f16_to_f32_serial`, `nk_bf16_to_f32_serial`, `nk_assign_from_to_`
|
|
14
|
+
#include "numkong/spatial/serial.h" // `nk_f32_sqrt_serial`, `nk_f64_sqrt_serial`
|
|
15
|
+
|
|
16
|
+
#if defined(__cplusplus)
|
|
17
|
+
extern "C" {
|
|
18
|
+
#endif
|
|
19
|
+
|
|
20
|
+
#define nk_define_kld_(input_type, accumulator_type, output_type, load_and_convert, epsilon, compute_log) \
|
|
21
|
+
NK_PUBLIC void nk_kld_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
22
|
+
nk_size_t n, output_type *result) { \
|
|
23
|
+
nk_##accumulator_type##_t d = 0, ai, bi; \
|
|
24
|
+
for (nk_size_t i = 0; i != n; ++i) { \
|
|
25
|
+
load_and_convert(a + i, &ai); \
|
|
26
|
+
load_and_convert(b + i, &bi); \
|
|
27
|
+
d += ai * compute_log((ai + epsilon) / (bi + epsilon)); \
|
|
28
|
+
} \
|
|
29
|
+
*result = (output_type)d; \
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
#define nk_define_jsd_(input_type, accumulator_type, output_type, load_and_convert, epsilon, compute_log, \
|
|
33
|
+
compute_sqrt) \
|
|
34
|
+
NK_PUBLIC void nk_jsd_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
35
|
+
nk_size_t n, output_type *result) { \
|
|
36
|
+
nk_##accumulator_type##_t d = 0, ai, bi; \
|
|
37
|
+
for (nk_size_t i = 0; i != n; ++i) { \
|
|
38
|
+
load_and_convert(a + i, &ai); \
|
|
39
|
+
load_and_convert(b + i, &bi); \
|
|
40
|
+
nk_##accumulator_type##_t mi = (ai + bi) / 2; \
|
|
41
|
+
d += ai * compute_log((ai + epsilon) / (mi + epsilon)); \
|
|
42
|
+
d += bi * compute_log((bi + epsilon) / (mi + epsilon)); \
|
|
43
|
+
} \
|
|
44
|
+
output_type d_half = ((output_type)d / 2); \
|
|
45
|
+
*result = d_half > 0 ? compute_sqrt(d_half) : 0; \
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* @brief Computes `log(x)` for any positive float using IEEE 754 bit extraction
|
|
50
|
+
* and a fast-converging series expansion.
|
|
51
|
+
*
|
|
52
|
+
* Exploits the IEEE 754 representation to extract the exponent and mantissa:
|
|
53
|
+
* `log(x) = log(2) * exponent + log(mantissa)`. The mantissa is reduced to the
|
|
54
|
+
* range `[√2/2, √2]` for optimal convergence. Uses the transformation
|
|
55
|
+
* `u = (m-1)/(m+1)` which converges much faster than the classic Mercator series,
|
|
56
|
+
* since `u` is bounded to approximately `[-0.17, 0.17]` after range reduction.
|
|
57
|
+
*
|
|
58
|
+
* Maximum relative error is approximately 0.00001% across all positive floats,
|
|
59
|
+
* roughly 300,000x more accurate than the 3-term Mercator series (which also
|
|
60
|
+
* only converges for inputs in `(0, 2)`).
|
|
61
|
+
*
|
|
62
|
+
* https://en.wikipedia.org/wiki/Logarithm#Power_series
|
|
63
|
+
*/
|
|
64
|
+
NK_INTERNAL nk_f32_t nk_f32_log_serial_(nk_f32_t x) {
|
|
65
|
+
nk_fui32_t conv;
|
|
66
|
+
conv.f = x;
|
|
67
|
+
int exp = ((conv.u >> 23) & 0xFF) - 127;
|
|
68
|
+
conv.u = (conv.u & 0x007FFFFF) | 0x3F800000; // mantissa ∈ [1, 2)
|
|
69
|
+
nk_f32_t m = conv.f;
|
|
70
|
+
// Range reduction: if m > √2, halve it and increment exponent
|
|
71
|
+
if (m > 1.41421356f) m *= 0.5f, exp++;
|
|
72
|
+
// Use (m-1)/(m+1) transformation for faster convergence
|
|
73
|
+
nk_f32_t u = (m - 1.0f) / (m + 1.0f);
|
|
74
|
+
nk_f32_t u2 = u * u;
|
|
75
|
+
// log(m) = 2 × (u + u³/3 + u⁵/5 + u⁷/7)
|
|
76
|
+
nk_f32_t log_m = 2.0f * u * (1.0f + u2 * (0.3333333333f + u2 * (0.2f + u2 * 0.142857143f)));
|
|
77
|
+
return (nk_f32_t)exp * 0.6931471805599453f + log_m;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
/**
|
|
81
|
+
* @brief Computes `log(x)` for any positive double using IEEE 754 bit extraction
|
|
82
|
+
* and a fast-converging series expansion.
|
|
83
|
+
*
|
|
84
|
+
* Exploits the IEEE 754 representation to extract the 11-bit exponent and 52-bit mantissa:
|
|
85
|
+
* `log(x) = log(2) * exponent + log(mantissa)`. The mantissa is reduced to the
|
|
86
|
+
* range `[√2/2, √2]` for optimal convergence. Uses the transformation
|
|
87
|
+
* `u = (m-1)/(m+1)` which converges much faster than the classic Mercator series,
|
|
88
|
+
* since `u` is bounded to approximately `[-0.17, 0.17]` after range reduction.
|
|
89
|
+
*
|
|
90
|
+
* Uses more series terms than the f32 version to achieve near-full f64 precision,
|
|
91
|
+
* with maximum relative error approximately 0.00000000000001% across all positive doubles.
|
|
92
|
+
*
|
|
93
|
+
* https://en.wikipedia.org/wiki/Logarithm#Power_series
|
|
94
|
+
*/
|
|
95
|
+
NK_INTERNAL nk_f64_t nk_f64_log_serial_(nk_f64_t x) {
|
|
96
|
+
nk_fui64_t conv;
|
|
97
|
+
conv.f = x;
|
|
98
|
+
int exp = ((conv.u >> 52) & 0x7FF) - 1023;
|
|
99
|
+
conv.u = (conv.u & 0x000FFFFFFFFFFFFFULL) | 0x3FF0000000000000ULL; // mantissa ∈ [1, 2)
|
|
100
|
+
nk_f64_t m = conv.f;
|
|
101
|
+
// Range reduction: if m > √2, halve it and increment exponent
|
|
102
|
+
if (m > 1.4142135623730950488) m *= 0.5, exp++;
|
|
103
|
+
// Use (m-1)/(m+1) transformation for faster convergence
|
|
104
|
+
nk_f64_t u = (m - 1.0) / (m + 1.0);
|
|
105
|
+
nk_f64_t u2 = u * u;
|
|
106
|
+
// 14-term Horner: P(u²) = 1 + u²/3 + u⁴/5 + ... + u²⁶/27, matching SIMD
|
|
107
|
+
nk_f64_t poly = 1.0 / 27.0;
|
|
108
|
+
poly = u2 * poly + 1.0 / 25.0;
|
|
109
|
+
poly = u2 * poly + 1.0 / 23.0;
|
|
110
|
+
poly = u2 * poly + 1.0 / 21.0;
|
|
111
|
+
poly = u2 * poly + 1.0 / 19.0;
|
|
112
|
+
poly = u2 * poly + 1.0 / 17.0;
|
|
113
|
+
poly = u2 * poly + 1.0 / 15.0;
|
|
114
|
+
poly = u2 * poly + 1.0 / 13.0;
|
|
115
|
+
poly = u2 * poly + 1.0 / 11.0;
|
|
116
|
+
poly = u2 * poly + 1.0 / 9.0;
|
|
117
|
+
poly = u2 * poly + 1.0 / 7.0;
|
|
118
|
+
poly = u2 * poly + 1.0 / 5.0;
|
|
119
|
+
poly = u2 * poly + 1.0 / 3.0;
|
|
120
|
+
poly = u2 * poly + 1.0;
|
|
121
|
+
return (nk_f64_t)exp * 0.6931471805599453 + 2.0 * u * poly;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
nk_define_kld_(f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
125
|
+
nk_define_jsd_(f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_, nk_f64_sqrt_serial)
|
|
126
|
+
|
|
127
|
+
nk_define_kld_(f16, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
128
|
+
nk_define_jsd_(f16, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
|
|
129
|
+
nk_f32_sqrt_serial)
|
|
130
|
+
|
|
131
|
+
nk_define_kld_(bf16, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
|
|
132
|
+
nk_define_jsd_(bf16, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
|
|
133
|
+
nk_f32_sqrt_serial)
|
|
134
|
+
|
|
135
|
+
NK_PUBLIC void nk_kld_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
136
|
+
nk_f64_t sum = 0, compensation = 0;
|
|
137
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
138
|
+
nk_f64_t ai = a[i], bi = b[i];
|
|
139
|
+
nk_f64_t term = ai * nk_f64_log_serial_((ai + NK_F64_DIVISION_EPSILON) / (bi + NK_F64_DIVISION_EPSILON));
|
|
140
|
+
nk_f64_t t = sum + term;
|
|
141
|
+
compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term)) ? ((sum - t) + term) : ((term - t) + sum);
|
|
142
|
+
sum = t;
|
|
143
|
+
}
|
|
144
|
+
*result = sum + compensation;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
NK_PUBLIC void nk_jsd_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
148
|
+
nk_f64_t sum = 0, compensation = 0;
|
|
149
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
150
|
+
nk_f64_t ai = a[i], bi = b[i];
|
|
151
|
+
nk_f64_t mi = (ai + bi) / 2;
|
|
152
|
+
nk_f64_t term_a = ai * nk_f64_log_serial_((ai + NK_F64_DIVISION_EPSILON) / (mi + NK_F64_DIVISION_EPSILON));
|
|
153
|
+
nk_f64_t t = sum + term_a;
|
|
154
|
+
compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term_a)) ? ((sum - t) + term_a) : ((term_a - t) + sum);
|
|
155
|
+
sum = t;
|
|
156
|
+
nk_f64_t term_b = bi * nk_f64_log_serial_((bi + NK_F64_DIVISION_EPSILON) / (mi + NK_F64_DIVISION_EPSILON));
|
|
157
|
+
t = sum + term_b;
|
|
158
|
+
compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term_b)) ? ((sum - t) + term_b) : ((term_b - t) + sum);
|
|
159
|
+
sum = t;
|
|
160
|
+
}
|
|
161
|
+
nk_f64_t d_half = (sum + compensation) / 2;
|
|
162
|
+
*result = d_half > 0 ? nk_f64_sqrt_serial(d_half) : 0;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
#if defined(__cplusplus)
|
|
166
|
+
} // extern "C"
|
|
167
|
+
#endif
|
|
168
|
+
|
|
169
|
+
#endif // NK_PROBABILITY_SERIAL_H
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Skylake-accelerated Probability Distribution Similarity Measures.
|
|
3
|
+
* @file include/numkong/probability/skylake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/probability.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_PROBABILITY_SKYLAKE_H
|
|
10
|
+
#define NK_PROBABILITY_SKYLAKE_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_X86_
|
|
13
|
+
#if NK_TARGET_SKYLAKE
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.h"
|
|
16
|
+
#include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`, `nk_f64_sqrt_haswell`
|
|
17
|
+
#include "numkong/spatial/skylake.h" // `nk_f32_sqrt_skylake`, `nk_f64_sqrt_skylake`
|
|
18
|
+
|
|
19
|
+
#if defined(__cplusplus)
|
|
20
|
+
extern "C" {
|
|
21
|
+
#endif
|
|
22
|
+
|
|
23
|
+
#if defined(__clang__)
|
|
24
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
|
|
25
|
+
apply_to = function)
|
|
26
|
+
#elif defined(__GNUC__)
|
|
27
|
+
#pragma GCC push_options
|
|
28
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
29
|
+
#endif
|
|
30
|
+
|
|
31
|
+
NK_INTERNAL __m512 nk_log2_f32x16_skylake_(__m512 x) {
|
|
32
|
+
// Extract the exponent and mantissa: x = 2^exp × m, m ∈ [1, 2)
|
|
33
|
+
__m512 one_f32x16 = _mm512_set1_ps(1.0f);
|
|
34
|
+
__m512 exponent_f32x16 = _mm512_getexp_ps(x);
|
|
35
|
+
__m512 mantissa_f32x16 = _mm512_getmant_ps(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_src);
|
|
36
|
+
|
|
37
|
+
// Compute log2(m) using the s-series: s = (m-1)/(m+1), s ∈ [0, 1/3] for m ∈ [1, 2)
|
|
38
|
+
// log2(m) = (2/ln2) × s × (1 + s²/3 + s⁴/5 + s⁶/7 + s⁸/9)
|
|
39
|
+
__m512 s_f32x16 = _mm512_div_ps(_mm512_sub_ps(mantissa_f32x16, one_f32x16),
|
|
40
|
+
_mm512_add_ps(mantissa_f32x16, one_f32x16));
|
|
41
|
+
__m512 s2_f32x16 = _mm512_mul_ps(s_f32x16, s_f32x16);
|
|
42
|
+
__m512 series_f32x16 = _mm512_set1_ps(0.111111111f); // 1/9
|
|
43
|
+
series_f32x16 = _mm512_fmadd_ps(series_f32x16, s2_f32x16, _mm512_set1_ps(0.142857143f)); // 1/7
|
|
44
|
+
series_f32x16 = _mm512_fmadd_ps(series_f32x16, s2_f32x16, _mm512_set1_ps(0.2f)); // 1/5
|
|
45
|
+
series_f32x16 = _mm512_fmadd_ps(series_f32x16, s2_f32x16, _mm512_set1_ps(0.333333333f)); // 1/3
|
|
46
|
+
series_f32x16 = _mm512_fmadd_ps(series_f32x16, s2_f32x16, one_f32x16); // 1
|
|
47
|
+
// log2(m) = (2/ln2) × s × series
|
|
48
|
+
__m512 log2m_f32x16 = _mm512_mul_ps(_mm512_set1_ps(2.885390081777927f), _mm512_mul_ps(s_f32x16, series_f32x16));
|
|
49
|
+
return _mm512_add_ps(log2m_f32x16, exponent_f32x16);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
NK_PUBLIC void nk_kld_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
53
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
54
|
+
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
55
|
+
__m512 epsilon_f32x16 = _mm512_set1_ps(epsilon);
|
|
56
|
+
__m512 a_f32x16, b_f32x16;
|
|
57
|
+
|
|
58
|
+
nk_kld_f32_skylake_cycle:
|
|
59
|
+
if (n < 16) {
|
|
60
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
61
|
+
a_f32x16 = _mm512_maskz_loadu_ps(mask, a);
|
|
62
|
+
b_f32x16 = _mm512_maskz_loadu_ps(mask, b);
|
|
63
|
+
n = 0;
|
|
64
|
+
}
|
|
65
|
+
else {
|
|
66
|
+
a_f32x16 = _mm512_loadu_ps(a);
|
|
67
|
+
b_f32x16 = _mm512_loadu_ps(b);
|
|
68
|
+
a += 16, b += 16, n -= 16;
|
|
69
|
+
}
|
|
70
|
+
__m512 ratio_f32x16 = _mm512_div_ps(_mm512_add_ps(a_f32x16, epsilon_f32x16),
|
|
71
|
+
_mm512_add_ps(b_f32x16, epsilon_f32x16));
|
|
72
|
+
__m512 log_ratio_f32x16 = nk_log2_f32x16_skylake_(ratio_f32x16);
|
|
73
|
+
sum_f32x16 = _mm512_fmadd_ps(a_f32x16, log_ratio_f32x16, sum_f32x16);
|
|
74
|
+
if (n) goto nk_kld_f32_skylake_cycle;
|
|
75
|
+
|
|
76
|
+
__m256 lower_f32x8 = _mm512_castps512_ps256(sum_f32x16);
|
|
77
|
+
__m256 upper_f32x8 = _mm512_extractf32x8_ps(sum_f32x16, 1);
|
|
78
|
+
nk_f64_t sum = _mm512_reduce_add_pd(_mm512_cvtps_pd(lower_f32x8)) +
|
|
79
|
+
_mm512_reduce_add_pd(_mm512_cvtps_pd(upper_f32x8));
|
|
80
|
+
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
81
|
+
*result = sum * log2_normalizer;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
NK_PUBLIC void nk_jsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
85
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
86
|
+
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
87
|
+
__m512 epsilon_f32x16 = _mm512_set1_ps(epsilon);
|
|
88
|
+
__m512 a_f32x16, b_f32x16;
|
|
89
|
+
|
|
90
|
+
nk_jsd_f32_skylake_cycle:
|
|
91
|
+
if (n < 16) {
|
|
92
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
93
|
+
a_f32x16 = _mm512_maskz_loadu_ps(mask, a);
|
|
94
|
+
b_f32x16 = _mm512_maskz_loadu_ps(mask, b);
|
|
95
|
+
n = 0;
|
|
96
|
+
}
|
|
97
|
+
else {
|
|
98
|
+
a_f32x16 = _mm512_loadu_ps(a);
|
|
99
|
+
b_f32x16 = _mm512_loadu_ps(b);
|
|
100
|
+
a += 16, b += 16, n -= 16;
|
|
101
|
+
}
|
|
102
|
+
__m512 mean_f32x16 = _mm512_mul_ps(_mm512_add_ps(a_f32x16, b_f32x16), _mm512_set1_ps(0.5f));
|
|
103
|
+
__mmask16 nonzero_mask_a = _mm512_cmp_ps_mask(a_f32x16, epsilon_f32x16, _CMP_GE_OQ);
|
|
104
|
+
__mmask16 nonzero_mask_b = _mm512_cmp_ps_mask(b_f32x16, epsilon_f32x16, _CMP_GE_OQ);
|
|
105
|
+
__mmask16 nonzero_mask = nonzero_mask_a & nonzero_mask_b;
|
|
106
|
+
__m512 mean_with_epsilon_f32x16 = _mm512_add_ps(mean_f32x16, epsilon_f32x16);
|
|
107
|
+
__m512 ratio_a_f32x16 = _mm512_div_ps(_mm512_add_ps(a_f32x16, epsilon_f32x16), mean_with_epsilon_f32x16);
|
|
108
|
+
__m512 ratio_b_f32x16 = _mm512_div_ps(_mm512_add_ps(b_f32x16, epsilon_f32x16), mean_with_epsilon_f32x16);
|
|
109
|
+
__m512 log_ratio_a_f32x16 = nk_log2_f32x16_skylake_(ratio_a_f32x16);
|
|
110
|
+
__m512 log_ratio_b_f32x16 = nk_log2_f32x16_skylake_(ratio_b_f32x16);
|
|
111
|
+
__m512 contribution_a_f32x16 = _mm512_maskz_mul_ps(nonzero_mask, a_f32x16, log_ratio_a_f32x16);
|
|
112
|
+
__m512 contribution_b_f32x16 = _mm512_maskz_mul_ps(nonzero_mask, b_f32x16, log_ratio_b_f32x16);
|
|
113
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, _mm512_add_ps(contribution_a_f32x16, contribution_b_f32x16));
|
|
114
|
+
if (n) goto nk_jsd_f32_skylake_cycle;
|
|
115
|
+
|
|
116
|
+
__m256 lower_f32x8 = _mm512_castps512_ps256(sum_f32x16);
|
|
117
|
+
__m256 upper_f32x8 = _mm512_extractf32x8_ps(sum_f32x16, 1);
|
|
118
|
+
nk_f64_t sum = (_mm512_reduce_add_pd(_mm512_cvtps_pd(lower_f32x8)) +
|
|
119
|
+
_mm512_reduce_add_pd(_mm512_cvtps_pd(upper_f32x8))) *
|
|
120
|
+
0.6931471805599453 / 2.0;
|
|
121
|
+
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
122
|
+
nk_unused_(log2_normalizer);
|
|
123
|
+
*result = sum > 0 ? nk_f64_sqrt_haswell(sum) : 0;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
NK_INTERNAL __m512d nk_log2_f64x8_skylake_(__m512d x) {
|
|
127
|
+
// Extract the exponent and mantissa: x = 2^exp × m, m ∈ [1, 2)
|
|
128
|
+
__m512d one_f64x8 = _mm512_set1_pd(1.0);
|
|
129
|
+
__m512d two_f64x8 = _mm512_set1_pd(2.0);
|
|
130
|
+
__m512d exponent_f64x8 = _mm512_getexp_pd(x);
|
|
131
|
+
__m512d mantissa_f64x8 = _mm512_getmant_pd(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_src);
|
|
132
|
+
|
|
133
|
+
// Compute log2(m) using the s-series: s = (m-1)/(m+1), s ∈ [0, 1/3] for m ∈ [1, 2)
|
|
134
|
+
// ln(m) = 2 × s × (1 + s²/3 + s⁴/5 + s⁶/7 + ...) converges fast since s² ≤ 1/9
|
|
135
|
+
// log2(m) = ln(m) × log2(e)
|
|
136
|
+
__m512d s_f64x8 = _mm512_div_pd(_mm512_sub_pd(mantissa_f64x8, one_f64x8), _mm512_add_pd(mantissa_f64x8, one_f64x8));
|
|
137
|
+
__m512d s2_f64x8 = _mm512_mul_pd(s_f64x8, s_f64x8);
|
|
138
|
+
|
|
139
|
+
// Polynomial P(s²) = 1 + s²/3 + s⁴/5 + ... using Horner's method
|
|
140
|
+
// 14 terms (k=0..13) achieves ~1 ULP accuracy for f64
|
|
141
|
+
__m512d poly_f64x8 = _mm512_set1_pd(1.0 / 27.0); // 1/(2*13+1)
|
|
142
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 25.0));
|
|
143
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 23.0));
|
|
144
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 21.0));
|
|
145
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 19.0));
|
|
146
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 17.0));
|
|
147
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 15.0));
|
|
148
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 13.0));
|
|
149
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 11.0));
|
|
150
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 9.0));
|
|
151
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 7.0));
|
|
152
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 5.0));
|
|
153
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0 / 3.0));
|
|
154
|
+
poly_f64x8 = _mm512_fmadd_pd(s2_f64x8, poly_f64x8, _mm512_set1_pd(1.0));
|
|
155
|
+
|
|
156
|
+
// ln(m) = 2 × s × P(s²), then log2(m) = ln(m) × log2(e)
|
|
157
|
+
__m512d ln_m_f64x8 = _mm512_mul_pd(_mm512_mul_pd(two_f64x8, s_f64x8), poly_f64x8);
|
|
158
|
+
__m512d log2e_f64x8 = _mm512_set1_pd(1.4426950408889634); // 1/ln(2)
|
|
159
|
+
__m512d log2_m_f64x8 = _mm512_mul_pd(ln_m_f64x8, log2e_f64x8);
|
|
160
|
+
|
|
161
|
+
// log2(x) = exponent + log2(m)
|
|
162
|
+
return _mm512_add_pd(exponent_f64x8, log2_m_f64x8);
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
NK_PUBLIC void nk_kld_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
166
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
167
|
+
__m512d compensation_f64x8 = _mm512_setzero_pd();
|
|
168
|
+
nk_f64_t epsilon = NK_F64_DIVISION_EPSILON;
|
|
169
|
+
__m512d epsilon_f64x8 = _mm512_set1_pd(epsilon);
|
|
170
|
+
__m512d a_f64x8, b_f64x8;
|
|
171
|
+
|
|
172
|
+
nk_kld_f64_skylake_cycle:
|
|
173
|
+
if (n < 8) {
|
|
174
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFF, n);
|
|
175
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a);
|
|
176
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, b);
|
|
177
|
+
n = 0;
|
|
178
|
+
}
|
|
179
|
+
else {
|
|
180
|
+
a_f64x8 = _mm512_loadu_pd(a);
|
|
181
|
+
b_f64x8 = _mm512_loadu_pd(b);
|
|
182
|
+
a += 8, b += 8, n -= 8;
|
|
183
|
+
}
|
|
184
|
+
__m512d ratio_f64x8 = _mm512_div_pd(_mm512_add_pd(a_f64x8, epsilon_f64x8), _mm512_add_pd(b_f64x8, epsilon_f64x8));
|
|
185
|
+
__m512d log_ratio_f64x8 = nk_log2_f64x8_skylake_(ratio_f64x8);
|
|
186
|
+
__m512d contribution_f64x8 = _mm512_mul_pd(a_f64x8, log_ratio_f64x8);
|
|
187
|
+
// Kahan compensated summation
|
|
188
|
+
__m512d compensated_f64x8 = _mm512_sub_pd(contribution_f64x8, compensation_f64x8);
|
|
189
|
+
__m512d tentative_f64x8 = _mm512_add_pd(sum_f64x8, compensated_f64x8);
|
|
190
|
+
compensation_f64x8 = _mm512_sub_pd(_mm512_sub_pd(tentative_f64x8, sum_f64x8), compensated_f64x8);
|
|
191
|
+
sum_f64x8 = tentative_f64x8;
|
|
192
|
+
if (n) goto nk_kld_f64_skylake_cycle;
|
|
193
|
+
|
|
194
|
+
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
195
|
+
*result = _mm512_reduce_add_pd(sum_f64x8) * log2_normalizer;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
NK_PUBLIC void nk_jsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
199
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
200
|
+
__m512d compensation_f64x8 = _mm512_setzero_pd();
|
|
201
|
+
nk_f64_t epsilon = NK_F64_DIVISION_EPSILON;
|
|
202
|
+
__m512d epsilon_f64x8 = _mm512_set1_pd(epsilon);
|
|
203
|
+
__m512d a_f64x8, b_f64x8;
|
|
204
|
+
|
|
205
|
+
nk_jsd_f64_skylake_cycle:
|
|
206
|
+
if (n < 8) {
|
|
207
|
+
__mmask8 mask = (__mmask8)_bzhi_u32(0xFF, n);
|
|
208
|
+
a_f64x8 = _mm512_maskz_loadu_pd(mask, a);
|
|
209
|
+
b_f64x8 = _mm512_maskz_loadu_pd(mask, b);
|
|
210
|
+
n = 0;
|
|
211
|
+
}
|
|
212
|
+
else {
|
|
213
|
+
a_f64x8 = _mm512_loadu_pd(a);
|
|
214
|
+
b_f64x8 = _mm512_loadu_pd(b);
|
|
215
|
+
a += 8, b += 8, n -= 8;
|
|
216
|
+
}
|
|
217
|
+
__m512d mean_f64x8 = _mm512_mul_pd(_mm512_add_pd(a_f64x8, b_f64x8), _mm512_set1_pd(0.5));
|
|
218
|
+
__mmask8 nonzero_mask_a = _mm512_cmp_pd_mask(a_f64x8, epsilon_f64x8, _CMP_GE_OQ);
|
|
219
|
+
__mmask8 nonzero_mask_b = _mm512_cmp_pd_mask(b_f64x8, epsilon_f64x8, _CMP_GE_OQ);
|
|
220
|
+
__mmask8 nonzero_mask = nonzero_mask_a & nonzero_mask_b;
|
|
221
|
+
__m512d mean_with_epsilon_f64x8 = _mm512_add_pd(mean_f64x8, epsilon_f64x8);
|
|
222
|
+
__m512d ratio_a_f64x8 = _mm512_div_pd(_mm512_add_pd(a_f64x8, epsilon_f64x8), mean_with_epsilon_f64x8);
|
|
223
|
+
__m512d ratio_b_f64x8 = _mm512_div_pd(_mm512_add_pd(b_f64x8, epsilon_f64x8), mean_with_epsilon_f64x8);
|
|
224
|
+
__m512d log_ratio_a_f64x8 = nk_log2_f64x8_skylake_(ratio_a_f64x8);
|
|
225
|
+
__m512d log_ratio_b_f64x8 = nk_log2_f64x8_skylake_(ratio_b_f64x8);
|
|
226
|
+
__m512d contribution_a_f64x8 = _mm512_maskz_mul_pd(nonzero_mask, a_f64x8, log_ratio_a_f64x8);
|
|
227
|
+
__m512d contribution_b_f64x8 = _mm512_maskz_mul_pd(nonzero_mask, b_f64x8, log_ratio_b_f64x8);
|
|
228
|
+
// Kahan compensated summation for contribution a
|
|
229
|
+
__m512d compensated_a_f64x8 = _mm512_sub_pd(contribution_a_f64x8, compensation_f64x8);
|
|
230
|
+
__m512d tentative_a_f64x8 = _mm512_add_pd(sum_f64x8, compensated_a_f64x8);
|
|
231
|
+
compensation_f64x8 = _mm512_sub_pd(_mm512_sub_pd(tentative_a_f64x8, sum_f64x8), compensated_a_f64x8);
|
|
232
|
+
sum_f64x8 = tentative_a_f64x8;
|
|
233
|
+
// Kahan compensated summation for contribution b
|
|
234
|
+
__m512d compensated_b_f64x8 = _mm512_sub_pd(contribution_b_f64x8, compensation_f64x8);
|
|
235
|
+
__m512d tentative_b_f64x8 = _mm512_add_pd(sum_f64x8, compensated_b_f64x8);
|
|
236
|
+
compensation_f64x8 = _mm512_sub_pd(_mm512_sub_pd(tentative_b_f64x8, sum_f64x8), compensated_b_f64x8);
|
|
237
|
+
sum_f64x8 = tentative_b_f64x8;
|
|
238
|
+
if (n) goto nk_jsd_f64_skylake_cycle;
|
|
239
|
+
|
|
240
|
+
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
241
|
+
nk_f64_t sum = _mm512_reduce_add_pd(sum_f64x8);
|
|
242
|
+
sum *= log2_normalizer / 2;
|
|
243
|
+
*result = sum > 0 ? nk_f64_sqrt_haswell(sum) : 0;
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
NK_PUBLIC void nk_kld_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
247
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
248
|
+
__m512 epsilon_f32x16 = _mm512_set1_ps(NK_F32_DIVISION_EPSILON);
|
|
249
|
+
__m512 a_f32x16, b_f32x16;
|
|
250
|
+
|
|
251
|
+
nk_kld_f16_skylake_cycle:
|
|
252
|
+
if (n < 16) {
|
|
253
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
254
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
|
|
255
|
+
b_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
|
|
256
|
+
n = 0;
|
|
257
|
+
}
|
|
258
|
+
else {
|
|
259
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)a));
|
|
260
|
+
b_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)b));
|
|
261
|
+
a += 16, b += 16, n -= 16;
|
|
262
|
+
}
|
|
263
|
+
__m512 ratio_f32x16 = _mm512_div_ps(_mm512_add_ps(a_f32x16, epsilon_f32x16),
|
|
264
|
+
_mm512_add_ps(b_f32x16, epsilon_f32x16));
|
|
265
|
+
__m512 log_ratio_f32x16 = nk_log2_f32x16_skylake_(ratio_f32x16);
|
|
266
|
+
__m512 contribution_f32x16 = _mm512_mul_ps(a_f32x16, log_ratio_f32x16);
|
|
267
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, contribution_f32x16);
|
|
268
|
+
if (n) goto nk_kld_f16_skylake_cycle;
|
|
269
|
+
|
|
270
|
+
nk_f32_t log2_normalizer = 0.6931471805599453f;
|
|
271
|
+
*result = _mm512_reduce_add_ps(sum_f32x16) * log2_normalizer;
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
NK_PUBLIC void nk_jsd_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
275
|
+
__m512 sum_a_f32x16 = _mm512_setzero_ps();
|
|
276
|
+
__m512 sum_b_f32x16 = _mm512_setzero_ps();
|
|
277
|
+
__m512 epsilon_f32x16 = _mm512_set1_ps(NK_F32_DIVISION_EPSILON);
|
|
278
|
+
__m512 half_f32x16 = _mm512_set1_ps(0.5f);
|
|
279
|
+
__m512 a_f32x16, b_f32x16;
|
|
280
|
+
|
|
281
|
+
nk_jsd_f16_skylake_cycle:
|
|
282
|
+
if (n < 16) {
|
|
283
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
284
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, a));
|
|
285
|
+
b_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, b));
|
|
286
|
+
n = 0;
|
|
287
|
+
}
|
|
288
|
+
else {
|
|
289
|
+
a_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)a));
|
|
290
|
+
b_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)b));
|
|
291
|
+
a += 16, b += 16, n -= 16;
|
|
292
|
+
}
|
|
293
|
+
__m512 mean_f32x16 = _mm512_mul_ps(_mm512_add_ps(a_f32x16, b_f32x16), half_f32x16);
|
|
294
|
+
__mmask16 nonzero_mask_a = _mm512_cmp_ps_mask(a_f32x16, epsilon_f32x16, _CMP_GE_OQ);
|
|
295
|
+
__mmask16 nonzero_mask_b = _mm512_cmp_ps_mask(b_f32x16, epsilon_f32x16, _CMP_GE_OQ);
|
|
296
|
+
__mmask16 nonzero_mask = nonzero_mask_a & nonzero_mask_b;
|
|
297
|
+
__m512 mean_with_epsilon_f32x16 = _mm512_add_ps(mean_f32x16, epsilon_f32x16);
|
|
298
|
+
__m512 ratio_a_f32x16 = _mm512_div_ps(_mm512_add_ps(a_f32x16, epsilon_f32x16), mean_with_epsilon_f32x16);
|
|
299
|
+
__m512 ratio_b_f32x16 = _mm512_div_ps(_mm512_add_ps(b_f32x16, epsilon_f32x16), mean_with_epsilon_f32x16);
|
|
300
|
+
__m512 log_ratio_a_f32x16 = nk_log2_f32x16_skylake_(ratio_a_f32x16);
|
|
301
|
+
__m512 log_ratio_b_f32x16 = nk_log2_f32x16_skylake_(ratio_b_f32x16);
|
|
302
|
+
sum_a_f32x16 = _mm512_mask3_fmadd_ps(a_f32x16, log_ratio_a_f32x16, sum_a_f32x16, nonzero_mask);
|
|
303
|
+
sum_b_f32x16 = _mm512_mask3_fmadd_ps(b_f32x16, log_ratio_b_f32x16, sum_b_f32x16, nonzero_mask);
|
|
304
|
+
if (n) goto nk_jsd_f16_skylake_cycle;
|
|
305
|
+
|
|
306
|
+
nk_f32_t log2_normalizer = 0.6931471805599453f;
|
|
307
|
+
nk_f32_t sum = _mm512_reduce_add_ps(_mm512_add_ps(sum_a_f32x16, sum_b_f32x16));
|
|
308
|
+
sum *= log2_normalizer / 2;
|
|
309
|
+
*result = sum > 0 ? nk_f32_sqrt_haswell(sum) : 0;
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
#if defined(__clang__)
|
|
313
|
+
#pragma clang attribute pop
|
|
314
|
+
#elif defined(__GNUC__)
|
|
315
|
+
#pragma GCC pop_options
|
|
316
|
+
#endif
|
|
317
|
+
|
|
318
|
+
#if defined(__cplusplus)
|
|
319
|
+
} // extern "C"
|
|
320
|
+
#endif
|
|
321
|
+
|
|
322
|
+
#endif // NK_TARGET_SKYLAKE
|
|
323
|
+
#endif // NK_TARGET_X86_
|
|
324
|
+
#endif // NK_PROBABILITY_SKYLAKE_H
|