@genai-fi/nanogpt 0.4.4 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseLayer-BhrMN8JO.js +135 -0
- package/dist/Generator.js +44 -41
- package/dist/NanoGPTModel.d.ts +12 -16
- package/dist/NanoGPTModel.js +128 -138
- package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
- package/dist/TeachableLLM.js +8 -5
- package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
- package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
- package/dist/broadcast_to-CMlkG8NS.js +44 -0
- package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
- package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
- package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
- package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
- package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
- package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
- package/dist/layers/BaseLayer.d.ts +28 -4
- package/dist/layers/BaseLayer.js +3 -16
- package/dist/layers/CausalSelfAttention.d.ts +22 -24
- package/dist/layers/CausalSelfAttention.js +73 -127
- package/dist/layers/MLP.d.ts +8 -15
- package/dist/layers/MLP.js +43 -81
- package/dist/layers/RMSNorm.d.ts +5 -11
- package/dist/layers/RMSNorm.js +13 -29
- package/dist/layers/RoPECache.js +14 -12
- package/dist/layers/TiedEmbedding.d.ts +6 -16
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.d.ts +12 -16
- package/dist/layers/TransformerBlock.js +20 -41
- package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
- package/dist/main.js +22 -19
- package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
- package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
- package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
- package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
- package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
- package/dist/ops/appendCache.js +4 -4
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +14 -15
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.d.ts +1 -0
- package/dist/ops/cpu/matMulMul.js +17 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.d.ts +1 -0
- package/dist/ops/cpu/normRMS.js +39 -0
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +8 -8
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +13 -9
- package/dist/ops/grads/fusedSoftmax.js +12 -9
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.d.ts +2 -0
- package/dist/ops/grads/normRMS.js +20 -0
- package/dist/ops/grads/qkv.js +19 -9
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.d.ts +2 -0
- package/dist/ops/matMulMul.js +9 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.d.ts +2 -0
- package/dist/ops/normRMS.js +10 -0
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +13 -12
- package/dist/ops/webgl/fusedSoftmax.js +43 -40
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.d.ts +3 -2
- package/dist/ops/webgl/matMulGelu.js +77 -75
- package/dist/ops/webgl/matMulMul.d.ts +14 -0
- package/dist/ops/webgl/matMulMul.js +28 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.d.ts +1 -0
- package/dist/ops/webgl/normRMS.js +86 -0
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops-ObfXLHYQ.js +1269 -0
- package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
- package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
- package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
- package/dist/slice_util-D-kaD4ZV.js +49 -0
- package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
- package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
- package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
- package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
- package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
- package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
- package/dist/tfjs_backend-NucKez4s.js +1010 -0
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +44 -44
- package/dist/training/Evaluator.js +6 -6
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +7 -7
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +10 -10
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/save.js +10 -8
- package/dist/utilities/weights.js +2 -2
- package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
- package/package.json +1 -1
- package/dist/slice_util-BdhYwFY_.js +0 -90
- package/dist/tfjs_backend-DuKis_xG.js +0 -2271
- package/dist/variable-BJTZ3jOy.js +0 -23
|
@@ -1,17 +1,20 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { mulDrop as
|
|
3
|
-
import { s as
|
|
4
|
-
const
|
|
1
|
+
import { h as d, b as u, s as f } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
import { mulDrop as l } from "../mulDrop.js";
|
|
3
|
+
import { s as g } from "../../sum-B_92TaHD.js";
|
|
4
|
+
const T = {
|
|
5
5
|
kernelName: "FusedSoftmax",
|
|
6
6
|
outputsToSave: [!0],
|
|
7
|
-
gradFunc: (
|
|
8
|
-
const [
|
|
7
|
+
gradFunc: (o, i, n) => {
|
|
8
|
+
const [s] = i, { dim: a, dropoutRate: t, seed: e } = n, p = !0, r = t && e ? l(o, s, t, e) : u(o, s);
|
|
9
9
|
return {
|
|
10
|
-
logits: () =>
|
|
10
|
+
logits: () => {
|
|
11
|
+
const m = g(r, [a], p), c = u(m, s);
|
|
12
|
+
return m.dispose(), f(r, c);
|
|
13
|
+
}
|
|
11
14
|
};
|
|
12
15
|
}
|
|
13
16
|
};
|
|
14
|
-
|
|
17
|
+
d(T);
|
|
15
18
|
export {
|
|
16
|
-
|
|
19
|
+
T as softmaxGradConfig
|
|
17
20
|
};
|
package/dist/ops/grads/gelu.js
CHANGED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import { h as t, e as g } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
function s(r, a, n) {
|
|
3
|
+
return g().runKernel("RMSNormGrad", { dy: r, x: a, gamma: n });
|
|
4
|
+
}
|
|
5
|
+
const u = {
|
|
6
|
+
kernelName: "RMSNorm",
|
|
7
|
+
inputsToSave: ["x", "gamma"],
|
|
8
|
+
outputsToSave: [],
|
|
9
|
+
gradFunc: (r, a) => {
|
|
10
|
+
const [n, e] = a, [m, o] = s(r, n, e);
|
|
11
|
+
return {
|
|
12
|
+
x: () => m,
|
|
13
|
+
gamma: () => o
|
|
14
|
+
};
|
|
15
|
+
}
|
|
16
|
+
};
|
|
17
|
+
t(u);
|
|
18
|
+
export {
|
|
19
|
+
u as normRMSGradConfig
|
|
20
|
+
};
|
package/dist/ops/grads/qkv.js
CHANGED
|
@@ -1,20 +1,30 @@
|
|
|
1
|
-
import {
|
|
2
|
-
const
|
|
1
|
+
import { h as Q } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
const V = {
|
|
3
3
|
kernelName: "QKV",
|
|
4
4
|
inputsToSave: ["x", "kernel"],
|
|
5
5
|
outputsToSave: [],
|
|
6
|
-
gradFunc: (
|
|
7
|
-
const [
|
|
6
|
+
gradFunc: (x, K) => {
|
|
7
|
+
const [f, h, M] = x, [p, l] = K, [t, n, e] = p.shape, i = f.transpose([0, 2, 1, 3]).reshape([t * n, e]), u = h.transpose([0, 2, 1, 3]).reshape([t * n, e]), k = M.transpose([0, 2, 1, 3]).reshape([t * n, e]);
|
|
8
8
|
return {
|
|
9
9
|
x: () => {
|
|
10
|
-
const
|
|
11
|
-
|
|
10
|
+
const s = l.slice([0, 0], [e, e]), o = i.matMul(s, !1, !0);
|
|
11
|
+
s.dispose();
|
|
12
|
+
const d = l.slice([0, e], [e, e]), r = u.matMul(d, !1, !0);
|
|
13
|
+
d.dispose();
|
|
14
|
+
const a = o.add(r);
|
|
15
|
+
o.dispose(), r.dispose();
|
|
16
|
+
const c = l.slice([0, 2 * e], [e, e]), m = k.matMul(c, !1, !0);
|
|
17
|
+
c.dispose();
|
|
18
|
+
const v = a.add(m).reshape([t, n, e]);
|
|
19
|
+
return a.dispose(), m.dispose(), v;
|
|
12
20
|
},
|
|
13
21
|
kernel: () => {
|
|
14
|
-
const
|
|
15
|
-
|
|
22
|
+
const s = p.reshape([t * n, e]), o = s.matMul(i, !0, !1), d = s.matMul(u, !0, !1), r = o.concat(d, 1);
|
|
23
|
+
o.dispose(), d.dispose();
|
|
24
|
+
const a = s.matMul(k, !0, !1), c = r.concat(a, 1);
|
|
25
|
+
return r.dispose(), a.dispose(), s.dispose(), c;
|
|
16
26
|
}
|
|
17
27
|
};
|
|
18
28
|
}
|
|
19
29
|
};
|
|
20
|
-
|
|
30
|
+
Q(V);
|
package/dist/ops/grads/rope.js
CHANGED
package/dist/ops/matMulGelu.js
CHANGED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
import { e as u } from "../index-iNhkcAEQ.js";
|
|
2
|
+
import "./cpu/matMulMul.js";
|
|
3
|
+
import "./webgl/matMulMul.js";
|
|
4
|
+
function m(e, r, t, l = !1, n = !1) {
|
|
5
|
+
return u().runKernel("MatMulMul", { x: e, kernel: r, y: t }, { transposeA: l, transposeB: n });
|
|
6
|
+
}
|
|
7
|
+
export {
|
|
8
|
+
m as matMulMul
|
|
9
|
+
};
|
package/dist/ops/mulDrop.js
CHANGED
package/dist/ops/qkv.js
CHANGED
package/dist/ops/scatterSub.js
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
import { r as
|
|
2
|
-
class
|
|
1
|
+
import { r as m } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
class h {
|
|
3
3
|
variableNames = ["q", "k"];
|
|
4
4
|
outputShape;
|
|
5
5
|
userCode;
|
|
6
6
|
customUniforms = [
|
|
7
7
|
{ name: "divisor", type: "float" },
|
|
8
|
-
{ name: "pastLen", type: "int" }
|
|
8
|
+
{ name: "pastLen", type: "int" },
|
|
9
|
+
{ name: "inf", type: "float" }
|
|
9
10
|
];
|
|
10
|
-
constructor(t,
|
|
11
|
-
this.outputShape = [t,
|
|
11
|
+
constructor(t, e, s, n, a) {
|
|
12
|
+
this.outputShape = [t, e, s, n], this.userCode = `
|
|
12
13
|
void main() {
|
|
13
14
|
ivec4 coords = getOutputCoords(); // [batch, nh, t1, t2]
|
|
14
15
|
int b = coords.x;
|
|
@@ -27,18 +28,18 @@ class l {
|
|
|
27
28
|
float scaled = sum * divisor;
|
|
28
29
|
|
|
29
30
|
// Mask out future positions
|
|
30
|
-
setOutput((t2 > t1 + pastLen) ?
|
|
31
|
+
setOutput((t2 > t1 + pastLen) ? inf : scaled);
|
|
31
32
|
}
|
|
32
33
|
`;
|
|
33
34
|
}
|
|
34
35
|
}
|
|
35
|
-
function
|
|
36
|
-
const { q: t, k:
|
|
37
|
-
return a.runWebGLProgram(d, [t,
|
|
36
|
+
function l(o) {
|
|
37
|
+
const { q: t, k: e } = o.inputs, { divisor: s, pastLen: n } = o.attrs, a = o.backend, i = t.shape[0], r = t.shape[2], c = e.shape[2], u = t.shape[1], p = t.shape[3], d = new h(i, u, r, c, p);
|
|
38
|
+
return a.runWebGLProgram(d, [t, e], "float32", [[s], [n], [Number.NEGATIVE_INFINITY]]);
|
|
38
39
|
}
|
|
39
|
-
const
|
|
40
|
+
const f = {
|
|
40
41
|
kernelName: "AttentionMask",
|
|
41
42
|
backendName: "webgl",
|
|
42
|
-
kernelFunc:
|
|
43
|
+
kernelFunc: l
|
|
43
44
|
};
|
|
44
|
-
|
|
45
|
+
m(f);
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
import { a$ as qt,
|
|
2
|
-
import { f as $t, a as
|
|
3
|
-
import { c as mt, g as Tt, a as Vt, b as Mt, e as wt } from "../../axis_util-
|
|
4
|
-
import { b as te
|
|
5
|
-
import { r as
|
|
1
|
+
import { a$ as qt, p as y, N as Ct, ad as Q, b0 as nt, b1 as It, b2 as pt, b3 as ft, b4 as kt, Z as ot, ah as B, b5 as q, b6 as Ut, O as Gt, k as jt, t as Zt, b7 as dt, ao as ct, b8 as tt, a1 as ut, b9 as Bt, K as Ht, ba as Kt, r as Xt } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
import { f as $t, a as Yt, g as yt, b as Jt, c as Qt } from "../../kernel_funcs_utils-C4eIk4fE.js";
|
|
3
|
+
import { c as mt, g as Tt, a as Vt, b as Mt, e as wt } from "../../axis_util-97KkkyRQ.js";
|
|
4
|
+
import { b as te } from "../../broadcast_to-CMlkG8NS.js";
|
|
5
|
+
import { r as ee } from "../../reshape-DxTPgnwL.js";
|
|
6
|
+
import { i as ne, c as oe } from "../../slice_util-D-kaD4ZV.js";
|
|
6
7
|
import { g as se } from "../../_commonjsHelpers-ByX85dGu.js";
|
|
7
|
-
import { r as st } from "../../Reshape-
|
|
8
|
+
import { r as st } from "../../Reshape-BE5rA4rT.js";
|
|
8
9
|
function re(t, e) {
|
|
9
10
|
for (var n = 0; n < e.length; n++) {
|
|
10
11
|
const o = e[n];
|
|
@@ -654,20 +655,20 @@ const _t = /* @__PURE__ */ se(Lt), ae = /* @__PURE__ */ re({
|
|
|
654
655
|
* limitations under the License.
|
|
655
656
|
* =============================================================================
|
|
656
657
|
*/
|
|
657
|
-
const
|
|
658
|
+
const J = (
|
|
658
659
|
// tslint:disable-next-line
|
|
659
660
|
_t || ae
|
|
660
661
|
);
|
|
661
662
|
function lt(t) {
|
|
662
|
-
return
|
|
663
|
+
return J.fromString(t, !0, 16);
|
|
663
664
|
}
|
|
664
|
-
const Nt = lt("c3a5c85c97cb3127"),
|
|
665
|
+
const Nt = lt("c3a5c85c97cb3127"), Y = lt("b492b66fbe98f273"), W = lt("9ae16a3b2f90404f");
|
|
665
666
|
function gt(t) {
|
|
666
667
|
return t.xor(t.shru(47));
|
|
667
668
|
}
|
|
668
669
|
function At(t, e, n) {
|
|
669
670
|
const o = t.slice(e, e + n);
|
|
670
|
-
return
|
|
671
|
+
return J.fromBytes(Array.from(o), !0, !0);
|
|
671
672
|
}
|
|
672
673
|
function R(t, e) {
|
|
673
674
|
return At(t, e, 8);
|
|
@@ -708,7 +709,7 @@ function le(t, e = t.length) {
|
|
|
708
709
|
return W;
|
|
709
710
|
}
|
|
710
711
|
function ce(t, e = t.length) {
|
|
711
|
-
const n = W.add(e * 2), o = R(t, 0).mul(
|
|
712
|
+
const n = W.add(e * 2), o = R(t, 0).mul(Y), s = R(t, 8), r = R(t, e - 8).mul(n), a = R(t, e - 16).mul(W);
|
|
712
713
|
return X(A(o.add(s), 43).add(A(r, 30)).add(a), o.add(A(s.add(W), 18)).add(r), n);
|
|
713
714
|
}
|
|
714
715
|
function he(t, e = t.length) {
|
|
@@ -716,19 +717,19 @@ function he(t, e = t.length) {
|
|
|
716
717
|
return X(A(c.add(h), 43).add(A(f, 30)).add(w), c.add(A(h.add(o), 18)).add(f), n);
|
|
717
718
|
}
|
|
718
719
|
function fe(t, e = t.length) {
|
|
719
|
-
const n =
|
|
720
|
+
const n = J.fromNumber(81, !0);
|
|
720
721
|
if (e <= 32)
|
|
721
722
|
return e <= 16 ? le(t, e) : ce(t, e);
|
|
722
723
|
if (e <= 64)
|
|
723
724
|
return he(t, e);
|
|
724
|
-
let o = n, s = n.mul(
|
|
725
|
+
let o = n, s = n.mul(Y).add(113), r = gt(s.mul(W).add(113)).mul(W), a = [J.UZERO, J.UZERO], i = [J.UZERO, J.UZERO];
|
|
725
726
|
o = o.mul(W).add(R(t, 0));
|
|
726
727
|
let u = 0;
|
|
727
728
|
const c = (e - 1 >> 6) * 64, h = c + (e - 1 & 63) - 63;
|
|
728
729
|
do
|
|
729
|
-
o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(
|
|
730
|
+
o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(Y), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(Y), o = o.xor(i[1]), s = s.add(a[0]).add(R(t, u + 40)), r = A(r.add(i[0]), 33).mul(Y), a = it(t, u, a[1].mul(Y), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], u += 64;
|
|
730
731
|
while (u !== c);
|
|
731
|
-
const f =
|
|
732
|
+
const f = Y.add(r.and(255).shl(1));
|
|
732
733
|
return u = h, i[0] = i[0].add(e - 1 & 63), a[0] = a[0].add(i[0]), i[0] = i[0].add(a[0]), o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(f), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(f), o = o.xor(i[1].mul(9)), s = s.add(a[0].mul(9).add(R(t, u + 40))), r = A(r.add(i[0]), 33).mul(f), a = it(t, u, a[1].mul(f), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], X(X(a[0], i[0], f).add(gt(s).mul(Nt)).add(r), X(a[1], i[1], f).add(o), f);
|
|
733
734
|
}
|
|
734
735
|
/**
|
|
@@ -954,7 +955,7 @@ function Ve(t) {
|
|
|
954
955
|
*/
|
|
955
956
|
function C(t) {
|
|
956
957
|
return (e, n, o, s, r) => {
|
|
957
|
-
const a = Ct(e, n), i = a.length, u =
|
|
958
|
+
const a = Ct(e, n), i = a.length, u = Q(a), c = y(a), h = nt(r, c), f = e.length, w = n.length, p = Q(e), m = Q(n), b = It(e, a), d = It(n, a);
|
|
958
959
|
if (b.length + d.length === 0)
|
|
959
960
|
for (let g = 0; g < h.length; ++g)
|
|
960
961
|
h[g] = t(o[g % o.length], s[g % s.length]);
|
|
@@ -1415,7 +1416,7 @@ const Xe = K((t) => Math.log(t));
|
|
|
1415
1416
|
* limitations under the License.
|
|
1416
1417
|
* =============================================================================
|
|
1417
1418
|
*/
|
|
1418
|
-
function
|
|
1419
|
+
function Ye(t, e, n, o) {
|
|
1419
1420
|
const s = nt(o, y(n));
|
|
1420
1421
|
for (let r = 0; r < s.length; ++r) {
|
|
1421
1422
|
const a = r * e;
|
|
@@ -1444,7 +1445,7 @@ function Qe(t, e, n, o) {
|
|
|
1444
1445
|
* limitations under the License.
|
|
1445
1446
|
* =============================================================================
|
|
1446
1447
|
*/
|
|
1447
|
-
const
|
|
1448
|
+
const Je = C((t, e) => Math.max(t, e));
|
|
1448
1449
|
/**
|
|
1449
1450
|
* @license
|
|
1450
1451
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1461,7 +1462,7 @@ const Ye = C((t, e) => Math.max(t, e));
|
|
|
1461
1462
|
* limitations under the License.
|
|
1462
1463
|
* =============================================================================
|
|
1463
1464
|
*/
|
|
1464
|
-
const
|
|
1465
|
+
const Qe = C((t, e) => Math.min(t, e));
|
|
1465
1466
|
/**
|
|
1466
1467
|
* @license
|
|
1467
1468
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1533,7 +1534,7 @@ const en = C((t, e) => t !== e ? 1 : 0);
|
|
|
1533
1534
|
* =============================================================================
|
|
1534
1535
|
*/
|
|
1535
1536
|
function nn(t, e, n, o, s) {
|
|
1536
|
-
const r = e.length, a = y(e), i =
|
|
1537
|
+
const r = e.length, a = y(e), i = Q(e), u = Q(s), c = nt(n, y(s));
|
|
1537
1538
|
for (let h = 0; h < a; ++h) {
|
|
1538
1539
|
const f = pt(h, r, i), w = new Array(f.length);
|
|
1539
1540
|
for (let m = 0; m < w.length; m++)
|
|
@@ -1589,7 +1590,7 @@ function on(t, e, n, o) {
|
|
|
1589
1590
|
function sn(t, e, n) {
|
|
1590
1591
|
t.forEach((o, s) => {
|
|
1591
1592
|
if (o < 0 || o >= n) {
|
|
1592
|
-
const r = pt(s, e.length,
|
|
1593
|
+
const r = pt(s, e.length, Q(e)).join(",");
|
|
1593
1594
|
throw new Error(`indices[${r}] = ${o} is not in [0, ${n})`);
|
|
1594
1595
|
}
|
|
1595
1596
|
});
|
|
@@ -1944,7 +1945,7 @@ class at {
|
|
|
1944
1945
|
if (h.length !== u && h.length !== 1) {
|
|
1945
1946
|
const m = this.defaultValueShape;
|
|
1946
1947
|
Zt(() => {
|
|
1947
|
-
const b =
|
|
1948
|
+
const b = ee(h, m);
|
|
1948
1949
|
h = te(b, i).dataSync();
|
|
1949
1950
|
});
|
|
1950
1951
|
}
|
|
@@ -2109,9 +2110,9 @@ const wn = K((t) => 1 / (1 + Math.exp(-t)));
|
|
|
2109
2110
|
* =============================================================================
|
|
2110
2111
|
*/
|
|
2111
2112
|
function In(t, e, n, o, s) {
|
|
2112
|
-
const r =
|
|
2113
|
+
const r = ne(o, e, n), a = y(n), i = Q(o);
|
|
2113
2114
|
if (r) {
|
|
2114
|
-
const f =
|
|
2115
|
+
const f = oe(e, i);
|
|
2115
2116
|
return s === "string" ? t.slice(f, f + a) : t.subarray(f, f + a);
|
|
2116
2117
|
}
|
|
2117
2118
|
const u = s === "string" ? $t(t) : t, c = B(o, s, u), h = B(n, s);
|
|
@@ -2119,7 +2120,7 @@ function In(t, e, n, o, s) {
|
|
|
2119
2120
|
const w = h.indexToLoc(f), p = w.map((m, b) => m + e[b]);
|
|
2120
2121
|
h.set(c.get(...p), ...w);
|
|
2121
2122
|
}
|
|
2122
|
-
return s === "string" ?
|
|
2123
|
+
return s === "string" ? Yt(h.values) : h.values;
|
|
2123
2124
|
}
|
|
2124
2125
|
/**
|
|
2125
2126
|
* @license
|
|
@@ -2790,9 +2791,9 @@ const An = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
2790
2791
|
lessImpl: Be,
|
|
2791
2792
|
linSpaceImpl: Ke,
|
|
2792
2793
|
logImpl: Xe,
|
|
2793
|
-
maxImpl:
|
|
2794
|
-
maximumImpl:
|
|
2795
|
-
minimumImpl:
|
|
2794
|
+
maxImpl: Ye,
|
|
2795
|
+
maximumImpl: Je,
|
|
2796
|
+
minimumImpl: Qe,
|
|
2796
2797
|
multiplyImpl: Pt,
|
|
2797
2798
|
negImpl: tn,
|
|
2798
2799
|
notEqualImpl: en,
|
|
@@ -3170,7 +3171,7 @@ class Un {
|
|
|
3170
3171
|
o[h] = e[n[h]];
|
|
3171
3172
|
if (this.outputShape = o, this.rank = o.length, this.rank > 6)
|
|
3172
3173
|
throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
|
|
3173
|
-
const s = yt(this.rank), r =
|
|
3174
|
+
const s = yt(this.rank), r = Jt("rc", this.rank), a = new Array(this.rank);
|
|
3174
3175
|
for (let h = 0; h < n.length; h++)
|
|
3175
3176
|
a[n[h]] = r[h];
|
|
3176
3177
|
const i = `vec2(${a.slice(-2).join()})`, u = `++${r[this.rank - 1]} < ${o[this.rank - 1]}`, c = `getChannel(getA(${a.join()}), ${i})`;
|
|
@@ -3348,7 +3349,7 @@ return a / b;`, Hn = `
|
|
|
3348
3349
|
}
|
|
3349
3350
|
|
|
3350
3351
|
return result;
|
|
3351
|
-
`, Kn =
|
|
3352
|
+
`, Kn = Qt({ opSnippet: Bn, packedOpSnippet: Hn, checkOutOfBounds: !0 });
|
|
3352
3353
|
class Xn {
|
|
3353
3354
|
variableNames = ["logits", "maxLogits"];
|
|
3354
3355
|
outputShape;
|
|
@@ -3368,7 +3369,7 @@ class Xn {
|
|
|
3368
3369
|
`;
|
|
3369
3370
|
}
|
|
3370
3371
|
}
|
|
3371
|
-
class
|
|
3372
|
+
class Yn {
|
|
3372
3373
|
variableNames = ["exp", "sum"];
|
|
3373
3374
|
outputShape;
|
|
3374
3375
|
userCode;
|
|
@@ -3395,7 +3396,7 @@ class Qn {
|
|
|
3395
3396
|
`;
|
|
3396
3397
|
}
|
|
3397
3398
|
}
|
|
3398
|
-
function
|
|
3399
|
+
function Jn(t) {
|
|
3399
3400
|
const { inputs: e, attrs: n } = t, { logits: o } = e, { dim: s, dropoutRate: r, seed: a } = n, i = t.backend;
|
|
3400
3401
|
if (!o)
|
|
3401
3402
|
throw new Error("Error in softmax: input logits is null");
|
|
@@ -3403,23 +3404,25 @@ function Yn(t) {
|
|
|
3403
3404
|
inputs: { x: o },
|
|
3404
3405
|
backend: i,
|
|
3405
3406
|
attrs: { reductionIndices: u, keepDims: !1 }
|
|
3406
|
-
}), h = wt(c.shape, u), f = new Xn(o.shape), w = i.runWebGLProgram(f, [o, c], "float32")
|
|
3407
|
+
}), h = wt(c.shape, u), f = new Xn(o.shape), w = i.runWebGLProgram(f, [o, c], "float32");
|
|
3408
|
+
i.disposeIntermediateTensorInfo(c);
|
|
3409
|
+
const p = Zn({ inputs: { x: w }, backend: i, attrs: { axis: u, keepDims: !1 } }), m = st({ inputs: { x: p }, backend: i, attrs: { shape: h } });
|
|
3407
3410
|
if (r !== void 0 && r > 0) {
|
|
3408
|
-
const d = new
|
|
3411
|
+
const d = new Yn(o.shape), g = i.runWebGLProgram(d, [w, m], "float32", [
|
|
3409
3412
|
[r],
|
|
3410
3413
|
[a ?? Math.random() * 1e4]
|
|
3411
3414
|
]);
|
|
3412
|
-
return i.disposeIntermediateTensorInfo(
|
|
3415
|
+
return i.disposeIntermediateTensorInfo(w), i.disposeIntermediateTensorInfo(p), i.disposeIntermediateTensorInfo(m), g;
|
|
3413
3416
|
}
|
|
3414
3417
|
const b = Kn({ inputs: { a: w, b: m }, backend: i });
|
|
3415
|
-
return i.disposeIntermediateTensorInfo(
|
|
3418
|
+
return i.disposeIntermediateTensorInfo(w), i.disposeIntermediateTensorInfo(p), i.disposeIntermediateTensorInfo(m), b;
|
|
3416
3419
|
}
|
|
3417
|
-
const
|
|
3420
|
+
const Qn = {
|
|
3418
3421
|
kernelName: "FusedSoftmax",
|
|
3419
3422
|
backendName: "webgl",
|
|
3420
|
-
kernelFunc:
|
|
3423
|
+
kernelFunc: Jn
|
|
3421
3424
|
};
|
|
3422
|
-
Xt(
|
|
3425
|
+
Xt(Qn);
|
|
3423
3426
|
export {
|
|
3424
|
-
|
|
3427
|
+
Jn as softmax
|
|
3425
3428
|
};
|
package/dist/ops/webgl/gelu.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { r as a } from "../../index
|
|
2
|
-
import { u as s, C as x } from "../../kernel_funcs_utils-
|
|
1
|
+
import { r as a } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
import { u as s, C as x } from "../../kernel_funcs_utils-C4eIk4fE.js";
|
|
3
3
|
const t = 0.7978845608028654, r = 0.044715, c = x + `
|
|
4
4
|
float x3 = x * x * x;
|
|
5
5
|
float inner = x + ${r} * x3;
|
|
@@ -7,9 +7,10 @@ type BatchMatMulConfig = {
|
|
|
7
7
|
transposeA: boolean;
|
|
8
8
|
transposeB: boolean;
|
|
9
9
|
backend: MathBackendWebGL;
|
|
10
|
-
activationSnippet
|
|
10
|
+
activationSnippet?: string;
|
|
11
|
+
multiplier?: TensorInfo;
|
|
11
12
|
};
|
|
12
|
-
export declare function batchMatMulGeluImpl({ a, b, transposeA, transposeB, backend, activationSnippet, }: BatchMatMulConfig): TensorInfo;
|
|
13
|
+
export declare function batchMatMulGeluImpl({ a, b, transposeA, transposeB, backend, activationSnippet, multiplier, }: BatchMatMulConfig): TensorInfo;
|
|
13
14
|
export declare function batchMatMulKernel(args: {
|
|
14
15
|
inputs: {
|
|
15
16
|
x: TensorInfo;
|