numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Similarity Measures for Curved Spaces.
|
|
3
|
+
* @file include/numkong/curved.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date August 27, 2024
|
|
6
|
+
*
|
|
7
|
+
* Contains following similarity measures:
|
|
8
|
+
*
|
|
9
|
+
* - Mahalanobis distance: √((a-b)ᵀ × C × (a-b))
|
|
10
|
+
* - Bilinear form: aᵀ × C × b
|
|
11
|
+
* - Bilinear form over complex numbers
|
|
12
|
+
*
|
|
13
|
+
* For dtypes:
|
|
14
|
+
*
|
|
15
|
+
* - 64-bit floating point numbers → 64-bit floats
|
|
16
|
+
* - 32-bit floating point numbers → 64-bit floats
|
|
17
|
+
* - 16-bit floating point numbers → 32-bit floats
|
|
18
|
+
* - 16-bit brain-floating point numbers → 32-bit floats
|
|
19
|
+
*
|
|
20
|
+
* For hardware architectures:
|
|
21
|
+
*
|
|
22
|
+
* - Arm: NEON, NEON+F16, NEON+BF16, SME+F64
|
|
23
|
+
* - x86: Haswell, Skylake, Genoa
|
|
24
|
+
* - RISC-V: RVV
|
|
25
|
+
*
|
|
26
|
+
* @section numerical_stability Numerical Stability
|
|
27
|
+
*
|
|
28
|
+
* To minimize catastrophic cancellation in large-magnitude sums:
|
|
29
|
+
* - f32 kernels widen public outputs to f64/f64c and accumulate in f64 precision where possible
|
|
30
|
+
* - f64 kernels use Dot2 algorithm (Ogita-Rump-Oishi 2005) in SIMD paths
|
|
31
|
+
* - Serial kernels use Neumaier compensated summation for all types
|
|
32
|
+
*
|
|
33
|
+
* @section usage Usage and Benefits
|
|
34
|
+
*
|
|
35
|
+
* These kernels target BLAS level 2 patterns where vectors are combined with a metric
|
|
36
|
+
* tensor or covariance matrix. Using raw bilinear and Mahalanobis forms avoids constructing
|
|
37
|
+
* intermediates and keeps memory traffic low, which is often faster than a full GEMM path
|
|
38
|
+
* for small and medium sizes. Complex bilinear forms return a complex scalar as two reals,
|
|
39
|
+
* serving complex-valued signals without extra packing or unpacking.
|
|
40
|
+
*
|
|
41
|
+
* @section references References
|
|
42
|
+
*
|
|
43
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
44
|
+
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
45
|
+
* - Neumaier, A. (1974). "Rundungsfehleranalyse einiger Verfahren zur Summation endlicher Summen"
|
|
46
|
+
* - Ogita, T., Rump, S.M., Oishi, S. (2005). "Accurate Sum and Dot Product"
|
|
47
|
+
*
|
|
48
|
+
*/
|
|
49
|
+
#ifndef NK_CURVED_H
|
|
50
|
+
#define NK_CURVED_H
|
|
51
|
+
|
|
52
|
+
#include "numkong/types.h"
|
|
53
|
+
|
|
54
|
+
#if defined(__cplusplus)
|
|
55
|
+
extern "C" {
|
|
56
|
+
#endif
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* @brief Bilinear form between vectors a and b under metric tensor C.
|
|
60
|
+
*
|
|
61
|
+
* Computes aᵀ × C × b = Σᵢ Σⱼ aᵢ × cᵢⱼ × bⱼ
|
|
62
|
+
*
|
|
63
|
+
* @param[in] a The first vector.
|
|
64
|
+
* @param[in] b The second vector.
|
|
65
|
+
* @param[in] c The metric tensor or covariance matrix, stored row-major as an n×n matrix.
|
|
66
|
+
* @param[in] n The number of dimensions in the vectors.
|
|
67
|
+
* @param[out] result The output bilinear form value.
|
|
68
|
+
*
|
|
69
|
+
* @note The output value can be negative.
|
|
70
|
+
*/
|
|
71
|
+
NK_DYNAMIC void nk_bilinear_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n, nk_f64_t *result);
|
|
72
|
+
/** @copydoc nk_bilinear_f64 */
|
|
73
|
+
NK_DYNAMIC void nk_bilinear_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n, nk_f64_t *result);
|
|
74
|
+
/** @copydoc nk_bilinear_f64 */
|
|
75
|
+
NK_DYNAMIC void nk_bilinear_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n, nk_f32_t *result);
|
|
76
|
+
/** @copydoc nk_bilinear_f64 */
|
|
77
|
+
NK_DYNAMIC void nk_bilinear_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
78
|
+
nk_f32_t *result);
|
|
79
|
+
|
|
80
|
+
/**
|
|
81
|
+
* @brief Mahalanobis distance between vectors a and b under metric tensor C.
|
|
82
|
+
*
|
|
83
|
+
* Computes √((a-b)ᵀ × C × (a-b)) = √(Σᵢ Σⱼ (aᵢ-bᵢ) × cᵢⱼ × (aⱼ-bⱼ))
|
|
84
|
+
*
|
|
85
|
+
* @param[in] a The first vector.
|
|
86
|
+
* @param[in] b The second vector.
|
|
87
|
+
* @param[in] c The Positive Semi-Definite (PSD) matrix, stored row-major as an n×n matrix.
|
|
88
|
+
* @param[in] n The number of dimensions in the vectors.
|
|
89
|
+
* @param[out] result The output distance value.
|
|
90
|
+
*
|
|
91
|
+
* @note The output value is non-negative when C is PSD.
|
|
92
|
+
* @note The output value is zero if and only if the two vectors are identical.
|
|
93
|
+
* @note The matrix C must be positive semi-definite. If C is not PSD, the quadratic form
|
|
94
|
+
* (a-b)ᵀ C (a-b) may be negative, and the square root will produce NaN.
|
|
95
|
+
*/
|
|
96
|
+
NK_DYNAMIC void nk_mahalanobis_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
97
|
+
nk_f64_t *result);
|
|
98
|
+
/** @copydoc nk_mahalanobis_f64 */
|
|
99
|
+
NK_DYNAMIC void nk_mahalanobis_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
100
|
+
nk_f64_t *result);
|
|
101
|
+
/** @copydoc nk_mahalanobis_f64 */
|
|
102
|
+
NK_DYNAMIC void nk_mahalanobis_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
103
|
+
nk_f32_t *result);
|
|
104
|
+
/** @copydoc nk_mahalanobis_f64 */
|
|
105
|
+
NK_DYNAMIC void nk_mahalanobis_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
106
|
+
nk_f32_t *result);
|
|
107
|
+
|
|
108
|
+
/**
|
|
109
|
+
* @brief Complex bilinear form between vectors a and b under metric tensor C.
|
|
110
|
+
*
|
|
111
|
+
* @param[in] a The first complex vector.
|
|
112
|
+
* @param[in] b The second complex vector.
|
|
113
|
+
* @param[in] c The complex metric tensor, stored row-major as an n×n matrix.
|
|
114
|
+
* @param[in] n The number of dimensions in the vectors.
|
|
115
|
+
* @param[out] results The output complex value with real and imaginary parts.
|
|
116
|
+
*/
|
|
117
|
+
NK_DYNAMIC void nk_bilinear_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
118
|
+
nk_f64c_t *results);
|
|
119
|
+
/** @copydoc nk_bilinear_f64c */
|
|
120
|
+
NK_DYNAMIC void nk_bilinear_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
121
|
+
nk_f64c_t *results);
|
|
122
|
+
/** @copydoc nk_bilinear_f64c */
|
|
123
|
+
NK_DYNAMIC void nk_bilinear_f16c(nk_f16c_t const *a, nk_f16c_t const *b, nk_f16c_t const *c, nk_size_t n,
|
|
124
|
+
nk_f32c_t *results);
|
|
125
|
+
/** @copydoc nk_bilinear_f64c */
|
|
126
|
+
NK_DYNAMIC void nk_bilinear_bf16c(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
|
|
127
|
+
nk_f32c_t *results);
|
|
128
|
+
|
|
129
|
+
/** @copydoc nk_bilinear_f64 */
|
|
130
|
+
NK_PUBLIC void nk_bilinear_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
131
|
+
nk_f64_t *result);
|
|
132
|
+
/** @copydoc nk_bilinear_f64c */
|
|
133
|
+
NK_PUBLIC void nk_bilinear_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
134
|
+
nk_f64c_t *results);
|
|
135
|
+
/** @copydoc nk_mahalanobis_f64 */
|
|
136
|
+
NK_PUBLIC void nk_mahalanobis_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
137
|
+
nk_f64_t *result);
|
|
138
|
+
/** @copydoc nk_bilinear_f32 */
|
|
139
|
+
NK_PUBLIC void nk_bilinear_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
140
|
+
nk_f64_t *result);
|
|
141
|
+
/** @copydoc nk_bilinear_f32c */
|
|
142
|
+
NK_PUBLIC void nk_bilinear_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
143
|
+
nk_f64c_t *results);
|
|
144
|
+
/** @copydoc nk_mahalanobis_f32 */
|
|
145
|
+
NK_PUBLIC void nk_mahalanobis_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
146
|
+
nk_f64_t *result);
|
|
147
|
+
/** @copydoc nk_bilinear_f16 */
|
|
148
|
+
NK_PUBLIC void nk_bilinear_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
149
|
+
nk_f32_t *result);
|
|
150
|
+
/** @copydoc nk_bilinear_f16c */
|
|
151
|
+
NK_PUBLIC void nk_bilinear_f16c_serial(nk_f16c_t const *a, nk_f16c_t const *b, nk_f16c_t const *c, nk_size_t n,
|
|
152
|
+
nk_f32c_t *results);
|
|
153
|
+
/** @copydoc nk_mahalanobis_f16 */
|
|
154
|
+
NK_PUBLIC void nk_mahalanobis_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
155
|
+
nk_f32_t *result);
|
|
156
|
+
/** @copydoc nk_bilinear_bf16 */
|
|
157
|
+
NK_PUBLIC void nk_bilinear_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
158
|
+
nk_f32_t *result);
|
|
159
|
+
/** @copydoc nk_bilinear_bf16c */
|
|
160
|
+
NK_PUBLIC void nk_bilinear_bf16c_serial(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
|
|
161
|
+
nk_f32c_t *results);
|
|
162
|
+
/** @copydoc nk_mahalanobis_bf16 */
|
|
163
|
+
NK_PUBLIC void nk_mahalanobis_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
164
|
+
nk_f32_t *result);
|
|
165
|
+
|
|
166
|
+
#if NK_TARGET_NEON
|
|
167
|
+
/** @copydoc nk_bilinear_f32 */
|
|
168
|
+
NK_PUBLIC void nk_bilinear_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
169
|
+
nk_f64_t *result);
|
|
170
|
+
/** @copydoc nk_bilinear_f32c */
|
|
171
|
+
NK_PUBLIC void nk_bilinear_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
172
|
+
nk_f64c_t *results);
|
|
173
|
+
/** @copydoc nk_mahalanobis_f32 */
|
|
174
|
+
NK_PUBLIC void nk_mahalanobis_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
175
|
+
nk_f64_t *result);
|
|
176
|
+
#endif // NK_TARGET_NEON
|
|
177
|
+
|
|
178
|
+
#if NK_TARGET_NEONHALF
|
|
179
|
+
/** @copydoc nk_bilinear_f16 */
|
|
180
|
+
NK_PUBLIC void nk_bilinear_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
181
|
+
nk_f32_t *result);
|
|
182
|
+
/** @copydoc nk_bilinear_f16c */
|
|
183
|
+
NK_PUBLIC void nk_bilinear_f16c_neonhalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_f16c_t const *c, nk_size_t n,
|
|
184
|
+
nk_f32c_t *results);
|
|
185
|
+
/** @copydoc nk_mahalanobis_f16 */
|
|
186
|
+
NK_PUBLIC void nk_mahalanobis_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
187
|
+
nk_f32_t *result);
|
|
188
|
+
#endif // NK_TARGET_NEONHALF
|
|
189
|
+
|
|
190
|
+
#if NK_TARGET_NEONBFDOT
|
|
191
|
+
/** @copydoc nk_bilinear_bf16 */
|
|
192
|
+
NK_PUBLIC void nk_bilinear_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
193
|
+
nk_f32_t *result);
|
|
194
|
+
/** @copydoc nk_bilinear_bf16c */
|
|
195
|
+
NK_PUBLIC void nk_bilinear_bf16c_neonbfdot(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
|
|
196
|
+
nk_f32c_t *results);
|
|
197
|
+
/** @copydoc nk_mahalanobis_bf16 */
|
|
198
|
+
NK_PUBLIC void nk_mahalanobis_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
199
|
+
nk_f32_t *result);
|
|
200
|
+
#endif // NK_TARGET_NEONBFDOT
|
|
201
|
+
|
|
202
|
+
#if NK_TARGET_SMEF64
|
|
203
|
+
/** @copydoc nk_bilinear_f32 */
|
|
204
|
+
NK_PUBLIC void nk_bilinear_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
205
|
+
nk_f64_t *result);
|
|
206
|
+
/** @copydoc nk_bilinear_f32c */
|
|
207
|
+
NK_PUBLIC void nk_bilinear_f32c_smef64(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
208
|
+
nk_f64c_t *result);
|
|
209
|
+
/** @copydoc nk_mahalanobis_f32 */
|
|
210
|
+
NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
211
|
+
nk_f64_t *result);
|
|
212
|
+
/** @copydoc nk_bilinear_f64 */
|
|
213
|
+
NK_PUBLIC void nk_bilinear_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
214
|
+
nk_f64_t *result);
|
|
215
|
+
/** @copydoc nk_bilinear_f64c */
|
|
216
|
+
NK_PUBLIC void nk_bilinear_f64c_smef64(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
217
|
+
nk_f64c_t *result);
|
|
218
|
+
/** @copydoc nk_mahalanobis_f64 */
|
|
219
|
+
NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
220
|
+
nk_f64_t *result);
|
|
221
|
+
#endif // NK_TARGET_SMEF64
|
|
222
|
+
|
|
223
|
+
#if NK_TARGET_HASWELL
|
|
224
|
+
/** @copydoc nk_bilinear_f32 */
|
|
225
|
+
NK_PUBLIC void nk_bilinear_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
226
|
+
nk_f64_t *result);
|
|
227
|
+
/** @copydoc nk_mahalanobis_f32 */
|
|
228
|
+
NK_PUBLIC void nk_mahalanobis_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
229
|
+
nk_f64_t *result);
|
|
230
|
+
/** @copydoc nk_bilinear_f16 */
|
|
231
|
+
NK_PUBLIC void nk_bilinear_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
232
|
+
nk_f32_t *result);
|
|
233
|
+
/** @copydoc nk_mahalanobis_f16 */
|
|
234
|
+
NK_PUBLIC void nk_mahalanobis_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
235
|
+
nk_f32_t *result);
|
|
236
|
+
/** @copydoc nk_bilinear_bf16 */
|
|
237
|
+
NK_PUBLIC void nk_bilinear_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
238
|
+
nk_f32_t *result);
|
|
239
|
+
/** @copydoc nk_mahalanobis_bf16 */
|
|
240
|
+
NK_PUBLIC void nk_mahalanobis_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
241
|
+
nk_f32_t *result);
|
|
242
|
+
#endif // NK_TARGET_HASWELL
|
|
243
|
+
|
|
244
|
+
#if NK_TARGET_SKYLAKE
|
|
245
|
+
/** @copydoc nk_bilinear_f64 */
|
|
246
|
+
NK_PUBLIC void nk_bilinear_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
247
|
+
nk_f64_t *result);
|
|
248
|
+
/** @copydoc nk_bilinear_f64c */
|
|
249
|
+
NK_PUBLIC void nk_bilinear_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
250
|
+
nk_f64c_t *results);
|
|
251
|
+
/** @copydoc nk_mahalanobis_f64 */
|
|
252
|
+
NK_PUBLIC void nk_mahalanobis_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
253
|
+
nk_f64_t *result);
|
|
254
|
+
/** @copydoc nk_bilinear_f32 */
|
|
255
|
+
NK_PUBLIC void nk_bilinear_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
256
|
+
nk_f64_t *result);
|
|
257
|
+
/** @copydoc nk_bilinear_f32c */
|
|
258
|
+
NK_PUBLIC void nk_bilinear_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
259
|
+
nk_f64c_t *results);
|
|
260
|
+
/** @copydoc nk_mahalanobis_f32 */
|
|
261
|
+
NK_PUBLIC void nk_mahalanobis_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
262
|
+
nk_f64_t *result);
|
|
263
|
+
#endif // NK_TARGET_SKYLAKE
|
|
264
|
+
|
|
265
|
+
#if NK_TARGET_GENOA
|
|
266
|
+
/** @copydoc nk_bilinear_bf16 */
|
|
267
|
+
NK_PUBLIC void nk_bilinear_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
268
|
+
nk_f32_t *result);
|
|
269
|
+
/** @copydoc nk_bilinear_bf16c */
|
|
270
|
+
NK_PUBLIC void nk_bilinear_bf16c_genoa(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
|
|
271
|
+
nk_f32c_t *results);
|
|
272
|
+
/** @copydoc nk_mahalanobis_bf16 */
|
|
273
|
+
NK_PUBLIC void nk_mahalanobis_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
274
|
+
nk_f32_t *result);
|
|
275
|
+
#endif // NK_TARGET_GENOA
|
|
276
|
+
|
|
277
|
+
#if NK_TARGET_RVV
|
|
278
|
+
/** @copydoc nk_bilinear_f64 */
|
|
279
|
+
NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
280
|
+
nk_f64_t *result);
|
|
281
|
+
/** @copydoc nk_mahalanobis_f64 */
|
|
282
|
+
NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
283
|
+
nk_f64_t *result);
|
|
284
|
+
/** @copydoc nk_bilinear_f32 */
|
|
285
|
+
NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
286
|
+
nk_f64_t *result);
|
|
287
|
+
/** @copydoc nk_mahalanobis_f32 */
|
|
288
|
+
NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
289
|
+
nk_f64_t *result);
|
|
290
|
+
/** @copydoc nk_bilinear_f16 */
|
|
291
|
+
NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
292
|
+
nk_f32_t *result);
|
|
293
|
+
/** @copydoc nk_mahalanobis_f16 */
|
|
294
|
+
NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
295
|
+
nk_f32_t *result);
|
|
296
|
+
/** @copydoc nk_bilinear_bf16 */
|
|
297
|
+
NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
298
|
+
nk_f32_t *result);
|
|
299
|
+
/** @copydoc nk_mahalanobis_bf16 */
|
|
300
|
+
NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
301
|
+
nk_f32_t *result);
|
|
302
|
+
#endif // NK_TARGET_RVV
|
|
303
|
+
|
|
304
|
+
/**
|
|
305
|
+
* @brief Returns the output dtype for bilinear forms.
|
|
306
|
+
*/
|
|
307
|
+
NK_INTERNAL nk_dtype_t nk_bilinear_output_dtype(nk_dtype_t dtype) {
|
|
308
|
+
switch (dtype) {
|
|
309
|
+
case nk_f64_k: return nk_f64_k;
|
|
310
|
+
case nk_f32_k: return nk_f64_k;
|
|
311
|
+
case nk_f16_k: return nk_f32_k;
|
|
312
|
+
case nk_bf16_k: return nk_f32_k;
|
|
313
|
+
case nk_f64c_k: return nk_f64c_k;
|
|
314
|
+
case nk_f32c_k: return nk_f64c_k;
|
|
315
|
+
case nk_f16c_k: return nk_f32c_k;
|
|
316
|
+
case nk_bf16c_k: return nk_f32c_k;
|
|
317
|
+
default: return nk_dtype_unknown_k;
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
/**
|
|
322
|
+
* @brief Returns the output dtype for Mahalanobis metrics.
|
|
323
|
+
*/
|
|
324
|
+
NK_INTERNAL nk_dtype_t nk_mahalanobis_output_dtype(nk_dtype_t dtype) {
|
|
325
|
+
switch (dtype) {
|
|
326
|
+
case nk_f64_k: return nk_f64_k;
|
|
327
|
+
case nk_f32_k: return nk_f64_k;
|
|
328
|
+
case nk_f16_k: return nk_f32_k;
|
|
329
|
+
case nk_bf16_k: return nk_f32_k;
|
|
330
|
+
default: return nk_dtype_unknown_k;
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
#if defined(__cplusplus)
|
|
335
|
+
} // extern "C"
|
|
336
|
+
#endif
|
|
337
|
+
|
|
338
|
+
#include "numkong/curved/serial.h"
|
|
339
|
+
#include "numkong/curved/neon.h"
|
|
340
|
+
#include "numkong/curved/neonhalf.h"
|
|
341
|
+
#include "numkong/curved/neonbfdot.h"
|
|
342
|
+
#include "numkong/curved/smef64.h"
|
|
343
|
+
#include "numkong/curved/haswell.h"
|
|
344
|
+
#include "numkong/curved/skylake.h"
|
|
345
|
+
#include "numkong/curved/genoa.h"
|
|
346
|
+
#include "numkong/curved/rvv.h"
|
|
347
|
+
|
|
348
|
+
#if defined(__cplusplus)
|
|
349
|
+
extern "C" {
|
|
350
|
+
#endif
|
|
351
|
+
|
|
352
|
+
#if !NK_DYNAMIC_DISPATCH
|
|
353
|
+
|
|
354
|
+
NK_PUBLIC void nk_bilinear_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n, nk_f64_t *result) {
|
|
355
|
+
#if NK_TARGET_SKYLAKE
|
|
356
|
+
nk_bilinear_f64_skylake(a, b, c, n, result);
|
|
357
|
+
#elif NK_TARGET_SMEF64
|
|
358
|
+
nk_bilinear_f64_smef64(a, b, c, n, result);
|
|
359
|
+
#elif NK_TARGET_RVV
|
|
360
|
+
nk_bilinear_f64_rvv(a, b, c, n, result);
|
|
361
|
+
#else
|
|
362
|
+
nk_bilinear_f64_serial(a, b, c, n, result);
|
|
363
|
+
#endif
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
NK_PUBLIC void nk_bilinear_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n, nk_f64_t *result) {
|
|
367
|
+
#if NK_TARGET_SKYLAKE
|
|
368
|
+
nk_bilinear_f32_skylake(a, b, c, n, result);
|
|
369
|
+
#elif NK_TARGET_SMEF64
|
|
370
|
+
nk_bilinear_f32_smef64(a, b, c, n, result);
|
|
371
|
+
#elif NK_TARGET_HASWELL
|
|
372
|
+
nk_bilinear_f32_haswell(a, b, c, n, result);
|
|
373
|
+
#elif NK_TARGET_NEON
|
|
374
|
+
nk_bilinear_f32_neon(a, b, c, n, result);
|
|
375
|
+
#elif NK_TARGET_RVV
|
|
376
|
+
nk_bilinear_f32_rvv(a, b, c, n, result);
|
|
377
|
+
#else
|
|
378
|
+
nk_bilinear_f32_serial(a, b, c, n, result);
|
|
379
|
+
#endif
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
NK_PUBLIC void nk_bilinear_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n, nk_f32_t *result) {
|
|
383
|
+
#if NK_TARGET_HASWELL
|
|
384
|
+
nk_bilinear_f16_haswell(a, b, c, n, result);
|
|
385
|
+
#elif NK_TARGET_NEONHALF
|
|
386
|
+
nk_bilinear_f16_neonhalf(a, b, c, n, result);
|
|
387
|
+
#elif NK_TARGET_RVV
|
|
388
|
+
nk_bilinear_f16_rvv(a, b, c, n, result);
|
|
389
|
+
#else
|
|
390
|
+
nk_bilinear_f16_serial(a, b, c, n, result);
|
|
391
|
+
#endif
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
NK_PUBLIC void nk_bilinear_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
395
|
+
nk_f32_t *result) {
|
|
396
|
+
#if NK_TARGET_GENOA
|
|
397
|
+
nk_bilinear_bf16_genoa(a, b, c, n, result);
|
|
398
|
+
#elif NK_TARGET_HASWELL
|
|
399
|
+
nk_bilinear_bf16_haswell(a, b, c, n, result);
|
|
400
|
+
#elif NK_TARGET_NEONBFDOT
|
|
401
|
+
nk_bilinear_bf16_neonbfdot(a, b, c, n, result);
|
|
402
|
+
#elif NK_TARGET_RVV
|
|
403
|
+
nk_bilinear_bf16_rvv(a, b, c, n, result);
|
|
404
|
+
#else
|
|
405
|
+
nk_bilinear_bf16_serial(a, b, c, n, result);
|
|
406
|
+
#endif
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
NK_PUBLIC void nk_bilinear_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
410
|
+
nk_f64c_t *results) {
|
|
411
|
+
#if NK_TARGET_SKYLAKE
|
|
412
|
+
nk_bilinear_f64c_skylake(a, b, c, n, results);
|
|
413
|
+
#elif NK_TARGET_SMEF64
|
|
414
|
+
nk_bilinear_f64c_smef64(a, b, c, n, results);
|
|
415
|
+
#else
|
|
416
|
+
nk_bilinear_f64c_serial(a, b, c, n, results);
|
|
417
|
+
#endif
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
NK_PUBLIC void nk_bilinear_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
421
|
+
nk_f64c_t *results) {
|
|
422
|
+
#if NK_TARGET_SKYLAKE
|
|
423
|
+
nk_bilinear_f32c_skylake(a, b, c, n, results);
|
|
424
|
+
#elif NK_TARGET_SMEF64
|
|
425
|
+
nk_bilinear_f32c_smef64(a, b, c, n, results);
|
|
426
|
+
#elif NK_TARGET_NEON
|
|
427
|
+
nk_bilinear_f32c_neon(a, b, c, n, results);
|
|
428
|
+
#else
|
|
429
|
+
nk_bilinear_f32c_serial(a, b, c, n, results);
|
|
430
|
+
#endif
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
NK_PUBLIC void nk_bilinear_f16c(nk_f16c_t const *a, nk_f16c_t const *b, nk_f16c_t const *c, nk_size_t n,
|
|
434
|
+
nk_f32c_t *results) {
|
|
435
|
+
#if NK_TARGET_NEONHALF
|
|
436
|
+
nk_bilinear_f16c_neonhalf(a, b, c, n, results);
|
|
437
|
+
#else
|
|
438
|
+
nk_bilinear_f16c_serial(a, b, c, n, results);
|
|
439
|
+
#endif
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
NK_PUBLIC void nk_bilinear_bf16c(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
|
|
443
|
+
nk_f32c_t *results) {
|
|
444
|
+
#if NK_TARGET_GENOA
|
|
445
|
+
nk_bilinear_bf16c_genoa(a, b, c, n, results);
|
|
446
|
+
#elif NK_TARGET_NEONBFDOT
|
|
447
|
+
nk_bilinear_bf16c_neonbfdot(a, b, c, n, results);
|
|
448
|
+
#else
|
|
449
|
+
nk_bilinear_bf16c_serial(a, b, c, n, results);
|
|
450
|
+
#endif
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
NK_PUBLIC void nk_mahalanobis_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
454
|
+
nk_f64_t *result) {
|
|
455
|
+
#if NK_TARGET_SKYLAKE
|
|
456
|
+
nk_mahalanobis_f64_skylake(a, b, c, n, result);
|
|
457
|
+
#elif NK_TARGET_SMEF64
|
|
458
|
+
nk_mahalanobis_f64_smef64(a, b, c, n, result);
|
|
459
|
+
#elif NK_TARGET_RVV
|
|
460
|
+
nk_mahalanobis_f64_rvv(a, b, c, n, result);
|
|
461
|
+
#else
|
|
462
|
+
nk_mahalanobis_f64_serial(a, b, c, n, result);
|
|
463
|
+
#endif
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
NK_PUBLIC void nk_mahalanobis_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
467
|
+
nk_f64_t *result) {
|
|
468
|
+
#if NK_TARGET_SKYLAKE
|
|
469
|
+
nk_mahalanobis_f32_skylake(a, b, c, n, result);
|
|
470
|
+
#elif NK_TARGET_SMEF64
|
|
471
|
+
nk_mahalanobis_f32_smef64(a, b, c, n, result);
|
|
472
|
+
#elif NK_TARGET_HASWELL
|
|
473
|
+
nk_mahalanobis_f32_haswell(a, b, c, n, result);
|
|
474
|
+
#elif NK_TARGET_NEON
|
|
475
|
+
nk_mahalanobis_f32_neon(a, b, c, n, result);
|
|
476
|
+
#elif NK_TARGET_RVV
|
|
477
|
+
nk_mahalanobis_f32_rvv(a, b, c, n, result);
|
|
478
|
+
#else
|
|
479
|
+
nk_mahalanobis_f32_serial(a, b, c, n, result);
|
|
480
|
+
#endif
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
NK_PUBLIC void nk_mahalanobis_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
484
|
+
nk_f32_t *result) {
|
|
485
|
+
#if NK_TARGET_HASWELL
|
|
486
|
+
nk_mahalanobis_f16_haswell(a, b, c, n, result);
|
|
487
|
+
#elif NK_TARGET_NEONHALF
|
|
488
|
+
nk_mahalanobis_f16_neonhalf(a, b, c, n, result);
|
|
489
|
+
#elif NK_TARGET_RVV
|
|
490
|
+
nk_mahalanobis_f16_rvv(a, b, c, n, result);
|
|
491
|
+
#else
|
|
492
|
+
nk_mahalanobis_f16_serial(a, b, c, n, result);
|
|
493
|
+
#endif
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
NK_PUBLIC void nk_mahalanobis_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
497
|
+
nk_f32_t *result) {
|
|
498
|
+
#if NK_TARGET_GENOA
|
|
499
|
+
nk_mahalanobis_bf16_genoa(a, b, c, n, result);
|
|
500
|
+
#elif NK_TARGET_HASWELL
|
|
501
|
+
nk_mahalanobis_bf16_haswell(a, b, c, n, result);
|
|
502
|
+
#elif NK_TARGET_NEONBFDOT
|
|
503
|
+
nk_mahalanobis_bf16_neonbfdot(a, b, c, n, result);
|
|
504
|
+
#elif NK_TARGET_RVV
|
|
505
|
+
nk_mahalanobis_bf16_rvv(a, b, c, n, result);
|
|
506
|
+
#else
|
|
507
|
+
nk_mahalanobis_bf16_serial(a, b, c, n, result);
|
|
508
|
+
#endif
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
#endif // !NK_DYNAMIC_DISPATCH
|
|
512
|
+
|
|
513
|
+
#if defined(__cplusplus)
|
|
514
|
+
} // extern "C"
|
|
515
|
+
#endif
|
|
516
|
+
|
|
517
|
+
#endif // NK_CURVED_H
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Curved-space kernels: bilinear, mahalanobis.
|
|
3
|
+
* @file include/numkong/curved.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 5, 2026
|
|
6
|
+
*/
|
|
7
|
+
#ifndef NK_CURVED_HPP
|
|
8
|
+
#define NK_CURVED_HPP
|
|
9
|
+
|
|
10
|
+
#include <cstdint> // `std::uint32_t`
|
|
11
|
+
#include <type_traits> // `std::is_same_v`
|
|
12
|
+
|
|
13
|
+
#include "numkong/curved.h"
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.hpp"
|
|
16
|
+
|
|
17
|
+
namespace ashvardanian::numkong {
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* @brief Bilinear form: aᵀ × C × b where C is a d×d matrix (row-major)
|
|
21
|
+
* @param[in] a,b Input vectors of length d
|
|
22
|
+
* @param[in] c Matrix of size dxd (row-major)
|
|
23
|
+
* @param[in] d Number of dimensions
|
|
24
|
+
* @param[out] r Pointer to output value
|
|
25
|
+
*
|
|
26
|
+
* @tparam in_type_ Input vector element type (real or complex)
|
|
27
|
+
* @tparam result_type_ Accumulator type, defaults to `in_type_::curved_result_t`
|
|
28
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
29
|
+
*
|
|
30
|
+
* @note For weighted inner products, Mahalanobis distance, etc.
|
|
31
|
+
*/
|
|
32
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
|
|
33
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
34
|
+
void bilinear(in_type_ const *a, in_type_ const *b, in_type_ const *c, std::size_t d, result_type_ *r) noexcept {
|
|
35
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
36
|
+
std::is_same_v<result_type_, typename in_type_::curved_result_t>;
|
|
37
|
+
|
|
38
|
+
// Real types
|
|
39
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_bilinear_f64(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
40
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
41
|
+
nk_bilinear_f32(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
42
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
43
|
+
nk_bilinear_f16(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
44
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
45
|
+
nk_bilinear_bf16(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
46
|
+
// Complex types
|
|
47
|
+
else if constexpr (std::is_same_v<in_type_, f64c_t> && simd)
|
|
48
|
+
nk_bilinear_f64c(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
49
|
+
else if constexpr (std::is_same_v<in_type_, f32c_t> && simd)
|
|
50
|
+
nk_bilinear_f32c(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
51
|
+
else if constexpr (std::is_same_v<in_type_, f16c_t> && simd)
|
|
52
|
+
nk_bilinear_f16c(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
53
|
+
else if constexpr (std::is_same_v<in_type_, bf16c_t> && simd)
|
|
54
|
+
nk_bilinear_bf16c(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
55
|
+
// Scalar fallback
|
|
56
|
+
else {
|
|
57
|
+
result_type_ sum {};
|
|
58
|
+
for (std::size_t i = 0; i < d; i++) {
|
|
59
|
+
for (std::size_t j = 0; j < d; j++) {
|
|
60
|
+
sum = sum + result_type_(a[i]) * result_type_(c[i * d + j]) * result_type_(b[j]);
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
*r = sum;
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
/**
|
|
68
|
+
* @brief Mahalanobis distance: √((a−b)ᵀ × C × (a−b)) where C is a d×d matrix (row-major)
|
|
69
|
+
* @param[in] a,b Input vectors of length d
|
|
70
|
+
* @param[in] c Covariance matrix of size dxd (row-major)
|
|
71
|
+
* @param[in] d Number of dimensions
|
|
72
|
+
* @param[out] r Pointer to output distance value
|
|
73
|
+
*
|
|
74
|
+
* @tparam in_type_ Input vector element type
|
|
75
|
+
* @tparam result_type_ Accumulator type, defaults to `in_type_::curved_result_t`
|
|
76
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
77
|
+
*/
|
|
78
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
|
|
79
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
80
|
+
void mahalanobis(in_type_ const *a, in_type_ const *b, in_type_ const *c, std::size_t d, result_type_ *r) noexcept {
|
|
81
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
82
|
+
std::is_same_v<result_type_, typename in_type_::curved_result_t>;
|
|
83
|
+
|
|
84
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd)
|
|
85
|
+
nk_mahalanobis_f64(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
86
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
87
|
+
nk_mahalanobis_f32(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
88
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
89
|
+
nk_mahalanobis_f16(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
90
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
91
|
+
nk_mahalanobis_bf16(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
|
|
92
|
+
// Scalar fallback
|
|
93
|
+
else {
|
|
94
|
+
result_type_ sum {};
|
|
95
|
+
for (std::size_t i = 0; i < d; i++) {
|
|
96
|
+
result_type_ di = result_type_(a[i]) - result_type_(b[i]);
|
|
97
|
+
for (std::size_t j = 0; j < d; j++) {
|
|
98
|
+
result_type_ dj = result_type_(a[j]) - result_type_(b[j]);
|
|
99
|
+
sum = sum + di * result_type_(c[i * d + j]) * dj;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
*r = sum.sqrt();
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
} // namespace ashvardanian::numkong
|
|
107
|
+
|
|
108
|
+
#include "numkong/tensor.hpp"
|
|
109
|
+
|
|
110
|
+
namespace ashvardanian::numkong {
|
|
111
|
+
|
|
112
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
|
|
113
|
+
allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_,
|
|
114
|
+
std::size_t max_rank_c_>
|
|
115
|
+
void bilinear(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b,
|
|
116
|
+
tensor_view<in_type_, max_rank_c_> c, std::size_t d, result_type_ *r) noexcept {
|
|
117
|
+
bilinear<in_type_, result_type_, allow_simd_>(a.data(), b.data(), c.data(), d, r);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
|
|
121
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
122
|
+
void bilinear(vector_view<in_type_> a, vector_view<in_type_> b, vector_view<in_type_> c, std::size_t d,
|
|
123
|
+
result_type_ *r) noexcept {
|
|
124
|
+
bilinear<in_type_, result_type_, allow_simd_>(a.data(), b.data(), c.data(), d, r);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
|
|
128
|
+
allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_,
|
|
129
|
+
std::size_t max_rank_c_>
|
|
130
|
+
void mahalanobis(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b,
|
|
131
|
+
tensor_view<in_type_, max_rank_c_> c, std::size_t d, result_type_ *r) noexcept {
|
|
132
|
+
mahalanobis<in_type_, result_type_, allow_simd_>(a.data(), b.data(), c.data(), d, r);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
|
|
136
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
137
|
+
void mahalanobis(vector_view<in_type_> a, vector_view<in_type_> b, vector_view<in_type_> c, std::size_t d,
|
|
138
|
+
result_type_ *r) noexcept {
|
|
139
|
+
mahalanobis<in_type_, result_type_, allow_simd_>(a.data(), b.data(), c.data(), d, r);
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
} // namespace ashvardanian::numkong
|
|
143
|
+
|
|
144
|
+
#endif // NK_CURVED_HPP
|