numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,689 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief JavaScript bindings for NumKong.
|
|
3
|
+
* @file javascript/numkong.c
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date October 18, 2023
|
|
6
|
+
*
|
|
7
|
+
* @see NodeJS docs: https://nodejs.org/api/n-api.html
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
#include <string.h> // `strcmp` function
|
|
11
|
+
|
|
12
|
+
#include <node_api.h> // `napi_*` functions — N-API v6+ for BigInt (Node ≥ 10.20)
|
|
13
|
+
|
|
14
|
+
#include <numkong/numkong.h> // `nk_*` functions — must be first to bring `_GNU_SOURCE`
|
|
15
|
+
|
|
16
|
+
/** @brief Global variable that caches the CPU capabilities, and is computed just once, when the module is loaded. */
|
|
17
|
+
nk_capability_t static_capabilities = nk_cap_serial_k;
|
|
18
|
+
|
|
19
|
+
#pragma region Helpers
|
|
20
|
+
|
|
21
|
+
/** @brief Parses a dtype string (e.g. "f32", "f16", "bf16") into a nk_dtype_t enum value. */
|
|
22
|
+
static nk_dtype_t parse_dtype_string(const char *str) {
|
|
23
|
+
if (strcmp(str, "f64") == 0) return nk_f64_k;
|
|
24
|
+
else if (strcmp(str, "f32") == 0) return nk_f32_k;
|
|
25
|
+
else if (strcmp(str, "f16") == 0) return nk_f16_k;
|
|
26
|
+
else if (strcmp(str, "bf16") == 0) return nk_bf16_k;
|
|
27
|
+
else if (strcmp(str, "e4m3") == 0) return nk_e4m3_k;
|
|
28
|
+
else if (strcmp(str, "e5m2") == 0) return nk_e5m2_k;
|
|
29
|
+
else if (strcmp(str, "e2m3") == 0) return nk_e2m3_k;
|
|
30
|
+
else if (strcmp(str, "e3m2") == 0) return nk_e3m2_k;
|
|
31
|
+
else if (strcmp(str, "i8") == 0) return nk_i8_k;
|
|
32
|
+
else if (strcmp(str, "u8") == 0) return nk_u8_k;
|
|
33
|
+
else if (strcmp(str, "i16") == 0) return nk_i16_k;
|
|
34
|
+
else if (strcmp(str, "u16") == 0) return nk_u16_k;
|
|
35
|
+
else if (strcmp(str, "i64") == 0) return nk_i64_k;
|
|
36
|
+
else if (strcmp(str, "u64") == 0) return nk_u64_k;
|
|
37
|
+
else if (strcmp(str, "u1") == 0) return nk_u1_k;
|
|
38
|
+
return nk_dtype_unknown_k;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/** @brief Validates that the N-API TypedArray type is compatible with the claimed dtype. */
|
|
42
|
+
static int is_compatible_napi_type(napi_typedarray_type napi_type, nk_dtype_t dtype) {
|
|
43
|
+
switch (dtype) {
|
|
44
|
+
case nk_f64_k: return napi_type == napi_float64_array;
|
|
45
|
+
case nk_f32_k: return napi_type == napi_float32_array;
|
|
46
|
+
case nk_f16_k:
|
|
47
|
+
case nk_bf16_k: return napi_type == napi_uint16_array;
|
|
48
|
+
case nk_e4m3_k:
|
|
49
|
+
case nk_e5m2_k:
|
|
50
|
+
case nk_e2m3_k:
|
|
51
|
+
case nk_e3m2_k:
|
|
52
|
+
case nk_u8_k:
|
|
53
|
+
case nk_u1_k: return napi_type == napi_uint8_array;
|
|
54
|
+
case nk_i8_k: return napi_type == napi_int8_array;
|
|
55
|
+
case nk_i16_k: return napi_type == napi_int16_array;
|
|
56
|
+
case nk_u16_k: return napi_type == napi_uint16_array;
|
|
57
|
+
default: return 0;
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
/**
|
|
62
|
+
* @brief Converts an nk_scalar_buffer_t result to a JavaScript number.
|
|
63
|
+
* @param env N-API environment.
|
|
64
|
+
* @param result The scalar buffer containing the result.
|
|
65
|
+
* @param out_dtype The dtype of the value stored in the buffer.
|
|
66
|
+
* @return napi_value containing the result as a JavaScript Number, or NULL on error.
|
|
67
|
+
*/
|
|
68
|
+
static napi_value scalar_to_js_number(napi_env env, nk_scalar_buffer_t const *result, nk_dtype_t out_dtype) {
|
|
69
|
+
// i64/u64 must return BigInt since they may exceed Number.MAX_SAFE_INTEGER
|
|
70
|
+
if (out_dtype == nk_i64_k) {
|
|
71
|
+
napi_value js_result;
|
|
72
|
+
if (napi_create_bigint_int64(env, result->i64, &js_result) != napi_ok) return NULL;
|
|
73
|
+
return js_result;
|
|
74
|
+
}
|
|
75
|
+
if (out_dtype == nk_u64_k) {
|
|
76
|
+
napi_value js_result;
|
|
77
|
+
if (napi_create_bigint_uint64(env, result->u64, &js_result) != napi_ok) return NULL;
|
|
78
|
+
return js_result;
|
|
79
|
+
}
|
|
80
|
+
double result_f64;
|
|
81
|
+
switch (out_dtype) {
|
|
82
|
+
case nk_f64_k: result_f64 = (double)result->f64; break;
|
|
83
|
+
case nk_f32_k: result_f64 = (double)result->f32; break;
|
|
84
|
+
case nk_f16_k: {
|
|
85
|
+
nk_f32_t t;
|
|
86
|
+
nk_f16_to_f32(&result->f16, &t);
|
|
87
|
+
result_f64 = (double)t;
|
|
88
|
+
break;
|
|
89
|
+
}
|
|
90
|
+
case nk_bf16_k: {
|
|
91
|
+
nk_f32_t t;
|
|
92
|
+
nk_bf16_to_f32(&result->bf16, &t);
|
|
93
|
+
result_f64 = (double)t;
|
|
94
|
+
break;
|
|
95
|
+
}
|
|
96
|
+
case nk_i8_k: result_f64 = (double)result->i8; break;
|
|
97
|
+
case nk_u8_k: result_f64 = (double)result->u8; break;
|
|
98
|
+
case nk_i16_k: result_f64 = (double)result->i16; break;
|
|
99
|
+
case nk_u16_k: result_f64 = (double)result->u16; break;
|
|
100
|
+
case nk_i32_k: result_f64 = (double)result->i32; break;
|
|
101
|
+
case nk_u32_k: result_f64 = (double)result->u32; break;
|
|
102
|
+
default: napi_throw_error(env, NULL, "Unexpected output dtype in result conversion"); return NULL;
|
|
103
|
+
}
|
|
104
|
+
napi_value js_result;
|
|
105
|
+
if (napi_create_double(env, result_f64, &js_result) != napi_ok) return NULL;
|
|
106
|
+
return js_result;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
/** @brief Returns the byte width for a given dtype. */
|
|
110
|
+
static inline size_t dtype_byte_width(nk_dtype_t dtype) { return nk_dtype_bits(dtype) / NK_BITS_PER_BYTE; }
|
|
111
|
+
|
|
112
|
+
/** @brief Returns the N-API typed array type for a given output dtype. */
|
|
113
|
+
static inline napi_typedarray_type napi_type_for_dtype(nk_dtype_t dtype) {
|
|
114
|
+
switch (dtype) {
|
|
115
|
+
case nk_f64_k: return napi_float64_array;
|
|
116
|
+
case nk_f32_k: return napi_float32_array;
|
|
117
|
+
case nk_i32_k: return napi_int32_array;
|
|
118
|
+
case nk_u32_k: return napi_uint32_array;
|
|
119
|
+
default: return napi_float32_array;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
#pragma endregion Helpers
|
|
124
|
+
|
|
125
|
+
#pragma region Distance API
|
|
126
|
+
|
|
127
|
+
/** @brief Core distance computation — resolves dtype, dispatches kernel, converts result. */
|
|
128
|
+
static napi_value dense(napi_env env, napi_callback_info info, nk_kernel_kind_t kernel_kind, nk_dtype_t dtype) {
|
|
129
|
+
size_t argc = 3;
|
|
130
|
+
napi_value args[3];
|
|
131
|
+
napi_status status;
|
|
132
|
+
|
|
133
|
+
// Get callback info and ensure the argument count is correct (2 or 3 args)
|
|
134
|
+
status = napi_get_cb_info(env, info, &argc, args, NULL, NULL);
|
|
135
|
+
if (status != napi_ok || argc < 2 || argc > 3) {
|
|
136
|
+
napi_throw_error(env, NULL, "Expected 2 or 3 arguments: (a, b[, dtype])");
|
|
137
|
+
return NULL;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
// Obtain the typed arrays from the arguments
|
|
141
|
+
void *data_a, *data_b;
|
|
142
|
+
size_t length_a, length_b;
|
|
143
|
+
napi_typedarray_type type_a, type_b;
|
|
144
|
+
napi_status status_a, status_b;
|
|
145
|
+
status_a = napi_get_typedarray_info(env, args[0], &type_a, &length_a, &data_a, NULL, NULL);
|
|
146
|
+
status_b = napi_get_typedarray_info(env, args[1], &type_b, &length_b, &data_b, NULL, NULL);
|
|
147
|
+
if (status_a != napi_ok || status_b != napi_ok || type_a != type_b || length_a != length_b) {
|
|
148
|
+
napi_throw_error(env, NULL, "Both arguments must be typed arrays of matching types and dimensionality");
|
|
149
|
+
return NULL;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// When dtype is unknown, try to resolve from optional 3rd argument or auto-detect
|
|
153
|
+
if (dtype == nk_dtype_unknown_k) {
|
|
154
|
+
if (argc == 3) {
|
|
155
|
+
// Parse explicit dtype string from 3rd argument
|
|
156
|
+
char dtype_str[16];
|
|
157
|
+
size_t str_len;
|
|
158
|
+
if (napi_get_value_string_utf8(env, args[2], dtype_str, sizeof(dtype_str), &str_len) != napi_ok) {
|
|
159
|
+
napi_throw_error(env, NULL, "Third argument must be a dtype string");
|
|
160
|
+
return NULL;
|
|
161
|
+
}
|
|
162
|
+
dtype = parse_dtype_string(dtype_str);
|
|
163
|
+
if (dtype == nk_dtype_unknown_k) {
|
|
164
|
+
napi_throw_error(env, NULL, "Unsupported dtype string");
|
|
165
|
+
return NULL;
|
|
166
|
+
}
|
|
167
|
+
if (!is_compatible_napi_type(type_a, dtype)) {
|
|
168
|
+
napi_throw_error(env, NULL, "TypedArray type is not compatible with the specified dtype");
|
|
169
|
+
return NULL;
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
else {
|
|
173
|
+
// Auto-detect from N-API TypedArray type (backward-compatible 4-type whitelist)
|
|
174
|
+
if (type_a != napi_float64_array && type_a != napi_float32_array && type_a != napi_int8_array &&
|
|
175
|
+
type_a != napi_uint8_array) {
|
|
176
|
+
napi_throw_error(
|
|
177
|
+
env, NULL,
|
|
178
|
+
"Only f64, f32, i8, u8 arrays are auto-detected; pass dtype string as 3rd argument " "for other " "types");
|
|
179
|
+
return NULL;
|
|
180
|
+
}
|
|
181
|
+
switch (type_a) {
|
|
182
|
+
case napi_float64_array: dtype = nk_f64_k; break;
|
|
183
|
+
case napi_float32_array: dtype = nk_f32_k; break;
|
|
184
|
+
case napi_int8_array: dtype = nk_i8_k; break;
|
|
185
|
+
case napi_uint8_array: dtype = nk_u8_k; break;
|
|
186
|
+
default: break;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
nk_metric_dense_punned_t metric = NULL;
|
|
192
|
+
nk_capability_t capability = nk_cap_serial_k;
|
|
193
|
+
nk_find_kernel_punned(kernel_kind, dtype, static_capabilities, (nk_kernel_punned_t *)&metric, &capability);
|
|
194
|
+
if (!metric || !capability) {
|
|
195
|
+
napi_throw_error(env, NULL, "Unsupported dtype for given metric");
|
|
196
|
+
return NULL;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
nk_dtype_t out_dtype = nk_kernel_output_dtype(kernel_kind, dtype);
|
|
200
|
+
if (out_dtype == nk_dtype_unknown_k) {
|
|
201
|
+
napi_throw_error(env, NULL, "Unsupported output dtype for given metric/input combination");
|
|
202
|
+
return NULL;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
// Adjust dimensions for sub-byte packed types (e.g. Uint8Array with u1 dtype → bits)
|
|
206
|
+
nk_size_t to_bits = nk_dtype_bits(dtype);
|
|
207
|
+
size_t dimensions = (to_bits && to_bits < NK_BITS_PER_BYTE) ? length_a * NK_BITS_PER_BYTE / to_bits : length_a;
|
|
208
|
+
|
|
209
|
+
nk_scalar_buffer_t result;
|
|
210
|
+
metric(data_a, data_b, dimensions, &result);
|
|
211
|
+
|
|
212
|
+
return scalar_to_js_number(env, &result, out_dtype);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
/** @brief N-API entry for inner product (dot). */
|
|
216
|
+
napi_value api_ip(napi_env env, napi_callback_info info) {
|
|
217
|
+
return dense(env, info, nk_kernel_dot_k, nk_dtype_unknown_k);
|
|
218
|
+
}
|
|
219
|
+
/** @brief N-API entry for angular distance. */
|
|
220
|
+
napi_value api_angular(napi_env env, napi_callback_info info) {
|
|
221
|
+
return dense(env, info, nk_kernel_angular_k, nk_dtype_unknown_k);
|
|
222
|
+
}
|
|
223
|
+
/** @brief N-API entry for squared Euclidean distance. */
|
|
224
|
+
napi_value api_sqeuclidean(napi_env env, napi_callback_info info) {
|
|
225
|
+
return dense(env, info, nk_kernel_sqeuclidean_k, nk_dtype_unknown_k);
|
|
226
|
+
}
|
|
227
|
+
/** @brief N-API entry for Euclidean distance. */
|
|
228
|
+
napi_value api_euclidean(napi_env env, napi_callback_info info) {
|
|
229
|
+
return dense(env, info, nk_kernel_euclidean_k, nk_dtype_unknown_k);
|
|
230
|
+
}
|
|
231
|
+
/** @brief N-API entry for Kullback-Leibler divergence. */
|
|
232
|
+
napi_value api_kld(napi_env env, napi_callback_info info) {
|
|
233
|
+
return dense(env, info, nk_kernel_kld_k, nk_dtype_unknown_k);
|
|
234
|
+
}
|
|
235
|
+
/** @brief N-API entry for Jensen-Shannon distance. */
|
|
236
|
+
napi_value api_jsd(napi_env env, napi_callback_info info) {
|
|
237
|
+
return dense(env, info, nk_kernel_jsd_k, nk_dtype_unknown_k);
|
|
238
|
+
}
|
|
239
|
+
/** @brief N-API entry for Hamming distance. */
|
|
240
|
+
napi_value api_hamming(napi_env env, napi_callback_info info) { return dense(env, info, nk_kernel_hamming_k, nk_u1_k); }
|
|
241
|
+
/** @brief N-API entry for Jaccard distance. */
|
|
242
|
+
napi_value api_jaccard(napi_env env, napi_callback_info info) { return dense(env, info, nk_kernel_jaccard_k, nk_u1_k); }
|
|
243
|
+
|
|
244
|
+
#pragma endregion Distance API
|
|
245
|
+
|
|
246
|
+
#pragma region Capabilities API
|
|
247
|
+
|
|
248
|
+
/**
|
|
249
|
+
* @brief Returns the runtime-detected SIMD capabilities as a bitmask.
|
|
250
|
+
* @return BigInt bitmask of nk_capability_t flags (33 flags from NEON to SME2P1)
|
|
251
|
+
*
|
|
252
|
+
* This function exposes the cached capability bitmask to JavaScript users,
|
|
253
|
+
* allowing them to query what SIMD extensions are available at runtime.
|
|
254
|
+
* The capabilities are detected once at module load time and cached in static_capabilities.
|
|
255
|
+
*/
|
|
256
|
+
napi_value api_get_capabilities(napi_env env, napi_callback_info info) {
|
|
257
|
+
napi_value result;
|
|
258
|
+
// Use cached capabilities from module load (static_capabilities set in Init())
|
|
259
|
+
napi_create_bigint_uint64(env, (uint64_t)static_capabilities, &result);
|
|
260
|
+
return result;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
#pragma endregion Capabilities API
|
|
264
|
+
|
|
265
|
+
#pragma region Cast API
|
|
266
|
+
|
|
267
|
+
/** @brief Converts a single value from a narrow type to f32. Reads uint32 bits, returns double. */
|
|
268
|
+
static napi_value cast_to_f32(napi_env env, napi_callback_info info, nk_dtype_t src_dtype) {
|
|
269
|
+
size_t argc = 1;
|
|
270
|
+
napi_value args[1];
|
|
271
|
+
napi_get_cb_info(env, info, &argc, args, NULL, NULL);
|
|
272
|
+
if (argc != 1) {
|
|
273
|
+
napi_throw_error(env, NULL, "Expected 1 argument");
|
|
274
|
+
return NULL;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
uint32_t bits;
|
|
278
|
+
if (napi_get_value_uint32(env, args[0], &bits) != napi_ok) {
|
|
279
|
+
napi_throw_error(env, NULL, "Argument must be a number");
|
|
280
|
+
return NULL;
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
nk_f32_t f32_val;
|
|
284
|
+
nk_cast(&bits, src_dtype, 1, &f32_val, nk_f32_k);
|
|
285
|
+
|
|
286
|
+
napi_value result;
|
|
287
|
+
napi_create_double(env, (double)f32_val, &result);
|
|
288
|
+
return result;
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
/** @brief Converts a single f32 value to a narrow type. Reads double, returns uint32 bits. */
|
|
292
|
+
static napi_value cast_from_f32(napi_env env, napi_callback_info info, nk_dtype_t dst_dtype) {
|
|
293
|
+
size_t argc = 1;
|
|
294
|
+
napi_value args[1];
|
|
295
|
+
napi_get_cb_info(env, info, &argc, args, NULL, NULL);
|
|
296
|
+
if (argc != 1) {
|
|
297
|
+
napi_throw_error(env, NULL, "Expected 1 argument");
|
|
298
|
+
return NULL;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
double f32_dbl;
|
|
302
|
+
if (napi_get_value_double(env, args[0], &f32_dbl) != napi_ok) {
|
|
303
|
+
napi_throw_error(env, NULL, "Argument must be a number");
|
|
304
|
+
return NULL;
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
nk_f32_t f32_val = (nk_f32_t)f32_dbl;
|
|
308
|
+
uint32_t bits = 0;
|
|
309
|
+
nk_cast(&f32_val, nk_f32_k, 1, &bits, dst_dtype);
|
|
310
|
+
|
|
311
|
+
napi_value result;
|
|
312
|
+
napi_create_uint32(env, bits, &result);
|
|
313
|
+
return result;
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
/** @brief N-API entry for scalar f16-to-f32 conversion. */
|
|
317
|
+
napi_value api_cast_f16_to_f32(napi_env env, napi_callback_info info) { return cast_to_f32(env, info, nk_f16_k); }
|
|
318
|
+
/** @brief N-API entry for scalar f32-to-f16 conversion. */
|
|
319
|
+
napi_value api_cast_f32_to_f16(napi_env env, napi_callback_info info) { return cast_from_f32(env, info, nk_f16_k); }
|
|
320
|
+
/** @brief N-API entry for scalar bf16-to-f32 conversion. */
|
|
321
|
+
napi_value api_cast_bf16_to_f32(napi_env env, napi_callback_info info) { return cast_to_f32(env, info, nk_bf16_k); }
|
|
322
|
+
/** @brief N-API entry for scalar f32-to-bf16 conversion. */
|
|
323
|
+
napi_value api_cast_f32_to_bf16(napi_env env, napi_callback_info info) { return cast_from_f32(env, info, nk_bf16_k); }
|
|
324
|
+
/** @brief N-API entry for scalar e4m3-to-f32 conversion. */
|
|
325
|
+
napi_value api_cast_e4m3_to_f32(napi_env env, napi_callback_info info) { return cast_to_f32(env, info, nk_e4m3_k); }
|
|
326
|
+
/** @brief N-API entry for scalar f32-to-e4m3 conversion. */
|
|
327
|
+
napi_value api_cast_f32_to_e4m3(napi_env env, napi_callback_info info) { return cast_from_f32(env, info, nk_e4m3_k); }
|
|
328
|
+
/** @brief N-API entry for scalar e5m2-to-f32 conversion. */
|
|
329
|
+
napi_value api_cast_e5m2_to_f32(napi_env env, napi_callback_info info) { return cast_to_f32(env, info, nk_e5m2_k); }
|
|
330
|
+
/** @brief N-API entry for scalar f32-to-e5m2 conversion. */
|
|
331
|
+
napi_value api_cast_f32_to_e5m2(napi_env env, napi_callback_info info) { return cast_from_f32(env, info, nk_e5m2_k); }
|
|
332
|
+
|
|
333
|
+
/**
|
|
334
|
+
* @brief Buffer casting function using nk_cast.
|
|
335
|
+
* @param env N-API environment
|
|
336
|
+
* @param info Callback info containing 4 arguments:
|
|
337
|
+
* - src: source TypedArray
|
|
338
|
+
* - srcType: source dtype string
|
|
339
|
+
* - dst: destination TypedArray
|
|
340
|
+
* - dstType: destination dtype string
|
|
341
|
+
* @return null (modifies dst in place)
|
|
342
|
+
*/
|
|
343
|
+
napi_value api_cast(napi_env env, napi_callback_info info) {
|
|
344
|
+
size_t argc = 4;
|
|
345
|
+
napi_value args[4];
|
|
346
|
+
napi_get_cb_info(env, info, &argc, args, NULL, NULL);
|
|
347
|
+
|
|
348
|
+
if (argc != 4) {
|
|
349
|
+
napi_throw_error(env, NULL, "cast requires 4 arguments: (src, srcType, dst, dstType)");
|
|
350
|
+
return NULL;
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
// Get source and destination arrays
|
|
354
|
+
void *src_data, *dst_data;
|
|
355
|
+
size_t src_len, dst_len;
|
|
356
|
+
napi_typedarray_type src_type, dst_type;
|
|
357
|
+
|
|
358
|
+
napi_get_typedarray_info(env, args[0], &src_type, &src_len, &src_data, NULL, NULL);
|
|
359
|
+
napi_get_typedarray_info(env, args[2], &dst_type, &dst_len, &dst_data, NULL, NULL);
|
|
360
|
+
|
|
361
|
+
// Get dtype strings
|
|
362
|
+
char src_dtype_str[16], dst_dtype_str[16];
|
|
363
|
+
size_t str_len;
|
|
364
|
+
napi_get_value_string_utf8(env, args[1], src_dtype_str, sizeof(src_dtype_str), &str_len);
|
|
365
|
+
napi_get_value_string_utf8(env, args[3], dst_dtype_str, sizeof(dst_dtype_str), &str_len);
|
|
366
|
+
|
|
367
|
+
// Map dtype strings to nk_dtype_t
|
|
368
|
+
nk_dtype_t src_dtype = parse_dtype_string(src_dtype_str);
|
|
369
|
+
nk_dtype_t dst_dtype = parse_dtype_string(dst_dtype_str);
|
|
370
|
+
|
|
371
|
+
if (src_dtype == nk_dtype_unknown_k || dst_dtype == nk_dtype_unknown_k) {
|
|
372
|
+
napi_throw_error(env, NULL, "Unsupported dtype string");
|
|
373
|
+
return NULL;
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
// Perform conversion using nk_cast
|
|
377
|
+
nk_cast(src_data, src_dtype, src_len, dst_data, dst_dtype);
|
|
378
|
+
|
|
379
|
+
return NULL; // Modifies dst_data in place
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
#pragma endregion Cast API
|
|
383
|
+
|
|
384
|
+
#pragma region Packed API
|
|
385
|
+
|
|
386
|
+
/** @brief Query packed buffer byte count: dotsPackedSize(width, depth, dtype) → number */
|
|
387
|
+
static napi_value api_dots_packed_size(napi_env env, napi_callback_info info) {
|
|
388
|
+
size_t argc = 3;
|
|
389
|
+
napi_value args[3];
|
|
390
|
+
napi_get_cb_info(env, info, &argc, args, NULL, NULL);
|
|
391
|
+
if (argc != 3) {
|
|
392
|
+
napi_throw_error(env, NULL, "dotsPackedSize requires 3 arguments: (width, depth, dtype)");
|
|
393
|
+
return NULL;
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
uint32_t width, depth;
|
|
397
|
+
napi_get_value_uint32(env, args[0], &width);
|
|
398
|
+
napi_get_value_uint32(env, args[1], &depth);
|
|
399
|
+
|
|
400
|
+
char dtype_str[16];
|
|
401
|
+
size_t str_len;
|
|
402
|
+
napi_get_value_string_utf8(env, args[2], dtype_str, sizeof(dtype_str), &str_len);
|
|
403
|
+
nk_dtype_t dtype = parse_dtype_string(dtype_str);
|
|
404
|
+
if (dtype == nk_dtype_unknown_k) {
|
|
405
|
+
napi_throw_error(env, NULL, "Unsupported dtype string");
|
|
406
|
+
return NULL;
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
nk_dots_packed_size_punned_t size_fn = NULL;
|
|
410
|
+
nk_capability_t cap = nk_cap_serial_k;
|
|
411
|
+
nk_find_kernel_punned(nk_kernel_dots_packed_size_k, dtype, static_capabilities, (nk_kernel_punned_t *)&size_fn,
|
|
412
|
+
&cap);
|
|
413
|
+
if (!size_fn) {
|
|
414
|
+
napi_throw_error(env, NULL, "dots_packed_size not available for this dtype");
|
|
415
|
+
return NULL;
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
nk_size_t byte_count = size_fn((nk_size_t)width, (nk_size_t)depth);
|
|
419
|
+
|
|
420
|
+
napi_value result;
|
|
421
|
+
napi_create_double(env, (double)byte_count, &result);
|
|
422
|
+
return result;
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
/** @brief Pack B matrix: dotsPack(data, width, depth, strideBytes, dtype) → { buffer, width, depth, byteLength } */
|
|
426
|
+
static napi_value api_dots_pack(napi_env env, napi_callback_info info) {
|
|
427
|
+
size_t argc = 5;
|
|
428
|
+
napi_value args[5];
|
|
429
|
+
napi_get_cb_info(env, info, &argc, args, NULL, NULL);
|
|
430
|
+
if (argc != 5) {
|
|
431
|
+
napi_throw_error(env, NULL, "dotsPack requires 5 arguments: (data, width, depth, strideBytes, dtype)");
|
|
432
|
+
return NULL;
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
void *data;
|
|
436
|
+
size_t data_len;
|
|
437
|
+
napi_typedarray_type arr_type;
|
|
438
|
+
napi_get_typedarray_info(env, args[0], &arr_type, &data_len, &data, NULL, NULL);
|
|
439
|
+
|
|
440
|
+
uint32_t width, depth, stride_bytes;
|
|
441
|
+
napi_get_value_uint32(env, args[1], &width);
|
|
442
|
+
napi_get_value_uint32(env, args[2], &depth);
|
|
443
|
+
napi_get_value_uint32(env, args[3], &stride_bytes);
|
|
444
|
+
|
|
445
|
+
char dtype_str[16];
|
|
446
|
+
size_t str_len;
|
|
447
|
+
napi_get_value_string_utf8(env, args[4], dtype_str, sizeof(dtype_str), &str_len);
|
|
448
|
+
nk_dtype_t dtype = parse_dtype_string(dtype_str);
|
|
449
|
+
if (dtype == nk_dtype_unknown_k) {
|
|
450
|
+
napi_throw_error(env, NULL, "Unsupported dtype string");
|
|
451
|
+
return NULL;
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
// Get packed size
|
|
455
|
+
nk_dots_packed_size_punned_t size_fn = NULL;
|
|
456
|
+
nk_capability_t cap = nk_cap_serial_k;
|
|
457
|
+
nk_find_kernel_punned(nk_kernel_dots_packed_size_k, dtype, static_capabilities, (nk_kernel_punned_t *)&size_fn,
|
|
458
|
+
&cap);
|
|
459
|
+
if (!size_fn) {
|
|
460
|
+
napi_throw_error(env, NULL, "dots_packed_size not available for this dtype");
|
|
461
|
+
return NULL;
|
|
462
|
+
}
|
|
463
|
+
nk_size_t packed_byte_count = size_fn((nk_size_t)width, (nk_size_t)depth);
|
|
464
|
+
|
|
465
|
+
// Allocate V8-managed ArrayBuffer for packed data
|
|
466
|
+
void *packed_data = NULL;
|
|
467
|
+
napi_value arraybuffer;
|
|
468
|
+
if (napi_create_arraybuffer(env, packed_byte_count, &packed_data, &arraybuffer) != napi_ok) {
|
|
469
|
+
napi_throw_error(env, NULL, "Failed to allocate packed buffer");
|
|
470
|
+
return NULL;
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
// Pack
|
|
474
|
+
nk_dots_pack_punned_t pack_fn = NULL;
|
|
475
|
+
cap = nk_cap_serial_k;
|
|
476
|
+
nk_find_kernel_punned(nk_kernel_dots_pack_k, dtype, static_capabilities, (nk_kernel_punned_t *)&pack_fn, &cap);
|
|
477
|
+
if (!pack_fn) {
|
|
478
|
+
napi_throw_error(env, NULL, "dots_pack not available for this dtype");
|
|
479
|
+
return NULL;
|
|
480
|
+
}
|
|
481
|
+
pack_fn(data, (nk_size_t)width, (nk_size_t)depth, (nk_size_t)stride_bytes, packed_data);
|
|
482
|
+
|
|
483
|
+
// Return object { buffer, width, depth, byteLength }
|
|
484
|
+
napi_value result_obj;
|
|
485
|
+
napi_create_object(env, &result_obj);
|
|
486
|
+
|
|
487
|
+
napi_value js_width, js_depth, js_byte_length;
|
|
488
|
+
napi_create_uint32(env, width, &js_width);
|
|
489
|
+
napi_create_uint32(env, depth, &js_depth);
|
|
490
|
+
napi_create_double(env, (double)packed_byte_count, &js_byte_length);
|
|
491
|
+
|
|
492
|
+
napi_set_named_property(env, result_obj, "buffer", arraybuffer);
|
|
493
|
+
napi_set_named_property(env, result_obj, "width", js_width);
|
|
494
|
+
napi_set_named_property(env, result_obj, "depth", js_depth);
|
|
495
|
+
napi_set_named_property(env, result_obj, "byteLength", js_byte_length);
|
|
496
|
+
|
|
497
|
+
return result_obj;
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
/**
|
|
501
|
+
* @brief Shared dispatcher for packed operations (dots, angulars, euclideans).
|
|
502
|
+
* Args: TypedArray a, ArrayBuffer packed, TypedArray result, numbers height/width/depth/aStride/resultStride, string
|
|
503
|
+
* dtype
|
|
504
|
+
*/
|
|
505
|
+
static napi_value api_packed_common(napi_env env, napi_callback_info info, nk_kernel_kind_t kernel_kind) {
|
|
506
|
+
size_t argc = 9;
|
|
507
|
+
napi_value args[9];
|
|
508
|
+
napi_get_cb_info(env, info, &argc, args, NULL, NULL);
|
|
509
|
+
if (argc != 9) {
|
|
510
|
+
napi_throw_error(env, NULL, "Packed operation requires 9 arguments");
|
|
511
|
+
return NULL;
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
// arg[0]: TypedArray a
|
|
515
|
+
void *a_data;
|
|
516
|
+
size_t a_len;
|
|
517
|
+
napi_typedarray_type a_type;
|
|
518
|
+
napi_get_typedarray_info(env, args[0], &a_type, &a_len, &a_data, NULL, NULL);
|
|
519
|
+
|
|
520
|
+
// arg[1]: ArrayBuffer packed
|
|
521
|
+
void *packed_data;
|
|
522
|
+
size_t packed_len;
|
|
523
|
+
napi_get_arraybuffer_info(env, args[1], &packed_data, &packed_len);
|
|
524
|
+
|
|
525
|
+
// arg[2]: TypedArray result
|
|
526
|
+
void *result_data;
|
|
527
|
+
size_t result_len;
|
|
528
|
+
napi_typedarray_type result_type;
|
|
529
|
+
napi_get_typedarray_info(env, args[2], &result_type, &result_len, &result_data, NULL, NULL);
|
|
530
|
+
|
|
531
|
+
// args[3..7]: height, width, depth, aStride, resultStride
|
|
532
|
+
uint32_t height, width, depth, a_stride, result_stride;
|
|
533
|
+
napi_get_value_uint32(env, args[3], &height);
|
|
534
|
+
napi_get_value_uint32(env, args[4], &width);
|
|
535
|
+
napi_get_value_uint32(env, args[5], &depth);
|
|
536
|
+
napi_get_value_uint32(env, args[6], &a_stride);
|
|
537
|
+
napi_get_value_uint32(env, args[7], &result_stride);
|
|
538
|
+
|
|
539
|
+
// arg[8]: dtype string
|
|
540
|
+
char dtype_str[16];
|
|
541
|
+
size_t str_len;
|
|
542
|
+
napi_get_value_string_utf8(env, args[8], dtype_str, sizeof(dtype_str), &str_len);
|
|
543
|
+
nk_dtype_t dtype = parse_dtype_string(dtype_str);
|
|
544
|
+
if (dtype == nk_dtype_unknown_k) {
|
|
545
|
+
napi_throw_error(env, NULL, "Unsupported dtype string");
|
|
546
|
+
return NULL;
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
nk_dots_packed_punned_t kernel = NULL;
|
|
550
|
+
nk_capability_t cap = nk_cap_serial_k;
|
|
551
|
+
nk_find_kernel_punned(kernel_kind, dtype, static_capabilities, (nk_kernel_punned_t *)&kernel, &cap);
|
|
552
|
+
if (!kernel) {
|
|
553
|
+
napi_throw_error(env, NULL, "Packed kernel not available for this dtype");
|
|
554
|
+
return NULL;
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
kernel(a_data, packed_data, result_data, (nk_size_t)height, (nk_size_t)width, (nk_size_t)depth, (nk_size_t)a_stride,
|
|
558
|
+
(nk_size_t)result_stride);
|
|
559
|
+
return NULL;
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
static napi_value api_dots_packed(napi_env env, napi_callback_info info) {
|
|
563
|
+
return api_packed_common(env, info, nk_kernel_dots_packed_k);
|
|
564
|
+
}
|
|
565
|
+
static napi_value api_angulars_packed(napi_env env, napi_callback_info info) {
|
|
566
|
+
return api_packed_common(env, info, nk_kernel_angulars_packed_k);
|
|
567
|
+
}
|
|
568
|
+
static napi_value api_euclideans_packed(napi_env env, napi_callback_info info) {
|
|
569
|
+
return api_packed_common(env, info, nk_kernel_euclideans_packed_k);
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
/**
|
|
573
|
+
* @brief Shared dispatcher for symmetric operations (dots, angulars, euclideans).
|
|
574
|
+
* Args: TypedArray vectors, TypedArray result, numbers nVectors/depth/vectorsStride/resultStride/rowStart/rowCount,
|
|
575
|
+
* string dtype
|
|
576
|
+
*/
|
|
577
|
+
static napi_value api_symmetric_common(napi_env env, napi_callback_info info, nk_kernel_kind_t kernel_kind) {
|
|
578
|
+
size_t argc = 9;
|
|
579
|
+
napi_value args[9];
|
|
580
|
+
napi_get_cb_info(env, info, &argc, args, NULL, NULL);
|
|
581
|
+
if (argc != 9) {
|
|
582
|
+
napi_throw_error(env, NULL, "Symmetric operation requires 9 arguments");
|
|
583
|
+
return NULL;
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
// arg[0]: TypedArray vectors
|
|
587
|
+
void *vectors_data;
|
|
588
|
+
size_t vectors_len;
|
|
589
|
+
napi_typedarray_type vectors_type;
|
|
590
|
+
napi_get_typedarray_info(env, args[0], &vectors_type, &vectors_len, &vectors_data, NULL, NULL);
|
|
591
|
+
|
|
592
|
+
// arg[1]: TypedArray result
|
|
593
|
+
void *result_data;
|
|
594
|
+
size_t result_len;
|
|
595
|
+
napi_typedarray_type result_type;
|
|
596
|
+
napi_get_typedarray_info(env, args[1], &result_type, &result_len, &result_data, NULL, NULL);
|
|
597
|
+
|
|
598
|
+
// args[2..7]: nVectors, depth, vectorsStride, resultStride, rowStart, rowCount
|
|
599
|
+
uint32_t n_vectors, depth, vectors_stride, result_stride, row_start, row_count;
|
|
600
|
+
napi_get_value_uint32(env, args[2], &n_vectors);
|
|
601
|
+
napi_get_value_uint32(env, args[3], &depth);
|
|
602
|
+
napi_get_value_uint32(env, args[4], &vectors_stride);
|
|
603
|
+
napi_get_value_uint32(env, args[5], &result_stride);
|
|
604
|
+
napi_get_value_uint32(env, args[6], &row_start);
|
|
605
|
+
napi_get_value_uint32(env, args[7], &row_count);
|
|
606
|
+
|
|
607
|
+
// arg[8]: dtype string
|
|
608
|
+
char dtype_str[16];
|
|
609
|
+
size_t str_len;
|
|
610
|
+
napi_get_value_string_utf8(env, args[8], dtype_str, sizeof(dtype_str), &str_len);
|
|
611
|
+
nk_dtype_t dtype = parse_dtype_string(dtype_str);
|
|
612
|
+
if (dtype == nk_dtype_unknown_k) {
|
|
613
|
+
napi_throw_error(env, NULL, "Unsupported dtype string");
|
|
614
|
+
return NULL;
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
nk_dots_symmetric_punned_t kernel = NULL;
|
|
618
|
+
nk_capability_t cap = nk_cap_serial_k;
|
|
619
|
+
nk_find_kernel_punned(kernel_kind, dtype, static_capabilities, (nk_kernel_punned_t *)&kernel, &cap);
|
|
620
|
+
if (!kernel) {
|
|
621
|
+
napi_throw_error(env, NULL, "Symmetric kernel not available for this dtype");
|
|
622
|
+
return NULL;
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
kernel(vectors_data, (nk_size_t)n_vectors, (nk_size_t)depth, (nk_size_t)vectors_stride, result_data,
|
|
626
|
+
(nk_size_t)result_stride, (nk_size_t)row_start, (nk_size_t)row_count);
|
|
627
|
+
return NULL;
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
static napi_value api_dots_symmetric(napi_env env, napi_callback_info info) {
|
|
631
|
+
return api_symmetric_common(env, info, nk_kernel_dots_symmetric_k);
|
|
632
|
+
}
|
|
633
|
+
static napi_value api_angulars_symmetric(napi_env env, napi_callback_info info) {
|
|
634
|
+
return api_symmetric_common(env, info, nk_kernel_angulars_symmetric_k);
|
|
635
|
+
}
|
|
636
|
+
static napi_value api_euclideans_symmetric(napi_env env, napi_callback_info info) {
|
|
637
|
+
return api_symmetric_common(env, info, nk_kernel_euclideans_symmetric_k);
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
#pragma endregion Packed API
|
|
641
|
+
|
|
642
|
+
#pragma region Module Init
|
|
643
|
+
|
|
644
|
+
/** @brief Registers a C function as a named JavaScript export. */
|
|
645
|
+
static napi_status export_function(napi_env env, napi_value exports, char const *name, napi_callback func) {
|
|
646
|
+
napi_value fn;
|
|
647
|
+
napi_status status = napi_create_function(env, name, NAPI_AUTO_LENGTH, func, NULL, &fn);
|
|
648
|
+
if (status != napi_ok) return status;
|
|
649
|
+
return napi_set_named_property(env, exports, name, fn);
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
/** @brief Module initialization — exports all functions, detects CPU capabilities. */
|
|
653
|
+
napi_value Init(napi_env env, napi_value exports) {
|
|
654
|
+
if (export_function(env, exports, "dot", api_ip) != napi_ok ||
|
|
655
|
+
export_function(env, exports, "inner", api_ip) != napi_ok ||
|
|
656
|
+
export_function(env, exports, "sqeuclidean", api_sqeuclidean) != napi_ok ||
|
|
657
|
+
export_function(env, exports, "euclidean", api_euclidean) != napi_ok ||
|
|
658
|
+
export_function(env, exports, "angular", api_angular) != napi_ok ||
|
|
659
|
+
export_function(env, exports, "hamming", api_hamming) != napi_ok ||
|
|
660
|
+
export_function(env, exports, "jaccard", api_jaccard) != napi_ok ||
|
|
661
|
+
export_function(env, exports, "kullbackleibler", api_kld) != napi_ok ||
|
|
662
|
+
export_function(env, exports, "jensenshannon", api_jsd) != napi_ok ||
|
|
663
|
+
export_function(env, exports, "getCapabilities", api_get_capabilities) != napi_ok ||
|
|
664
|
+
export_function(env, exports, "castF16ToF32", api_cast_f16_to_f32) != napi_ok ||
|
|
665
|
+
export_function(env, exports, "castF32ToF16", api_cast_f32_to_f16) != napi_ok ||
|
|
666
|
+
export_function(env, exports, "castBF16ToF32", api_cast_bf16_to_f32) != napi_ok ||
|
|
667
|
+
export_function(env, exports, "castF32ToBF16", api_cast_f32_to_bf16) != napi_ok ||
|
|
668
|
+
export_function(env, exports, "castE4M3ToF32", api_cast_e4m3_to_f32) != napi_ok ||
|
|
669
|
+
export_function(env, exports, "castF32ToE4M3", api_cast_f32_to_e4m3) != napi_ok ||
|
|
670
|
+
export_function(env, exports, "castE5M2ToF32", api_cast_e5m2_to_f32) != napi_ok ||
|
|
671
|
+
export_function(env, exports, "castF32ToE5M2", api_cast_f32_to_e5m2) != napi_ok ||
|
|
672
|
+
export_function(env, exports, "cast", api_cast) != napi_ok ||
|
|
673
|
+
export_function(env, exports, "dotsPackedSize", api_dots_packed_size) != napi_ok ||
|
|
674
|
+
export_function(env, exports, "dotsPack", api_dots_pack) != napi_ok ||
|
|
675
|
+
export_function(env, exports, "dotsPacked", api_dots_packed) != napi_ok ||
|
|
676
|
+
export_function(env, exports, "angularsPacked", api_angulars_packed) != napi_ok ||
|
|
677
|
+
export_function(env, exports, "euclideansPacked", api_euclideans_packed) != napi_ok ||
|
|
678
|
+
export_function(env, exports, "dotsSymmetric", api_dots_symmetric) != napi_ok ||
|
|
679
|
+
export_function(env, exports, "angularsSymmetric", api_angulars_symmetric) != napi_ok ||
|
|
680
|
+
export_function(env, exports, "euclideansSymmetric", api_euclideans_symmetric) != napi_ok) {
|
|
681
|
+
return NULL;
|
|
682
|
+
}
|
|
683
|
+
static_capabilities = nk_capabilities();
|
|
684
|
+
return exports;
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
#pragma endregion Module Init
|
|
688
|
+
|
|
689
|
+
NAPI_MODULE(NODE_GYP_MODULE_NAME, Init)
|