@genai-fi/nanogpt 0.10.3 → 0.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +10 -5
- package/dist/Generator.js +1789 -1765
- package/dist/{RealDiv-KAPDe8zB.js → RealDiv-Ds-jvL09.js} +22 -22
- package/dist/{Reshape-BYkmUnAv.js → Reshape-Cd6e-Otn.js} +1 -1
- package/dist/{Reshape-Zt6eb7yh.js → Reshape-Ct266DEk.js} +9 -9
- package/dist/TeachableLLM.d.ts +4 -3
- package/dist/TeachableLLM.js +14 -14
- package/dist/Trainer.d.ts +2 -2
- package/dist/Trainer.js +6 -6
- package/dist/{axis_util-BaG7mf5A.js → axis_util-DofAuy0p.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-RCe-rHaj.js → backend_util-C7NWHpv7.js} +7 -7
- package/dist/{backend_webgpu-DE3ACOLx.js → backend_webgpu-B0Vls736.js} +10 -10
- package/dist/{broadcast_to-B3eYlZm7.js → broadcast_to-DDaNMbX7.js} +2 -2
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/matMulGelu.js +2 -2
- package/dist/checks/normRMS.js +4 -4
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.js +2 -2
- package/dist/checks/qkv.js +4 -4
- package/dist/checks/rope.js +2 -2
- package/dist/{clip_by_value-BnO7-a88.js → clip_by_value-Dn5tzexi.js} +4 -4
- package/dist/complex-DClmWqJt.js +11 -0
- package/dist/{concat-BV8bt5H-.js → concat-C6X3AAlQ.js} +1 -1
- package/dist/{concat_util-DpW8mL_l.js → concat_util-CHsJFZJJ.js} +1 -1
- package/dist/{dataset-BcwmTGYc.js → dataset-DcjWqUVQ.js} +7 -7
- package/dist/{dropout-BcvN9JYi.js → dropout-OxuaJz6z.js} +11 -11
- package/dist/{expand_dims-DT4tEPwA.js → expand_dims-BzfJK2uc.js} +3 -3
- package/dist/{exports_initializers-Hta_rEnm.js → exports_initializers-eS9QJ6ut.js} +1 -1
- package/dist/{floor-D5QdR_le.js → floor-DIb-lN_u.js} +1 -1
- package/dist/gather-BcO5UQNJ.js +9 -0
- package/dist/{gelu-CjNPL4OH.js → gelu-DqTbCx5x.js} +1 -1
- package/dist/{gpgpu_math-DAOmgtXR.js → gpgpu_math-CJcbnKPC.js} +2 -2
- package/dist/{index-DOvlwCh-.js → index-D0RBWjq8.js} +52 -52
- package/dist/{index-BwexR4lA.js → index-Dj5TkmPY.js} +89 -89
- package/dist/{kernel_funcs_utils-CCzYdUZg.js → kernel_funcs_utils-CSaumNDs.js} +11 -11
- package/dist/layers/BaseLayer.js +2 -2
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +4 -4
- package/dist/layers/PositionEmbedding.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +6 -6
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +17 -17
- package/dist/log_sum_exp-VLZgbFAH.js +39 -0
- package/dist/main.d.ts +1 -1
- package/dist/main.js +9 -9
- package/dist/{matMul16-BWRSOCWB.js → matMul16-cDxwemKj.js} +7 -7
- package/dist/{matMulGelu-CzfgT6Wq.js → matMulGelu-B2s_80-H.js} +18 -18
- package/dist/{mat_mul-SjpJRLyL.js → mat_mul-DxpNTCRz.js} +3 -3
- package/dist/{mod-AnXEvvpo.js → mod-PrOKlFxH.js} +1 -1
- package/dist/models/NanoGPTV1.js +2 -2
- package/dist/models/model.js +9 -9
- package/dist/{ones-D2rT0xk2.js → ones-BX_wEgzB.js} +3 -3
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/add16.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/concat16.js +2 -2
- package/dist/ops/cpu/adamAdjust.js +6 -6
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +5 -5
- package/dist/ops/cpu/attentionMask.js +10 -10
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +6 -6
- package/dist/ops/cpu/gelu.js +9 -9
- package/dist/ops/cpu/matMul16.js +2 -2
- package/dist/ops/cpu/matMulGelu.js +3 -3
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +3 -3
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +9 -9
- package/dist/ops/cpu/scatterSub.js +11 -11
- package/dist/ops/dot16.js +2 -2
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/add16.js +4 -4
- package/dist/ops/grads/attentionMask.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMul16.js +3 -3
- package/dist/ops/grads/matMulGelu.js +3 -3
- package/dist/ops/grads/normRMS.js +7 -7
- package/dist/ops/grads/pack16.js +3 -3
- package/dist/ops/grads/qkv.js +6 -6
- package/dist/ops/grads/rope.js +2 -2
- package/dist/ops/grads/softmax16.js +1 -1
- package/dist/ops/grads/unpack16.js +2 -2
- package/dist/ops/matMul16.js +3 -3
- package/dist/ops/matMulGelu.js +2 -2
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mul16.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/pack16.js +2 -2
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/reshape16.js +6 -6
- package/dist/ops/rope.js +2 -2
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.js +2 -2
- package/dist/ops/softmax16.js +1 -1
- package/dist/ops/sub16.js +1 -1
- package/dist/ops/sum16.js +2 -2
- package/dist/ops/transpose16.js +3 -3
- package/dist/ops/unpack16.js +2 -2
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +4 -4
- package/dist/ops/webgl/fusedSoftmax.js +6 -6
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMul16.js +11 -11
- package/dist/ops/webgl/matMulGelu.js +4 -4
- package/dist/ops/webgl/matMulMul.js +7 -7
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +7 -7
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +4 -4
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/add16.js +1 -1
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +5 -5
- package/dist/ops/webgpu/attentionMask32_program.js +2 -2
- package/dist/ops/webgpu/concat16.js +5 -5
- package/dist/ops/webgpu/gatherSub.js +5 -5
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/matMul16.js +18 -18
- package/dist/ops/webgpu/matMul16_program.js +2 -2
- package/dist/ops/webgpu/mul16.js +4 -4
- package/dist/ops/webgpu/normRMS.js +6 -6
- package/dist/ops/webgpu/normRMSGrad.js +4 -4
- package/dist/ops/webgpu/pack16.js +1 -1
- package/dist/ops/webgpu/pack16_program.js +2 -2
- package/dist/ops/webgpu/qkv.js +6 -6
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/slice16.js +4 -4
- package/dist/ops/webgpu/softmax16.js +2 -2
- package/dist/ops/webgpu/softmax16_program.js +2 -2
- package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
- package/dist/ops/webgpu/softmax16grad.js +1 -1
- package/dist/ops/webgpu/sub16.js +4 -4
- package/dist/ops/webgpu/sum16.js +6 -6
- package/dist/ops/webgpu/transpose16.js +2 -2
- package/dist/ops/webgpu/transpose16_program.js +2 -2
- package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
- package/dist/ops/webgpu/unpack16.js +3 -3
- package/dist/ops/webgpu/utils/binary_op.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/{ops-B5yanEdW.js → ops-FJapAPfm.js} +56 -56
- package/dist/{pack16-nQ6JaLo-.js → pack16-k4jq6aMX.js} +7 -7
- package/dist/patches/webgpu_backend.js +7 -7
- package/dist/patches/webgpu_base.js +1 -1
- package/dist/patches/webgpu_program.js +8 -8
- package/dist/{random_width-or-CEftb.js → random_width-UGQn4OWb.js} +33 -33
- package/dist/range-CuGvVN2c.js +10 -0
- package/dist/{relu-CP0ZcxWO.js → relu-Cf80uA2p.js} +1 -1
- package/dist/{reshape-ByE68wS9.js → reshape-CkjKPPqB.js} +1 -1
- package/dist/{resize_nearest_neighbor-B19mCEg2.js → resize_nearest_neighbor-DB8k9KN_.js} +43 -43
- package/dist/{rope-Ir4mTyD1.js → rope-BmZmp9uP.js} +1 -1
- package/dist/{scatter_nd_util-lvSiX8q4.js → scatter_nd_util-BY22Cc-C.js} +1 -1
- package/dist/{selu_util-kbhpTdYD.js → selu_util-BuLbmbrl.js} +5 -5
- package/dist/{shared-DT1TkE6w.js → shared-B7USJZgw.js} +1 -1
- package/dist/{shared-dntlHIDQ.js → shared-BQboIImQ.js} +86 -86
- package/dist/{slice-BfEGSH82.js → slice-Aqy7KbJh.js} +3 -3
- package/dist/{slice_util-uTKwiEpW.js → slice_util-D8CQRenR.js} +7 -7
- package/dist/{softmax-CA5jFsLR.js → softmax-faLoUZVT.js} +1 -1
- package/dist/{split-CVLc0w--.js → split-BNz5jcGc.js} +3 -3
- package/dist/{squeeze-C7Z2srUo.js → squeeze--YMgaAAf.js} +2 -2
- package/dist/{stack-Cf4n9h0N.js → stack-WJK22CFn.js} +1 -1
- package/dist/{step-CINUs5QB.js → step-dXR33iOg.js} +32 -32
- package/dist/sum-BdplSvq_.js +11 -0
- package/dist/tensor-BQqrDvpx.js +8 -0
- package/dist/tensor1d-LxP9asMm.js +11 -0
- package/dist/{tensor2d-Bs9wZRc7.js → tensor2d-BN1sSfQO.js} +3 -3
- package/dist/{tensor4d-BARPdTaS.js → tensor4d-DVwr7pLF.js} +1 -1
- package/dist/{tfjs_backend-y1cvNhLA.js → tfjs_backend-Vi4JfLzT.js} +28 -28
- package/dist/{tile-mbfagpsB.js → tile-CvN_LyVr.js} +4 -4
- package/dist/tokeniser/BaseTokeniser.d.ts +27 -0
- package/dist/tokeniser/BaseTokeniser.js +94 -0
- package/dist/tokeniser/CharTokeniser.d.ts +4 -3
- package/dist/tokeniser/CharTokeniser.js +46 -32
- package/dist/tokeniser/bpe.d.ts +4 -3
- package/dist/tokeniser/bpe.js +60 -45
- package/dist/tokeniser/type.d.ts +11 -0
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.d.ts +2 -2
- package/dist/training/DatasetBuilder.js +32 -36
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.d.ts +3 -3
- package/dist/training/Trainer.js +2 -2
- package/dist/training/sparseCrossEntropy.js +3 -3
- package/dist/{transpose-ClWiBS_b.js → transpose-JawVKyZy.js} +5 -5
- package/dist/{unsorted_segment_sum-BDDhB_E6.js → unsorted_segment_sum-LAbmE9G4.js} +78 -78
- package/dist/utilities/dummy.js +3 -3
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.js +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.js +5 -5
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-WawDEaAb.js → variable-DQ9yYgEU.js} +1 -1
- package/dist/{webgpu_program-DuOXPQol.js → webgpu_program-CAE4RICo.js} +3 -3
- package/dist/{webgpu_util-RxEF33Rj.js → webgpu_util-BdovYhXr.js} +1 -1
- package/dist/{zeros-KnWaWf-X.js → zeros-DeiE2zTa.js} +2 -2
- package/dist/{zeros_like-DvE73F4e.js → zeros_like-BAz3iKru.js} +77 -77
- package/package.json +1 -1
- package/dist/complex-DjxcVmoX.js +0 -11
- package/dist/gather-D3JcZUaI.js +0 -9
- package/dist/log_sum_exp-ngO0-4pK.js +0 -39
- package/dist/range-BklejeeW.js +0 -10
- package/dist/sum-DWAtNGez.js +0 -11
- package/dist/tensor-DJoc7gJU.js +0 -8
- package/dist/tensor1d-D11P_7Dp.js +0 -11
|
@@ -1,30 +1,30 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { k as C, c as g, m as D } from "./step-
|
|
3
|
-
import { r as b } from "./reshape-
|
|
4
|
-
import { m as pn, a as hn, e as w } from "./log_sum_exp-
|
|
5
|
-
import { s as K } from "./sum-
|
|
1
|
+
import { q as h, u as c, E as d, bo as T, bp as q, bq as H, y as l, br as P, N as _, bs as y, bt as I, bu as W, bv as B, bw as A, bx as G, by as L, bz as O, bA as z, bB as F, D as M, $ as j, bC as J, bD as Q, bE as U, a2 as V, c as N, m as X, bF as Y, bG as Z, bH as R, bI as nn, bJ as tn, bK as sn, bL as en, bM as rn, bN as on, bO as an, bP as un, aG as cn, bQ as ln } from "./index-D0RBWjq8.js";
|
|
2
|
+
import { k as C, c as g, m as D } from "./step-dXR33iOg.js";
|
|
3
|
+
import { r as b } from "./reshape-CkjKPPqB.js";
|
|
4
|
+
import { m as pn, a as hn, e as w } from "./log_sum_exp-VLZgbFAH.js";
|
|
5
|
+
import { s as K } from "./sum-BdplSvq_.js";
|
|
6
6
|
function fn(s, n = null, t = !1) {
|
|
7
|
-
const
|
|
8
|
-
return d.runKernel(T,
|
|
7
|
+
const u = { x: c(s, "x", "all", "bool") }, o = { axis: n, keepDims: t };
|
|
8
|
+
return d.runKernel(T, u, o);
|
|
9
9
|
}
|
|
10
10
|
const nt = /* @__PURE__ */ h({ all_: fn });
|
|
11
11
|
function dn(s, n = null, t = !1) {
|
|
12
|
-
const
|
|
13
|
-
return d.runKernel(q,
|
|
12
|
+
const u = { x: c(s, "x", "any", "bool") }, o = { axis: n, keepDims: t };
|
|
13
|
+
return d.runKernel(q, u, o);
|
|
14
14
|
}
|
|
15
15
|
const tt = /* @__PURE__ */ h({ any_: dn });
|
|
16
16
|
function mn(s, n = 0) {
|
|
17
|
-
const e = { x: c(s, "x", "argMax") },
|
|
18
|
-
return d.runKernel(H, e,
|
|
17
|
+
const e = { x: c(s, "x", "argMax") }, u = { axis: n };
|
|
18
|
+
return d.runKernel(H, e, u);
|
|
19
19
|
}
|
|
20
20
|
const st = /* @__PURE__ */ h({ argMax_: mn });
|
|
21
|
-
function $n(s, n, t, e,
|
|
21
|
+
function $n(s, n, t, e, u) {
|
|
22
22
|
const o = c(s, "x", "avgPool", "float32"), p = 1;
|
|
23
23
|
l(C(t, p), () => `Error in avgPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`);
|
|
24
24
|
let r = o, a = !1;
|
|
25
|
-
o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${r.rank}.`), g("avgPool", e,
|
|
26
|
-
const
|
|
27
|
-
let f = d.runKernel(P,
|
|
25
|
+
o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${r.rank}.`), g("avgPool", e, u);
|
|
26
|
+
const i = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: u };
|
|
27
|
+
let f = d.runKernel(P, i, m);
|
|
28
28
|
return f = _(f, o.dtype), a ? b(f, [f.shape[1], f.shape[2], f.shape[3]]) : f;
|
|
29
29
|
}
|
|
30
30
|
const et = /* @__PURE__ */ h({ avgPool_: $n });
|
|
@@ -34,66 +34,66 @@ function bn(s) {
|
|
|
34
34
|
}
|
|
35
35
|
const rt = /* @__PURE__ */ h({ tanh_: bn });
|
|
36
36
|
function xn(s, n, t) {
|
|
37
|
-
const e = c(s, "x", "batchToSpaceND"),
|
|
38
|
-
l(e.rank >= 1 + n.length, () => `input rank is ${e.rank} but should be > than blockShape.length ${n.length}`), l(t.length === n.length, () => `crops.length is ${t.length} but should be equal to blockShape.length ${n.length}`), l(e.shape[0] %
|
|
37
|
+
const e = c(s, "x", "batchToSpaceND"), u = n.reduce((r, a) => r * a);
|
|
38
|
+
l(e.rank >= 1 + n.length, () => `input rank is ${e.rank} but should be > than blockShape.length ${n.length}`), l(t.length === n.length, () => `crops.length is ${t.length} but should be equal to blockShape.length ${n.length}`), l(e.shape[0] % u === 0, () => `input tensor batch is ${e.shape[0]} but is not divisible by the product of the elements of blockShape ${n.join(" * ")} === ${u}`);
|
|
39
39
|
const o = { x: e }, p = { blockShape: n, crops: t };
|
|
40
|
-
return d.runKernel(
|
|
40
|
+
return d.runKernel(I, o, p);
|
|
41
41
|
}
|
|
42
42
|
const ot = /* @__PURE__ */ h({ batchToSpaceND_: xn });
|
|
43
43
|
function kn(s) {
|
|
44
44
|
let n;
|
|
45
45
|
return s.rank === 0 || s.rank === 1 ? n = b(s, [1, 1, 1, s.size]) : s.rank === 2 ? n = b(s, [1, 1, s.shape[0], s.shape[1]]) : s.rank === 3 ? n = b(s, [1, s.shape[0], s.shape[1], s.shape[2]]) : n = s, n;
|
|
46
46
|
}
|
|
47
|
-
function vn(s, n, t, e,
|
|
47
|
+
function vn(s, n, t, e, u, o) {
|
|
48
48
|
o == null && (o = 1e-3);
|
|
49
49
|
const p = c(s, "x", "batchNorm"), r = c(n, "mean", "batchNorm"), a = c(t, "variance", "batchNorm");
|
|
50
|
-
let
|
|
51
|
-
|
|
50
|
+
let i;
|
|
51
|
+
u != null && (i = c(u, "scale", "batchNorm"));
|
|
52
52
|
let m;
|
|
53
|
-
e != null && (m = c(e, "offset", "batchNorm")), l(r.rank === a.rank, () => "Batch normalization gradient requires mean and variance to have equal ranks."), l(m == null || r.rank === m.rank, () => "Batch normalization gradient requires mean and offset to have equal ranks."), l(
|
|
53
|
+
e != null && (m = c(e, "offset", "batchNorm")), l(r.rank === a.rank, () => "Batch normalization gradient requires mean and variance to have equal ranks."), l(m == null || r.rank === m.rank, () => "Batch normalization gradient requires mean and offset to have equal ranks."), l(i == null || r.rank === i.rank, () => "Batch normalization gradient requires mean and scale to have equal ranks.");
|
|
54
54
|
const x = {
|
|
55
55
|
x: kn(p),
|
|
56
|
-
scale:
|
|
56
|
+
scale: i,
|
|
57
57
|
offset: m,
|
|
58
58
|
mean: r,
|
|
59
59
|
variance: a
|
|
60
|
-
}, k = { varianceEpsilon: o }, $ = d.runKernel(
|
|
60
|
+
}, k = { varianceEpsilon: o }, $ = d.runKernel(W, x, k);
|
|
61
61
|
return b($, p.shape);
|
|
62
62
|
}
|
|
63
63
|
const at = /* @__PURE__ */ h({ batchNorm_: vn });
|
|
64
|
-
function gn(s, n, t, e,
|
|
64
|
+
function gn(s, n, t, e, u = "NHWC", o = [1, 1], p) {
|
|
65
65
|
const r = c(s, "x", "conv2d", "float32"), a = c(n, "filter", "conv2d", "float32");
|
|
66
|
-
let
|
|
67
|
-
r.rank === 3 && (m = !0,
|
|
68
|
-
const f =
|
|
66
|
+
let i = r, m = !1;
|
|
67
|
+
r.rank === 3 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(i.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${i.rank}.`), l(a.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ${a.rank}.`), g("conv2d", e, p);
|
|
68
|
+
const f = u === "NHWC" ? i.shape[3] : i.shape[1];
|
|
69
69
|
l(f === a.shape[2], () => `Error in conv2d: depth of input (${f}) must match input depth for filter ${a.shape[2]}.`), l(C(t, o), () => `Error in conv2D: Either strides or dilations must be 1. Got strides ${t} and dilations '${o}'`), l(D(o), () => "Error in conv2D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv2D: Strides should be larger than 0.");
|
|
70
|
-
const x = { x:
|
|
70
|
+
const x = { x: i, filter: a }, k = { strides: t, pad: e, dataFormat: u, dilations: o, dimRoundingMode: p }, $ = d.runKernel(B, x, k);
|
|
71
71
|
return m ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
|
|
72
72
|
}
|
|
73
73
|
const S = /* @__PURE__ */ h({ conv2d_: gn });
|
|
74
|
-
function Dn(s, n, t, e,
|
|
74
|
+
function Dn(s, n, t, e, u = "NWC", o = 1, p) {
|
|
75
75
|
const r = c(s, "x", "conv1d"), a = c(n, "filter", "conv1d");
|
|
76
|
-
let
|
|
77
|
-
r.rank === 2 && (m = !0,
|
|
78
|
-
const f = b(a, [1, a.shape[0], a.shape[1], a.shape[2]]), x = b(
|
|
76
|
+
let i = r, m = !1;
|
|
77
|
+
r.rank === 2 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1]])), l(i.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${i.rank}.`), l(a.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ${a.rank}.`), g("conv1d", e, p), l(i.shape[2] === a.shape[1], () => `Error in conv1d: depth of input (${i.shape[2]}) must match input depth for filter ${a.shape[1]}.`), l(C(t, o), () => `Error in conv1D: Either stride or dilation must be 1. Got stride ${t} and dilation '${o}'`), l(D(o), () => "Error in conv1D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv1D: Stride should be larger than 0."), l(u === "NWC", () => `Error in conv1d: got dataFormat of ${u} but only NWC is currently supported.`);
|
|
78
|
+
const f = b(a, [1, a.shape[0], a.shape[1], a.shape[2]]), x = b(i, [i.shape[0], 1, i.shape[1], i.shape[2]]), v = S(x, f, [1, t], e, "NHWC", [1, o], p);
|
|
79
79
|
return m ? b(v, [v.shape[2], v.shape[3]]) : b(v, [v.shape[0], v.shape[2], v.shape[3]]);
|
|
80
80
|
}
|
|
81
|
-
const
|
|
82
|
-
function Cn(s, n, t, e,
|
|
81
|
+
const ut = /* @__PURE__ */ h({ conv1d_: Dn });
|
|
82
|
+
function Cn(s, n, t, e, u, o = "NHWC", p) {
|
|
83
83
|
l(s.length === n.rank, () => `Length of inShape (${s.length}) and rank of dy (${n.rank}) must match`);
|
|
84
|
-
let r = s, a = n,
|
|
85
|
-
n.rank === 3 && (
|
|
84
|
+
let r = s, a = n, i = !1;
|
|
85
|
+
n.rank === 3 && (i = !0, a = b(n, [1, n.shape[0], n.shape[1], n.shape[2]]), r = [1, s[0], s[1], s[2]]), l(r.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ${r.length}.`), l(a.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got rank ${a.rank}`), l(t.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got rank ${t.rank}`);
|
|
86
86
|
const m = o === "NHWC" ? r[3] : r[1], f = o === "NHWC" ? a.shape[3] : a.shape[1];
|
|
87
|
-
l(m === t.shape[2], () => `Error in conv2dDerInput: depth of input (${m}) must match input depth for filter ${t.shape[2]}.`), l(f === t.shape[3], () => `Error in conv2dDerInput: depth of output (${f}) must match output depth for filter ${t.shape[3]}.`), g("conv2dDerInput",
|
|
88
|
-
const x = { dy: a, filter: t }, k = { strides: e, pad:
|
|
89
|
-
return
|
|
87
|
+
l(m === t.shape[2], () => `Error in conv2dDerInput: depth of input (${m}) must match input depth for filter ${t.shape[2]}.`), l(f === t.shape[3], () => `Error in conv2dDerInput: depth of output (${f}) must match output depth for filter ${t.shape[3]}.`), g("conv2dDerInput", u, p);
|
|
88
|
+
const x = { dy: a, filter: t }, k = { strides: e, pad: u, dataFormat: o, dimRoundingMode: p, inputShape: r }, $ = d.runKernel(A, x, k);
|
|
89
|
+
return i ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
|
|
90
90
|
}
|
|
91
91
|
const En = /* @__PURE__ */ h({ conv2DBackpropInput_: Cn });
|
|
92
|
-
function Nn(s, n, t, e,
|
|
92
|
+
function Nn(s, n, t, e, u, o) {
|
|
93
93
|
const p = c(s, "x", "conv2dTranspose"), r = c(n, "filter", "conv2dTranspose");
|
|
94
|
-
return En(t, p, r, e,
|
|
94
|
+
return En(t, p, r, e, u, "NHWC", o);
|
|
95
95
|
}
|
|
96
|
-
const
|
|
96
|
+
const it = /* @__PURE__ */ h({ conv2dTranspose_: Nn });
|
|
97
97
|
function _n(s) {
|
|
98
98
|
const t = { x: c(s, "x", "cos", "float32") };
|
|
99
99
|
return d.runKernel(G, t);
|
|
@@ -114,21 +114,21 @@ function Sn(s, n = 0, t = !1, e = !1) {
|
|
|
114
114
|
return d.runKernel(z, o, p);
|
|
115
115
|
}
|
|
116
116
|
const ht = /* @__PURE__ */ h({ cumsum_: Sn });
|
|
117
|
-
function Tn(s, n, t, e,
|
|
117
|
+
function Tn(s, n, t, e, u = "NHWC", o = [1, 1], p) {
|
|
118
118
|
const r = c(s, "x", "depthwiseConv2d", "float32"), a = c(n, "filter", "depthwiseConv2d", "float32");
|
|
119
|
-
let
|
|
120
|
-
r.rank === 3 && (m = !0,
|
|
121
|
-
const f =
|
|
119
|
+
let i = r, m = !1;
|
|
120
|
+
r.rank === 3 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(i.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got rank ${i.rank}.`), l(a.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ${a.rank}.`);
|
|
121
|
+
const f = u === "NHWC" ? i.shape[3] : i.shape[1];
|
|
122
122
|
l(f === a.shape[2], () => `Error in depthwiseConv2d: number of input channels (${f}) must match the inChannels dimension in filter ${a.shape[2]}.`), g("depthwiseConv2d", e, p);
|
|
123
|
-
const x = { x:
|
|
123
|
+
const x = { x: i, filter: a }, k = { strides: t, pad: e, dataFormat: u, dilations: o, dimRoundingMode: p }, $ = d.runKernel(F, x, k);
|
|
124
124
|
return m ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
|
|
125
125
|
}
|
|
126
126
|
const qn = /* @__PURE__ */ h({ depthwiseConv2d_: Tn });
|
|
127
127
|
function Hn(s, n) {
|
|
128
128
|
let t = c(s, "a", "equal", "string_or_numeric"), e = c(n, "b", "equal", "string_or_numeric");
|
|
129
129
|
[t, e] = M(t, e), j(t.shape, e.shape);
|
|
130
|
-
const
|
|
131
|
-
return d.runKernel(J,
|
|
130
|
+
const u = { a: t, b: e };
|
|
131
|
+
return d.runKernel(J, u);
|
|
132
132
|
}
|
|
133
133
|
const ft = /* @__PURE__ */ h({ equal_: Hn });
|
|
134
134
|
function Pn(s) {
|
|
@@ -143,36 +143,36 @@ function yn(s) {
|
|
|
143
143
|
return d.runKernel(U, t);
|
|
144
144
|
}
|
|
145
145
|
const mt = /* @__PURE__ */ h({ softplus_: yn });
|
|
146
|
-
function
|
|
146
|
+
function In(s, n = -1) {
|
|
147
147
|
const t = c(s, "logits", "logSoftmax");
|
|
148
148
|
if (n === -1 && (n = t.rank - 1), n !== t.rank - 1)
|
|
149
149
|
throw Error(`Log Softmax along a non-last dimension is not yet supported. Logits was rank ${t.rank} and axis was ${n}`);
|
|
150
|
-
return V((
|
|
151
|
-
const r = pn(
|
|
152
|
-
return o([
|
|
150
|
+
return V((u, o) => {
|
|
151
|
+
const r = pn(u, n, !0), a = N(u, r), i = N(_(a, "float32"), hn(K(w(a), n, !0)));
|
|
152
|
+
return o([i]), { value: i, gradFunc: (f, x) => {
|
|
153
153
|
const [k] = x, $ = !0, E = w(k);
|
|
154
154
|
return N(f, X(K(f, n, $), E));
|
|
155
155
|
} };
|
|
156
156
|
})(t);
|
|
157
157
|
}
|
|
158
|
-
const $t = /* @__PURE__ */ h({ logSoftmax_:
|
|
159
|
-
function
|
|
158
|
+
const $t = /* @__PURE__ */ h({ logSoftmax_: In });
|
|
159
|
+
function Wn(s) {
|
|
160
160
|
const t = { x: c(s, "x", "logicalNot", "bool") };
|
|
161
161
|
return d.runKernel(Y, t);
|
|
162
162
|
}
|
|
163
|
-
const bt = /* @__PURE__ */ h({ logicalNot_:
|
|
164
|
-
function
|
|
163
|
+
const bt = /* @__PURE__ */ h({ logicalNot_: Wn });
|
|
164
|
+
function Bn(s, n, t, e, u) {
|
|
165
165
|
const o = c(s, "x", "maxPool"), p = 1;
|
|
166
166
|
let r = o, a = !1;
|
|
167
|
-
o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${r.rank}.`), l(C(t, p), () => `Error in maxPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`), g("maxPool", e,
|
|
168
|
-
const
|
|
167
|
+
o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${r.rank}.`), l(C(t, p), () => `Error in maxPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`), g("maxPool", e, u);
|
|
168
|
+
const i = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: u }, f = d.runKernel(Z, i, m);
|
|
169
169
|
return a ? b(f, [f.shape[1], f.shape[2], f.shape[3]]) : f;
|
|
170
170
|
}
|
|
171
|
-
const xt = /* @__PURE__ */ h({ maxPool_:
|
|
172
|
-
function An(s, n, t = 1, e = 0,
|
|
171
|
+
const xt = /* @__PURE__ */ h({ maxPool_: Bn });
|
|
172
|
+
function An(s, n, t = 1, e = 0, u = "int32") {
|
|
173
173
|
if (n < 2)
|
|
174
174
|
throw new Error(`Error in oneHot: depth must be >=2, but it is ${n}`);
|
|
175
|
-
const p = { indices: c(s, "indices", "oneHot", "int32") }, r = { dtype:
|
|
175
|
+
const p = { indices: c(s, "indices", "oneHot", "int32") }, r = { dtype: u, depth: n, onValue: t, offValue: e };
|
|
176
176
|
return d.runKernel(R, p, r);
|
|
177
177
|
}
|
|
178
178
|
const kt = /* @__PURE__ */ h({ oneHot_: An });
|
|
@@ -185,20 +185,20 @@ function Ln(s, n, t = 0) {
|
|
|
185
185
|
const e = c(s, "x", "pad");
|
|
186
186
|
if (e.rank === 0)
|
|
187
187
|
throw new Error("pad(scalar) is not defined. Pass non-scalar to pad");
|
|
188
|
-
const
|
|
189
|
-
return d.runKernel(tn, o,
|
|
188
|
+
const u = { paddings: n, constantValue: t }, o = { x: e };
|
|
189
|
+
return d.runKernel(tn, o, u);
|
|
190
190
|
}
|
|
191
191
|
const gt = /* @__PURE__ */ h({ pad_: Ln });
|
|
192
192
|
function On(s, n, t) {
|
|
193
193
|
const e = c(s, "x", "spaceToBatchND");
|
|
194
194
|
l(e.rank >= 1 + n.length, () => `input rank ${e.rank} should be > than [blockShape] ${n.length}`), l(t.length === n.length, () => `paddings.shape[0] ${t.length} must be equal to [blockShape] ${n.length}`), l(e.shape.reduce((p, r, a) => a > 0 && a <= n.length ? p && (r + t[a - 1][0] + t[a - 1][1]) % n[a - 1] === 0 : p, !0), () => `input spatial dimensions ${e.shape.slice(1)} with paddings ${t.toString()} must be divisible by blockShapes ${n.toString()}`);
|
|
195
|
-
const
|
|
196
|
-
return d.runKernel(sn,
|
|
195
|
+
const u = { x: e }, o = { blockShape: n, paddings: t };
|
|
196
|
+
return d.runKernel(sn, u, o);
|
|
197
197
|
}
|
|
198
198
|
const Dt = /* @__PURE__ */ h({ spaceToBatchND_: On });
|
|
199
199
|
function zn(s, n) {
|
|
200
|
-
const e = { x: c(s, "x", "reverse") },
|
|
201
|
-
return d.runKernel(en, e,
|
|
200
|
+
const e = { x: c(s, "x", "reverse") }, u = { dims: n };
|
|
201
|
+
return d.runKernel(en, e, u);
|
|
202
202
|
}
|
|
203
203
|
const Ct = /* @__PURE__ */ h({ reverse_: zn });
|
|
204
204
|
function Fn(s) {
|
|
@@ -211,15 +211,15 @@ function Mn(s) {
|
|
|
211
211
|
return d.runKernel(on, t);
|
|
212
212
|
}
|
|
213
213
|
const Nt = /* @__PURE__ */ h({ selu_: Mn });
|
|
214
|
-
function jn(s, n, t, e,
|
|
215
|
-
const r = c(s, "x", "separableConv2d"), a = c(n, "depthwiseFilter", "separableConv2d"),
|
|
214
|
+
function jn(s, n, t, e, u, o = [1, 1], p = "NHWC") {
|
|
215
|
+
const r = c(s, "x", "separableConv2d"), a = c(n, "depthwiseFilter", "separableConv2d"), i = c(t, "pointwiseFilter", "separableConv2d");
|
|
216
216
|
let m = r, f = !1;
|
|
217
217
|
if (r.rank === 3 && (f = !0, m = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), p === "NCHW")
|
|
218
218
|
throw new Error("separableConv2d currently does not support dataFormat NCHW; only NHWC is supported");
|
|
219
|
-
l(m.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got rank ${m.rank}.`), l(a.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but got rank ${a.rank}.`), l(
|
|
219
|
+
l(m.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got rank ${m.rank}.`), l(a.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but got rank ${a.rank}.`), l(i.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but got rank ${a.rank}.`), l(i.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter must be 1, but got ${i.shape[0]}.`), l(i.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise filter must be 1, but got ${i.shape[1]}.`);
|
|
220
220
|
const x = a.shape[2], k = a.shape[3];
|
|
221
|
-
l(
|
|
222
|
-
const $ = qn(m, a, e,
|
|
221
|
+
l(i.shape[2] === x * k, () => `Error in separableConv2d: the third dimension of pointwise filter must be ${x * k}, but got ${i.shape[2]}.`);
|
|
222
|
+
const $ = qn(m, a, e, u, p, o), v = S($, i, 1, "valid", p);
|
|
223
223
|
return f ? b(v, [v.shape[1], v.shape[2], v.shape[3]]) : v;
|
|
224
224
|
}
|
|
225
225
|
const _t = /* @__PURE__ */ h({ separableConv2d_: jn });
|
|
@@ -234,9 +234,9 @@ function Qn(s) {
|
|
|
234
234
|
}
|
|
235
235
|
const Kt = /* @__PURE__ */ h({ sinh_: Qn });
|
|
236
236
|
function Un(s, n, t) {
|
|
237
|
-
const e = c(s, "x", "unsortedSegmentSum"),
|
|
237
|
+
const e = c(s, "x", "unsortedSegmentSum"), u = c(n, "segmentIds", "unsortedSegmentSum", "int32");
|
|
238
238
|
l(cn(t), () => "numSegments must be of dtype int");
|
|
239
|
-
const o = { x: e, segmentIds:
|
|
239
|
+
const o = { x: e, segmentIds: u }, p = { numSegments: t };
|
|
240
240
|
return d.runKernel(ln, o, p);
|
|
241
241
|
}
|
|
242
242
|
const St = /* @__PURE__ */ h({ unsortedSegmentSum_: Un });
|
|
@@ -258,10 +258,10 @@ export {
|
|
|
258
258
|
tt as h,
|
|
259
259
|
st as i,
|
|
260
260
|
at as j,
|
|
261
|
-
|
|
261
|
+
ut as k,
|
|
262
262
|
bt as l,
|
|
263
263
|
xt as m,
|
|
264
|
-
|
|
264
|
+
it as n,
|
|
265
265
|
S as o,
|
|
266
266
|
lt as p,
|
|
267
267
|
pt as q,
|
package/dist/utilities/dummy.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { a as y, e as S, v as w } from "../index-
|
|
2
|
-
import { z as m } from "../zeros-
|
|
3
|
-
import { o as P } from "../ones-
|
|
1
|
+
import { a as y, e as S, v as w } from "../index-D0RBWjq8.js";
|
|
2
|
+
import { z as m } from "../zeros-DeiE2zTa.js";
|
|
3
|
+
import { o as P } from "../ones-BX_wEgzB.js";
|
|
4
4
|
async function b(s) {
|
|
5
5
|
const t = m([1, s.config.blockSize], "int32"), [n, o] = s.forward({ training: !1 }, t);
|
|
6
6
|
await n.data(), n.dispose(), o && o.dispose(), t.dispose();
|
package/dist/utilities/packed.js
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
import { m as w } from "../index-
|
|
2
|
-
import { t as g } from "../tensor2d-
|
|
3
|
-
import { e as y } from "../expand_dims-
|
|
4
|
-
import { s as h } from "../sum-
|
|
5
|
-
import { c as T } from "../concat-
|
|
1
|
+
import { m as w } from "../index-D0RBWjq8.js";
|
|
2
|
+
import { t as g } from "../tensor2d-BN1sSfQO.js";
|
|
3
|
+
import { e as y } from "../expand_dims-BzfJK2uc.js";
|
|
4
|
+
import { s as h } from "../sum-BdplSvq_.js";
|
|
5
|
+
import { c as T } from "../concat-C6X3AAlQ.js";
|
|
6
6
|
const p = 16;
|
|
7
7
|
function A(o, t) {
|
|
8
8
|
if (!t)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import "../index-
|
|
2
|
-
import { t as p } from "../tensor-
|
|
1
|
+
import "../index-D0RBWjq8.js";
|
|
2
|
+
import { t as p } from "../tensor-BQqrDvpx.js";
|
|
3
3
|
function h(n) {
|
|
4
4
|
const e = n.reduce((s, o) => s + o.length, 0), a = new Float32Array(e);
|
|
5
5
|
let t = 0;
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { ad as z, ac as F, ab as E, a9 as j, y as A } from "./index-D0RBWjq8.js";
|
|
2
2
|
function L(t, s) {
|
|
3
3
|
if (Math.max(...t) > 5)
|
|
4
4
|
throw new Error("Cannot symbolically compute strides for rank > 6 tensor.");
|
|
@@ -27,7 +27,7 @@ var w;
|
|
|
27
27
|
})(w || (w = {}));
|
|
28
28
|
const H = (t, s, e, o, i) => {
|
|
29
29
|
const u = { dtype: o.dtype, shape: o.shape }, n = D(e, u, s), r = t.createShaderModule({ code: n, label: s.constructor.name });
|
|
30
|
-
let d =
|
|
30
|
+
let d = E().get("WEBGPU_PRINT_SHADER");
|
|
31
31
|
if (d !== "") {
|
|
32
32
|
d = d.toLowerCase();
|
|
33
33
|
const p = d.split(",");
|
|
@@ -281,7 +281,7 @@ function y(t, s = "") {
|
|
|
281
281
|
const e = t.length, o = s !== "" ? `get${s.charAt(0).toUpperCase() + s.slice(1)}CoordsFromIndex` : "getCoordsFromIndex", i = s !== "" ? `${s.charAt(0).toLowerCase() + s.slice(1)}ShapeStrides` : "outShapeStrides";
|
|
282
282
|
if (e <= 1)
|
|
283
283
|
return `fn ${o}(index : i32) -> i32 { return index; }`;
|
|
284
|
-
const u =
|
|
284
|
+
const u = j(t), n = g(e), r = [];
|
|
285
285
|
for (let p = 0; p < e; p++)
|
|
286
286
|
r.push(`d${p}`);
|
|
287
287
|
if (u.length === 1)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { c as f } from "./complex-
|
|
1
|
+
import { w as n, U as m, V as i, E as c } from "./index-D0RBWjq8.js";
|
|
2
|
+
import { c as f } from "./complex-DClmWqJt.js";
|
|
3
3
|
function e(o, r = "float32") {
|
|
4
4
|
if (n(o), r === "complex64") {
|
|
5
5
|
const s = e(o, "float32"), t = e(o, "float32");
|