@genai-fi/nanogpt 0.5.1 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +90 -41
- package/dist/NanoGPTModel.d.ts +1 -0
- package/dist/NanoGPTModel.js +86 -73
- package/dist/{Reshape-BE5rA4rT.js → Reshape-Bt_t7RNz.js} +4 -4
- package/dist/TeachableLLM.js +1 -1
- package/dist/TiedEmbedding-DORsPlNL.js +44 -0
- package/dist/{axis_util-97KkkyRQ.js → axis_util-CVbf1vmL.js} +3 -3
- package/dist/{broadcast_to-CMlkG8NS.js → broadcast_to-BBoMQXbL.js} +4 -4
- package/dist/{concat-Cxbo2sOz.js → concat-BRRtq4S2.js} +1 -1
- package/dist/dataset-ZHEPJmED.js +1226 -0
- package/dist/{dropout-kbDY39Ci.js → dropout-lQm_YyX3.js} +1 -1
- package/dist/{gather-Bxe1Qip8.js → gather-BWyutxwi.js} +3 -3
- package/dist/{gpgpu_math-C0zyxKFi.js → gpgpu_math-Df7gzJWH.js} +1 -1
- package/dist/{index-iNhkcAEQ.js → index-CnHyhpKc.js} +32 -32
- package/dist/{kernel_funcs_utils-C4eIk4fE.js → kernel_funcs_utils-Dqo82NH4.js} +25 -25
- package/dist/layers/BaseLayer.js +114 -3
- package/dist/layers/CausalSelfAttention.js +29 -28
- package/dist/layers/MLP.js +10 -9
- package/dist/layers/RMSNorm.js +12 -11
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +8 -6
- package/dist/layers/TransformerBlock.js +2 -2
- package/dist/{log_sum_exp-CkumwesB.js → log_sum_exp-CRH7Np9v.js} +12 -12
- package/dist/main.js +1 -1
- package/dist/{mat_mul-D0SifYfJ.js → mat_mul-DeGU1U_C.js} +3 -3
- package/dist/{max-CYaAjEEp.js → max-CcnEArWK.js} +3 -3
- package/dist/{moments-B06NlR_V.js → moments-DLTE6-1p.js} +4 -4
- package/dist/{norm-D3676xIo.js → norm-BpWsOapl.js} +5 -5
- package/dist/{ones-BIeFnPHR.js → ones-CDWGzVnm.js} +6 -6
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +27 -27
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +36 -36
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.js +22 -22
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{ops-ObfXLHYQ.js → ops-DzQTmLIl.js} +60 -60
- package/dist/{TiedEmbedding-DsDRvLB0.js → random_width-DI2h9CMs.js} +1215 -1250
- package/dist/{range-BsFU-SNG.js → range-CkOJ7090.js} +1 -1
- package/dist/{reshape-DxTPgnwL.js → reshape-CTIbqjwm.js} +1 -1
- package/dist/{sin-BOX-JVAj.js → sin-HzioENy_.js} +5 -5
- package/dist/{slice_util-D-kaD4ZV.js → slice_util-n4wHKmex.js} +1 -1
- package/dist/{softmax-BjsptB07.js → softmax-DX6qXAbm.js} +2 -2
- package/dist/{split-BCbrzthj.js → split-CVwhL8Oe.js} +3 -3
- package/dist/{stack--cqr9Dgc.js → stack-S2-D2JAQ.js} +1 -1
- package/dist/{sum-B_92TaHD.js → sum-UdfvaNhB.js} +4 -4
- package/dist/{tensor-CfiPXsW4.js → tensor-IZex6Bwp.js} +1 -1
- package/dist/{tensor2d-tSxWdFMH.js → tensor2d-CqtBzOKq.js} +1 -1
- package/dist/{tfjs_backend-NucKez4s.js → tfjs_backend-DX9yVvwk.js} +41 -41
- package/dist/tokeniser/CharTokeniser.js +27 -27
- package/dist/tokeniser/bpe.d.ts +1 -0
- package/dist/tokeniser/bpe.js +38 -35
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +22 -1242
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +5 -5
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/weights.js +2 -2
- package/dist/variable-BGvK-VN3.js +23 -0
- package/dist/{zeros-NMYTayy7.js → zeros-CYMicyqz.js} +3 -3
- package/package.json +1 -1
- package/dist/BaseLayer-BhrMN8JO.js +0 -135
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
import { a$ as qt,
|
|
2
|
-
import { f as $t, a as
|
|
3
|
-
import { c as mt, g as Tt, a as Vt, b as Mt, e as wt } from "../../axis_util-
|
|
4
|
-
import { b as te } from "../../broadcast_to-
|
|
5
|
-
import { r as ee } from "../../reshape-
|
|
6
|
-
import { i as ne, c as oe } from "../../slice_util-
|
|
1
|
+
import { a$ as qt, q as y, Q as Ct, ad as J, b0 as nt, b1 as It, b2 as pt, b3 as ft, b4 as kt, Z as ot, al as B, b5 as q, b6 as Ut, U as Gt, l as jt, t as Zt, b7 as dt, ao as ct, b8 as tt, a1 as ut, b9 as Bt, N as Ht, ba as Kt, r as Xt } from "../../index-CnHyhpKc.js";
|
|
2
|
+
import { f as $t, a as Qt, g as yt, b as Yt, c as Jt } from "../../kernel_funcs_utils-Dqo82NH4.js";
|
|
3
|
+
import { c as mt, g as Tt, a as Vt, b as Mt, e as wt } from "../../axis_util-CVbf1vmL.js";
|
|
4
|
+
import { b as te } from "../../broadcast_to-BBoMQXbL.js";
|
|
5
|
+
import { r as ee } from "../../reshape-CTIbqjwm.js";
|
|
6
|
+
import { i as ne, c as oe } from "../../slice_util-n4wHKmex.js";
|
|
7
7
|
import { g as se } from "../../_commonjsHelpers-ByX85dGu.js";
|
|
8
|
-
import { r as st } from "../../Reshape-
|
|
8
|
+
import { r as st } from "../../Reshape-Bt_t7RNz.js";
|
|
9
9
|
function re(t, e) {
|
|
10
10
|
for (var n = 0; n < e.length; n++) {
|
|
11
11
|
const o = e[n];
|
|
@@ -655,20 +655,20 @@ const _t = /* @__PURE__ */ se(Lt), ae = /* @__PURE__ */ re({
|
|
|
655
655
|
* limitations under the License.
|
|
656
656
|
* =============================================================================
|
|
657
657
|
*/
|
|
658
|
-
const
|
|
658
|
+
const Y = (
|
|
659
659
|
// tslint:disable-next-line
|
|
660
660
|
_t || ae
|
|
661
661
|
);
|
|
662
662
|
function lt(t) {
|
|
663
|
-
return
|
|
663
|
+
return Y.fromString(t, !0, 16);
|
|
664
664
|
}
|
|
665
|
-
const Nt = lt("c3a5c85c97cb3127"),
|
|
665
|
+
const Nt = lt("c3a5c85c97cb3127"), Q = lt("b492b66fbe98f273"), W = lt("9ae16a3b2f90404f");
|
|
666
666
|
function gt(t) {
|
|
667
667
|
return t.xor(t.shru(47));
|
|
668
668
|
}
|
|
669
669
|
function At(t, e, n) {
|
|
670
670
|
const o = t.slice(e, e + n);
|
|
671
|
-
return
|
|
671
|
+
return Y.fromBytes(Array.from(o), !0, !0);
|
|
672
672
|
}
|
|
673
673
|
function R(t, e) {
|
|
674
674
|
return At(t, e, 8);
|
|
@@ -709,7 +709,7 @@ function le(t, e = t.length) {
|
|
|
709
709
|
return W;
|
|
710
710
|
}
|
|
711
711
|
function ce(t, e = t.length) {
|
|
712
|
-
const n = W.add(e * 2), o = R(t, 0).mul(
|
|
712
|
+
const n = W.add(e * 2), o = R(t, 0).mul(Q), s = R(t, 8), r = R(t, e - 8).mul(n), a = R(t, e - 16).mul(W);
|
|
713
713
|
return X(A(o.add(s), 43).add(A(r, 30)).add(a), o.add(A(s.add(W), 18)).add(r), n);
|
|
714
714
|
}
|
|
715
715
|
function he(t, e = t.length) {
|
|
@@ -717,19 +717,19 @@ function he(t, e = t.length) {
|
|
|
717
717
|
return X(A(c.add(h), 43).add(A(f, 30)).add(w), c.add(A(h.add(o), 18)).add(f), n);
|
|
718
718
|
}
|
|
719
719
|
function fe(t, e = t.length) {
|
|
720
|
-
const n =
|
|
720
|
+
const n = Y.fromNumber(81, !0);
|
|
721
721
|
if (e <= 32)
|
|
722
722
|
return e <= 16 ? le(t, e) : ce(t, e);
|
|
723
723
|
if (e <= 64)
|
|
724
724
|
return he(t, e);
|
|
725
|
-
let o = n, s = n.mul(
|
|
725
|
+
let o = n, s = n.mul(Q).add(113), r = gt(s.mul(W).add(113)).mul(W), a = [Y.UZERO, Y.UZERO], i = [Y.UZERO, Y.UZERO];
|
|
726
726
|
o = o.mul(W).add(R(t, 0));
|
|
727
727
|
let u = 0;
|
|
728
728
|
const c = (e - 1 >> 6) * 64, h = c + (e - 1 & 63) - 63;
|
|
729
729
|
do
|
|
730
|
-
o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(
|
|
730
|
+
o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(Q), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(Q), o = o.xor(i[1]), s = s.add(a[0]).add(R(t, u + 40)), r = A(r.add(i[0]), 33).mul(Q), a = it(t, u, a[1].mul(Q), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], u += 64;
|
|
731
731
|
while (u !== c);
|
|
732
|
-
const f =
|
|
732
|
+
const f = Q.add(r.and(255).shl(1));
|
|
733
733
|
return u = h, i[0] = i[0].add(e - 1 & 63), a[0] = a[0].add(i[0]), i[0] = i[0].add(a[0]), o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(f), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(f), o = o.xor(i[1].mul(9)), s = s.add(a[0].mul(9).add(R(t, u + 40))), r = A(r.add(i[0]), 33).mul(f), a = it(t, u, a[1].mul(f), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], X(X(a[0], i[0], f).add(gt(s).mul(Nt)).add(r), X(a[1], i[1], f).add(o), f);
|
|
734
734
|
}
|
|
735
735
|
/**
|
|
@@ -955,7 +955,7 @@ function Ve(t) {
|
|
|
955
955
|
*/
|
|
956
956
|
function C(t) {
|
|
957
957
|
return (e, n, o, s, r) => {
|
|
958
|
-
const a = Ct(e, n), i = a.length, u =
|
|
958
|
+
const a = Ct(e, n), i = a.length, u = J(a), c = y(a), h = nt(r, c), f = e.length, w = n.length, p = J(e), m = J(n), b = It(e, a), d = It(n, a);
|
|
959
959
|
if (b.length + d.length === 0)
|
|
960
960
|
for (let g = 0; g < h.length; ++g)
|
|
961
961
|
h[g] = t(o[g % o.length], s[g % s.length]);
|
|
@@ -1416,7 +1416,7 @@ const Xe = K((t) => Math.log(t));
|
|
|
1416
1416
|
* limitations under the License.
|
|
1417
1417
|
* =============================================================================
|
|
1418
1418
|
*/
|
|
1419
|
-
function
|
|
1419
|
+
function Qe(t, e, n, o) {
|
|
1420
1420
|
const s = nt(o, y(n));
|
|
1421
1421
|
for (let r = 0; r < s.length; ++r) {
|
|
1422
1422
|
const a = r * e;
|
|
@@ -1445,7 +1445,7 @@ function Ye(t, e, n, o) {
|
|
|
1445
1445
|
* limitations under the License.
|
|
1446
1446
|
* =============================================================================
|
|
1447
1447
|
*/
|
|
1448
|
-
const
|
|
1448
|
+
const Ye = C((t, e) => Math.max(t, e));
|
|
1449
1449
|
/**
|
|
1450
1450
|
* @license
|
|
1451
1451
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1462,7 +1462,7 @@ const Je = C((t, e) => Math.max(t, e));
|
|
|
1462
1462
|
* limitations under the License.
|
|
1463
1463
|
* =============================================================================
|
|
1464
1464
|
*/
|
|
1465
|
-
const
|
|
1465
|
+
const Je = C((t, e) => Math.min(t, e));
|
|
1466
1466
|
/**
|
|
1467
1467
|
* @license
|
|
1468
1468
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1534,7 +1534,7 @@ const en = C((t, e) => t !== e ? 1 : 0);
|
|
|
1534
1534
|
* =============================================================================
|
|
1535
1535
|
*/
|
|
1536
1536
|
function nn(t, e, n, o, s) {
|
|
1537
|
-
const r = e.length, a = y(e), i =
|
|
1537
|
+
const r = e.length, a = y(e), i = J(e), u = J(s), c = nt(n, y(s));
|
|
1538
1538
|
for (let h = 0; h < a; ++h) {
|
|
1539
1539
|
const f = pt(h, r, i), w = new Array(f.length);
|
|
1540
1540
|
for (let m = 0; m < w.length; m++)
|
|
@@ -1590,7 +1590,7 @@ function on(t, e, n, o) {
|
|
|
1590
1590
|
function sn(t, e, n) {
|
|
1591
1591
|
t.forEach((o, s) => {
|
|
1592
1592
|
if (o < 0 || o >= n) {
|
|
1593
|
-
const r = pt(s, e.length,
|
|
1593
|
+
const r = pt(s, e.length, J(e)).join(",");
|
|
1594
1594
|
throw new Error(`indices[${r}] = ${o} is not in [0, ${n})`);
|
|
1595
1595
|
}
|
|
1596
1596
|
});
|
|
@@ -2110,7 +2110,7 @@ const wn = K((t) => 1 / (1 + Math.exp(-t)));
|
|
|
2110
2110
|
* =============================================================================
|
|
2111
2111
|
*/
|
|
2112
2112
|
function In(t, e, n, o, s) {
|
|
2113
|
-
const r = ne(o, e, n), a = y(n), i =
|
|
2113
|
+
const r = ne(o, e, n), a = y(n), i = J(o);
|
|
2114
2114
|
if (r) {
|
|
2115
2115
|
const f = oe(e, i);
|
|
2116
2116
|
return s === "string" ? t.slice(f, f + a) : t.subarray(f, f + a);
|
|
@@ -2120,7 +2120,7 @@ function In(t, e, n, o, s) {
|
|
|
2120
2120
|
const w = h.indexToLoc(f), p = w.map((m, b) => m + e[b]);
|
|
2121
2121
|
h.set(c.get(...p), ...w);
|
|
2122
2122
|
}
|
|
2123
|
-
return s === "string" ?
|
|
2123
|
+
return s === "string" ? Qt(h.values) : h.values;
|
|
2124
2124
|
}
|
|
2125
2125
|
/**
|
|
2126
2126
|
* @license
|
|
@@ -2791,9 +2791,9 @@ const An = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
2791
2791
|
lessImpl: Be,
|
|
2792
2792
|
linSpaceImpl: Ke,
|
|
2793
2793
|
logImpl: Xe,
|
|
2794
|
-
maxImpl:
|
|
2795
|
-
maximumImpl:
|
|
2796
|
-
minimumImpl:
|
|
2794
|
+
maxImpl: Qe,
|
|
2795
|
+
maximumImpl: Ye,
|
|
2796
|
+
minimumImpl: Je,
|
|
2797
2797
|
multiplyImpl: Pt,
|
|
2798
2798
|
negImpl: tn,
|
|
2799
2799
|
notEqualImpl: en,
|
|
@@ -3171,7 +3171,7 @@ class Un {
|
|
|
3171
3171
|
o[h] = e[n[h]];
|
|
3172
3172
|
if (this.outputShape = o, this.rank = o.length, this.rank > 6)
|
|
3173
3173
|
throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
|
|
3174
|
-
const s = yt(this.rank), r =
|
|
3174
|
+
const s = yt(this.rank), r = Yt("rc", this.rank), a = new Array(this.rank);
|
|
3175
3175
|
for (let h = 0; h < n.length; h++)
|
|
3176
3176
|
a[n[h]] = r[h];
|
|
3177
3177
|
const i = `vec2(${a.slice(-2).join()})`, u = `++${r[this.rank - 1]} < ${o[this.rank - 1]}`, c = `getChannel(getA(${a.join()}), ${i})`;
|
|
@@ -3349,7 +3349,7 @@ return a / b;`, Hn = `
|
|
|
3349
3349
|
}
|
|
3350
3350
|
|
|
3351
3351
|
return result;
|
|
3352
|
-
`, Kn =
|
|
3352
|
+
`, Kn = Jt({ opSnippet: Bn, packedOpSnippet: Hn, checkOutOfBounds: !0 });
|
|
3353
3353
|
class Xn {
|
|
3354
3354
|
variableNames = ["logits", "maxLogits"];
|
|
3355
3355
|
outputShape;
|
|
@@ -3369,7 +3369,7 @@ class Xn {
|
|
|
3369
3369
|
`;
|
|
3370
3370
|
}
|
|
3371
3371
|
}
|
|
3372
|
-
class
|
|
3372
|
+
class Qn {
|
|
3373
3373
|
variableNames = ["exp", "sum"];
|
|
3374
3374
|
outputShape;
|
|
3375
3375
|
userCode;
|
|
@@ -3396,7 +3396,7 @@ class Yn {
|
|
|
3396
3396
|
`;
|
|
3397
3397
|
}
|
|
3398
3398
|
}
|
|
3399
|
-
function
|
|
3399
|
+
function Yn(t) {
|
|
3400
3400
|
const { inputs: e, attrs: n } = t, { logits: o } = e, { dim: s, dropoutRate: r, seed: a } = n, i = t.backend;
|
|
3401
3401
|
if (!o)
|
|
3402
3402
|
throw new Error("Error in softmax: input logits is null");
|
|
@@ -3408,7 +3408,7 @@ function Jn(t) {
|
|
|
3408
3408
|
i.disposeIntermediateTensorInfo(c);
|
|
3409
3409
|
const p = Zn({ inputs: { x: w }, backend: i, attrs: { axis: u, keepDims: !1 } }), m = st({ inputs: { x: p }, backend: i, attrs: { shape: h } });
|
|
3410
3410
|
if (r !== void 0 && r > 0) {
|
|
3411
|
-
const d = new
|
|
3411
|
+
const d = new Qn(o.shape), g = i.runWebGLProgram(d, [w, m], "float32", [
|
|
3412
3412
|
[r],
|
|
3413
3413
|
[a ?? Math.random() * 1e4]
|
|
3414
3414
|
]);
|
|
@@ -3417,12 +3417,12 @@ function Jn(t) {
|
|
|
3417
3417
|
const b = Kn({ inputs: { a: w, b: m }, backend: i });
|
|
3418
3418
|
return i.disposeIntermediateTensorInfo(w), i.disposeIntermediateTensorInfo(p), i.disposeIntermediateTensorInfo(m), b;
|
|
3419
3419
|
}
|
|
3420
|
-
const
|
|
3420
|
+
const Jn = {
|
|
3421
3421
|
kernelName: "FusedSoftmax",
|
|
3422
3422
|
backendName: "webgl",
|
|
3423
|
-
kernelFunc:
|
|
3423
|
+
kernelFunc: Yn
|
|
3424
3424
|
};
|
|
3425
|
-
Xt(
|
|
3425
|
+
Xt(Jn);
|
|
3426
3426
|
export {
|
|
3427
|
-
|
|
3427
|
+
Yn as softmax
|
|
3428
3428
|
};
|
package/dist/ops/webgl/gelu.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { r as a } from "../../index-
|
|
2
|
-
import { u as s, C as x } from "../../kernel_funcs_utils-
|
|
1
|
+
import { r as a } from "../../index-CnHyhpKc.js";
|
|
2
|
+
import { u as s, C as x } from "../../kernel_funcs_utils-Dqo82NH4.js";
|
|
3
3
|
const t = 0.7978845608028654, r = 0.044715, c = x + `
|
|
4
4
|
float x3 = x * x * x;
|
|
5
5
|
float inner = x + ${r} * x3;
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { r as C, t as R, e as I,
|
|
2
|
-
import { r as S } from "../../Reshape-
|
|
3
|
-
import { u as H } from "../../gpgpu_math-
|
|
4
|
-
import { m as B } from "../../mat_mul-
|
|
1
|
+
import { r as C, t as R, e as I, q as G, Q as L, l as U, U as F } from "../../index-CnHyhpKc.js";
|
|
2
|
+
import { r as S } from "../../Reshape-Bt_t7RNz.js";
|
|
3
|
+
import { u as H } from "../../gpgpu_math-Df7gzJWH.js";
|
|
4
|
+
import { m as B } from "../../mat_mul-DeGU1U_C.js";
|
|
5
5
|
/**
|
|
6
6
|
* @license
|
|
7
7
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -69,7 +69,7 @@ class W {
|
|
|
69
69
|
`;
|
|
70
70
|
}
|
|
71
71
|
}
|
|
72
|
-
const g = 0.7978845608028654, w = 0.044715,
|
|
72
|
+
const g = 0.7978845608028654, w = 0.044715, q = `
|
|
73
73
|
vec4 x3 = x * x * x;
|
|
74
74
|
vec4 inner = x + ${w} * x3;
|
|
75
75
|
inner = ${g} * inner;
|
|
@@ -77,7 +77,7 @@ const g = 0.7978845608028654, w = 0.044715, j = `
|
|
|
77
77
|
inner = 0.5 * (1.0 + inner);
|
|
78
78
|
vec4 result = x * inner;
|
|
79
79
|
return result;
|
|
80
|
-
`,
|
|
80
|
+
`, j = `
|
|
81
81
|
vec4 a2 = a * a;
|
|
82
82
|
vec4 a3 = a2 * a;
|
|
83
83
|
vec4 u = ${g} * (a + ${w} * a3);
|
|
@@ -97,29 +97,29 @@ function O({
|
|
|
97
97
|
multiplier: o
|
|
98
98
|
}) {
|
|
99
99
|
const r = t.shape.length, u = e.shape.length, l = s ? t.shape[r - 2] : t.shape[r - 1], h = n ? e.shape[u - 1] : e.shape[u - 2], p = s ? t.shape[r - 1] : t.shape[r - 2], d = n ? e.shape[u - 2] : e.shape[u - 1], $ = t.shape.slice(0, -2), x = e.shape.slice(0, -2), m = G($), i = G(x), M = L(t.shape.slice(0, -2), e.shape.slice(0, -2)).concat([p, d]);
|
|
100
|
-
|
|
100
|
+
U(
|
|
101
101
|
l === h,
|
|
102
102
|
() => `Error in matMul: inner shapes (${l}) and (${h}) of Tensors with shapes ${t.shape} and ${e.shape} and transposeA=${s} and transposeB=${n} must match.`
|
|
103
103
|
);
|
|
104
|
-
const f = s ? [m, l, p] : [m, p, l], v = n ? [i, d, h] : [i, h, d], A = S({ inputs: { x: t }, backend: a, attrs: { shape: f } }), y = S({ inputs: { x: e }, backend: a, attrs: { shape: v } }),
|
|
104
|
+
const f = s ? [m, l, p] : [m, p, l], v = n ? [i, d, h] : [i, h, d], A = S({ inputs: { x: t }, backend: a, attrs: { shape: f } }), y = S({ inputs: { x: e }, backend: a, attrs: { shape: v } }), D = [A, y], E = Math.max(m, i), N = c, T = F(t.dtype, e.dtype), _ = new W(
|
|
105
105
|
f,
|
|
106
106
|
v,
|
|
107
|
-
[
|
|
107
|
+
[E, p, d],
|
|
108
108
|
s,
|
|
109
109
|
n,
|
|
110
110
|
!1,
|
|
111
|
-
|
|
111
|
+
N,
|
|
112
112
|
!!o,
|
|
113
113
|
!1
|
|
114
|
-
),
|
|
115
|
-
o &&
|
|
116
|
-
const z = a.runWebGLProgram(_,
|
|
117
|
-
|
|
118
|
-
for (const P of
|
|
114
|
+
), k = [A, y];
|
|
115
|
+
o && k.push(o);
|
|
116
|
+
const z = a.runWebGLProgram(_, k, T), K = S({ inputs: { x: z }, backend: a, attrs: { shape: M } });
|
|
117
|
+
D.push(z);
|
|
118
|
+
for (const P of D)
|
|
119
119
|
a.disposeIntermediateTensorInfo(P);
|
|
120
120
|
return K;
|
|
121
121
|
}
|
|
122
|
-
function
|
|
122
|
+
function Q(t) {
|
|
123
123
|
const { inputs: e, backend: s } = t, { x: n, kernel: a } = e;
|
|
124
124
|
if (n === void 0 || a === void 0)
|
|
125
125
|
throw new Error("BatchMatMul requires two input tensors.");
|
|
@@ -129,15 +129,15 @@ function J(t) {
|
|
|
129
129
|
transposeA: !1,
|
|
130
130
|
transposeB: !1,
|
|
131
131
|
backend: s,
|
|
132
|
-
activationSnippet:
|
|
132
|
+
activationSnippet: q
|
|
133
133
|
});
|
|
134
134
|
}
|
|
135
|
-
const
|
|
135
|
+
const J = {
|
|
136
136
|
kernelName: "MatMulGelu",
|
|
137
137
|
backendName: "webgl",
|
|
138
|
-
kernelFunc:
|
|
138
|
+
kernelFunc: Q
|
|
139
139
|
};
|
|
140
|
-
C(
|
|
140
|
+
C(J);
|
|
141
141
|
function V(t) {
|
|
142
142
|
const { dy: e, x: s, kernel: n } = t.inputs, a = t.backend;
|
|
143
143
|
return R(() => {
|
|
@@ -148,7 +148,7 @@ function V(t) {
|
|
|
148
148
|
transposeA: !1,
|
|
149
149
|
transposeB: !1,
|
|
150
150
|
backend: a,
|
|
151
|
-
activationSnippet:
|
|
151
|
+
activationSnippet: j,
|
|
152
152
|
multiplier: e
|
|
153
153
|
})
|
|
154
154
|
), o = B(c, n, !1, !0), r = B(s, c, !0, !1);
|
|
@@ -164,5 +164,5 @@ C(X);
|
|
|
164
164
|
export {
|
|
165
165
|
se as MATMUL_SHARED_DIM_THRESHOLD,
|
|
166
166
|
O as batchMatMulGeluImpl,
|
|
167
|
-
|
|
167
|
+
Q as batchMatMulKernel
|
|
168
168
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { r as p, e as G } from "../../index-
|
|
2
|
-
import { s as x } from "../../sum-
|
|
1
|
+
import { r as p, e as G } from "../../index-CnHyhpKc.js";
|
|
2
|
+
import { s as x } from "../../sum-UdfvaNhB.js";
|
|
3
3
|
class y {
|
|
4
4
|
variableNames = ["x", "meanSquare", "gamma"];
|
|
5
5
|
outputShape;
|
package/dist/ops/webgl/qkv.js
CHANGED
package/dist/ops/webgl/rope.js
CHANGED