numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
package/wasm/numkong.js
ADDED
|
@@ -0,0 +1,1124 @@
|
|
|
1
|
+
// javascript/dist/esm/types.js
|
|
2
|
+
var conversionFunctions;
|
|
3
|
+
var DType;
|
|
4
|
+
(function(DType2) {
|
|
5
|
+
DType2[DType2["F64"] = 0] = "F64";
|
|
6
|
+
DType2[DType2["F32"] = 1] = "F32";
|
|
7
|
+
DType2[DType2["F16"] = 2] = "F16";
|
|
8
|
+
DType2[DType2["BF16"] = 3] = "BF16";
|
|
9
|
+
DType2[DType2["E4M3"] = 4] = "E4M3";
|
|
10
|
+
DType2[DType2["E5M2"] = 5] = "E5M2";
|
|
11
|
+
DType2[DType2["E2M3"] = 6] = "E2M3";
|
|
12
|
+
DType2[DType2["E3M2"] = 7] = "E3M2";
|
|
13
|
+
DType2[DType2["I8"] = 8] = "I8";
|
|
14
|
+
DType2[DType2["U8"] = 9] = "U8";
|
|
15
|
+
DType2[DType2["U1"] = 10] = "U1";
|
|
16
|
+
DType2[DType2["I32"] = 11] = "I32";
|
|
17
|
+
DType2[DType2["U32"] = 12] = "U32";
|
|
18
|
+
})(DType || (DType = {}));
|
|
19
|
+
var DTYPE_STRINGS = [
|
|
20
|
+
"f64",
|
|
21
|
+
"f32",
|
|
22
|
+
"f16",
|
|
23
|
+
"bf16",
|
|
24
|
+
"e4m3",
|
|
25
|
+
"e5m2",
|
|
26
|
+
"e2m3",
|
|
27
|
+
"e3m2",
|
|
28
|
+
"i8",
|
|
29
|
+
"u8",
|
|
30
|
+
"u1",
|
|
31
|
+
"i32",
|
|
32
|
+
"u32"
|
|
33
|
+
];
|
|
34
|
+
function dtypeToString(d) {
|
|
35
|
+
return DTYPE_STRINGS[d];
|
|
36
|
+
}
|
|
37
|
+
function inferDtype(arr) {
|
|
38
|
+
if (arr instanceof Float64Array)
|
|
39
|
+
return DType.F64;
|
|
40
|
+
if (arr instanceof Float32Array)
|
|
41
|
+
return DType.F32;
|
|
42
|
+
if (arr instanceof Int32Array)
|
|
43
|
+
return DType.I32;
|
|
44
|
+
if (arr instanceof Int8Array)
|
|
45
|
+
return DType.I8;
|
|
46
|
+
if (arr instanceof Uint8Array)
|
|
47
|
+
return DType.U8;
|
|
48
|
+
if (arr instanceof Uint16Array)
|
|
49
|
+
return DType.F16;
|
|
50
|
+
if (arr instanceof Uint32Array)
|
|
51
|
+
return DType.U32;
|
|
52
|
+
throw new Error(`Cannot infer dtype from ${arr.constructor.name}`);
|
|
53
|
+
}
|
|
54
|
+
var TensorBase = class {
|
|
55
|
+
constructor(buffer, byteOffset, dtype) {
|
|
56
|
+
this.buffer = buffer;
|
|
57
|
+
this.byteOffset = byteOffset;
|
|
58
|
+
this.dtype = dtype;
|
|
59
|
+
}
|
|
60
|
+
/** @brief Bytes per element for this tensor's dtype (compiles to jump table). */
|
|
61
|
+
get bytesPerElement() {
|
|
62
|
+
switch (this.dtype) {
|
|
63
|
+
case DType.F64:
|
|
64
|
+
return 8;
|
|
65
|
+
case DType.F32:
|
|
66
|
+
case DType.I32:
|
|
67
|
+
case DType.U32:
|
|
68
|
+
return 4;
|
|
69
|
+
case DType.F16:
|
|
70
|
+
case DType.BF16:
|
|
71
|
+
return 2;
|
|
72
|
+
default:
|
|
73
|
+
return 1;
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
/** @brief Total byte length of the tensor data. */
|
|
77
|
+
get byteLength() {
|
|
78
|
+
return this.length * this.bytesPerElement;
|
|
79
|
+
}
|
|
80
|
+
};
|
|
81
|
+
var VectorBase = class extends TensorBase {
|
|
82
|
+
constructor(buffer, byteOffset, length, dtype) {
|
|
83
|
+
super(buffer, byteOffset, dtype);
|
|
84
|
+
this.length = length;
|
|
85
|
+
}
|
|
86
|
+
get rank() {
|
|
87
|
+
return 1;
|
|
88
|
+
}
|
|
89
|
+
};
|
|
90
|
+
var VectorView = class _VectorView extends VectorBase {
|
|
91
|
+
constructor(buffer, byteOffset, length, dtype) {
|
|
92
|
+
super(buffer, byteOffset, length, dtype);
|
|
93
|
+
}
|
|
94
|
+
toString() {
|
|
95
|
+
return `VectorView(${this.length}, ${dtypeToString(this.dtype)})`;
|
|
96
|
+
}
|
|
97
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
98
|
+
return this.toString();
|
|
99
|
+
}
|
|
100
|
+
/** @brief Create a VectorView from any TypedArray, inferring or accepting dtype. */
|
|
101
|
+
static from(arr, dtype) {
|
|
102
|
+
const d = dtype ?? inferDtype(arr);
|
|
103
|
+
return new _VectorView(arr.buffer, arr.byteOffset, arr.length, d);
|
|
104
|
+
}
|
|
105
|
+
};
|
|
106
|
+
var Vector = class _Vector extends VectorBase {
|
|
107
|
+
constructor(lengthOrBuffer, dtypeOrLength, dtype) {
|
|
108
|
+
if (typeof lengthOrBuffer === "number") {
|
|
109
|
+
const length = lengthOrBuffer;
|
|
110
|
+
const dt = dtypeOrLength;
|
|
111
|
+
let bpe;
|
|
112
|
+
switch (dt) {
|
|
113
|
+
case DType.F64:
|
|
114
|
+
bpe = 8;
|
|
115
|
+
break;
|
|
116
|
+
case DType.F32:
|
|
117
|
+
case DType.I32:
|
|
118
|
+
case DType.U32:
|
|
119
|
+
bpe = 4;
|
|
120
|
+
break;
|
|
121
|
+
case DType.F16:
|
|
122
|
+
case DType.BF16:
|
|
123
|
+
bpe = 2;
|
|
124
|
+
break;
|
|
125
|
+
default:
|
|
126
|
+
bpe = 1;
|
|
127
|
+
break;
|
|
128
|
+
}
|
|
129
|
+
super(new ArrayBuffer(length * bpe), 0, length, dt);
|
|
130
|
+
} else {
|
|
131
|
+
super(lengthOrBuffer, 0, dtypeOrLength, dtype);
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
toString() {
|
|
135
|
+
return `Vector(${this.length}, ${dtypeToString(this.dtype)})`;
|
|
136
|
+
}
|
|
137
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
138
|
+
return this.toString();
|
|
139
|
+
}
|
|
140
|
+
/** @brief Create an owning Vector by copying data from a TypedArray. */
|
|
141
|
+
static fromTypedArray(arr, dtype) {
|
|
142
|
+
const d = dtype ?? inferDtype(arr);
|
|
143
|
+
return new _Vector(arr.buffer.slice(arr.byteOffset, arr.byteOffset + arr.byteLength), arr.length, d);
|
|
144
|
+
}
|
|
145
|
+
/** @brief Create an owning Vector by copying data from any TensorBase. */
|
|
146
|
+
static fromView(view) {
|
|
147
|
+
return new _Vector(view.buffer.slice(view.byteOffset, view.byteOffset + view.byteLength), view.length, view.dtype);
|
|
148
|
+
}
|
|
149
|
+
/** @brief Return a TypedArray view over this Vector's owned buffer (zero-copy). */
|
|
150
|
+
toTypedArray() {
|
|
151
|
+
switch (this.dtype) {
|
|
152
|
+
case DType.F64:
|
|
153
|
+
return new Float64Array(this.buffer, 0, this.length);
|
|
154
|
+
case DType.F32:
|
|
155
|
+
return new Float32Array(this.buffer, 0, this.length);
|
|
156
|
+
case DType.I32:
|
|
157
|
+
return new Int32Array(this.buffer, 0, this.length);
|
|
158
|
+
case DType.U32:
|
|
159
|
+
return new Uint32Array(this.buffer, 0, this.length);
|
|
160
|
+
case DType.F16:
|
|
161
|
+
case DType.BF16:
|
|
162
|
+
return new Uint16Array(this.buffer, 0, this.length);
|
|
163
|
+
case DType.I8:
|
|
164
|
+
return new Int8Array(this.buffer, 0, this.length);
|
|
165
|
+
default:
|
|
166
|
+
return new Uint8Array(this.buffer, 0, this.length);
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
};
|
|
170
|
+
var MatrixBase = class extends TensorBase {
|
|
171
|
+
constructor(buffer, byteOffset, dtype, rows, cols, rowStride, colStride) {
|
|
172
|
+
super(buffer, byteOffset, dtype);
|
|
173
|
+
this.rows = rows;
|
|
174
|
+
this.cols = cols;
|
|
175
|
+
this.rowStride = rowStride;
|
|
176
|
+
this.colStride = colStride;
|
|
177
|
+
}
|
|
178
|
+
get length() {
|
|
179
|
+
return this.rows * this.cols;
|
|
180
|
+
}
|
|
181
|
+
get rank() {
|
|
182
|
+
return 2;
|
|
183
|
+
}
|
|
184
|
+
};
|
|
185
|
+
var Matrix = class _Matrix extends MatrixBase {
|
|
186
|
+
constructor(rowsOrBuffer, colsOrByteOffset, dtype, rows, cols, rowStride, colStride) {
|
|
187
|
+
if (typeof rowsOrBuffer === "number") {
|
|
188
|
+
const r = rowsOrBuffer;
|
|
189
|
+
const c = colsOrByteOffset;
|
|
190
|
+
let bpe;
|
|
191
|
+
switch (dtype) {
|
|
192
|
+
case DType.F64:
|
|
193
|
+
bpe = 8;
|
|
194
|
+
break;
|
|
195
|
+
case DType.F32:
|
|
196
|
+
case DType.I32:
|
|
197
|
+
case DType.U32:
|
|
198
|
+
bpe = 4;
|
|
199
|
+
break;
|
|
200
|
+
case DType.F16:
|
|
201
|
+
case DType.BF16:
|
|
202
|
+
bpe = 2;
|
|
203
|
+
break;
|
|
204
|
+
default:
|
|
205
|
+
bpe = 1;
|
|
206
|
+
break;
|
|
207
|
+
}
|
|
208
|
+
super(new ArrayBuffer(r * c * bpe), 0, dtype, r, c, c * bpe, bpe);
|
|
209
|
+
} else {
|
|
210
|
+
const r = rows;
|
|
211
|
+
const c = cols;
|
|
212
|
+
let bpe;
|
|
213
|
+
switch (dtype) {
|
|
214
|
+
case DType.F64:
|
|
215
|
+
bpe = 8;
|
|
216
|
+
break;
|
|
217
|
+
case DType.F32:
|
|
218
|
+
case DType.I32:
|
|
219
|
+
case DType.U32:
|
|
220
|
+
bpe = 4;
|
|
221
|
+
break;
|
|
222
|
+
case DType.F16:
|
|
223
|
+
case DType.BF16:
|
|
224
|
+
bpe = 2;
|
|
225
|
+
break;
|
|
226
|
+
default:
|
|
227
|
+
bpe = 1;
|
|
228
|
+
break;
|
|
229
|
+
}
|
|
230
|
+
super(rowsOrBuffer, colsOrByteOffset, dtype, r, c, rowStride ?? c * bpe, colStride ?? bpe);
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
toString() {
|
|
234
|
+
return `Matrix(${this.rows}\xD7${this.cols}, ${dtypeToString(this.dtype)})`;
|
|
235
|
+
}
|
|
236
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
237
|
+
return this.toString();
|
|
238
|
+
}
|
|
239
|
+
static fromTypedArray(array, rows, cols, dtype) {
|
|
240
|
+
const d = dtype ?? inferDtype(array);
|
|
241
|
+
const buf = array.buffer.slice(array.byteOffset, array.byteOffset + array.byteLength);
|
|
242
|
+
return new _Matrix(buf, 0, d, rows, cols);
|
|
243
|
+
}
|
|
244
|
+
toTypedArray() {
|
|
245
|
+
switch (this.dtype) {
|
|
246
|
+
case DType.F64:
|
|
247
|
+
return new Float64Array(this.buffer, this.byteOffset, this.rows * this.cols);
|
|
248
|
+
case DType.F32:
|
|
249
|
+
return new Float32Array(this.buffer, this.byteOffset, this.rows * this.cols);
|
|
250
|
+
case DType.I32:
|
|
251
|
+
return new Int32Array(this.buffer, this.byteOffset, this.rows * this.cols);
|
|
252
|
+
case DType.U32:
|
|
253
|
+
return new Uint32Array(this.buffer, this.byteOffset, this.rows * this.cols);
|
|
254
|
+
case DType.F16:
|
|
255
|
+
case DType.BF16:
|
|
256
|
+
return new Uint16Array(this.buffer, this.byteOffset, this.rows * this.cols);
|
|
257
|
+
case DType.I8:
|
|
258
|
+
return new Int8Array(this.buffer, this.byteOffset, this.rows * this.cols);
|
|
259
|
+
default:
|
|
260
|
+
return new Uint8Array(this.buffer, this.byteOffset, this.rows * this.cols);
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
row(index) {
|
|
264
|
+
return new VectorView(this.buffer, this.byteOffset + index * this.rowStride, this.cols, this.dtype);
|
|
265
|
+
}
|
|
266
|
+
};
|
|
267
|
+
var PackedMatrix = class {
|
|
268
|
+
constructor(buffer, width, depth, dtype, byteLength) {
|
|
269
|
+
this._disposed = false;
|
|
270
|
+
this.buffer = buffer;
|
|
271
|
+
this.width = width;
|
|
272
|
+
this.depth = depth;
|
|
273
|
+
this.dtype = dtype;
|
|
274
|
+
this.byteLength = byteLength;
|
|
275
|
+
}
|
|
276
|
+
dispose() {
|
|
277
|
+
this._disposed = true;
|
|
278
|
+
}
|
|
279
|
+
get disposed() {
|
|
280
|
+
return this._disposed;
|
|
281
|
+
}
|
|
282
|
+
toString() {
|
|
283
|
+
return `PackedMatrix(${this.width}\xD7${this.depth}, ${dtypeToString(this.dtype)}, ${this.byteLength} bytes)`;
|
|
284
|
+
}
|
|
285
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
286
|
+
return this.toString();
|
|
287
|
+
}
|
|
288
|
+
};
|
|
289
|
+
function outputDtype(family, input) {
|
|
290
|
+
switch (input) {
|
|
291
|
+
case DType.F64:
|
|
292
|
+
return DType.F64;
|
|
293
|
+
case DType.F32:
|
|
294
|
+
return DType.F64;
|
|
295
|
+
case DType.F16:
|
|
296
|
+
case DType.BF16:
|
|
297
|
+
case DType.E4M3:
|
|
298
|
+
case DType.E5M2:
|
|
299
|
+
case DType.E2M3:
|
|
300
|
+
case DType.E3M2:
|
|
301
|
+
return DType.F32;
|
|
302
|
+
case DType.I8:
|
|
303
|
+
return family === "dots" ? DType.I32 : DType.F32;
|
|
304
|
+
case DType.U8:
|
|
305
|
+
return family === "dots" ? DType.U32 : DType.F32;
|
|
306
|
+
default:
|
|
307
|
+
return DType.F32;
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
var Float16Array = class extends Uint16Array {
|
|
311
|
+
constructor(length, byteOffset, arrayLength) {
|
|
312
|
+
if (typeof length === "number") {
|
|
313
|
+
super(length);
|
|
314
|
+
} else if (ArrayBuffer.isView(length) || length instanceof ArrayBuffer) {
|
|
315
|
+
super(length, byteOffset, arrayLength);
|
|
316
|
+
} else {
|
|
317
|
+
const src = length;
|
|
318
|
+
const arr = new Uint16Array(src.length);
|
|
319
|
+
if (conversionFunctions) {
|
|
320
|
+
for (let i = 0; i < src.length; i++) {
|
|
321
|
+
arr[i] = conversionFunctions.castF32ToF16(src[i]);
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
super(arr);
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
/**
|
|
328
|
+
* @brief Converts the entire f16 array to f32 (Float32Array).
|
|
329
|
+
* @returns Float32Array with decoded values
|
|
330
|
+
*/
|
|
331
|
+
toFloat32Array() {
|
|
332
|
+
if (!conversionFunctions) {
|
|
333
|
+
throw new Error("Conversion functions not initialized");
|
|
334
|
+
}
|
|
335
|
+
const result = new Float32Array(this.length);
|
|
336
|
+
for (let i = 0; i < this.length; i++) {
|
|
337
|
+
result[i] = conversionFunctions.castF16ToF32(this[i]);
|
|
338
|
+
}
|
|
339
|
+
return result;
|
|
340
|
+
}
|
|
341
|
+
/**
|
|
342
|
+
* @brief Gets the f32 value at the specified index.
|
|
343
|
+
* @param index Array index
|
|
344
|
+
* @returns Decoded f32 value
|
|
345
|
+
*/
|
|
346
|
+
getFloat32(index) {
|
|
347
|
+
if (!conversionFunctions) {
|
|
348
|
+
throw new Error("Conversion functions not initialized");
|
|
349
|
+
}
|
|
350
|
+
return conversionFunctions.castF16ToF32(this[index]);
|
|
351
|
+
}
|
|
352
|
+
/**
|
|
353
|
+
* @brief Sets the value at the specified index from an f32 value.
|
|
354
|
+
* @param index Array index
|
|
355
|
+
* @param value f32 value to encode and store
|
|
356
|
+
*/
|
|
357
|
+
setFloat32(index, value) {
|
|
358
|
+
if (!conversionFunctions) {
|
|
359
|
+
throw new Error("Conversion functions not initialized");
|
|
360
|
+
}
|
|
361
|
+
this[index] = conversionFunctions.castF32ToF16(value);
|
|
362
|
+
}
|
|
363
|
+
toString() {
|
|
364
|
+
if (!conversionFunctions)
|
|
365
|
+
return `Float16Array(${this.length})`;
|
|
366
|
+
const limit = Math.min(this.length, 20);
|
|
367
|
+
const parts = [];
|
|
368
|
+
for (let i = 0; i < limit; i++) {
|
|
369
|
+
const f = conversionFunctions.castF16ToF32(this[i]);
|
|
370
|
+
parts.push(`${f} [0x${this[i].toString(16).padStart(4, "0")}]`);
|
|
371
|
+
}
|
|
372
|
+
const suffix = this.length > 20 ? ", ..." : "";
|
|
373
|
+
return `Float16Array(${this.length}) [${parts.join(", ")}${suffix}]`;
|
|
374
|
+
}
|
|
375
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
376
|
+
return this.toString();
|
|
377
|
+
}
|
|
378
|
+
};
|
|
379
|
+
var BFloat16Array = class extends Uint16Array {
|
|
380
|
+
constructor(length, byteOffset, arrayLength) {
|
|
381
|
+
if (typeof length === "number") {
|
|
382
|
+
super(length);
|
|
383
|
+
} else if (ArrayBuffer.isView(length) || length instanceof ArrayBuffer) {
|
|
384
|
+
super(length, byteOffset, arrayLength);
|
|
385
|
+
} else {
|
|
386
|
+
const src = length;
|
|
387
|
+
const arr = new Uint16Array(src.length);
|
|
388
|
+
if (conversionFunctions) {
|
|
389
|
+
for (let i = 0; i < src.length; i++) {
|
|
390
|
+
arr[i] = conversionFunctions.castF32ToBF16(src[i]);
|
|
391
|
+
}
|
|
392
|
+
}
|
|
393
|
+
super(arr);
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
toFloat32Array() {
|
|
397
|
+
if (!conversionFunctions) {
|
|
398
|
+
throw new Error("Conversion functions not initialized");
|
|
399
|
+
}
|
|
400
|
+
const result = new Float32Array(this.length);
|
|
401
|
+
for (let i = 0; i < this.length; i++) {
|
|
402
|
+
result[i] = conversionFunctions.castBF16ToF32(this[i]);
|
|
403
|
+
}
|
|
404
|
+
return result;
|
|
405
|
+
}
|
|
406
|
+
getFloat32(index) {
|
|
407
|
+
if (!conversionFunctions) {
|
|
408
|
+
throw new Error("Conversion functions not initialized");
|
|
409
|
+
}
|
|
410
|
+
return conversionFunctions.castBF16ToF32(this[index]);
|
|
411
|
+
}
|
|
412
|
+
setFloat32(index, value) {
|
|
413
|
+
if (!conversionFunctions) {
|
|
414
|
+
throw new Error("Conversion functions not initialized");
|
|
415
|
+
}
|
|
416
|
+
this[index] = conversionFunctions.castF32ToBF16(value);
|
|
417
|
+
}
|
|
418
|
+
toString() {
|
|
419
|
+
if (!conversionFunctions)
|
|
420
|
+
return `BFloat16Array(${this.length})`;
|
|
421
|
+
const limit = Math.min(this.length, 20);
|
|
422
|
+
const parts = [];
|
|
423
|
+
for (let i = 0; i < limit; i++) {
|
|
424
|
+
const f = conversionFunctions.castBF16ToF32(this[i]);
|
|
425
|
+
parts.push(`${f} [0x${this[i].toString(16).padStart(4, "0")}]`);
|
|
426
|
+
}
|
|
427
|
+
const suffix = this.length > 20 ? ", ..." : "";
|
|
428
|
+
return `BFloat16Array(${this.length}) [${parts.join(", ")}${suffix}]`;
|
|
429
|
+
}
|
|
430
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
431
|
+
return this.toString();
|
|
432
|
+
}
|
|
433
|
+
};
|
|
434
|
+
var E4M3Array = class extends Uint8Array {
|
|
435
|
+
constructor(length, byteOffset, arrayLength) {
|
|
436
|
+
if (typeof length === "number") {
|
|
437
|
+
super(length);
|
|
438
|
+
} else if (ArrayBuffer.isView(length) || length instanceof ArrayBuffer) {
|
|
439
|
+
super(length, byteOffset, arrayLength);
|
|
440
|
+
} else {
|
|
441
|
+
const src = length;
|
|
442
|
+
const arr = new Uint8Array(src.length);
|
|
443
|
+
if (conversionFunctions) {
|
|
444
|
+
for (let i = 0; i < src.length; i++) {
|
|
445
|
+
arr[i] = conversionFunctions.castF32ToE4M3(src[i]);
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
super(arr);
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
toFloat32Array() {
|
|
452
|
+
if (!conversionFunctions) {
|
|
453
|
+
throw new Error("Conversion functions not initialized");
|
|
454
|
+
}
|
|
455
|
+
const result = new Float32Array(this.length);
|
|
456
|
+
for (let i = 0; i < this.length; i++) {
|
|
457
|
+
result[i] = conversionFunctions.castE4M3ToF32(this[i]);
|
|
458
|
+
}
|
|
459
|
+
return result;
|
|
460
|
+
}
|
|
461
|
+
getFloat32(index) {
|
|
462
|
+
if (!conversionFunctions) {
|
|
463
|
+
throw new Error("Conversion functions not initialized");
|
|
464
|
+
}
|
|
465
|
+
return conversionFunctions.castE4M3ToF32(this[index]);
|
|
466
|
+
}
|
|
467
|
+
setFloat32(index, value) {
|
|
468
|
+
if (!conversionFunctions) {
|
|
469
|
+
throw new Error("Conversion functions not initialized");
|
|
470
|
+
}
|
|
471
|
+
this[index] = conversionFunctions.castF32ToE4M3(value);
|
|
472
|
+
}
|
|
473
|
+
toString() {
|
|
474
|
+
if (!conversionFunctions)
|
|
475
|
+
return `E4M3Array(${this.length})`;
|
|
476
|
+
const limit = Math.min(this.length, 20);
|
|
477
|
+
const parts = [];
|
|
478
|
+
for (let i = 0; i < limit; i++) {
|
|
479
|
+
const f = conversionFunctions.castE4M3ToF32(this[i]);
|
|
480
|
+
parts.push(`${f} [0x${this[i].toString(16).padStart(2, "0")}]`);
|
|
481
|
+
}
|
|
482
|
+
const suffix = this.length > 20 ? ", ..." : "";
|
|
483
|
+
return `E4M3Array(${this.length}) [${parts.join(", ")}${suffix}]`;
|
|
484
|
+
}
|
|
485
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
486
|
+
return this.toString();
|
|
487
|
+
}
|
|
488
|
+
};
|
|
489
|
+
var E5M2Array = class extends Uint8Array {
|
|
490
|
+
constructor(length, byteOffset, arrayLength) {
|
|
491
|
+
if (typeof length === "number") {
|
|
492
|
+
super(length);
|
|
493
|
+
} else if (ArrayBuffer.isView(length) || length instanceof ArrayBuffer) {
|
|
494
|
+
super(length, byteOffset, arrayLength);
|
|
495
|
+
} else {
|
|
496
|
+
const src = length;
|
|
497
|
+
const arr = new Uint8Array(src.length);
|
|
498
|
+
if (conversionFunctions) {
|
|
499
|
+
for (let i = 0; i < src.length; i++) {
|
|
500
|
+
arr[i] = conversionFunctions.castF32ToE5M2(src[i]);
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
super(arr);
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
toFloat32Array() {
|
|
507
|
+
if (!conversionFunctions) {
|
|
508
|
+
throw new Error("Conversion functions not initialized");
|
|
509
|
+
}
|
|
510
|
+
const result = new Float32Array(this.length);
|
|
511
|
+
for (let i = 0; i < this.length; i++) {
|
|
512
|
+
result[i] = conversionFunctions.castE5M2ToF32(this[i]);
|
|
513
|
+
}
|
|
514
|
+
return result;
|
|
515
|
+
}
|
|
516
|
+
getFloat32(index) {
|
|
517
|
+
if (!conversionFunctions) {
|
|
518
|
+
throw new Error("Conversion functions not initialized");
|
|
519
|
+
}
|
|
520
|
+
return conversionFunctions.castE5M2ToF32(this[index]);
|
|
521
|
+
}
|
|
522
|
+
setFloat32(index, value) {
|
|
523
|
+
if (!conversionFunctions) {
|
|
524
|
+
throw new Error("Conversion functions not initialized");
|
|
525
|
+
}
|
|
526
|
+
this[index] = conversionFunctions.castF32ToE5M2(value);
|
|
527
|
+
}
|
|
528
|
+
toString() {
|
|
529
|
+
if (!conversionFunctions)
|
|
530
|
+
return `E5M2Array(${this.length})`;
|
|
531
|
+
const limit = Math.min(this.length, 20);
|
|
532
|
+
const parts = [];
|
|
533
|
+
for (let i = 0; i < limit; i++) {
|
|
534
|
+
const f = conversionFunctions.castE5M2ToF32(this[i]);
|
|
535
|
+
parts.push(`${f} [0x${this[i].toString(16).padStart(2, "0")}]`);
|
|
536
|
+
}
|
|
537
|
+
const suffix = this.length > 20 ? ", ..." : "";
|
|
538
|
+
return `E5M2Array(${this.length}) [${parts.join(", ")}${suffix}]`;
|
|
539
|
+
}
|
|
540
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
541
|
+
return this.toString();
|
|
542
|
+
}
|
|
543
|
+
};
|
|
544
|
+
var BinaryArray = class _BinaryArray extends Uint8Array {
|
|
545
|
+
constructor(bitLength) {
|
|
546
|
+
const byteLength = Math.ceil(bitLength / 8);
|
|
547
|
+
super(byteLength);
|
|
548
|
+
this._bitLength = bitLength;
|
|
549
|
+
}
|
|
550
|
+
/**
|
|
551
|
+
* @brief Gets the bit value at the specified index.
|
|
552
|
+
* @param index Bit index (0 to bitLength-1)
|
|
553
|
+
* @returns 0 or 1
|
|
554
|
+
*/
|
|
555
|
+
getBit(index) {
|
|
556
|
+
if (index < 0 || index >= this._bitLength) {
|
|
557
|
+
throw new RangeError("Index out of bounds");
|
|
558
|
+
}
|
|
559
|
+
const byteIndex = index >>> 3;
|
|
560
|
+
const bitIndex = index & 7;
|
|
561
|
+
return this[byteIndex] >>> bitIndex & 1;
|
|
562
|
+
}
|
|
563
|
+
/**
|
|
564
|
+
* @brief Sets the bit value at the specified index.
|
|
565
|
+
* @param index Bit index (0 to bitLength-1)
|
|
566
|
+
* @param value 0 or 1
|
|
567
|
+
*/
|
|
568
|
+
setBit(index, value) {
|
|
569
|
+
if (index < 0 || index >= this._bitLength) {
|
|
570
|
+
throw new RangeError("Index out of bounds");
|
|
571
|
+
}
|
|
572
|
+
const byteIndex = index >>> 3;
|
|
573
|
+
const bitIndex = index & 7;
|
|
574
|
+
if (value) {
|
|
575
|
+
this[byteIndex] |= 1 << bitIndex;
|
|
576
|
+
} else {
|
|
577
|
+
this[byteIndex] &= ~(1 << bitIndex);
|
|
578
|
+
}
|
|
579
|
+
}
|
|
580
|
+
/**
|
|
581
|
+
* @brief Returns the logical bit length of the array.
|
|
582
|
+
*/
|
|
583
|
+
get bitLength() {
|
|
584
|
+
return this._bitLength;
|
|
585
|
+
}
|
|
586
|
+
/**
|
|
587
|
+
* @brief Creates a BinaryArray from a Float32Array (positive values = 1, else 0).
|
|
588
|
+
* @param vector Source floating-point vector
|
|
589
|
+
* @returns Binary array with quantized values
|
|
590
|
+
*/
|
|
591
|
+
static fromFloat32Array(vector) {
|
|
592
|
+
const binary = new _BinaryArray(vector.length);
|
|
593
|
+
for (let i = 0; i < vector.length; i++) {
|
|
594
|
+
if (vector[i] > 0) {
|
|
595
|
+
binary.setBit(i, 1);
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
return binary;
|
|
599
|
+
}
|
|
600
|
+
/**
|
|
601
|
+
* @brief Creates a BinaryArray from a Float64Array (positive values = 1, else 0).
|
|
602
|
+
* @param vector Source floating-point vector
|
|
603
|
+
* @returns Binary array with quantized values
|
|
604
|
+
*/
|
|
605
|
+
static fromFloat64Array(vector) {
|
|
606
|
+
const binary = new _BinaryArray(vector.length);
|
|
607
|
+
for (let i = 0; i < vector.length; i++) {
|
|
608
|
+
if (vector[i] > 0) {
|
|
609
|
+
binary.setBit(i, 1);
|
|
610
|
+
}
|
|
611
|
+
}
|
|
612
|
+
return binary;
|
|
613
|
+
}
|
|
614
|
+
toString() {
|
|
615
|
+
const limit = Math.min(this.length, 20);
|
|
616
|
+
const parts = [];
|
|
617
|
+
for (let i = 0; i < limit; i++) {
|
|
618
|
+
parts.push(`0b${this[i].toString(2).padStart(8, "0")}`);
|
|
619
|
+
}
|
|
620
|
+
const suffix = this.length > 20 ? ", ..." : "";
|
|
621
|
+
return `BinaryArray(${this._bitLength}) [${parts.join(", ")}${suffix}]`;
|
|
622
|
+
}
|
|
623
|
+
[Symbol.for("nodejs.util.inspect.custom")]() {
|
|
624
|
+
return this.toString();
|
|
625
|
+
}
|
|
626
|
+
};
|
|
627
|
+
function isFloat16Array(obj) {
|
|
628
|
+
return obj instanceof Float16Array;
|
|
629
|
+
}
|
|
630
|
+
function isBFloat16Array(obj) {
|
|
631
|
+
return obj instanceof BFloat16Array;
|
|
632
|
+
}
|
|
633
|
+
function isE4M3Array(obj) {
|
|
634
|
+
return obj instanceof E4M3Array;
|
|
635
|
+
}
|
|
636
|
+
function isE5M2Array(obj) {
|
|
637
|
+
return obj instanceof E5M2Array;
|
|
638
|
+
}
|
|
639
|
+
function isBinaryArray(obj) {
|
|
640
|
+
return obj instanceof BinaryArray;
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
// javascript/dist/esm/numkong-wasm.js
|
|
644
|
+
var Module = null;
|
|
645
|
+
var isMemory64 = false;
|
|
646
|
+
var resultPtr = 0;
|
|
647
|
+
var HEAP32;
|
|
648
|
+
var HEAPU8;
|
|
649
|
+
var HEAPU32;
|
|
650
|
+
var HEAPF32;
|
|
651
|
+
var HEAPF64;
|
|
652
|
+
function toWasmPtr(n) {
|
|
653
|
+
return isMemory64 ? BigInt(n) : n;
|
|
654
|
+
}
|
|
655
|
+
function initWasm(wasmModule) {
|
|
656
|
+
Module = wasmModule;
|
|
657
|
+
const buffer = wasmModule.wasmMemory.buffer;
|
|
658
|
+
HEAP32 = new Int32Array(buffer);
|
|
659
|
+
HEAPU8 = new Uint8Array(buffer);
|
|
660
|
+
HEAPU32 = new Uint32Array(buffer);
|
|
661
|
+
HEAPF32 = new Float32Array(buffer);
|
|
662
|
+
HEAPF64 = new Float64Array(buffer);
|
|
663
|
+
try {
|
|
664
|
+
const probe = wasmModule._malloc(8);
|
|
665
|
+
wasmModule._nk_dot_f32(BigInt(probe), BigInt(probe), 0, BigInt(probe));
|
|
666
|
+
isMemory64 = true;
|
|
667
|
+
wasmModule._free(probe);
|
|
668
|
+
} catch {
|
|
669
|
+
isMemory64 = false;
|
|
670
|
+
}
|
|
671
|
+
resultPtr = wasmModule._malloc(8);
|
|
672
|
+
}
|
|
673
|
+
function detectType(arr) {
|
|
674
|
+
if (arr instanceof Float64Array) {
|
|
675
|
+
return { dtype: DType.F64, bytesPerElement: 8, heapView: "HEAPF64", resultType: "f64" };
|
|
676
|
+
} else if (arr instanceof Float32Array) {
|
|
677
|
+
return { dtype: DType.F32, bytesPerElement: 4, heapView: "HEAPF32", resultType: "f64" };
|
|
678
|
+
} else if (arr instanceof Int8Array) {
|
|
679
|
+
return { dtype: DType.I8, bytesPerElement: 1, heapView: "HEAPU8", resultType: "i32" };
|
|
680
|
+
} else if (arr instanceof Uint8Array) {
|
|
681
|
+
return { dtype: DType.U8, bytesPerElement: 1, heapView: "HEAPU8", resultType: "u32" };
|
|
682
|
+
}
|
|
683
|
+
const constructorName = arr.constructor.name;
|
|
684
|
+
if (constructorName === "Float16Array") {
|
|
685
|
+
return { dtype: DType.F16, bytesPerElement: 2, heapView: "HEAPU8", resultType: "f32" };
|
|
686
|
+
} else if (constructorName === "BFloat16Array") {
|
|
687
|
+
return { dtype: DType.BF16, bytesPerElement: 2, heapView: "HEAPU8", resultType: "f32" };
|
|
688
|
+
} else if (constructorName === "E4M3Array") {
|
|
689
|
+
throw new Error("E4M3 not yet supported in WASM backend");
|
|
690
|
+
} else if (constructorName === "E5M2Array") {
|
|
691
|
+
throw new Error("E5M2 not yet supported in WASM backend");
|
|
692
|
+
} else if (constructorName === "BinaryArray") {
|
|
693
|
+
return { dtype: DType.U1, bytesPerElement: 1, heapView: "HEAPU8", resultType: "u32" };
|
|
694
|
+
}
|
|
695
|
+
throw new Error(`Unsupported array type: ${constructorName}`);
|
|
696
|
+
}
|
|
697
|
+
function typeInfoFromDtype(dtype) {
|
|
698
|
+
switch (dtype) {
|
|
699
|
+
case DType.F64:
|
|
700
|
+
return { dtype, bytesPerElement: 8, heapView: "HEAPF64", resultType: "f64" };
|
|
701
|
+
case DType.F32:
|
|
702
|
+
return { dtype, bytesPerElement: 4, heapView: "HEAPF32", resultType: "f64" };
|
|
703
|
+
case DType.F16:
|
|
704
|
+
return { dtype, bytesPerElement: 2, heapView: "HEAPU8", resultType: "f32" };
|
|
705
|
+
case DType.BF16:
|
|
706
|
+
return { dtype, bytesPerElement: 2, heapView: "HEAPU8", resultType: "f32" };
|
|
707
|
+
case DType.I8:
|
|
708
|
+
return { dtype, bytesPerElement: 1, heapView: "HEAPU8", resultType: "i32" };
|
|
709
|
+
case DType.U8:
|
|
710
|
+
return { dtype, bytesPerElement: 1, heapView: "HEAPU8", resultType: "u32" };
|
|
711
|
+
case DType.U1:
|
|
712
|
+
return { dtype, bytesPerElement: 1, heapView: "HEAPU8", resultType: "u32" };
|
|
713
|
+
default:
|
|
714
|
+
throw new Error(`Unsupported dtype: ${dtype}`);
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
function resolveInput(a) {
|
|
718
|
+
if (a instanceof TensorBase) {
|
|
719
|
+
return {
|
|
720
|
+
buffer: a.buffer,
|
|
721
|
+
byteOffset: a.byteOffset,
|
|
722
|
+
length: a.length,
|
|
723
|
+
byteLength: a.byteLength,
|
|
724
|
+
typeInfo: typeInfoFromDtype(a.dtype)
|
|
725
|
+
};
|
|
726
|
+
}
|
|
727
|
+
const typeInfo = detectType(a);
|
|
728
|
+
return {
|
|
729
|
+
buffer: a.buffer,
|
|
730
|
+
byteOffset: a.byteOffset,
|
|
731
|
+
length: a.length,
|
|
732
|
+
byteLength: a.length * typeInfo.bytesPerElement,
|
|
733
|
+
typeInfo
|
|
734
|
+
};
|
|
735
|
+
}
|
|
736
|
+
function allocAndCopyResolved(buffer, byteOffset, byteLength) {
|
|
737
|
+
if (!Module)
|
|
738
|
+
throw new Error("WASM module not initialized");
|
|
739
|
+
const ptr = Module._malloc(byteLength);
|
|
740
|
+
const src = new Uint8Array(buffer, byteOffset, byteLength);
|
|
741
|
+
HEAPU8.set(src, ptr);
|
|
742
|
+
return ptr;
|
|
743
|
+
}
|
|
744
|
+
function readResult(ptr, resultType) {
|
|
745
|
+
if (!Module)
|
|
746
|
+
throw new Error("WASM module not initialized");
|
|
747
|
+
switch (resultType) {
|
|
748
|
+
case "f64":
|
|
749
|
+
return HEAPF64[ptr / 8];
|
|
750
|
+
case "f32":
|
|
751
|
+
return HEAPF32[ptr / 4];
|
|
752
|
+
case "i32":
|
|
753
|
+
return HEAP32[ptr / 4];
|
|
754
|
+
case "u32":
|
|
755
|
+
return HEAPU32[ptr / 4];
|
|
756
|
+
}
|
|
757
|
+
}
|
|
758
|
+
function distance(metric, a, b) {
|
|
759
|
+
if (!Module) {
|
|
760
|
+
throw new Error("WASM module not initialized. Call initWasm() first.");
|
|
761
|
+
}
|
|
762
|
+
const resolvedA = resolveInput(a);
|
|
763
|
+
const resolvedB = resolveInput(b);
|
|
764
|
+
if (resolvedA.length !== resolvedB.length) {
|
|
765
|
+
throw new Error(`Array length mismatch: ${resolvedA.length} !== ${resolvedB.length}`);
|
|
766
|
+
}
|
|
767
|
+
const n = resolvedA.length;
|
|
768
|
+
const isOnHeapA = resolvedA.buffer === Module.wasmMemory.buffer;
|
|
769
|
+
const isOnHeapB = resolvedB.buffer === Module.wasmMemory.buffer;
|
|
770
|
+
const aOff = isOnHeapA ? resolvedA.byteOffset : allocAndCopyResolved(resolvedA.buffer, resolvedA.byteOffset, resolvedA.byteLength);
|
|
771
|
+
const bOff = isOnHeapB ? resolvedB.byteOffset : allocAndCopyResolved(resolvedB.buffer, resolvedB.byteOffset, resolvedB.byteLength);
|
|
772
|
+
try {
|
|
773
|
+
const fnName = `_nk_${metric}_${dtypeToString(resolvedA.typeInfo.dtype)}`;
|
|
774
|
+
const fn = Module[fnName];
|
|
775
|
+
if (!fn || typeof fn !== "function") {
|
|
776
|
+
throw new Error(`Function ${fnName} not available in WASM module`);
|
|
777
|
+
}
|
|
778
|
+
fn(toWasmPtr(aOff), toWasmPtr(bOff), n, toWasmPtr(resultPtr));
|
|
779
|
+
return readResult(resultPtr, resolvedA.typeInfo.resultType);
|
|
780
|
+
} finally {
|
|
781
|
+
if (!isOnHeapA)
|
|
782
|
+
Module._free(aOff);
|
|
783
|
+
if (!isOnHeapB)
|
|
784
|
+
Module._free(bOff);
|
|
785
|
+
}
|
|
786
|
+
}
|
|
787
|
+
function sqeuclidean(a, b) {
|
|
788
|
+
return distance("sqeuclidean", a, b);
|
|
789
|
+
}
|
|
790
|
+
function euclidean(a, b) {
|
|
791
|
+
return distance("euclidean", a, b);
|
|
792
|
+
}
|
|
793
|
+
function angular(a, b) {
|
|
794
|
+
return distance("angular", a, b);
|
|
795
|
+
}
|
|
796
|
+
function dot(a, b) {
|
|
797
|
+
return distance("dot", a, b);
|
|
798
|
+
}
|
|
799
|
+
var inner = dot;
|
|
800
|
+
function hamming(a, b) {
|
|
801
|
+
if (!Module) {
|
|
802
|
+
throw new Error("WASM module not initialized");
|
|
803
|
+
}
|
|
804
|
+
const bufferA = a.buffer, offsetA = a.byteOffset, lengthA = a.length;
|
|
805
|
+
const bufferB = b.buffer, offsetB = b.byteOffset, lengthB = b.length;
|
|
806
|
+
const byteLengthA = a instanceof TensorBase ? a.byteLength : lengthA;
|
|
807
|
+
const byteLengthB = b instanceof TensorBase ? b.byteLength : lengthB;
|
|
808
|
+
if (lengthA !== lengthB) {
|
|
809
|
+
throw new Error(`Array length mismatch: ${lengthA} !== ${lengthB}`);
|
|
810
|
+
}
|
|
811
|
+
const isOnHeapA = bufferA === Module.wasmMemory.buffer;
|
|
812
|
+
const isOnHeapB = bufferB === Module.wasmMemory.buffer;
|
|
813
|
+
const aOff = isOnHeapA ? offsetA : allocAndCopyResolved(bufferA, offsetA, byteLengthA);
|
|
814
|
+
const bOff = isOnHeapB ? offsetB : allocAndCopyResolved(bufferB, offsetB, byteLengthB);
|
|
815
|
+
try {
|
|
816
|
+
const fn = Module._nk_hamming_u1;
|
|
817
|
+
if (!fn || typeof fn !== "function") {
|
|
818
|
+
throw new Error("Function _nk_hamming_u1 not available in WASM module");
|
|
819
|
+
}
|
|
820
|
+
fn(toWasmPtr(aOff), toWasmPtr(bOff), lengthA, toWasmPtr(resultPtr));
|
|
821
|
+
return readResult(resultPtr, "u32");
|
|
822
|
+
} finally {
|
|
823
|
+
if (!isOnHeapA)
|
|
824
|
+
Module._free(aOff);
|
|
825
|
+
if (!isOnHeapB)
|
|
826
|
+
Module._free(bOff);
|
|
827
|
+
}
|
|
828
|
+
}
|
|
829
|
+
function jaccard(a, b) {
|
|
830
|
+
if (!Module) {
|
|
831
|
+
throw new Error("WASM module not initialized");
|
|
832
|
+
}
|
|
833
|
+
const bufferA = a.buffer, offsetA = a.byteOffset, lengthA = a.length;
|
|
834
|
+
const bufferB = b.buffer, offsetB = b.byteOffset, lengthB = b.length;
|
|
835
|
+
const byteLengthA = a instanceof TensorBase ? a.byteLength : lengthA;
|
|
836
|
+
const byteLengthB = b instanceof TensorBase ? b.byteLength : lengthB;
|
|
837
|
+
if (lengthA !== lengthB) {
|
|
838
|
+
throw new Error(`Array length mismatch: ${lengthA} !== ${lengthB}`);
|
|
839
|
+
}
|
|
840
|
+
const isOnHeapA = bufferA === Module.wasmMemory.buffer;
|
|
841
|
+
const isOnHeapB = bufferB === Module.wasmMemory.buffer;
|
|
842
|
+
const aOff = isOnHeapA ? offsetA : allocAndCopyResolved(bufferA, offsetA, byteLengthA);
|
|
843
|
+
const bOff = isOnHeapB ? offsetB : allocAndCopyResolved(bufferB, offsetB, byteLengthB);
|
|
844
|
+
try {
|
|
845
|
+
const fn = Module._nk_jaccard_u1;
|
|
846
|
+
if (!fn || typeof fn !== "function") {
|
|
847
|
+
throw new Error("Function _nk_jaccard_u1 not available in WASM module");
|
|
848
|
+
}
|
|
849
|
+
fn(toWasmPtr(aOff), toWasmPtr(bOff), lengthA, toWasmPtr(resultPtr));
|
|
850
|
+
return readResult(resultPtr, "f32");
|
|
851
|
+
} finally {
|
|
852
|
+
if (!isOnHeapA)
|
|
853
|
+
Module._free(aOff);
|
|
854
|
+
if (!isOnHeapB)
|
|
855
|
+
Module._free(bOff);
|
|
856
|
+
}
|
|
857
|
+
}
|
|
858
|
+
function kullbackleibler(a, b) {
|
|
859
|
+
if (!Module) {
|
|
860
|
+
throw new Error("WASM module not initialized");
|
|
861
|
+
}
|
|
862
|
+
const resolvedA = resolveInput(a);
|
|
863
|
+
const resolvedB = resolveInput(b);
|
|
864
|
+
if (resolvedA.length !== resolvedB.length) {
|
|
865
|
+
throw new Error(`Array length mismatch: ${resolvedA.length} !== ${resolvedB.length}`);
|
|
866
|
+
}
|
|
867
|
+
const n = resolvedA.length;
|
|
868
|
+
const isOnHeapA = resolvedA.buffer === Module.wasmMemory.buffer;
|
|
869
|
+
const isOnHeapB = resolvedB.buffer === Module.wasmMemory.buffer;
|
|
870
|
+
const aOff = isOnHeapA ? resolvedA.byteOffset : allocAndCopyResolved(resolvedA.buffer, resolvedA.byteOffset, resolvedA.byteLength);
|
|
871
|
+
const bOff = isOnHeapB ? resolvedB.byteOffset : allocAndCopyResolved(resolvedB.buffer, resolvedB.byteOffset, resolvedB.byteLength);
|
|
872
|
+
try {
|
|
873
|
+
const fnName = `_nk_kld_${dtypeToString(resolvedA.typeInfo.dtype)}`;
|
|
874
|
+
const fn = Module[fnName];
|
|
875
|
+
if (!fn || typeof fn !== "function") {
|
|
876
|
+
throw new Error(`Function ${fnName} not available in WASM module`);
|
|
877
|
+
}
|
|
878
|
+
fn(toWasmPtr(aOff), toWasmPtr(bOff), n, toWasmPtr(resultPtr));
|
|
879
|
+
return readResult(resultPtr, resolvedA.typeInfo.resultType);
|
|
880
|
+
} finally {
|
|
881
|
+
if (!isOnHeapA)
|
|
882
|
+
Module._free(aOff);
|
|
883
|
+
if (!isOnHeapB)
|
|
884
|
+
Module._free(bOff);
|
|
885
|
+
}
|
|
886
|
+
}
|
|
887
|
+
function jensenshannon(a, b) {
|
|
888
|
+
if (!Module) {
|
|
889
|
+
throw new Error("WASM module not initialized");
|
|
890
|
+
}
|
|
891
|
+
const resolvedA = resolveInput(a);
|
|
892
|
+
const resolvedB = resolveInput(b);
|
|
893
|
+
if (resolvedA.length !== resolvedB.length) {
|
|
894
|
+
throw new Error(`Array length mismatch: ${resolvedA.length} !== ${resolvedB.length}`);
|
|
895
|
+
}
|
|
896
|
+
const n = resolvedA.length;
|
|
897
|
+
const isOnHeapA = resolvedA.buffer === Module.wasmMemory.buffer;
|
|
898
|
+
const isOnHeapB = resolvedB.buffer === Module.wasmMemory.buffer;
|
|
899
|
+
const aOff = isOnHeapA ? resolvedA.byteOffset : allocAndCopyResolved(resolvedA.buffer, resolvedA.byteOffset, resolvedA.byteLength);
|
|
900
|
+
const bOff = isOnHeapB ? resolvedB.byteOffset : allocAndCopyResolved(resolvedB.buffer, resolvedB.byteOffset, resolvedB.byteLength);
|
|
901
|
+
try {
|
|
902
|
+
const fnName = `_nk_jsd_${dtypeToString(resolvedA.typeInfo.dtype)}`;
|
|
903
|
+
const fn = Module[fnName];
|
|
904
|
+
if (!fn || typeof fn !== "function") {
|
|
905
|
+
throw new Error(`Function ${fnName} not available in WASM module`);
|
|
906
|
+
}
|
|
907
|
+
fn(toWasmPtr(aOff), toWasmPtr(bOff), n, toWasmPtr(resultPtr));
|
|
908
|
+
return readResult(resultPtr, resolvedA.typeInfo.resultType);
|
|
909
|
+
} finally {
|
|
910
|
+
if (!isOnHeapA)
|
|
911
|
+
Module._free(aOff);
|
|
912
|
+
if (!isOnHeapB)
|
|
913
|
+
Module._free(bOff);
|
|
914
|
+
}
|
|
915
|
+
}
|
|
916
|
+
function getCapabilities() {
|
|
917
|
+
if (!Module) {
|
|
918
|
+
throw new Error("WASM module not initialized");
|
|
919
|
+
}
|
|
920
|
+
const caps = Module._nk_capabilities();
|
|
921
|
+
return typeof caps === "bigint" ? caps : BigInt(caps);
|
|
922
|
+
}
|
|
923
|
+
function hasCapability(cap) {
|
|
924
|
+
return (getCapabilities() & cap) !== 0n;
|
|
925
|
+
}
|
|
926
|
+
var packedRegistry = null;
|
|
927
|
+
var WasmPackedMatrix = class extends PackedMatrix {
|
|
928
|
+
constructor(heapPointer, byteLength, width, depth, dtype) {
|
|
929
|
+
const buffer = new ArrayBuffer(byteLength);
|
|
930
|
+
new Uint8Array(buffer).set(HEAPU8.subarray(heapPointer, heapPointer + byteLength));
|
|
931
|
+
super(buffer, width, depth, dtype, byteLength);
|
|
932
|
+
this._wasmDisposed = false;
|
|
933
|
+
this._heapPointer = heapPointer;
|
|
934
|
+
if (!packedRegistry && typeof FinalizationRegistry !== "undefined") {
|
|
935
|
+
packedRegistry = new FinalizationRegistry((ptr) => {
|
|
936
|
+
if (Module)
|
|
937
|
+
Module._free(ptr);
|
|
938
|
+
});
|
|
939
|
+
}
|
|
940
|
+
if (packedRegistry) {
|
|
941
|
+
packedRegistry.register(this, heapPointer);
|
|
942
|
+
}
|
|
943
|
+
}
|
|
944
|
+
get heapPointer() {
|
|
945
|
+
return this._heapPointer;
|
|
946
|
+
}
|
|
947
|
+
dispose() {
|
|
948
|
+
if (!this._wasmDisposed && Module) {
|
|
949
|
+
Module._free(this._heapPointer);
|
|
950
|
+
this._wasmDisposed = true;
|
|
951
|
+
}
|
|
952
|
+
super.dispose();
|
|
953
|
+
}
|
|
954
|
+
};
|
|
955
|
+
function allocAndCopyMatrix(matrix) {
|
|
956
|
+
const byteLength = matrix.rows * matrix.rowStride;
|
|
957
|
+
return allocAndCopyResolved(matrix.buffer, matrix.byteOffset, byteLength);
|
|
958
|
+
}
|
|
959
|
+
function dotsPackedSize(width, depth, dtype) {
|
|
960
|
+
if (!Module)
|
|
961
|
+
throw new Error("WASM module not initialized");
|
|
962
|
+
const fnName = `_nk_dots_packed_size_${dtypeToString(dtype)}`;
|
|
963
|
+
const fn = Module[fnName];
|
|
964
|
+
if (!fn || typeof fn !== "function") {
|
|
965
|
+
throw new Error(`Function ${fnName} not available in WASM module`);
|
|
966
|
+
}
|
|
967
|
+
return fn(width, depth);
|
|
968
|
+
}
|
|
969
|
+
function dotsPack(matrix) {
|
|
970
|
+
if (!Module)
|
|
971
|
+
throw new Error("WASM module not initialized");
|
|
972
|
+
const dtypeStr = dtypeToString(matrix.dtype);
|
|
973
|
+
const sizeFnName = `_nk_dots_packed_size_${dtypeStr}`;
|
|
974
|
+
const packFnName = `_nk_dots_pack_${dtypeStr}`;
|
|
975
|
+
const sizeFn = Module[sizeFnName];
|
|
976
|
+
const packFn = Module[packFnName];
|
|
977
|
+
if (!sizeFn || !packFn) {
|
|
978
|
+
throw new Error(`Pack functions not available for dtype ${dtypeStr}`);
|
|
979
|
+
}
|
|
980
|
+
const packedByteCount = sizeFn(matrix.rows, matrix.cols);
|
|
981
|
+
const packedPtr = Module._malloc(packedByteCount);
|
|
982
|
+
const matrixPtr = allocAndCopyMatrix(matrix);
|
|
983
|
+
try {
|
|
984
|
+
packFn(toWasmPtr(matrixPtr), matrix.rows, matrix.cols, matrix.rowStride, toWasmPtr(packedPtr));
|
|
985
|
+
} finally {
|
|
986
|
+
Module._free(matrixPtr);
|
|
987
|
+
}
|
|
988
|
+
return new WasmPackedMatrix(packedPtr, packedByteCount, matrix.rows, matrix.cols, matrix.dtype);
|
|
989
|
+
}
|
|
990
|
+
function wasmPackedOperation(metricPrefix, family, a, packed, out) {
|
|
991
|
+
if (!Module)
|
|
992
|
+
throw new Error("WASM module not initialized");
|
|
993
|
+
if (a.cols !== packed.depth) {
|
|
994
|
+
throw new Error(`Matrix cols (${a.cols}) must match packed depth (${packed.depth})`);
|
|
995
|
+
}
|
|
996
|
+
const outDtype = outputDtype(family, a.dtype);
|
|
997
|
+
if (!out) {
|
|
998
|
+
out = new Matrix(a.rows, packed.width, outDtype);
|
|
999
|
+
}
|
|
1000
|
+
const dtypeStr = dtypeToString(a.dtype);
|
|
1001
|
+
const fnName = `_nk_${metricPrefix}_${dtypeStr}`;
|
|
1002
|
+
const fn = Module[fnName];
|
|
1003
|
+
if (!fn || typeof fn !== "function") {
|
|
1004
|
+
throw new Error(`Function ${fnName} not available in WASM module`);
|
|
1005
|
+
}
|
|
1006
|
+
const outBpe = out.bytesPerElement;
|
|
1007
|
+
const resultByteLength = out.rows * out.cols * outBpe;
|
|
1008
|
+
const aPtr = allocAndCopyMatrix(a);
|
|
1009
|
+
const resultPtr2 = Module._malloc(resultByteLength);
|
|
1010
|
+
let packedPtr;
|
|
1011
|
+
let packedAllocated = false;
|
|
1012
|
+
if (packed instanceof WasmPackedMatrix) {
|
|
1013
|
+
packedPtr = allocAndCopyResolved(packed.buffer, 0, packed.byteLength);
|
|
1014
|
+
packedAllocated = true;
|
|
1015
|
+
} else {
|
|
1016
|
+
packedPtr = allocAndCopyResolved(packed.buffer, 0, packed.byteLength);
|
|
1017
|
+
packedAllocated = true;
|
|
1018
|
+
}
|
|
1019
|
+
try {
|
|
1020
|
+
fn(toWasmPtr(aPtr), toWasmPtr(packedPtr), toWasmPtr(resultPtr2), a.rows, packed.width, a.cols, a.rowStride, out.rowStride);
|
|
1021
|
+
const outArray = new Uint8Array(out.buffer, out.byteOffset, resultByteLength);
|
|
1022
|
+
outArray.set(HEAPU8.subarray(resultPtr2, resultPtr2 + resultByteLength));
|
|
1023
|
+
} finally {
|
|
1024
|
+
Module._free(aPtr);
|
|
1025
|
+
Module._free(resultPtr2);
|
|
1026
|
+
if (packedAllocated)
|
|
1027
|
+
Module._free(packedPtr);
|
|
1028
|
+
}
|
|
1029
|
+
return out;
|
|
1030
|
+
}
|
|
1031
|
+
function wasmSymmetricOperation(metricPrefix, family, vectors, out, rowStart = 0, rowCount) {
|
|
1032
|
+
if (!Module)
|
|
1033
|
+
throw new Error("WASM module not initialized");
|
|
1034
|
+
const count = rowCount ?? vectors.rows - rowStart;
|
|
1035
|
+
const outDtype = outputDtype(family, vectors.dtype);
|
|
1036
|
+
if (!out) {
|
|
1037
|
+
out = new Matrix(vectors.rows, vectors.rows, outDtype);
|
|
1038
|
+
}
|
|
1039
|
+
const dtypeStr = dtypeToString(vectors.dtype);
|
|
1040
|
+
const fnName = `_nk_${metricPrefix}_${dtypeStr}`;
|
|
1041
|
+
const fn = Module[fnName];
|
|
1042
|
+
if (!fn || typeof fn !== "function") {
|
|
1043
|
+
throw new Error(`Function ${fnName} not available in WASM module`);
|
|
1044
|
+
}
|
|
1045
|
+
const resultByteLength = out.rows * out.cols * out.bytesPerElement;
|
|
1046
|
+
const vectorsPtr = allocAndCopyMatrix(vectors);
|
|
1047
|
+
const resultPtr2 = Module._malloc(resultByteLength);
|
|
1048
|
+
try {
|
|
1049
|
+
fn(toWasmPtr(vectorsPtr), vectors.rows, vectors.cols, vectors.rowStride, toWasmPtr(resultPtr2), out.rowStride, rowStart, count);
|
|
1050
|
+
const outArray = new Uint8Array(out.buffer, out.byteOffset, resultByteLength);
|
|
1051
|
+
outArray.set(HEAPU8.subarray(resultPtr2, resultPtr2 + resultByteLength));
|
|
1052
|
+
} finally {
|
|
1053
|
+
Module._free(vectorsPtr);
|
|
1054
|
+
Module._free(resultPtr2);
|
|
1055
|
+
}
|
|
1056
|
+
return out;
|
|
1057
|
+
}
|
|
1058
|
+
function dotsPacked(a, packed, out) {
|
|
1059
|
+
return wasmPackedOperation("dots_packed", "dots", a, packed, out);
|
|
1060
|
+
}
|
|
1061
|
+
function angularsPacked(a, packed, out) {
|
|
1062
|
+
return wasmPackedOperation("angulars_packed", "angulars", a, packed, out);
|
|
1063
|
+
}
|
|
1064
|
+
function euclideansPacked(a, packed, out) {
|
|
1065
|
+
return wasmPackedOperation("euclideans_packed", "euclideans", a, packed, out);
|
|
1066
|
+
}
|
|
1067
|
+
function dotsSymmetric(vectors, out, options) {
|
|
1068
|
+
return wasmSymmetricOperation("dots_symmetric", "dots", vectors, out, options?.rowStart ?? 0, options?.rowCount);
|
|
1069
|
+
}
|
|
1070
|
+
function angularsSymmetric(vectors, out, options) {
|
|
1071
|
+
return wasmSymmetricOperation("angulars_symmetric", "angulars", vectors, out, options?.rowStart ?? 0, options?.rowCount);
|
|
1072
|
+
}
|
|
1073
|
+
function euclideansSymmetric(vectors, out, options) {
|
|
1074
|
+
return wasmSymmetricOperation("euclideans_symmetric", "euclideans", vectors, out, options?.rowStart ?? 0, options?.rowCount);
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
// javascript/dist/esm/numkong-browser.js
|
|
1078
|
+
var glueUrl = new URL("./numkong-emscripten.js", import.meta.url);
|
|
1079
|
+
var { default: NumKongModule } = await import(glueUrl.href);
|
|
1080
|
+
var wasmInstance = await NumKongModule({
|
|
1081
|
+
locateFile: (path) => new URL(path, glueUrl).href
|
|
1082
|
+
});
|
|
1083
|
+
initWasm(wasmInstance);
|
|
1084
|
+
export {
|
|
1085
|
+
BFloat16Array,
|
|
1086
|
+
BinaryArray,
|
|
1087
|
+
DType,
|
|
1088
|
+
E4M3Array,
|
|
1089
|
+
E5M2Array,
|
|
1090
|
+
Float16Array,
|
|
1091
|
+
Matrix,
|
|
1092
|
+
MatrixBase,
|
|
1093
|
+
PackedMatrix,
|
|
1094
|
+
TensorBase,
|
|
1095
|
+
Vector,
|
|
1096
|
+
VectorBase,
|
|
1097
|
+
VectorView,
|
|
1098
|
+
angular,
|
|
1099
|
+
angularsPacked,
|
|
1100
|
+
angularsSymmetric,
|
|
1101
|
+
dot,
|
|
1102
|
+
dotsPack,
|
|
1103
|
+
dotsPacked,
|
|
1104
|
+
dotsPackedSize,
|
|
1105
|
+
dotsSymmetric,
|
|
1106
|
+
dtypeToString,
|
|
1107
|
+
euclidean,
|
|
1108
|
+
euclideansPacked,
|
|
1109
|
+
euclideansSymmetric,
|
|
1110
|
+
getCapabilities,
|
|
1111
|
+
hamming,
|
|
1112
|
+
hasCapability,
|
|
1113
|
+
inner,
|
|
1114
|
+
isBFloat16Array,
|
|
1115
|
+
isBinaryArray,
|
|
1116
|
+
isE4M3Array,
|
|
1117
|
+
isE5M2Array,
|
|
1118
|
+
isFloat16Array,
|
|
1119
|
+
jaccard,
|
|
1120
|
+
jensenshannon,
|
|
1121
|
+
kullbackleibler,
|
|
1122
|
+
outputDtype,
|
|
1123
|
+
sqeuclidean
|
|
1124
|
+
};
|