numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
package/c/numkong.c
ADDED
|
@@ -0,0 +1,950 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Dynamic dispatch library for NumKong.
|
|
3
|
+
* @file c/numkong.c
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 13, 2024
|
|
6
|
+
*/
|
|
7
|
+
#include "dispatch.h"
|
|
8
|
+
|
|
9
|
+
/* MemorySanitizer cannot track initialization through SIMD intrinsics (SVE, NEON, SSE, AVX),
|
|
10
|
+
* causing false-positive "use-of-uninitialized-value" reports. We unpoison results after dispatch.
|
|
11
|
+
*/
|
|
12
|
+
#if defined(__has_feature)
|
|
13
|
+
#if __has_feature(memory_sanitizer)
|
|
14
|
+
#include <sanitizer/msan_interface.h>
|
|
15
|
+
#define nk_unpoison_(ptr, size) __msan_unpoison((ptr), (size))
|
|
16
|
+
#endif
|
|
17
|
+
#endif
|
|
18
|
+
#ifndef nk_unpoison_
|
|
19
|
+
#define nk_unpoison_(ptr, size) nk_unused_(ptr), nk_unused_(size)
|
|
20
|
+
#endif
|
|
21
|
+
|
|
22
|
+
#ifdef __cplusplus
|
|
23
|
+
extern "C" {
|
|
24
|
+
#endif
|
|
25
|
+
|
|
26
|
+
// WASM capability detection for standalone Emscripten builds.
|
|
27
|
+
// EM_JS embeds JavaScript probes for runtime SIMD detection. It only works in
|
|
28
|
+
// standalone builds — Pyodide side modules cannot use EM_JS (the linker fails
|
|
29
|
+
// with undefined ___em_js__* symbols). Pyodide builds define NK_PYODIDE_SIDE_MODULE
|
|
30
|
+
// and fall through to compile-time detection in capabilities.h instead.
|
|
31
|
+
#if defined(__EMSCRIPTEN__) && NK_DYNAMIC_DISPATCH && !defined(NK_PYODIDE_SIDE_MODULE)
|
|
32
|
+
#include <emscripten.h>
|
|
33
|
+
|
|
34
|
+
// EM_JS expands to an empty-parameter-list declaration `()` and a trailing `;`,
|
|
35
|
+
// which trigger `-Wstrict-prototypes` and `-Wextra-semi` under Clang/Emscripten.
|
|
36
|
+
#if defined(__clang__)
|
|
37
|
+
#pragma clang diagnostic push
|
|
38
|
+
#pragma clang diagnostic ignored "-Wstrict-prototypes"
|
|
39
|
+
#pragma clang diagnostic ignored "-Wextra-semi"
|
|
40
|
+
#endif
|
|
41
|
+
|
|
42
|
+
EM_JS(int, nk_has_v128, (), {
|
|
43
|
+
var test = new Uint8Array([
|
|
44
|
+
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7b, 0x03,
|
|
45
|
+
0x02, 0x01, 0x00, 0x0a, 0x09, 0x01, 0x07, 0x00, 0xfd, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x0b
|
|
46
|
+
]);
|
|
47
|
+
try {
|
|
48
|
+
return WebAssembly.validate(test) ? 1 : 0;
|
|
49
|
+
}
|
|
50
|
+
catch (e) {
|
|
51
|
+
return 0;
|
|
52
|
+
}
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
EM_JS(int, nk_has_relaxed, (), {
|
|
56
|
+
var test = new Uint8Array([
|
|
57
|
+
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x01, 0x60, 0x03,
|
|
58
|
+
0x7b, 0x7b, 0x7b, 0x01, 0x7b, 0x03, 0x02, 0x01, 0x00, 0x0a, 0x09, 0x01, 0x07,
|
|
59
|
+
0x00, 0x20, 0x00, 0x20, 0x01, 0x20, 0x02, 0xfd, 0xaf, 0x01, 0x0b
|
|
60
|
+
]);
|
|
61
|
+
try {
|
|
62
|
+
return WebAssembly.validate(test) ? 1 : 0;
|
|
63
|
+
}
|
|
64
|
+
catch (e) {
|
|
65
|
+
return 0;
|
|
66
|
+
}
|
|
67
|
+
});
|
|
68
|
+
|
|
69
|
+
#if defined(__clang__)
|
|
70
|
+
#pragma clang diagnostic pop
|
|
71
|
+
#endif
|
|
72
|
+
#endif // __EMSCRIPTEN__ && NK_DYNAMIC_DISPATCH && !NK_PYODIDE_SIDE_MODULE
|
|
73
|
+
|
|
74
|
+
/**
|
|
75
|
+
* @brief Fill memory with 0xFF - produces NaN for floats, -1 for signed integers, and MAX for unsigned.
|
|
76
|
+
* Avoids libc dependency on memset.
|
|
77
|
+
*/
|
|
78
|
+
NK_INTERNAL void nk_fill_error_(void *ptr, nk_size_t bytes) {
|
|
79
|
+
nk_u8_t *p = (nk_u8_t *)ptr;
|
|
80
|
+
while (bytes--) *p++ = 0xFF;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
void nk_error_dense_(void const *a, void const *b, nk_size_t n, void *d) {
|
|
84
|
+
nk_unused_(a);
|
|
85
|
+
nk_unused_(b);
|
|
86
|
+
nk_unused_(n);
|
|
87
|
+
nk_fill_error_(d, sizeof(nk_fmax_t));
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
void nk_error_sparse_intersect_(void const *a, void const *b, nk_size_t a_length, nk_size_t b_length, void *result,
|
|
91
|
+
nk_size_t *count) {
|
|
92
|
+
nk_unused_(a);
|
|
93
|
+
nk_unused_(b);
|
|
94
|
+
nk_unused_(a_length);
|
|
95
|
+
nk_unused_(b_length);
|
|
96
|
+
nk_unused_(result);
|
|
97
|
+
if (count) *count = 0;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
void nk_error_sparse_dot_(void const *a, void const *b, void const *a_weights, void const *b_weights,
|
|
101
|
+
nk_size_t a_length, nk_size_t b_length, void *product) {
|
|
102
|
+
nk_unused_(a);
|
|
103
|
+
nk_unused_(b);
|
|
104
|
+
nk_unused_(a_weights);
|
|
105
|
+
nk_unused_(b_weights);
|
|
106
|
+
nk_unused_(a_length);
|
|
107
|
+
nk_unused_(b_length);
|
|
108
|
+
nk_fill_error_(product, sizeof(nk_fmax_t));
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
void nk_error_curved_(void const *a, void const *b, void const *c, nk_size_t n, void *result) {
|
|
112
|
+
nk_unused_(a);
|
|
113
|
+
nk_unused_(b);
|
|
114
|
+
nk_unused_(c);
|
|
115
|
+
nk_unused_(n);
|
|
116
|
+
nk_fill_error_(result, sizeof(nk_fmax_t));
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
void nk_error_geospatial_(void const *a_lats, void const *a_lons, void const *b_lats, void const *b_lons, nk_size_t n,
|
|
120
|
+
void *results) {
|
|
121
|
+
nk_unused_(a_lats);
|
|
122
|
+
nk_unused_(a_lons);
|
|
123
|
+
nk_unused_(b_lats);
|
|
124
|
+
nk_unused_(b_lons);
|
|
125
|
+
nk_unused_(n);
|
|
126
|
+
nk_fill_error_(results, sizeof(nk_fmax_t));
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
void nk_error_each_fma_(void const *a, void const *b, void const *c, nk_size_t n, void const *alpha, void const *beta,
|
|
130
|
+
void *result) {
|
|
131
|
+
nk_unused_(a);
|
|
132
|
+
nk_unused_(b);
|
|
133
|
+
nk_unused_(c);
|
|
134
|
+
nk_unused_(alpha);
|
|
135
|
+
nk_unused_(beta);
|
|
136
|
+
nk_fill_error_(result, n * sizeof(nk_fmax_t));
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
void nk_error_each_blend_(void const *a, void const *b, nk_size_t n, void const *alpha, void const *beta,
|
|
140
|
+
void *result) {
|
|
141
|
+
nk_unused_(a);
|
|
142
|
+
nk_unused_(b);
|
|
143
|
+
nk_unused_(alpha);
|
|
144
|
+
nk_unused_(beta);
|
|
145
|
+
nk_fill_error_(result, n * sizeof(nk_fmax_t));
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
void nk_error_each_scale_(void const *a, nk_size_t n, void const *alpha, void const *beta, void *result) {
|
|
149
|
+
nk_unused_(a);
|
|
150
|
+
nk_unused_(alpha);
|
|
151
|
+
nk_unused_(beta);
|
|
152
|
+
nk_fill_error_(result, n * sizeof(nk_fmax_t));
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
void nk_error_each_sum_(void const *a, void const *b, nk_size_t n, void *y) {
|
|
156
|
+
nk_unused_(a);
|
|
157
|
+
nk_unused_(b);
|
|
158
|
+
nk_fill_error_(y, n * sizeof(nk_fmax_t));
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
void nk_error_trigonometry_(void const *x, nk_size_t n, void *y) {
|
|
162
|
+
nk_unused_(x);
|
|
163
|
+
nk_fill_error_(y, n * sizeof(nk_fmax_t));
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
void nk_error_mesh_(void const *a, void const *b, nk_size_t n, void *a_centroid, void *b_centroid, void *rotation,
|
|
167
|
+
void *scale, void *result) {
|
|
168
|
+
nk_unused_(a);
|
|
169
|
+
nk_unused_(b);
|
|
170
|
+
nk_unused_(n);
|
|
171
|
+
if (a_centroid) nk_fill_error_(a_centroid, 3 * sizeof(nk_fmax_t));
|
|
172
|
+
if (b_centroid) nk_fill_error_(b_centroid, 3 * sizeof(nk_fmax_t));
|
|
173
|
+
if (rotation) nk_fill_error_(rotation, 9 * sizeof(nk_fmax_t));
|
|
174
|
+
if (scale) nk_fill_error_(scale, sizeof(nk_fmax_t));
|
|
175
|
+
nk_fill_error_(result, sizeof(nk_fmax_t));
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
void nk_error_reduce_moments_(void const *data, nk_size_t count, nk_size_t stride_bytes, void *sum_ptr,
|
|
179
|
+
void *sumsq_ptr) {
|
|
180
|
+
nk_unused_(data), nk_unused_(count), nk_unused_(stride_bytes), nk_unused_(sum_ptr), nk_unused_(sumsq_ptr);
|
|
181
|
+
nk_fill_error_(sum_ptr, sizeof(nk_fmax_t));
|
|
182
|
+
nk_fill_error_(sumsq_ptr, sizeof(nk_fmax_t));
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
void nk_error_reduce_minmax_(void const *data, nk_size_t count, nk_size_t stride_bytes, void *min_value,
|
|
186
|
+
nk_size_t *min_index, void *max_value, nk_size_t *max_index) {
|
|
187
|
+
nk_unused_(data), nk_unused_(count), nk_unused_(stride_bytes), nk_unused_(min_value), nk_unused_(min_index),
|
|
188
|
+
nk_unused_(max_value), nk_unused_(max_index);
|
|
189
|
+
nk_fill_error_(min_value, sizeof(nk_fmax_t));
|
|
190
|
+
nk_fill_error_(min_index, sizeof(nk_size_t));
|
|
191
|
+
nk_fill_error_(max_value, sizeof(nk_fmax_t));
|
|
192
|
+
nk_fill_error_(max_index, sizeof(nk_size_t));
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
nk_size_t nk_error_packed_size_(nk_size_t n, nk_size_t k) {
|
|
196
|
+
nk_unused_(n);
|
|
197
|
+
nk_unused_(k);
|
|
198
|
+
return 0;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
void nk_error_pack_(void const *b, nk_size_t n, nk_size_t k, nk_size_t b_stride, void *b_packed) {
|
|
202
|
+
nk_unused_(b);
|
|
203
|
+
nk_unused_(n);
|
|
204
|
+
nk_unused_(k);
|
|
205
|
+
nk_unused_(b_stride);
|
|
206
|
+
nk_unused_(b_packed);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
void nk_error_dots_(void const *a, void const *b_packed, void *c, nk_size_t m, nk_size_t n, nk_size_t k,
|
|
210
|
+
nk_size_t a_stride, nk_size_t c_stride) {
|
|
211
|
+
nk_unused_(a);
|
|
212
|
+
nk_unused_(b_packed);
|
|
213
|
+
nk_unused_(k);
|
|
214
|
+
nk_unused_(a_stride);
|
|
215
|
+
for (nk_size_t row = 0; row < m; ++row) nk_fill_error_((nk_u8_t *)c + row * c_stride, n * sizeof(nk_fmax_t));
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
void nk_error_dots_symmetric_(void const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, void *result,
|
|
219
|
+
nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
220
|
+
nk_unused_(vectors);
|
|
221
|
+
nk_unused_(depth);
|
|
222
|
+
nk_unused_(stride);
|
|
223
|
+
nk_unused_(row_start);
|
|
224
|
+
nk_unused_(row_count);
|
|
225
|
+
for (nk_size_t row = 0; row < n_vectors; ++row)
|
|
226
|
+
nk_fill_error_((nk_u8_t *)result + row * result_stride, n_vectors * sizeof(nk_fmax_t));
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
// Global dispatch table - 64-byte aligned for cache performance
|
|
230
|
+
// Type defined in dispatch.h, made non-static for access from dtype files
|
|
231
|
+
NK_ALIGN64 nk_implementations_t nk_dispatch_table;
|
|
232
|
+
|
|
233
|
+
// Direct dispatch macros using central dispatch table (no lazy initialization)
|
|
234
|
+
#define nk_dispatch_dense_(name, extension, input_type, output_type) \
|
|
235
|
+
NK_DYNAMIC void nk_##name##_##extension(nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_size_t n, \
|
|
236
|
+
nk_##output_type##_t *results) { \
|
|
237
|
+
nk_dispatch_table.name##_##extension(a, b, n, (void *)results); \
|
|
238
|
+
nk_unpoison_((void *)results, sizeof(nk_##output_type##_t)); \
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
#define nk_dispatch_sparse_(name, extension, type) \
|
|
242
|
+
NK_DYNAMIC void nk_##name##_##extension(nk_##type##_t const *a, nk_##type##_t const *b, nk_size_t a_length, \
|
|
243
|
+
nk_size_t b_length, nk_##type##_t *result, nk_size_t *count) { \
|
|
244
|
+
nk_dispatch_table.name##_##extension(a, b, a_length, b_length, (void *)result, count); \
|
|
245
|
+
nk_unpoison_(count, sizeof(nk_size_t)); \
|
|
246
|
+
nk_unpoison_((void *)result, (*count) * sizeof(nk_##type##_t)); \
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
#define nk_dispatch_sparse_dot_(name, index_type, weight_type, output_type) \
|
|
250
|
+
NK_DYNAMIC void nk_##name##_##index_type##weight_type(nk_##index_type##_t const *a, nk_##index_type##_t const *b, \
|
|
251
|
+
nk_##weight_type##_t const *a_weights, \
|
|
252
|
+
nk_##weight_type##_t const *b_weights, nk_size_t a_length, \
|
|
253
|
+
nk_size_t b_length, nk_##output_type##_t *product) { \
|
|
254
|
+
nk_dispatch_table.name##_##index_type##weight_type(a, b, a_weights, b_weights, a_length, b_length, \
|
|
255
|
+
(void *)product); \
|
|
256
|
+
nk_unpoison_((void *)product, sizeof(nk_##output_type##_t)); \
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
#define nk_dispatch_curved_(name, extension, output_type) \
|
|
260
|
+
NK_DYNAMIC void nk_##name##_##extension(nk_##extension##_t const *a, nk_##extension##_t const *b, \
|
|
261
|
+
nk_##extension##_t const *c, nk_size_t n, nk_##output_type##_t *result) { \
|
|
262
|
+
nk_dispatch_table.name##_##extension(a, b, c, n, (void *)result); \
|
|
263
|
+
nk_unpoison_((void *)result, sizeof(nk_##output_type##_t)); \
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
#define nk_dispatch_geospatial_(name, extension, output_type) \
|
|
267
|
+
NK_DYNAMIC void nk_##name##_##extension(nk_##extension##_t const *a_lats, nk_##extension##_t const *a_lons, \
|
|
268
|
+
nk_##extension##_t const *b_lats, nk_##extension##_t const *b_lons, \
|
|
269
|
+
nk_size_t n, nk_##output_type##_t *results) { \
|
|
270
|
+
nk_dispatch_table.name##_##extension(a_lats, a_lons, b_lats, b_lons, n, (void *)results); \
|
|
271
|
+
nk_unpoison_((void *)results, sizeof(nk_##output_type##_t)); \
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
#define nk_dispatch_each_fma_(extension, scalar_type) \
|
|
275
|
+
NK_DYNAMIC void nk_each_fma_##extension( \
|
|
276
|
+
nk_##extension##_t const *a, nk_##extension##_t const *b, nk_##extension##_t const *c, nk_size_t n, \
|
|
277
|
+
nk_##scalar_type##_t const *alpha, nk_##scalar_type##_t const *beta, nk_##extension##_t *result) { \
|
|
278
|
+
nk_dispatch_table.each_fma_##extension(a, b, c, n, (void const *)alpha, (void const *)beta, result); \
|
|
279
|
+
nk_unpoison_((void *)result, n * sizeof(nk_##extension##_t)); \
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
#define nk_dispatch_each_blend_(extension, scalar_type) \
|
|
283
|
+
NK_DYNAMIC void nk_each_blend_##extension(nk_##extension##_t const *a, nk_##extension##_t const *b, nk_size_t n, \
|
|
284
|
+
nk_##scalar_type##_t const *alpha, nk_##scalar_type##_t const *beta, \
|
|
285
|
+
nk_##extension##_t *result) { \
|
|
286
|
+
nk_dispatch_table.each_blend_##extension(a, b, n, (void const *)alpha, (void const *)beta, result); \
|
|
287
|
+
nk_unpoison_((void *)result, n * sizeof(nk_##extension##_t)); \
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
#define nk_dispatch_each_scale_(extension, scalar_type) \
|
|
291
|
+
NK_DYNAMIC void nk_each_scale_##extension(nk_##extension##_t const *a, nk_size_t n, \
|
|
292
|
+
nk_##scalar_type##_t const *alpha, nk_##scalar_type##_t const *beta, \
|
|
293
|
+
nk_##extension##_t *result) { \
|
|
294
|
+
nk_dispatch_table.each_scale_##extension(a, n, (void const *)alpha, (void const *)beta, result); \
|
|
295
|
+
nk_unpoison_((void *)result, n * sizeof(nk_##extension##_t)); \
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
#define nk_dispatch_each_sum_(extension) \
|
|
299
|
+
NK_DYNAMIC void nk_each_sum_##extension(nk_##extension##_t const *a, nk_##extension##_t const *b, nk_size_t n, \
|
|
300
|
+
nk_##extension##_t *result) { \
|
|
301
|
+
nk_dispatch_table.each_sum_##extension(a, b, n, result); \
|
|
302
|
+
nk_unpoison_((void *)result, n * sizeof(nk_##extension##_t)); \
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
#define nk_dispatch_trigonometry_(name, extension) \
|
|
306
|
+
NK_DYNAMIC void nk_each_##name##_##extension(nk_##extension##_t const *inputs, nk_size_t n, \
|
|
307
|
+
nk_##extension##_t *outputs) { \
|
|
308
|
+
nk_dispatch_table.each_##name##_##extension(inputs, n, outputs); \
|
|
309
|
+
nk_unpoison_((void *)outputs, n * sizeof(nk_##extension##_t)); \
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
#define nk_dispatch_mesh_(name, extension, transform_type, metric_type) \
|
|
313
|
+
NK_DYNAMIC void nk_##name##_##extension(nk_##extension##_t const *a, nk_##extension##_t const *b, nk_size_t n, \
|
|
314
|
+
nk_##transform_type##_t *a_centroid, nk_##transform_type##_t *b_centroid, \
|
|
315
|
+
nk_##transform_type##_t *rotation, nk_##transform_type##_t *scale, \
|
|
316
|
+
nk_##metric_type##_t *result) { \
|
|
317
|
+
nk_dispatch_table.name##_##extension(a, b, n, (void *)a_centroid, (void *)b_centroid, (void *)rotation, \
|
|
318
|
+
(void *)scale, (void *)result); \
|
|
319
|
+
if (a_centroid) nk_unpoison_((void *)a_centroid, 3 * sizeof(nk_##transform_type##_t)); \
|
|
320
|
+
if (b_centroid) nk_unpoison_((void *)b_centroid, 3 * sizeof(nk_##transform_type##_t)); \
|
|
321
|
+
if (rotation) nk_unpoison_((void *)rotation, 9 * sizeof(nk_##transform_type##_t)); \
|
|
322
|
+
if (scale) nk_unpoison_((void *)scale, sizeof(nk_##transform_type##_t)); \
|
|
323
|
+
nk_unpoison_((void *)result, sizeof(nk_##metric_type##_t)); \
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
#define nk_dispatch_reduce_moments_(extension, data_type, sum_type, sumsq_type) \
|
|
327
|
+
NK_DYNAMIC void nk_reduce_moments_##extension(data_type const *data, nk_size_t count, nk_size_t stride_bytes, \
|
|
328
|
+
sum_type *sum_ptr, sumsq_type *sumsq_ptr) { \
|
|
329
|
+
((nk_kernel_reduce_moments_punned_t)nk_dispatch_table.reduce_moments_##extension)(data, count, stride_bytes, \
|
|
330
|
+
sum_ptr, sumsq_ptr); \
|
|
331
|
+
nk_unpoison_((void *)sum_ptr, sizeof(sum_type)); \
|
|
332
|
+
nk_unpoison_((void *)sumsq_ptr, sizeof(sumsq_type)); \
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
#define nk_dispatch_reduce_minmax_(extension, data_type, minmax_type) \
|
|
336
|
+
NK_DYNAMIC void nk_reduce_minmax_##extension(data_type const *data, nk_size_t count, nk_size_t stride_bytes, \
|
|
337
|
+
minmax_type *min_value, nk_size_t *min_index, minmax_type *max_value, \
|
|
338
|
+
nk_size_t *max_index) { \
|
|
339
|
+
((nk_kernel_reduce_minmax_punned_t)nk_dispatch_table.reduce_minmax_##extension)( \
|
|
340
|
+
data, count, stride_bytes, min_value, min_index, max_value, max_index); \
|
|
341
|
+
nk_unpoison_((void *)min_value, sizeof(minmax_type)); \
|
|
342
|
+
nk_unpoison_(min_index, sizeof(nk_size_t)); \
|
|
343
|
+
nk_unpoison_((void *)max_value, sizeof(minmax_type)); \
|
|
344
|
+
nk_unpoison_(max_index, sizeof(nk_size_t)); \
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
#define nk_dispatch_cross_packed_size_(api_name, name, input_type, accum_type) \
|
|
348
|
+
NK_DYNAMIC nk_size_t nk_##api_name##_packed_size_##name(nk_size_t n, nk_size_t k) { \
|
|
349
|
+
return nk_dispatch_table.api_name##_packed_size_##name(n, k); \
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
#define nk_dispatch_cross_pack_(api_name, name, input_type, accum_type) \
|
|
353
|
+
NK_DYNAMIC void nk_##api_name##_pack_##name(nk_##input_type##_t const *b, nk_size_t n, nk_size_t k, \
|
|
354
|
+
nk_size_t b_stride, void *b_packed) { \
|
|
355
|
+
nk_dispatch_table.api_name##_pack_##name(b, n, k, b_stride, b_packed); \
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
#define nk_dispatch_cross_packed_(api_name, name, input_type, accum_type, output_type) \
|
|
359
|
+
NK_DYNAMIC void nk_##api_name##_packed_##name(nk_##input_type##_t const *a, void const *b_packed, \
|
|
360
|
+
nk_##output_type##_t *c, nk_size_t m, nk_size_t n, nk_size_t k, \
|
|
361
|
+
nk_size_t a_stride, nk_size_t c_stride) { \
|
|
362
|
+
nk_dispatch_table.api_name##_packed_##name(a, b_packed, c, m, n, k, a_stride, c_stride); \
|
|
363
|
+
nk_unpoison_((void *)c, m * c_stride); \
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
#define nk_dispatch_cross_symmetric_(api_name, name, input_type, output_type) \
|
|
367
|
+
NK_DYNAMIC void nk_##api_name##_symmetric_##name( \
|
|
368
|
+
nk_##input_type##_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, \
|
|
369
|
+
nk_##output_type##_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) { \
|
|
370
|
+
nk_dispatch_table.api_name##_symmetric_##name(vectors, n_vectors, depth, stride, result, result_stride, \
|
|
371
|
+
row_start, row_count); \
|
|
372
|
+
nk_unpoison_((void *)result, row_count * result_stride); \
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
#define nk_dispatch_maxsim_packed_(name, output_type) \
|
|
376
|
+
NK_DYNAMIC void nk_maxsim_packed_##name(void const *q_packed, void const *d_packed, nk_size_t n_q, nk_size_t n_d, \
|
|
377
|
+
nk_size_t depth, nk_##output_type##_t *result) { \
|
|
378
|
+
nk_dispatch_table.maxsim_packed_##name(q_packed, d_packed, n_q, n_d, depth, (void *)result); \
|
|
379
|
+
nk_unpoison_((void *)result, sizeof(nk_##output_type##_t)); \
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
// Dot products
|
|
383
|
+
nk_dispatch_dense_(dot, f64c, f64c, f64c)
|
|
384
|
+
nk_dispatch_dense_(dot, f32c, f32c, f64c)
|
|
385
|
+
nk_dispatch_dense_(dot, bf16c, bf16c, f32c)
|
|
386
|
+
nk_dispatch_dense_(dot, f16c, f16c, f32c)
|
|
387
|
+
nk_dispatch_dense_(dot, f64, f64, f64)
|
|
388
|
+
nk_dispatch_dense_(dot, f32, f32, f64)
|
|
389
|
+
nk_dispatch_dense_(dot, bf16, bf16, f32)
|
|
390
|
+
nk_dispatch_dense_(dot, f16, f16, f32)
|
|
391
|
+
nk_dispatch_dense_(dot, e5m2, e5m2, f32)
|
|
392
|
+
nk_dispatch_dense_(dot, e4m3, e4m3, f32)
|
|
393
|
+
nk_dispatch_dense_(dot, e3m2, e3m2, f32)
|
|
394
|
+
nk_dispatch_dense_(dot, e2m3, e2m3, f32)
|
|
395
|
+
nk_dispatch_dense_(dot, i8, i8, i32)
|
|
396
|
+
nk_dispatch_dense_(dot, i4, i4x2, i32)
|
|
397
|
+
nk_dispatch_dense_(dot, u8, u8, u32)
|
|
398
|
+
nk_dispatch_dense_(dot, u4, u4x2, u32)
|
|
399
|
+
nk_dispatch_dense_(dot, u1, u1x8, u32)
|
|
400
|
+
nk_dispatch_dense_(vdot, f64c, f64c, f64c)
|
|
401
|
+
nk_dispatch_dense_(vdot, f32c, f32c, f64c)
|
|
402
|
+
nk_dispatch_dense_(vdot, bf16c, bf16c, f32c)
|
|
403
|
+
nk_dispatch_dense_(vdot, f16c, f16c, f32c)
|
|
404
|
+
|
|
405
|
+
// Spatial distances
|
|
406
|
+
nk_dispatch_dense_(angular, f64, f64, f64)
|
|
407
|
+
nk_dispatch_dense_(angular, f32, f32, f64)
|
|
408
|
+
nk_dispatch_dense_(angular, bf16, bf16, f32)
|
|
409
|
+
nk_dispatch_dense_(angular, f16, f16, f32)
|
|
410
|
+
nk_dispatch_dense_(angular, e5m2, e5m2, f32)
|
|
411
|
+
nk_dispatch_dense_(angular, e4m3, e4m3, f32)
|
|
412
|
+
nk_dispatch_dense_(angular, e3m2, e3m2, f32)
|
|
413
|
+
nk_dispatch_dense_(angular, e2m3, e2m3, f32)
|
|
414
|
+
nk_dispatch_dense_(angular, i8, i8, f32)
|
|
415
|
+
nk_dispatch_dense_(angular, i4, i4x2, f32)
|
|
416
|
+
nk_dispatch_dense_(angular, u8, u8, f32)
|
|
417
|
+
nk_dispatch_dense_(angular, u4, u4x2, f32)
|
|
418
|
+
nk_dispatch_dense_(euclidean, f64, f64, f64)
|
|
419
|
+
nk_dispatch_dense_(euclidean, f32, f32, f64)
|
|
420
|
+
nk_dispatch_dense_(euclidean, bf16, bf16, f32)
|
|
421
|
+
nk_dispatch_dense_(euclidean, f16, f16, f32)
|
|
422
|
+
nk_dispatch_dense_(euclidean, e5m2, e5m2, f32)
|
|
423
|
+
nk_dispatch_dense_(euclidean, e4m3, e4m3, f32)
|
|
424
|
+
nk_dispatch_dense_(euclidean, e3m2, e3m2, f32)
|
|
425
|
+
nk_dispatch_dense_(euclidean, e2m3, e2m3, f32)
|
|
426
|
+
nk_dispatch_dense_(euclidean, i8, i8, f32)
|
|
427
|
+
nk_dispatch_dense_(euclidean, i4, i4x2, f32)
|
|
428
|
+
nk_dispatch_dense_(euclidean, u8, u8, f32)
|
|
429
|
+
nk_dispatch_dense_(euclidean, u4, u4x2, f32)
|
|
430
|
+
nk_dispatch_dense_(sqeuclidean, f64, f64, f64)
|
|
431
|
+
nk_dispatch_dense_(sqeuclidean, f32, f32, f64)
|
|
432
|
+
nk_dispatch_dense_(sqeuclidean, bf16, bf16, f32)
|
|
433
|
+
nk_dispatch_dense_(sqeuclidean, f16, f16, f32)
|
|
434
|
+
nk_dispatch_dense_(sqeuclidean, e5m2, e5m2, f32)
|
|
435
|
+
nk_dispatch_dense_(sqeuclidean, e4m3, e4m3, f32)
|
|
436
|
+
nk_dispatch_dense_(sqeuclidean, e3m2, e3m2, f32)
|
|
437
|
+
nk_dispatch_dense_(sqeuclidean, e2m3, e2m3, f32)
|
|
438
|
+
nk_dispatch_dense_(sqeuclidean, i8, i8, u32)
|
|
439
|
+
nk_dispatch_dense_(sqeuclidean, i4, i4x2, u32)
|
|
440
|
+
nk_dispatch_dense_(sqeuclidean, u8, u8, u32)
|
|
441
|
+
nk_dispatch_dense_(sqeuclidean, u4, u4x2, u32)
|
|
442
|
+
|
|
443
|
+
// Binary distances
|
|
444
|
+
nk_dispatch_dense_(hamming, u8, u8, u32)
|
|
445
|
+
nk_dispatch_dense_(hamming, u1, u1x8, u32)
|
|
446
|
+
nk_dispatch_dense_(jaccard, u32, u32, f32)
|
|
447
|
+
nk_dispatch_dense_(jaccard, u16, u16, f32)
|
|
448
|
+
nk_dispatch_dense_(jaccard, u1, u1x8, f32)
|
|
449
|
+
|
|
450
|
+
// Curved spaces
|
|
451
|
+
nk_dispatch_curved_(bilinear, f64c, f64c)
|
|
452
|
+
nk_dispatch_curved_(bilinear, f32c, f64c)
|
|
453
|
+
nk_dispatch_curved_(bilinear, bf16c, f32c)
|
|
454
|
+
nk_dispatch_curved_(bilinear, f16c, f32c)
|
|
455
|
+
nk_dispatch_curved_(bilinear, f64, f64)
|
|
456
|
+
nk_dispatch_curved_(bilinear, f32, f64)
|
|
457
|
+
nk_dispatch_curved_(bilinear, bf16, f32)
|
|
458
|
+
nk_dispatch_curved_(bilinear, f16, f32)
|
|
459
|
+
nk_dispatch_curved_(mahalanobis, f64, f64)
|
|
460
|
+
nk_dispatch_curved_(mahalanobis, f32, f64)
|
|
461
|
+
nk_dispatch_curved_(mahalanobis, bf16, f32)
|
|
462
|
+
nk_dispatch_curved_(mahalanobis, f16, f32)
|
|
463
|
+
|
|
464
|
+
// Geospatial distances
|
|
465
|
+
nk_dispatch_geospatial_(haversine, f64, f64)
|
|
466
|
+
nk_dispatch_geospatial_(haversine, f32, f32)
|
|
467
|
+
nk_dispatch_geospatial_(vincenty, f64, f64)
|
|
468
|
+
nk_dispatch_geospatial_(vincenty, f32, f32)
|
|
469
|
+
|
|
470
|
+
// Probability distributions
|
|
471
|
+
nk_dispatch_dense_(kld, f64, f64, f64)
|
|
472
|
+
nk_dispatch_dense_(kld, f32, f32, f64)
|
|
473
|
+
nk_dispatch_dense_(kld, bf16, bf16, f32)
|
|
474
|
+
nk_dispatch_dense_(kld, f16, f16, f32)
|
|
475
|
+
nk_dispatch_dense_(jsd, f64, f64, f64)
|
|
476
|
+
nk_dispatch_dense_(jsd, f32, f32, f64)
|
|
477
|
+
nk_dispatch_dense_(jsd, bf16, bf16, f32)
|
|
478
|
+
nk_dispatch_dense_(jsd, f16, f16, f32)
|
|
479
|
+
|
|
480
|
+
// Mesh alignment (RMSD, Kabsch, Umeyama)
|
|
481
|
+
nk_dispatch_mesh_(rmsd, f64, f64, f64)
|
|
482
|
+
nk_dispatch_mesh_(rmsd, f32, f32, f64)
|
|
483
|
+
nk_dispatch_mesh_(rmsd, bf16, f32, f32)
|
|
484
|
+
nk_dispatch_mesh_(rmsd, f16, f32, f32)
|
|
485
|
+
nk_dispatch_mesh_(kabsch, f64, f64, f64)
|
|
486
|
+
nk_dispatch_mesh_(kabsch, f32, f32, f64)
|
|
487
|
+
nk_dispatch_mesh_(kabsch, bf16, f32, f32)
|
|
488
|
+
nk_dispatch_mesh_(kabsch, f16, f32, f32)
|
|
489
|
+
nk_dispatch_mesh_(umeyama, f64, f64, f64)
|
|
490
|
+
nk_dispatch_mesh_(umeyama, f32, f32, f64)
|
|
491
|
+
nk_dispatch_mesh_(umeyama, bf16, f32, f32)
|
|
492
|
+
nk_dispatch_mesh_(umeyama, f16, f32, f32)
|
|
493
|
+
|
|
494
|
+
// Sparse sets
|
|
495
|
+
nk_dispatch_sparse_(sparse_intersect, u64, u64)
|
|
496
|
+
nk_dispatch_sparse_(sparse_intersect, u32, u32)
|
|
497
|
+
nk_dispatch_sparse_(sparse_intersect, u16, u16)
|
|
498
|
+
nk_dispatch_sparse_dot_(sparse_dot, u32, f32, f64)
|
|
499
|
+
nk_dispatch_sparse_dot_(sparse_dot, u16, bf16, f32)
|
|
500
|
+
|
|
501
|
+
// Element-wise operations
|
|
502
|
+
nk_dispatch_each_scale_(f64c, f64c)
|
|
503
|
+
nk_dispatch_each_scale_(f32c, f32c)
|
|
504
|
+
nk_dispatch_each_scale_(f64, f64)
|
|
505
|
+
nk_dispatch_each_scale_(f32, f32)
|
|
506
|
+
nk_dispatch_each_scale_(bf16, f32)
|
|
507
|
+
nk_dispatch_each_scale_(f16, f32)
|
|
508
|
+
nk_dispatch_each_scale_(e5m2, f32)
|
|
509
|
+
nk_dispatch_each_scale_(e4m3, f32)
|
|
510
|
+
nk_dispatch_each_scale_(e3m2, f32)
|
|
511
|
+
nk_dispatch_each_scale_(e2m3, f32)
|
|
512
|
+
nk_dispatch_each_scale_(i64, f64)
|
|
513
|
+
nk_dispatch_each_scale_(i32, f64)
|
|
514
|
+
nk_dispatch_each_scale_(i16, f32)
|
|
515
|
+
nk_dispatch_each_scale_(i8, f32)
|
|
516
|
+
nk_dispatch_each_scale_(u64, f64)
|
|
517
|
+
nk_dispatch_each_scale_(u32, f64)
|
|
518
|
+
nk_dispatch_each_scale_(u16, f32)
|
|
519
|
+
nk_dispatch_each_scale_(u8, f32)
|
|
520
|
+
nk_dispatch_each_sum_(f64c)
|
|
521
|
+
nk_dispatch_each_sum_(f32c)
|
|
522
|
+
nk_dispatch_each_sum_(f64)
|
|
523
|
+
nk_dispatch_each_sum_(f32)
|
|
524
|
+
nk_dispatch_each_sum_(bf16)
|
|
525
|
+
nk_dispatch_each_sum_(f16)
|
|
526
|
+
nk_dispatch_each_sum_(e5m2)
|
|
527
|
+
nk_dispatch_each_sum_(e4m3)
|
|
528
|
+
nk_dispatch_each_sum_(e3m2)
|
|
529
|
+
nk_dispatch_each_sum_(e2m3)
|
|
530
|
+
nk_dispatch_each_sum_(i64)
|
|
531
|
+
nk_dispatch_each_sum_(i32)
|
|
532
|
+
nk_dispatch_each_sum_(i16)
|
|
533
|
+
nk_dispatch_each_sum_(i8)
|
|
534
|
+
nk_dispatch_each_sum_(u64)
|
|
535
|
+
nk_dispatch_each_sum_(u32)
|
|
536
|
+
nk_dispatch_each_sum_(u16)
|
|
537
|
+
nk_dispatch_each_sum_(u8)
|
|
538
|
+
nk_dispatch_each_blend_(f64c, f64c)
|
|
539
|
+
nk_dispatch_each_blend_(f32c, f32c)
|
|
540
|
+
nk_dispatch_each_blend_(f64, f64)
|
|
541
|
+
nk_dispatch_each_blend_(f32, f32)
|
|
542
|
+
nk_dispatch_each_blend_(bf16, f32)
|
|
543
|
+
nk_dispatch_each_blend_(f16, f32)
|
|
544
|
+
nk_dispatch_each_blend_(e5m2, f32)
|
|
545
|
+
nk_dispatch_each_blend_(e4m3, f32)
|
|
546
|
+
nk_dispatch_each_blend_(e3m2, f32)
|
|
547
|
+
nk_dispatch_each_blend_(e2m3, f32)
|
|
548
|
+
nk_dispatch_each_blend_(i64, f64)
|
|
549
|
+
nk_dispatch_each_blend_(i32, f64)
|
|
550
|
+
nk_dispatch_each_blend_(i16, f32)
|
|
551
|
+
nk_dispatch_each_blend_(i8, f32)
|
|
552
|
+
nk_dispatch_each_blend_(u64, f64)
|
|
553
|
+
nk_dispatch_each_blend_(u32, f64)
|
|
554
|
+
nk_dispatch_each_blend_(u16, f32)
|
|
555
|
+
nk_dispatch_each_blend_(u8, f32)
|
|
556
|
+
nk_dispatch_each_fma_(f64c, f64c)
|
|
557
|
+
nk_dispatch_each_fma_(f32c, f32c)
|
|
558
|
+
nk_dispatch_each_fma_(f64, f64)
|
|
559
|
+
nk_dispatch_each_fma_(f32, f32)
|
|
560
|
+
nk_dispatch_each_fma_(bf16, f32)
|
|
561
|
+
nk_dispatch_each_fma_(f16, f32)
|
|
562
|
+
nk_dispatch_each_fma_(e5m2, f32)
|
|
563
|
+
nk_dispatch_each_fma_(e4m3, f32)
|
|
564
|
+
nk_dispatch_each_fma_(e3m2, f32)
|
|
565
|
+
nk_dispatch_each_fma_(e2m3, f32)
|
|
566
|
+
nk_dispatch_each_fma_(i64, f64)
|
|
567
|
+
nk_dispatch_each_fma_(i32, f64)
|
|
568
|
+
nk_dispatch_each_fma_(i16, f32)
|
|
569
|
+
nk_dispatch_each_fma_(i8, f32)
|
|
570
|
+
nk_dispatch_each_fma_(u64, f64)
|
|
571
|
+
nk_dispatch_each_fma_(u32, f64)
|
|
572
|
+
nk_dispatch_each_fma_(u16, f32)
|
|
573
|
+
nk_dispatch_each_fma_(u8, f32)
|
|
574
|
+
|
|
575
|
+
// Trigonometry functions
|
|
576
|
+
nk_dispatch_trigonometry_(sin, f64)
|
|
577
|
+
nk_dispatch_trigonometry_(sin, f32)
|
|
578
|
+
nk_dispatch_trigonometry_(sin, f16)
|
|
579
|
+
nk_dispatch_trigonometry_(cos, f64)
|
|
580
|
+
nk_dispatch_trigonometry_(cos, f32)
|
|
581
|
+
nk_dispatch_trigonometry_(cos, f16)
|
|
582
|
+
nk_dispatch_trigonometry_(atan, f64)
|
|
583
|
+
nk_dispatch_trigonometry_(atan, f32)
|
|
584
|
+
nk_dispatch_trigonometry_(atan, f16)
|
|
585
|
+
|
|
586
|
+
// Horizontal reductions: moments (sum + sum-of-squares)
|
|
587
|
+
nk_dispatch_reduce_moments_(f64, nk_f64_t, nk_f64_t, nk_f64_t)
|
|
588
|
+
nk_dispatch_reduce_moments_(f32, nk_f32_t, nk_f64_t, nk_f64_t)
|
|
589
|
+
nk_dispatch_reduce_moments_(bf16, nk_bf16_t, nk_f32_t, nk_f32_t)
|
|
590
|
+
nk_dispatch_reduce_moments_(f16, nk_f16_t, nk_f32_t, nk_f32_t)
|
|
591
|
+
nk_dispatch_reduce_moments_(e5m2, nk_e5m2_t, nk_f32_t, nk_f32_t)
|
|
592
|
+
nk_dispatch_reduce_moments_(e4m3, nk_e4m3_t, nk_f32_t, nk_f32_t)
|
|
593
|
+
nk_dispatch_reduce_moments_(e3m2, nk_e3m2_t, nk_f32_t, nk_f32_t)
|
|
594
|
+
nk_dispatch_reduce_moments_(e2m3, nk_e2m3_t, nk_f32_t, nk_f32_t)
|
|
595
|
+
nk_dispatch_reduce_moments_(i64, nk_i64_t, nk_i64_t, nk_u64_t)
|
|
596
|
+
nk_dispatch_reduce_moments_(i32, nk_i32_t, nk_i64_t, nk_u64_t)
|
|
597
|
+
nk_dispatch_reduce_moments_(i16, nk_i16_t, nk_i64_t, nk_u64_t)
|
|
598
|
+
nk_dispatch_reduce_moments_(i8, nk_i8_t, nk_i64_t, nk_u64_t)
|
|
599
|
+
nk_dispatch_reduce_moments_(i4, nk_i4x2_t, nk_i64_t, nk_u64_t)
|
|
600
|
+
nk_dispatch_reduce_moments_(u64, nk_u64_t, nk_u64_t, nk_u64_t)
|
|
601
|
+
nk_dispatch_reduce_moments_(u32, nk_u32_t, nk_u64_t, nk_u64_t)
|
|
602
|
+
nk_dispatch_reduce_moments_(u16, nk_u16_t, nk_u64_t, nk_u64_t)
|
|
603
|
+
nk_dispatch_reduce_moments_(u8, nk_u8_t, nk_u64_t, nk_u64_t)
|
|
604
|
+
nk_dispatch_reduce_moments_(u4, nk_u4x2_t, nk_u64_t, nk_u64_t)
|
|
605
|
+
nk_dispatch_reduce_moments_(u1, nk_u1x8_t, nk_u64_t, nk_u64_t)
|
|
606
|
+
|
|
607
|
+
// Horizontal reductions: minmax (min + max with indices)
|
|
608
|
+
nk_dispatch_reduce_minmax_(f64, nk_f64_t, nk_f64_t)
|
|
609
|
+
nk_dispatch_reduce_minmax_(f32, nk_f32_t, nk_f32_t)
|
|
610
|
+
nk_dispatch_reduce_minmax_(bf16, nk_bf16_t, nk_bf16_t)
|
|
611
|
+
nk_dispatch_reduce_minmax_(f16, nk_f16_t, nk_f16_t)
|
|
612
|
+
nk_dispatch_reduce_minmax_(e5m2, nk_e5m2_t, nk_e5m2_t)
|
|
613
|
+
nk_dispatch_reduce_minmax_(e4m3, nk_e4m3_t, nk_e4m3_t)
|
|
614
|
+
nk_dispatch_reduce_minmax_(e3m2, nk_e3m2_t, nk_e3m2_t)
|
|
615
|
+
nk_dispatch_reduce_minmax_(e2m3, nk_e2m3_t, nk_e2m3_t)
|
|
616
|
+
nk_dispatch_reduce_minmax_(i64, nk_i64_t, nk_i64_t)
|
|
617
|
+
nk_dispatch_reduce_minmax_(i32, nk_i32_t, nk_i32_t)
|
|
618
|
+
nk_dispatch_reduce_minmax_(i16, nk_i16_t, nk_i16_t)
|
|
619
|
+
nk_dispatch_reduce_minmax_(i8, nk_i8_t, nk_i8_t)
|
|
620
|
+
nk_dispatch_reduce_minmax_(i4, nk_i4x2_t, nk_i8_t)
|
|
621
|
+
nk_dispatch_reduce_minmax_(u64, nk_u64_t, nk_u64_t)
|
|
622
|
+
nk_dispatch_reduce_minmax_(u32, nk_u32_t, nk_u32_t)
|
|
623
|
+
nk_dispatch_reduce_minmax_(u16, nk_u16_t, nk_u16_t)
|
|
624
|
+
nk_dispatch_reduce_minmax_(u8, nk_u8_t, nk_u8_t)
|
|
625
|
+
nk_dispatch_reduce_minmax_(u4, nk_u4x2_t, nk_u8_t)
|
|
626
|
+
nk_dispatch_reduce_minmax_(u1, nk_u1x8_t, nk_u8_t)
|
|
627
|
+
|
|
628
|
+
// Dots packed sizes
|
|
629
|
+
nk_dispatch_cross_packed_size_(dots, f64, f64, f64)
|
|
630
|
+
nk_dispatch_cross_packed_size_(dots, f32, f32, f32)
|
|
631
|
+
nk_dispatch_cross_packed_size_(dots, bf16, bf16, f32)
|
|
632
|
+
nk_dispatch_cross_packed_size_(dots, f16, f16, f32)
|
|
633
|
+
nk_dispatch_cross_packed_size_(dots, e5m2, e5m2, f32)
|
|
634
|
+
nk_dispatch_cross_packed_size_(dots, e4m3, e4m3, f32)
|
|
635
|
+
nk_dispatch_cross_packed_size_(dots, e3m2, e3m2, f32)
|
|
636
|
+
nk_dispatch_cross_packed_size_(dots, e2m3, e2m3, f32)
|
|
637
|
+
nk_dispatch_cross_packed_size_(dots, i8, i8, i32)
|
|
638
|
+
nk_dispatch_cross_packed_size_(dots, i4, i4x2, i32)
|
|
639
|
+
nk_dispatch_cross_packed_size_(dots, u8, u8, u32)
|
|
640
|
+
nk_dispatch_cross_packed_size_(dots, u4, u4x2, u32)
|
|
641
|
+
nk_dispatch_cross_packed_size_(dots, u1, u1x8, u32)
|
|
642
|
+
|
|
643
|
+
// Dots packing
|
|
644
|
+
nk_dispatch_cross_pack_(dots, f64, f64, f64)
|
|
645
|
+
nk_dispatch_cross_pack_(dots, f32, f32, f32)
|
|
646
|
+
nk_dispatch_cross_pack_(dots, bf16, bf16, f32)
|
|
647
|
+
nk_dispatch_cross_pack_(dots, f16, f16, f32)
|
|
648
|
+
nk_dispatch_cross_pack_(dots, e5m2, e5m2, f32)
|
|
649
|
+
nk_dispatch_cross_pack_(dots, e4m3, e4m3, f32)
|
|
650
|
+
nk_dispatch_cross_pack_(dots, e3m2, e3m2, f32)
|
|
651
|
+
nk_dispatch_cross_pack_(dots, e2m3, e2m3, f32)
|
|
652
|
+
nk_dispatch_cross_pack_(dots, i8, i8, i32)
|
|
653
|
+
nk_dispatch_cross_pack_(dots, i4, i4x2, i32)
|
|
654
|
+
nk_dispatch_cross_pack_(dots, u8, u8, u32)
|
|
655
|
+
nk_dispatch_cross_pack_(dots, u4, u4x2, u32)
|
|
656
|
+
nk_dispatch_cross_pack_(dots, u1, u1x8, u32)
|
|
657
|
+
|
|
658
|
+
// Dots packed
|
|
659
|
+
nk_dispatch_cross_packed_(dots, f64, f64, f64, f64)
|
|
660
|
+
nk_dispatch_cross_packed_(dots, f32, f32, f32, f64)
|
|
661
|
+
nk_dispatch_cross_packed_(dots, bf16, bf16, f32, f32)
|
|
662
|
+
nk_dispatch_cross_packed_(dots, f16, f16, f32, f32)
|
|
663
|
+
nk_dispatch_cross_packed_(dots, e5m2, e5m2, f32, f32)
|
|
664
|
+
nk_dispatch_cross_packed_(dots, e4m3, e4m3, f32, f32)
|
|
665
|
+
nk_dispatch_cross_packed_(dots, e3m2, e3m2, f32, f32)
|
|
666
|
+
nk_dispatch_cross_packed_(dots, e2m3, e2m3, f32, f32)
|
|
667
|
+
nk_dispatch_cross_packed_(dots, i8, i8, i32, i32)
|
|
668
|
+
nk_dispatch_cross_packed_(dots, i4, i4x2, i32, i32)
|
|
669
|
+
nk_dispatch_cross_packed_(dots, u8, u8, u32, u32)
|
|
670
|
+
nk_dispatch_cross_packed_(dots, u4, u4x2, u32, u32)
|
|
671
|
+
nk_dispatch_cross_packed_(dots, u1, u1x8, u32, u32)
|
|
672
|
+
|
|
673
|
+
// Dots symmetric
|
|
674
|
+
nk_dispatch_cross_symmetric_(dots, f64, f64, f64)
|
|
675
|
+
nk_dispatch_cross_symmetric_(dots, f32, f32, f64)
|
|
676
|
+
nk_dispatch_cross_symmetric_(dots, bf16, bf16, f32)
|
|
677
|
+
nk_dispatch_cross_symmetric_(dots, f16, f16, f32)
|
|
678
|
+
nk_dispatch_cross_symmetric_(dots, e5m2, e5m2, f32)
|
|
679
|
+
nk_dispatch_cross_symmetric_(dots, e4m3, e4m3, f32)
|
|
680
|
+
nk_dispatch_cross_symmetric_(dots, e3m2, e3m2, f32)
|
|
681
|
+
nk_dispatch_cross_symmetric_(dots, e2m3, e2m3, f32)
|
|
682
|
+
nk_dispatch_cross_symmetric_(dots, i8, i8, i32)
|
|
683
|
+
nk_dispatch_cross_symmetric_(dots, i4, i4x2, i32)
|
|
684
|
+
nk_dispatch_cross_symmetric_(dots, u8, u8, u32)
|
|
685
|
+
nk_dispatch_cross_symmetric_(dots, u4, u4x2, u32)
|
|
686
|
+
nk_dispatch_cross_symmetric_(dots, u1, u1x8, u32)
|
|
687
|
+
|
|
688
|
+
// Sets packed
|
|
689
|
+
nk_dispatch_cross_packed_(hammings, u1, u1x8, u32, u32)
|
|
690
|
+
nk_dispatch_cross_packed_(jaccards, u1, u1x8, f32, f32)
|
|
691
|
+
|
|
692
|
+
// Sets symmetric
|
|
693
|
+
nk_dispatch_cross_symmetric_(hammings, u1, u1x8, u32)
|
|
694
|
+
nk_dispatch_cross_symmetric_(jaccards, u1, u1x8, f32)
|
|
695
|
+
|
|
696
|
+
// Angulars packed
|
|
697
|
+
nk_dispatch_cross_packed_(angulars, f64, f64, f64, f64)
|
|
698
|
+
nk_dispatch_cross_packed_(angulars, f32, f32, f32, f64)
|
|
699
|
+
nk_dispatch_cross_packed_(angulars, bf16, bf16, f32, f32)
|
|
700
|
+
nk_dispatch_cross_packed_(angulars, f16, f16, f32, f32)
|
|
701
|
+
nk_dispatch_cross_packed_(angulars, e5m2, e5m2, f32, f32)
|
|
702
|
+
nk_dispatch_cross_packed_(angulars, e4m3, e4m3, f32, f32)
|
|
703
|
+
nk_dispatch_cross_packed_(angulars, e3m2, e3m2, f32, f32)
|
|
704
|
+
nk_dispatch_cross_packed_(angulars, e2m3, e2m3, f32, f32)
|
|
705
|
+
nk_dispatch_cross_packed_(angulars, i8, i8, i32, f32)
|
|
706
|
+
nk_dispatch_cross_packed_(angulars, i4, i4x2, i32, f32)
|
|
707
|
+
nk_dispatch_cross_packed_(angulars, u8, u8, u32, f32)
|
|
708
|
+
nk_dispatch_cross_packed_(angulars, u4, u4x2, u32, f32)
|
|
709
|
+
|
|
710
|
+
// Angulars symmetric
|
|
711
|
+
nk_dispatch_cross_symmetric_(angulars, f64, f64, f64)
|
|
712
|
+
nk_dispatch_cross_symmetric_(angulars, f32, f32, f64)
|
|
713
|
+
nk_dispatch_cross_symmetric_(angulars, bf16, bf16, f32)
|
|
714
|
+
nk_dispatch_cross_symmetric_(angulars, f16, f16, f32)
|
|
715
|
+
nk_dispatch_cross_symmetric_(angulars, e5m2, e5m2, f32)
|
|
716
|
+
nk_dispatch_cross_symmetric_(angulars, e4m3, e4m3, f32)
|
|
717
|
+
nk_dispatch_cross_symmetric_(angulars, e3m2, e3m2, f32)
|
|
718
|
+
nk_dispatch_cross_symmetric_(angulars, e2m3, e2m3, f32)
|
|
719
|
+
nk_dispatch_cross_symmetric_(angulars, i8, i8, f32)
|
|
720
|
+
nk_dispatch_cross_symmetric_(angulars, i4, i4x2, f32)
|
|
721
|
+
nk_dispatch_cross_symmetric_(angulars, u8, u8, f32)
|
|
722
|
+
nk_dispatch_cross_symmetric_(angulars, u4, u4x2, f32)
|
|
723
|
+
|
|
724
|
+
// Euclideans packed
|
|
725
|
+
nk_dispatch_cross_packed_(euclideans, f64, f64, f64, f64)
|
|
726
|
+
nk_dispatch_cross_packed_(euclideans, f32, f32, f32, f64)
|
|
727
|
+
nk_dispatch_cross_packed_(euclideans, bf16, bf16, f32, f32)
|
|
728
|
+
nk_dispatch_cross_packed_(euclideans, f16, f16, f32, f32)
|
|
729
|
+
nk_dispatch_cross_packed_(euclideans, e5m2, e5m2, f32, f32)
|
|
730
|
+
nk_dispatch_cross_packed_(euclideans, e4m3, e4m3, f32, f32)
|
|
731
|
+
nk_dispatch_cross_packed_(euclideans, e3m2, e3m2, f32, f32)
|
|
732
|
+
nk_dispatch_cross_packed_(euclideans, e2m3, e2m3, f32, f32)
|
|
733
|
+
nk_dispatch_cross_packed_(euclideans, i8, i8, i32, f32)
|
|
734
|
+
nk_dispatch_cross_packed_(euclideans, i4, i4x2, i32, f32)
|
|
735
|
+
nk_dispatch_cross_packed_(euclideans, u8, u8, u32, f32)
|
|
736
|
+
nk_dispatch_cross_packed_(euclideans, u4, u4x2, u32, f32)
|
|
737
|
+
|
|
738
|
+
// Euclideans symmetric
|
|
739
|
+
nk_dispatch_cross_symmetric_(euclideans, f64, f64, f64)
|
|
740
|
+
nk_dispatch_cross_symmetric_(euclideans, f32, f32, f64)
|
|
741
|
+
nk_dispatch_cross_symmetric_(euclideans, bf16, bf16, f32)
|
|
742
|
+
nk_dispatch_cross_symmetric_(euclideans, f16, f16, f32)
|
|
743
|
+
nk_dispatch_cross_symmetric_(euclideans, e5m2, e5m2, f32)
|
|
744
|
+
nk_dispatch_cross_symmetric_(euclideans, e4m3, e4m3, f32)
|
|
745
|
+
nk_dispatch_cross_symmetric_(euclideans, e3m2, e3m2, f32)
|
|
746
|
+
nk_dispatch_cross_symmetric_(euclideans, e2m3, e2m3, f32)
|
|
747
|
+
nk_dispatch_cross_symmetric_(euclideans, i8, i8, f32)
|
|
748
|
+
nk_dispatch_cross_symmetric_(euclideans, i4, i4x2, f32)
|
|
749
|
+
nk_dispatch_cross_symmetric_(euclideans, u8, u8, f32)
|
|
750
|
+
nk_dispatch_cross_symmetric_(euclideans, u4, u4x2, f32)
|
|
751
|
+
|
|
752
|
+
// MaxSim packed sizes
|
|
753
|
+
nk_dispatch_cross_packed_size_(maxsim, f32, f32, f32)
|
|
754
|
+
nk_dispatch_cross_packed_size_(maxsim, bf16, bf16, f32)
|
|
755
|
+
nk_dispatch_cross_packed_size_(maxsim, f16, f16, f32)
|
|
756
|
+
|
|
757
|
+
// MaxSim packing
|
|
758
|
+
nk_dispatch_cross_pack_(maxsim, f32, f32, f32)
|
|
759
|
+
nk_dispatch_cross_pack_(maxsim, bf16, bf16, f32)
|
|
760
|
+
nk_dispatch_cross_pack_(maxsim, f16, f16, f32)
|
|
761
|
+
|
|
762
|
+
// MaxSim packed scoring
|
|
763
|
+
nk_dispatch_maxsim_packed_(f32, f64)
|
|
764
|
+
nk_dispatch_maxsim_packed_(bf16, f32)
|
|
765
|
+
nk_dispatch_maxsim_packed_(f16, f32)
|
|
766
|
+
|
|
767
|
+
NK_DYNAMIC int nk_uses_dynamic_dispatch(void) { return 1; }
|
|
768
|
+
NK_DYNAMIC int nk_configure_thread(nk_capability_t c) { return nk_configure_thread_(c); }
|
|
769
|
+
|
|
770
|
+
NK_DYNAMIC void nk_cast(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
771
|
+
nk_dispatch_table.cast(from, from_type, n, to, to_type);
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
// Forward declarations for dtype-specific dispatch initialization functions
|
|
775
|
+
void nk_dispatch_f64c_init_(nk_capability_t caps);
|
|
776
|
+
void nk_dispatch_f32c_init_(nk_capability_t caps);
|
|
777
|
+
void nk_dispatch_bf16c_init_(nk_capability_t caps);
|
|
778
|
+
void nk_dispatch_f16c_init_(nk_capability_t caps);
|
|
779
|
+
void nk_dispatch_f64_init_(nk_capability_t caps);
|
|
780
|
+
void nk_dispatch_f32_init_(nk_capability_t caps);
|
|
781
|
+
void nk_dispatch_bf16_init_(nk_capability_t caps);
|
|
782
|
+
void nk_dispatch_f16_init_(nk_capability_t caps);
|
|
783
|
+
void nk_dispatch_e5m2_init_(nk_capability_t caps);
|
|
784
|
+
void nk_dispatch_e4m3_init_(nk_capability_t caps);
|
|
785
|
+
void nk_dispatch_e3m2_init_(nk_capability_t caps);
|
|
786
|
+
void nk_dispatch_e2m3_init_(nk_capability_t caps);
|
|
787
|
+
void nk_dispatch_i64_init_(nk_capability_t caps);
|
|
788
|
+
void nk_dispatch_i32_init_(nk_capability_t caps);
|
|
789
|
+
void nk_dispatch_i16_init_(nk_capability_t caps);
|
|
790
|
+
void nk_dispatch_i8_init_(nk_capability_t caps);
|
|
791
|
+
void nk_dispatch_i4_init_(nk_capability_t caps);
|
|
792
|
+
void nk_dispatch_u64_init_(nk_capability_t caps);
|
|
793
|
+
void nk_dispatch_u32_init_(nk_capability_t caps);
|
|
794
|
+
void nk_dispatch_u16_init_(nk_capability_t caps);
|
|
795
|
+
void nk_dispatch_u8_init_(nk_capability_t caps);
|
|
796
|
+
void nk_dispatch_u4_init_(nk_capability_t caps);
|
|
797
|
+
void nk_dispatch_u1_init_(nk_capability_t caps);
|
|
798
|
+
void nk_dispatch_cast_init_(nk_capability_t caps);
|
|
799
|
+
void nk_dispatch_math_init_(nk_capability_t caps);
|
|
800
|
+
|
|
801
|
+
NK_INTERNAL void nk_dispatch_table_update_implementation_(nk_capability_t caps) {
|
|
802
|
+
nk_dispatch_f64c_init_(caps);
|
|
803
|
+
nk_dispatch_f32c_init_(caps);
|
|
804
|
+
nk_dispatch_bf16c_init_(caps);
|
|
805
|
+
nk_dispatch_f16c_init_(caps);
|
|
806
|
+
nk_dispatch_f64_init_(caps);
|
|
807
|
+
nk_dispatch_f32_init_(caps);
|
|
808
|
+
nk_dispatch_bf16_init_(caps);
|
|
809
|
+
nk_dispatch_f16_init_(caps);
|
|
810
|
+
nk_dispatch_e5m2_init_(caps);
|
|
811
|
+
nk_dispatch_e4m3_init_(caps);
|
|
812
|
+
nk_dispatch_e3m2_init_(caps);
|
|
813
|
+
nk_dispatch_e2m3_init_(caps);
|
|
814
|
+
nk_dispatch_i64_init_(caps);
|
|
815
|
+
nk_dispatch_i32_init_(caps);
|
|
816
|
+
nk_dispatch_i16_init_(caps);
|
|
817
|
+
nk_dispatch_i8_init_(caps);
|
|
818
|
+
nk_dispatch_i4_init_(caps);
|
|
819
|
+
nk_dispatch_u64_init_(caps);
|
|
820
|
+
nk_dispatch_u32_init_(caps);
|
|
821
|
+
nk_dispatch_u16_init_(caps);
|
|
822
|
+
nk_dispatch_u8_init_(caps);
|
|
823
|
+
nk_dispatch_u4_init_(caps);
|
|
824
|
+
nk_dispatch_u1_init_(caps);
|
|
825
|
+
nk_dispatch_cast_init_(caps);
|
|
826
|
+
nk_dispatch_math_init_(caps);
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
NK_INTERNAL void nk_dispatch_table_init(void) { nk_dispatch_table_update_implementation_(nk_capabilities()); }
|
|
830
|
+
|
|
831
|
+
NK_DYNAMIC void nk_dispatch_table_update(nk_capability_t caps) { nk_dispatch_table_update_implementation_(caps); }
|
|
832
|
+
NK_DYNAMIC nk_capability_t nk_capabilities(void) {
|
|
833
|
+
//! The latency of the CPUID instruction can be over 100 cycles, so we cache the result.
|
|
834
|
+
static nk_capability_t static_capabilities = nk_cap_any_k;
|
|
835
|
+
if (static_capabilities != nk_cap_any_k) return static_capabilities;
|
|
836
|
+
|
|
837
|
+
static_capabilities = nk_capabilities_();
|
|
838
|
+
|
|
839
|
+
// Initialize the central dispatch table with the detected capabilities
|
|
840
|
+
nk_dispatch_table_init();
|
|
841
|
+
|
|
842
|
+
return static_capabilities;
|
|
843
|
+
}
|
|
844
|
+
|
|
845
|
+
NK_DYNAMIC void nk_find_kernel_punned( //
|
|
846
|
+
nk_kernel_kind_t kind, //
|
|
847
|
+
nk_dtype_t dtype, //
|
|
848
|
+
nk_capability_t viable, //
|
|
849
|
+
nk_kernel_punned_t *kernel_output, //
|
|
850
|
+
nk_capability_t *capability_output) {
|
|
851
|
+
|
|
852
|
+
// Modern compilers abso-freaking-lutely love optimizing-out my logic!
|
|
853
|
+
// Just marking the variables as `volatile` is not enough, so we have
|
|
854
|
+
// to add inline assembly to further discourage them!
|
|
855
|
+
#if defined(_MSC_VER)
|
|
856
|
+
_ReadWriteBarrier();
|
|
857
|
+
#else
|
|
858
|
+
__asm__ __volatile__("" ::: "memory");
|
|
859
|
+
#endif
|
|
860
|
+
|
|
861
|
+
nk_kernel_punned_t *m = kernel_output;
|
|
862
|
+
nk_capability_t *c = capability_output;
|
|
863
|
+
|
|
864
|
+
switch (dtype) {
|
|
865
|
+
|
|
866
|
+
case nk_f64c_k: nk_dispatch_f64c_find_(viable, kind, m, c); return;
|
|
867
|
+
case nk_f32c_k: nk_dispatch_f32c_find_(viable, kind, m, c); return;
|
|
868
|
+
case nk_bf16c_k: nk_dispatch_bf16c_find_(viable, kind, m, c); return;
|
|
869
|
+
case nk_f16c_k: nk_dispatch_f16c_find_(viable, kind, m, c); return;
|
|
870
|
+
|
|
871
|
+
case nk_f64_k: nk_dispatch_f64_find_(viable, kind, m, c); return;
|
|
872
|
+
case nk_f32_k: nk_dispatch_f32_find_(viable, kind, m, c); return;
|
|
873
|
+
case nk_bf16_k: nk_dispatch_bf16_find_(viable, kind, m, c); return;
|
|
874
|
+
case nk_f16_k: nk_dispatch_f16_find_(viable, kind, m, c); return;
|
|
875
|
+
|
|
876
|
+
case nk_e5m2_k: nk_dispatch_e5m2_find_(viable, kind, m, c); return;
|
|
877
|
+
case nk_e4m3_k: nk_dispatch_e4m3_find_(viable, kind, m, c); return;
|
|
878
|
+
case nk_e3m2_k: nk_dispatch_e3m2_find_(viable, kind, m, c); return;
|
|
879
|
+
case nk_e2m3_k: nk_dispatch_e2m3_find_(viable, kind, m, c); return;
|
|
880
|
+
|
|
881
|
+
case nk_i64_k: nk_dispatch_i64_find_(viable, kind, m, c); return;
|
|
882
|
+
case nk_i32_k: nk_dispatch_i32_find_(viable, kind, m, c); return;
|
|
883
|
+
case nk_i16_k: nk_dispatch_i16_find_(viable, kind, m, c); return;
|
|
884
|
+
case nk_i8_k: nk_dispatch_i8_find_(viable, kind, m, c); return;
|
|
885
|
+
case nk_i4_k: nk_dispatch_i4_find_(viable, kind, m, c); return;
|
|
886
|
+
|
|
887
|
+
case nk_u64_k: nk_dispatch_u64_find_(viable, kind, m, c); return;
|
|
888
|
+
case nk_u32_k: nk_dispatch_u32_find_(viable, kind, m, c); return;
|
|
889
|
+
case nk_u16_k: nk_dispatch_u16_find_(viable, kind, m, c); return;
|
|
890
|
+
case nk_u8_k: nk_dispatch_u8_find_(viable, kind, m, c); return;
|
|
891
|
+
case nk_u4_k: nk_dispatch_u4_find_(viable, kind, m, c); return;
|
|
892
|
+
case nk_u1_k: nk_dispatch_u1_find_(viable, kind, m, c); return;
|
|
893
|
+
|
|
894
|
+
case nk_dtype_unknown_k: nk_dispatch_cast_find_(viable, kind, m, c); return;
|
|
895
|
+
default: break;
|
|
896
|
+
}
|
|
897
|
+
|
|
898
|
+
// Replace with zeros if no suitable implementation was found
|
|
899
|
+
*m = (nk_kernel_punned_t)0;
|
|
900
|
+
*c = (nk_capability_t)0;
|
|
901
|
+
|
|
902
|
+
// Modern compilers abso-freaking-lutely love optimizing-out my logic!
|
|
903
|
+
// Just marking the variables as `volatile` is not enough, so we have
|
|
904
|
+
// to add inline assembly to further discourage them!
|
|
905
|
+
#if defined(_MSC_VER)
|
|
906
|
+
_ReadWriteBarrier();
|
|
907
|
+
#else
|
|
908
|
+
__asm__ __volatile__("" ::: "memory");
|
|
909
|
+
#endif
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
// Auto-initialization for dynamic libraries - ensures dispatch table is populated on library load
|
|
913
|
+
#if defined(__GNUC__) || defined(__clang__)
|
|
914
|
+
__attribute__((constructor)) static void nk_auto_init(void) {
|
|
915
|
+
nk_capabilities(); // Triggers dispatch table initialization
|
|
916
|
+
}
|
|
917
|
+
#elif defined(_MSC_VER)
|
|
918
|
+
static void nk_auto_init(void);
|
|
919
|
+
#pragma section(".CRT$XCU", read)
|
|
920
|
+
__declspec(allocate(".CRT$XCU")) static void (*nk_auto_init_ptr)(void) = nk_auto_init;
|
|
921
|
+
static void nk_auto_init(void) {
|
|
922
|
+
nk_capabilities(); // Triggers dispatch table initialization
|
|
923
|
+
}
|
|
924
|
+
#ifdef _WIN32
|
|
925
|
+
#include <windows.h>
|
|
926
|
+
BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) {
|
|
927
|
+
nk_unused_(hinstDLL);
|
|
928
|
+
nk_unused_(lpReserved);
|
|
929
|
+
if (fdwReason == DLL_PROCESS_ATTACH) nk_auto_init();
|
|
930
|
+
return TRUE;
|
|
931
|
+
}
|
|
932
|
+
#endif
|
|
933
|
+
#endif
|
|
934
|
+
|
|
935
|
+
// SME ABI runtime stubs — provide the lazy-ZA-save helpers that compiler-rt
|
|
936
|
+
// may not ship (e.g., Apple's toolchain). Called by compiler-generated code
|
|
937
|
+
// in __arm_new("za") prologues/epilogues (used by dots streaming functions).
|
|
938
|
+
//
|
|
939
|
+
// In NumKong, TPIDR2_EL0 is always null at entry because no NK_PUBLIC function
|
|
940
|
+
// carries ZA state. So __arm_tpidr2_save is always a no-op and
|
|
941
|
+
// __arm_tpidr2_restore has nothing to restore.
|
|
942
|
+
// Weak linkage lets a real compiler-rt override these if available.
|
|
943
|
+
#if NK_TARGET_ARM_ && NK_TARGET_SME
|
|
944
|
+
__attribute__((weak, visibility("default"))) void __arm_tpidr2_save(void) {}
|
|
945
|
+
__attribute__((weak, visibility("default"))) void __arm_tpidr2_restore(void *blk) { nk_unused_(blk); }
|
|
946
|
+
#endif
|
|
947
|
+
|
|
948
|
+
#ifdef __cplusplus
|
|
949
|
+
}
|
|
950
|
+
#endif
|