@genai-fi/nanogpt 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/dist/Generator.js +95 -46
  2. package/dist/NanoGPTModel.d.ts +3 -2
  3. package/dist/NanoGPTModel.js +91 -76
  4. package/dist/{Reshape-BE5rA4rT.js → Reshape-Bt_t7RNz.js} +4 -4
  5. package/dist/TeachableLLM.js +1 -1
  6. package/dist/TiedEmbedding-DORsPlNL.js +44 -0
  7. package/dist/{axis_util-97KkkyRQ.js → axis_util-CVbf1vmL.js} +3 -3
  8. package/dist/{broadcast_to-CMlkG8NS.js → broadcast_to-BBoMQXbL.js} +4 -4
  9. package/dist/{concat-Cxbo2sOz.js → concat-BRRtq4S2.js} +1 -1
  10. package/dist/dataset-ZHEPJmED.js +1226 -0
  11. package/dist/{dropout-kbDY39Ci.js → dropout-lQm_YyX3.js} +1 -1
  12. package/dist/{gather-Bxe1Qip8.js → gather-BWyutxwi.js} +3 -3
  13. package/dist/{gpgpu_math-C0zyxKFi.js → gpgpu_math-Df7gzJWH.js} +1 -1
  14. package/dist/{index-iNhkcAEQ.js → index-CnHyhpKc.js} +32 -32
  15. package/dist/{kernel_funcs_utils-C4eIk4fE.js → kernel_funcs_utils-Dqo82NH4.js} +25 -25
  16. package/dist/layers/BaseLayer.js +114 -3
  17. package/dist/layers/CausalSelfAttention.d.ts +2 -3
  18. package/dist/layers/CausalSelfAttention.js +31 -30
  19. package/dist/layers/MLP.js +10 -9
  20. package/dist/layers/RMSNorm.js +12 -11
  21. package/dist/layers/RoPECache.js +3 -3
  22. package/dist/layers/TiedEmbedding.js +8 -6
  23. package/dist/layers/TransformerBlock.js +2 -2
  24. package/dist/{log_sum_exp-CkumwesB.js → log_sum_exp-CRH7Np9v.js} +12 -12
  25. package/dist/main.js +1 -1
  26. package/dist/{mat_mul-D0SifYfJ.js → mat_mul-DeGU1U_C.js} +3 -3
  27. package/dist/{max-CYaAjEEp.js → max-CcnEArWK.js} +3 -3
  28. package/dist/{moments-B06NlR_V.js → moments-DLTE6-1p.js} +4 -4
  29. package/dist/{norm-D3676xIo.js → norm-BpWsOapl.js} +5 -5
  30. package/dist/{ones-BIeFnPHR.js → ones-CDWGzVnm.js} +6 -6
  31. package/dist/ops/appendCache.js +3 -3
  32. package/dist/ops/attentionMask.js +1 -1
  33. package/dist/ops/cpu/appendCache.js +2 -2
  34. package/dist/ops/cpu/attentionMask.js +5 -5
  35. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  36. package/dist/ops/cpu/gatherSub.js +5 -5
  37. package/dist/ops/cpu/gelu.js +1 -1
  38. package/dist/ops/cpu/matMulGelu.js +1 -1
  39. package/dist/ops/cpu/matMulMul.js +1 -1
  40. package/dist/ops/cpu/mulDropout.js +1 -1
  41. package/dist/ops/cpu/normRMS.js +1 -1
  42. package/dist/ops/cpu/qkv.js +3 -3
  43. package/dist/ops/cpu/rope.js +5 -5
  44. package/dist/ops/cpu/scatterSub.js +27 -27
  45. package/dist/ops/fusedSoftmax.js +1 -1
  46. package/dist/ops/gatherSub.js +1 -1
  47. package/dist/ops/gelu.js +1 -1
  48. package/dist/ops/grads/attentionMask.js +1 -1
  49. package/dist/ops/grads/fusedSoftmax.js +2 -2
  50. package/dist/ops/grads/gelu.js +1 -1
  51. package/dist/ops/grads/matMulGelu.js +1 -1
  52. package/dist/ops/grads/normRMS.js +1 -1
  53. package/dist/ops/grads/qkv.js +1 -1
  54. package/dist/ops/grads/rope.js +1 -1
  55. package/dist/ops/matMulGelu.js +1 -1
  56. package/dist/ops/matMulMul.js +1 -1
  57. package/dist/ops/mulDrop.js +1 -1
  58. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  59. package/dist/ops/normRMS.js +1 -1
  60. package/dist/ops/qkv.js +1 -1
  61. package/dist/ops/scatterSub.js +1 -1
  62. package/dist/ops/webgl/appendCache.js +1 -1
  63. package/dist/ops/webgl/attentionMask.js +1 -1
  64. package/dist/ops/webgl/fusedSoftmax.js +36 -36
  65. package/dist/ops/webgl/gatherSub.js +1 -1
  66. package/dist/ops/webgl/gelu.js +2 -2
  67. package/dist/ops/webgl/matMulGelu.js +22 -22
  68. package/dist/ops/webgl/matMulMul.js +1 -1
  69. package/dist/ops/webgl/mulDropout.js +1 -1
  70. package/dist/ops/webgl/normRMS.js +2 -2
  71. package/dist/ops/webgl/qkv.js +1 -1
  72. package/dist/ops/webgl/rope.js +1 -1
  73. package/dist/ops/webgl/scatterSub.js +1 -1
  74. package/dist/{ops-ObfXLHYQ.js → ops-DzQTmLIl.js} +60 -60
  75. package/dist/{TiedEmbedding-DsDRvLB0.js → random_width-DI2h9CMs.js} +1215 -1250
  76. package/dist/{range-BsFU-SNG.js → range-CkOJ7090.js} +1 -1
  77. package/dist/{reshape-DxTPgnwL.js → reshape-CTIbqjwm.js} +1 -1
  78. package/dist/{sin-BOX-JVAj.js → sin-HzioENy_.js} +5 -5
  79. package/dist/{slice_util-D-kaD4ZV.js → slice_util-n4wHKmex.js} +1 -1
  80. package/dist/{softmax-BjsptB07.js → softmax-DX6qXAbm.js} +2 -2
  81. package/dist/{split-BCbrzthj.js → split-CVwhL8Oe.js} +3 -3
  82. package/dist/{stack--cqr9Dgc.js → stack-S2-D2JAQ.js} +1 -1
  83. package/dist/{sum-B_92TaHD.js → sum-UdfvaNhB.js} +4 -4
  84. package/dist/{tensor-CfiPXsW4.js → tensor-IZex6Bwp.js} +1 -1
  85. package/dist/{tensor2d-tSxWdFMH.js → tensor2d-CqtBzOKq.js} +1 -1
  86. package/dist/{tfjs_backend-NucKez4s.js → tfjs_backend-DX9yVvwk.js} +41 -41
  87. package/dist/tokeniser/CharTokeniser.js +27 -27
  88. package/dist/tokeniser/bpe.d.ts +1 -0
  89. package/dist/tokeniser/bpe.js +38 -35
  90. package/dist/training/AdamExt.js +1 -1
  91. package/dist/training/DatasetBuilder.js +22 -1242
  92. package/dist/training/FullTrainer.js +1 -1
  93. package/dist/training/Trainer.js +5 -5
  94. package/dist/training/sparseCrossEntropy.js +4 -4
  95. package/dist/utilities/dummy.js +2 -2
  96. package/dist/utilities/generate.js +3 -3
  97. package/dist/utilities/load.js +1 -1
  98. package/dist/utilities/profile.js +1 -1
  99. package/dist/utilities/save.js +5 -5
  100. package/dist/utilities/weights.js +2 -2
  101. package/dist/variable-BGvK-VN3.js +23 -0
  102. package/dist/{zeros-NMYTayy7.js → zeros-CYMicyqz.js} +3 -3
  103. package/package.json +1 -1
  104. package/dist/BaseLayer-BhrMN8JO.js +0 -135
@@ -1,11 +1,11 @@
1
- import { a$ as qt, p as y, N as Ct, ad as Q, b0 as nt, b1 as It, b2 as pt, b3 as ft, b4 as kt, Z as ot, ah as B, b5 as q, b6 as Ut, O as Gt, k as jt, t as Zt, b7 as dt, ao as ct, b8 as tt, a1 as ut, b9 as Bt, K as Ht, ba as Kt, r as Xt } from "../../index-iNhkcAEQ.js";
2
- import { f as $t, a as Yt, g as yt, b as Jt, c as Qt } from "../../kernel_funcs_utils-C4eIk4fE.js";
3
- import { c as mt, g as Tt, a as Vt, b as Mt, e as wt } from "../../axis_util-97KkkyRQ.js";
4
- import { b as te } from "../../broadcast_to-CMlkG8NS.js";
5
- import { r as ee } from "../../reshape-DxTPgnwL.js";
6
- import { i as ne, c as oe } from "../../slice_util-D-kaD4ZV.js";
1
+ import { a$ as qt, q as y, Q as Ct, ad as J, b0 as nt, b1 as It, b2 as pt, b3 as ft, b4 as kt, Z as ot, al as B, b5 as q, b6 as Ut, U as Gt, l as jt, t as Zt, b7 as dt, ao as ct, b8 as tt, a1 as ut, b9 as Bt, N as Ht, ba as Kt, r as Xt } from "../../index-CnHyhpKc.js";
2
+ import { f as $t, a as Qt, g as yt, b as Yt, c as Jt } from "../../kernel_funcs_utils-Dqo82NH4.js";
3
+ import { c as mt, g as Tt, a as Vt, b as Mt, e as wt } from "../../axis_util-CVbf1vmL.js";
4
+ import { b as te } from "../../broadcast_to-BBoMQXbL.js";
5
+ import { r as ee } from "../../reshape-CTIbqjwm.js";
6
+ import { i as ne, c as oe } from "../../slice_util-n4wHKmex.js";
7
7
  import { g as se } from "../../_commonjsHelpers-ByX85dGu.js";
8
- import { r as st } from "../../Reshape-BE5rA4rT.js";
8
+ import { r as st } from "../../Reshape-Bt_t7RNz.js";
9
9
  function re(t, e) {
10
10
  for (var n = 0; n < e.length; n++) {
11
11
  const o = e[n];
@@ -655,20 +655,20 @@ const _t = /* @__PURE__ */ se(Lt), ae = /* @__PURE__ */ re({
655
655
  * limitations under the License.
656
656
  * =============================================================================
657
657
  */
658
- const J = (
658
+ const Y = (
659
659
  // tslint:disable-next-line
660
660
  _t || ae
661
661
  );
662
662
  function lt(t) {
663
- return J.fromString(t, !0, 16);
663
+ return Y.fromString(t, !0, 16);
664
664
  }
665
- const Nt = lt("c3a5c85c97cb3127"), Y = lt("b492b66fbe98f273"), W = lt("9ae16a3b2f90404f");
665
+ const Nt = lt("c3a5c85c97cb3127"), Q = lt("b492b66fbe98f273"), W = lt("9ae16a3b2f90404f");
666
666
  function gt(t) {
667
667
  return t.xor(t.shru(47));
668
668
  }
669
669
  function At(t, e, n) {
670
670
  const o = t.slice(e, e + n);
671
- return J.fromBytes(Array.from(o), !0, !0);
671
+ return Y.fromBytes(Array.from(o), !0, !0);
672
672
  }
673
673
  function R(t, e) {
674
674
  return At(t, e, 8);
@@ -709,7 +709,7 @@ function le(t, e = t.length) {
709
709
  return W;
710
710
  }
711
711
  function ce(t, e = t.length) {
712
- const n = W.add(e * 2), o = R(t, 0).mul(Y), s = R(t, 8), r = R(t, e - 8).mul(n), a = R(t, e - 16).mul(W);
712
+ const n = W.add(e * 2), o = R(t, 0).mul(Q), s = R(t, 8), r = R(t, e - 8).mul(n), a = R(t, e - 16).mul(W);
713
713
  return X(A(o.add(s), 43).add(A(r, 30)).add(a), o.add(A(s.add(W), 18)).add(r), n);
714
714
  }
715
715
  function he(t, e = t.length) {
@@ -717,19 +717,19 @@ function he(t, e = t.length) {
717
717
  return X(A(c.add(h), 43).add(A(f, 30)).add(w), c.add(A(h.add(o), 18)).add(f), n);
718
718
  }
719
719
  function fe(t, e = t.length) {
720
- const n = J.fromNumber(81, !0);
720
+ const n = Y.fromNumber(81, !0);
721
721
  if (e <= 32)
722
722
  return e <= 16 ? le(t, e) : ce(t, e);
723
723
  if (e <= 64)
724
724
  return he(t, e);
725
- let o = n, s = n.mul(Y).add(113), r = gt(s.mul(W).add(113)).mul(W), a = [J.UZERO, J.UZERO], i = [J.UZERO, J.UZERO];
725
+ let o = n, s = n.mul(Q).add(113), r = gt(s.mul(W).add(113)).mul(W), a = [Y.UZERO, Y.UZERO], i = [Y.UZERO, Y.UZERO];
726
726
  o = o.mul(W).add(R(t, 0));
727
727
  let u = 0;
728
728
  const c = (e - 1 >> 6) * 64, h = c + (e - 1 & 63) - 63;
729
729
  do
730
- o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(Y), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(Y), o = o.xor(i[1]), s = s.add(a[0]).add(R(t, u + 40)), r = A(r.add(i[0]), 33).mul(Y), a = it(t, u, a[1].mul(Y), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], u += 64;
730
+ o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(Q), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(Q), o = o.xor(i[1]), s = s.add(a[0]).add(R(t, u + 40)), r = A(r.add(i[0]), 33).mul(Q), a = it(t, u, a[1].mul(Q), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], u += 64;
731
731
  while (u !== c);
732
- const f = Y.add(r.and(255).shl(1));
732
+ const f = Q.add(r.and(255).shl(1));
733
733
  return u = h, i[0] = i[0].add(e - 1 & 63), a[0] = a[0].add(i[0]), i[0] = i[0].add(a[0]), o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(f), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(f), o = o.xor(i[1].mul(9)), s = s.add(a[0].mul(9).add(R(t, u + 40))), r = A(r.add(i[0]), 33).mul(f), a = it(t, u, a[1].mul(f), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], X(X(a[0], i[0], f).add(gt(s).mul(Nt)).add(r), X(a[1], i[1], f).add(o), f);
734
734
  }
735
735
  /**
@@ -955,7 +955,7 @@ function Ve(t) {
955
955
  */
956
956
  function C(t) {
957
957
  return (e, n, o, s, r) => {
958
- const a = Ct(e, n), i = a.length, u = Q(a), c = y(a), h = nt(r, c), f = e.length, w = n.length, p = Q(e), m = Q(n), b = It(e, a), d = It(n, a);
958
+ const a = Ct(e, n), i = a.length, u = J(a), c = y(a), h = nt(r, c), f = e.length, w = n.length, p = J(e), m = J(n), b = It(e, a), d = It(n, a);
959
959
  if (b.length + d.length === 0)
960
960
  for (let g = 0; g < h.length; ++g)
961
961
  h[g] = t(o[g % o.length], s[g % s.length]);
@@ -1416,7 +1416,7 @@ const Xe = K((t) => Math.log(t));
1416
1416
  * limitations under the License.
1417
1417
  * =============================================================================
1418
1418
  */
1419
- function Ye(t, e, n, o) {
1419
+ function Qe(t, e, n, o) {
1420
1420
  const s = nt(o, y(n));
1421
1421
  for (let r = 0; r < s.length; ++r) {
1422
1422
  const a = r * e;
@@ -1445,7 +1445,7 @@ function Ye(t, e, n, o) {
1445
1445
  * limitations under the License.
1446
1446
  * =============================================================================
1447
1447
  */
1448
- const Je = C((t, e) => Math.max(t, e));
1448
+ const Ye = C((t, e) => Math.max(t, e));
1449
1449
  /**
1450
1450
  * @license
1451
1451
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1462,7 +1462,7 @@ const Je = C((t, e) => Math.max(t, e));
1462
1462
  * limitations under the License.
1463
1463
  * =============================================================================
1464
1464
  */
1465
- const Qe = C((t, e) => Math.min(t, e));
1465
+ const Je = C((t, e) => Math.min(t, e));
1466
1466
  /**
1467
1467
  * @license
1468
1468
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1534,7 +1534,7 @@ const en = C((t, e) => t !== e ? 1 : 0);
1534
1534
  * =============================================================================
1535
1535
  */
1536
1536
  function nn(t, e, n, o, s) {
1537
- const r = e.length, a = y(e), i = Q(e), u = Q(s), c = nt(n, y(s));
1537
+ const r = e.length, a = y(e), i = J(e), u = J(s), c = nt(n, y(s));
1538
1538
  for (let h = 0; h < a; ++h) {
1539
1539
  const f = pt(h, r, i), w = new Array(f.length);
1540
1540
  for (let m = 0; m < w.length; m++)
@@ -1590,7 +1590,7 @@ function on(t, e, n, o) {
1590
1590
  function sn(t, e, n) {
1591
1591
  t.forEach((o, s) => {
1592
1592
  if (o < 0 || o >= n) {
1593
- const r = pt(s, e.length, Q(e)).join(",");
1593
+ const r = pt(s, e.length, J(e)).join(",");
1594
1594
  throw new Error(`indices[${r}] = ${o} is not in [0, ${n})`);
1595
1595
  }
1596
1596
  });
@@ -2110,7 +2110,7 @@ const wn = K((t) => 1 / (1 + Math.exp(-t)));
2110
2110
  * =============================================================================
2111
2111
  */
2112
2112
  function In(t, e, n, o, s) {
2113
- const r = ne(o, e, n), a = y(n), i = Q(o);
2113
+ const r = ne(o, e, n), a = y(n), i = J(o);
2114
2114
  if (r) {
2115
2115
  const f = oe(e, i);
2116
2116
  return s === "string" ? t.slice(f, f + a) : t.subarray(f, f + a);
@@ -2120,7 +2120,7 @@ function In(t, e, n, o, s) {
2120
2120
  const w = h.indexToLoc(f), p = w.map((m, b) => m + e[b]);
2121
2121
  h.set(c.get(...p), ...w);
2122
2122
  }
2123
- return s === "string" ? Yt(h.values) : h.values;
2123
+ return s === "string" ? Qt(h.values) : h.values;
2124
2124
  }
2125
2125
  /**
2126
2126
  * @license
@@ -2791,9 +2791,9 @@ const An = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
2791
2791
  lessImpl: Be,
2792
2792
  linSpaceImpl: Ke,
2793
2793
  logImpl: Xe,
2794
- maxImpl: Ye,
2795
- maximumImpl: Je,
2796
- minimumImpl: Qe,
2794
+ maxImpl: Qe,
2795
+ maximumImpl: Ye,
2796
+ minimumImpl: Je,
2797
2797
  multiplyImpl: Pt,
2798
2798
  negImpl: tn,
2799
2799
  notEqualImpl: en,
@@ -3171,7 +3171,7 @@ class Un {
3171
3171
  o[h] = e[n[h]];
3172
3172
  if (this.outputShape = o, this.rank = o.length, this.rank > 6)
3173
3173
  throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
3174
- const s = yt(this.rank), r = Jt("rc", this.rank), a = new Array(this.rank);
3174
+ const s = yt(this.rank), r = Yt("rc", this.rank), a = new Array(this.rank);
3175
3175
  for (let h = 0; h < n.length; h++)
3176
3176
  a[n[h]] = r[h];
3177
3177
  const i = `vec2(${a.slice(-2).join()})`, u = `++${r[this.rank - 1]} < ${o[this.rank - 1]}`, c = `getChannel(getA(${a.join()}), ${i})`;
@@ -3349,7 +3349,7 @@ return a / b;`, Hn = `
3349
3349
  }
3350
3350
 
3351
3351
  return result;
3352
- `, Kn = Qt({ opSnippet: Bn, packedOpSnippet: Hn, checkOutOfBounds: !0 });
3352
+ `, Kn = Jt({ opSnippet: Bn, packedOpSnippet: Hn, checkOutOfBounds: !0 });
3353
3353
  class Xn {
3354
3354
  variableNames = ["logits", "maxLogits"];
3355
3355
  outputShape;
@@ -3369,7 +3369,7 @@ class Xn {
3369
3369
  `;
3370
3370
  }
3371
3371
  }
3372
- class Yn {
3372
+ class Qn {
3373
3373
  variableNames = ["exp", "sum"];
3374
3374
  outputShape;
3375
3375
  userCode;
@@ -3396,7 +3396,7 @@ class Yn {
3396
3396
  `;
3397
3397
  }
3398
3398
  }
3399
- function Jn(t) {
3399
+ function Yn(t) {
3400
3400
  const { inputs: e, attrs: n } = t, { logits: o } = e, { dim: s, dropoutRate: r, seed: a } = n, i = t.backend;
3401
3401
  if (!o)
3402
3402
  throw new Error("Error in softmax: input logits is null");
@@ -3408,7 +3408,7 @@ function Jn(t) {
3408
3408
  i.disposeIntermediateTensorInfo(c);
3409
3409
  const p = Zn({ inputs: { x: w }, backend: i, attrs: { axis: u, keepDims: !1 } }), m = st({ inputs: { x: p }, backend: i, attrs: { shape: h } });
3410
3410
  if (r !== void 0 && r > 0) {
3411
- const d = new Yn(o.shape), g = i.runWebGLProgram(d, [w, m], "float32", [
3411
+ const d = new Qn(o.shape), g = i.runWebGLProgram(d, [w, m], "float32", [
3412
3412
  [r],
3413
3413
  [a ?? Math.random() * 1e4]
3414
3414
  ]);
@@ -3417,12 +3417,12 @@ function Jn(t) {
3417
3417
  const b = Kn({ inputs: { a: w, b: m }, backend: i });
3418
3418
  return i.disposeIntermediateTensorInfo(w), i.disposeIntermediateTensorInfo(p), i.disposeIntermediateTensorInfo(m), b;
3419
3419
  }
3420
- const Qn = {
3420
+ const Jn = {
3421
3421
  kernelName: "FusedSoftmax",
3422
3422
  backendName: "webgl",
3423
- kernelFunc: Jn
3423
+ kernelFunc: Yn
3424
3424
  };
3425
- Xt(Qn);
3425
+ Xt(Jn);
3426
3426
  export {
3427
- Jn as softmax
3427
+ Yn as softmax
3428
3428
  };
@@ -1,4 +1,4 @@
1
- import { r as l } from "../../index-iNhkcAEQ.js";
1
+ import { r as l } from "../../index-CnHyhpKc.js";
2
2
  class u {
3
3
  variableNames = ["labels", "logits", "values"];
4
4
  outputShape;
@@ -1,5 +1,5 @@
1
- import { r as a } from "../../index-iNhkcAEQ.js";
2
- import { u as s, C as x } from "../../kernel_funcs_utils-C4eIk4fE.js";
1
+ import { r as a } from "../../index-CnHyhpKc.js";
2
+ import { u as s, C as x } from "../../kernel_funcs_utils-Dqo82NH4.js";
3
3
  const t = 0.7978845608028654, r = 0.044715, c = x + `
4
4
  float x3 = x * x * x;
5
5
  float inner = x + ${r} * x3;
@@ -1,7 +1,7 @@
1
- import { r as C, t as R, e as I, p as G, N as L, k as F, O as U } from "../../index-iNhkcAEQ.js";
2
- import { r as S } from "../../Reshape-BE5rA4rT.js";
3
- import { u as H } from "../../gpgpu_math-C0zyxKFi.js";
4
- import { m as B } from "../../mat_mul-D0SifYfJ.js";
1
+ import { r as C, t as R, e as I, q as G, Q as L, l as U, U as F } from "../../index-CnHyhpKc.js";
2
+ import { r as S } from "../../Reshape-Bt_t7RNz.js";
3
+ import { u as H } from "../../gpgpu_math-Df7gzJWH.js";
4
+ import { m as B } from "../../mat_mul-DeGU1U_C.js";
5
5
  /**
6
6
  * @license
7
7
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -69,7 +69,7 @@ class W {
69
69
  `;
70
70
  }
71
71
  }
72
- const g = 0.7978845608028654, w = 0.044715, j = `
72
+ const g = 0.7978845608028654, w = 0.044715, q = `
73
73
  vec4 x3 = x * x * x;
74
74
  vec4 inner = x + ${w} * x3;
75
75
  inner = ${g} * inner;
@@ -77,7 +77,7 @@ const g = 0.7978845608028654, w = 0.044715, j = `
77
77
  inner = 0.5 * (1.0 + inner);
78
78
  vec4 result = x * inner;
79
79
  return result;
80
- `, q = `
80
+ `, j = `
81
81
  vec4 a2 = a * a;
82
82
  vec4 a3 = a2 * a;
83
83
  vec4 u = ${g} * (a + ${w} * a3);
@@ -97,29 +97,29 @@ function O({
97
97
  multiplier: o
98
98
  }) {
99
99
  const r = t.shape.length, u = e.shape.length, l = s ? t.shape[r - 2] : t.shape[r - 1], h = n ? e.shape[u - 1] : e.shape[u - 2], p = s ? t.shape[r - 1] : t.shape[r - 2], d = n ? e.shape[u - 2] : e.shape[u - 1], $ = t.shape.slice(0, -2), x = e.shape.slice(0, -2), m = G($), i = G(x), M = L(t.shape.slice(0, -2), e.shape.slice(0, -2)).concat([p, d]);
100
- F(
100
+ U(
101
101
  l === h,
102
102
  () => `Error in matMul: inner shapes (${l}) and (${h}) of Tensors with shapes ${t.shape} and ${e.shape} and transposeA=${s} and transposeB=${n} must match.`
103
103
  );
104
- const f = s ? [m, l, p] : [m, p, l], v = n ? [i, d, h] : [i, h, d], A = S({ inputs: { x: t }, backend: a, attrs: { shape: f } }), y = S({ inputs: { x: e }, backend: a, attrs: { shape: v } }), k = [A, y], N = Math.max(m, i), E = c, T = U(t.dtype, e.dtype), _ = new W(
104
+ const f = s ? [m, l, p] : [m, p, l], v = n ? [i, d, h] : [i, h, d], A = S({ inputs: { x: t }, backend: a, attrs: { shape: f } }), y = S({ inputs: { x: e }, backend: a, attrs: { shape: v } }), D = [A, y], E = Math.max(m, i), N = c, T = F(t.dtype, e.dtype), _ = new W(
105
105
  f,
106
106
  v,
107
- [N, p, d],
107
+ [E, p, d],
108
108
  s,
109
109
  n,
110
110
  !1,
111
- E,
111
+ N,
112
112
  !!o,
113
113
  !1
114
- ), D = [A, y];
115
- o && D.push(o);
116
- const z = a.runWebGLProgram(_, D, T), K = S({ inputs: { x: z }, backend: a, attrs: { shape: M } });
117
- k.push(z);
118
- for (const P of k)
114
+ ), k = [A, y];
115
+ o && k.push(o);
116
+ const z = a.runWebGLProgram(_, k, T), K = S({ inputs: { x: z }, backend: a, attrs: { shape: M } });
117
+ D.push(z);
118
+ for (const P of D)
119
119
  a.disposeIntermediateTensorInfo(P);
120
120
  return K;
121
121
  }
122
- function J(t) {
122
+ function Q(t) {
123
123
  const { inputs: e, backend: s } = t, { x: n, kernel: a } = e;
124
124
  if (n === void 0 || a === void 0)
125
125
  throw new Error("BatchMatMul requires two input tensors.");
@@ -129,15 +129,15 @@ function J(t) {
129
129
  transposeA: !1,
130
130
  transposeB: !1,
131
131
  backend: s,
132
- activationSnippet: j
132
+ activationSnippet: q
133
133
  });
134
134
  }
135
- const Q = {
135
+ const J = {
136
136
  kernelName: "MatMulGelu",
137
137
  backendName: "webgl",
138
- kernelFunc: J
138
+ kernelFunc: Q
139
139
  };
140
- C(Q);
140
+ C(J);
141
141
  function V(t) {
142
142
  const { dy: e, x: s, kernel: n } = t.inputs, a = t.backend;
143
143
  return R(() => {
@@ -148,7 +148,7 @@ function V(t) {
148
148
  transposeA: !1,
149
149
  transposeB: !1,
150
150
  backend: a,
151
- activationSnippet: q,
151
+ activationSnippet: j,
152
152
  multiplier: e
153
153
  })
154
154
  ), o = B(c, n, !1, !0), r = B(s, c, !0, !1);
@@ -164,5 +164,5 @@ C(X);
164
164
  export {
165
165
  se as MATMUL_SHARED_DIM_THRESHOLD,
166
166
  O as batchMatMulGeluImpl,
167
- J as batchMatMulKernel
167
+ Q as batchMatMulKernel
168
168
  };
@@ -1,4 +1,4 @@
1
- import { r as u } from "../../index-iNhkcAEQ.js";
1
+ import { r as u } from "../../index-CnHyhpKc.js";
2
2
  import { batchMatMulGeluImpl as c } from "./matMulGelu.js";
3
3
  const M = `
4
4
  return a * b;
@@ -1,4 +1,4 @@
1
- import { r as m } from "../../index-iNhkcAEQ.js";
1
+ import { r as m } from "../../index-CnHyhpKc.js";
2
2
  class f {
3
3
  variableNames = ["a", "b"];
4
4
  outputShape;
@@ -1,5 +1,5 @@
1
- import { r as p, e as G } from "../../index-iNhkcAEQ.js";
2
- import { s as x } from "../../sum-B_92TaHD.js";
1
+ import { r as p, e as G } from "../../index-CnHyhpKc.js";
2
+ import { s as x } from "../../sum-UdfvaNhB.js";
3
3
  class y {
4
4
  variableNames = ["x", "meanSquare", "gamma"];
5
5
  outputShape;
@@ -1,4 +1,4 @@
1
- import { r as i } from "../../index-iNhkcAEQ.js";
1
+ import { r as i } from "../../index-CnHyhpKc.js";
2
2
  class l {
3
3
  variableNames = ["x", "kernel"];
4
4
  outputShape;
@@ -1,4 +1,4 @@
1
- import { r as u } from "../../index-iNhkcAEQ.js";
1
+ import { r as u } from "../../index-CnHyhpKc.js";
2
2
  class l {
3
3
  variableNames = ["x", "sin", "cos"];
4
4
  outputShape;
@@ -1,4 +1,4 @@
1
- import { r as i } from "../../index-iNhkcAEQ.js";
1
+ import { r as i } from "../../index-CnHyhpKc.js";
2
2
  class u {
3
3
  variableNames = ["labels", "softmaxProbs", "dy"];
4
4
  outputShape;