@genai-fi/nanogpt 0.4.4 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. package/dist/BaseLayer-BhrMN8JO.js +135 -0
  2. package/dist/Generator.js +44 -41
  3. package/dist/NanoGPTModel.d.ts +12 -16
  4. package/dist/NanoGPTModel.js +128 -138
  5. package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
  6. package/dist/TeachableLLM.js +8 -5
  7. package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
  8. package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
  9. package/dist/broadcast_to-CMlkG8NS.js +44 -0
  10. package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
  11. package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
  12. package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
  13. package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
  14. package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
  15. package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
  16. package/dist/layers/BaseLayer.d.ts +28 -4
  17. package/dist/layers/BaseLayer.js +3 -16
  18. package/dist/layers/CausalSelfAttention.d.ts +22 -24
  19. package/dist/layers/CausalSelfAttention.js +73 -127
  20. package/dist/layers/MLP.d.ts +8 -15
  21. package/dist/layers/MLP.js +43 -81
  22. package/dist/layers/RMSNorm.d.ts +5 -11
  23. package/dist/layers/RMSNorm.js +13 -29
  24. package/dist/layers/RoPECache.js +14 -12
  25. package/dist/layers/TiedEmbedding.d.ts +6 -16
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.d.ts +12 -16
  28. package/dist/layers/TransformerBlock.js +20 -41
  29. package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
  30. package/dist/main.js +22 -19
  31. package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
  32. package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
  33. package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
  34. package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
  35. package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
  36. package/dist/ops/appendCache.js +4 -4
  37. package/dist/ops/attentionMask.d.ts +1 -1
  38. package/dist/ops/attentionMask.js +4 -4
  39. package/dist/ops/cpu/appendCache.js +2 -2
  40. package/dist/ops/cpu/attentionMask.js +14 -15
  41. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  42. package/dist/ops/cpu/gatherSub.js +5 -5
  43. package/dist/ops/cpu/gelu.js +1 -1
  44. package/dist/ops/cpu/matMulGelu.js +1 -1
  45. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  46. package/dist/ops/cpu/matMulMul.js +17 -0
  47. package/dist/ops/cpu/mulDropout.js +1 -1
  48. package/dist/ops/cpu/normRMS.d.ts +1 -0
  49. package/dist/ops/cpu/normRMS.js +39 -0
  50. package/dist/ops/cpu/qkv.js +3 -3
  51. package/dist/ops/cpu/rope.js +5 -5
  52. package/dist/ops/cpu/scatterSub.js +8 -8
  53. package/dist/ops/fusedSoftmax.js +1 -1
  54. package/dist/ops/gatherSub.js +1 -1
  55. package/dist/ops/gelu.js +1 -1
  56. package/dist/ops/grads/attentionMask.js +13 -9
  57. package/dist/ops/grads/fusedSoftmax.js +12 -9
  58. package/dist/ops/grads/gelu.js +1 -1
  59. package/dist/ops/grads/matMulGelu.js +1 -1
  60. package/dist/ops/grads/normRMS.d.ts +2 -0
  61. package/dist/ops/grads/normRMS.js +20 -0
  62. package/dist/ops/grads/qkv.js +19 -9
  63. package/dist/ops/grads/rope.js +1 -1
  64. package/dist/ops/matMulGelu.js +1 -1
  65. package/dist/ops/matMulMul.d.ts +2 -0
  66. package/dist/ops/matMulMul.js +9 -0
  67. package/dist/ops/mulDrop.js +1 -1
  68. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  69. package/dist/ops/normRMS.d.ts +2 -0
  70. package/dist/ops/normRMS.js +10 -0
  71. package/dist/ops/qkv.js +1 -1
  72. package/dist/ops/scatterSub.js +1 -1
  73. package/dist/ops/webgl/appendCache.js +1 -1
  74. package/dist/ops/webgl/attentionMask.js +13 -12
  75. package/dist/ops/webgl/fusedSoftmax.js +43 -40
  76. package/dist/ops/webgl/gatherSub.js +1 -1
  77. package/dist/ops/webgl/gelu.js +2 -2
  78. package/dist/ops/webgl/matMulGelu.d.ts +3 -2
  79. package/dist/ops/webgl/matMulGelu.js +77 -75
  80. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  81. package/dist/ops/webgl/matMulMul.js +28 -0
  82. package/dist/ops/webgl/mulDropout.js +1 -1
  83. package/dist/ops/webgl/normRMS.d.ts +1 -0
  84. package/dist/ops/webgl/normRMS.js +86 -0
  85. package/dist/ops/webgl/qkv.js +1 -1
  86. package/dist/ops/webgl/rope.js +1 -1
  87. package/dist/ops/webgl/scatterSub.js +1 -1
  88. package/dist/ops-ObfXLHYQ.js +1269 -0
  89. package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
  90. package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
  91. package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
  92. package/dist/slice_util-D-kaD4ZV.js +49 -0
  93. package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
  94. package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
  95. package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
  96. package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
  97. package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
  98. package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
  99. package/dist/tfjs_backend-NucKez4s.js +1010 -0
  100. package/dist/training/AdamExt.js +1 -1
  101. package/dist/training/DatasetBuilder.js +44 -44
  102. package/dist/training/Evaluator.js +6 -6
  103. package/dist/training/FullTrainer.js +1 -1
  104. package/dist/training/Trainer.js +7 -7
  105. package/dist/training/sparseCrossEntropy.js +4 -4
  106. package/dist/utilities/dummy.js +10 -10
  107. package/dist/utilities/generate.js +3 -3
  108. package/dist/utilities/load.js +1 -1
  109. package/dist/utilities/profile.js +1 -1
  110. package/dist/utilities/save.js +10 -8
  111. package/dist/utilities/weights.js +2 -2
  112. package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
  113. package/package.json +1 -1
  114. package/dist/slice_util-BdhYwFY_.js +0 -90
  115. package/dist/tfjs_backend-DuKis_xG.js +0 -2271
  116. package/dist/variable-BJTZ3jOy.js +0 -23
@@ -1,17 +1,20 @@
1
- import { g as p, b as m, s as d } from "../../index--6vO-cOz.js";
2
- import { mulDrop as c } from "../mulDrop.js";
3
- import { s as f } from "../../sum-DdkDf2MG.js";
4
- const g = {
1
+ import { h as d, b as u, s as f } from "../../index-iNhkcAEQ.js";
2
+ import { mulDrop as l } from "../mulDrop.js";
3
+ import { s as g } from "../../sum-B_92TaHD.js";
4
+ const T = {
5
5
  kernelName: "FusedSoftmax",
6
6
  outputsToSave: [!0],
7
- gradFunc: (s, a, u) => {
8
- const [o] = a, { dim: i, dropoutRate: t, seed: r } = u, n = !0, e = t && r ? c(s, o, t, r) : m(s, o);
7
+ gradFunc: (o, i, n) => {
8
+ const [s] = i, { dim: a, dropoutRate: t, seed: e } = n, p = !0, r = t && e ? l(o, s, t, e) : u(o, s);
9
9
  return {
10
- logits: () => d(e, m(f(e, [i], n), o))
10
+ logits: () => {
11
+ const m = g(r, [a], p), c = u(m, s);
12
+ return m.dispose(), f(r, c);
13
+ }
11
14
  };
12
15
  }
13
16
  };
14
- p(g);
17
+ d(T);
15
18
  export {
16
- g as softmaxGradConfig
19
+ T as softmaxGradConfig
17
20
  };
@@ -1,4 +1,4 @@
1
- import { g as t, e as n } from "../../index--6vO-cOz.js";
1
+ import { h as t, e as n } from "../../index-iNhkcAEQ.js";
2
2
  import "../cpu/gelu.js";
3
3
  import "../webgl/gelu.js";
4
4
  const o = {
@@ -1,4 +1,4 @@
1
- import { g as a, e as o } from "../../index--6vO-cOz.js";
1
+ import { h as a, e as o } from "../../index-iNhkcAEQ.js";
2
2
  function s(e, n, r) {
3
3
  return o().runKernel("MatMulGeluGrad", { dy: e, x: n, kernel: r });
4
4
  }
@@ -0,0 +1,2 @@
1
+ import { GradConfig } from '@tensorflow/tfjs-core';
2
+ export declare const normRMSGradConfig: GradConfig;
@@ -0,0 +1,20 @@
1
+ import { h as t, e as g } from "../../index-iNhkcAEQ.js";
2
+ function s(r, a, n) {
3
+ return g().runKernel("RMSNormGrad", { dy: r, x: a, gamma: n });
4
+ }
5
+ const u = {
6
+ kernelName: "RMSNorm",
7
+ inputsToSave: ["x", "gamma"],
8
+ outputsToSave: [],
9
+ gradFunc: (r, a) => {
10
+ const [n, e] = a, [m, o] = s(r, n, e);
11
+ return {
12
+ x: () => m,
13
+ gamma: () => o
14
+ };
15
+ }
16
+ };
17
+ t(u);
18
+ export {
19
+ u as normRMSGradConfig
20
+ };
@@ -1,20 +1,30 @@
1
- import { g as v } from "../../index--6vO-cOz.js";
2
- const g = {
1
+ import { h as Q } from "../../index-iNhkcAEQ.js";
2
+ const V = {
3
3
  kernelName: "QKV",
4
4
  inputsToSave: ["x", "kernel"],
5
5
  outputsToSave: [],
6
- gradFunc: (k, m) => {
7
- const [x, K, f] = k, [c, r] = m, [t, s, e] = c.shape, d = x.transpose([0, 2, 1, 3]).reshape([t * s, e]), l = K.transpose([0, 2, 1, 3]).reshape([t * s, e]), u = f.transpose([0, 2, 1, 3]).reshape([t * s, e]), i = r.slice([0, 0], [e, e]), h = r.slice([0, e], [e, e]), M = r.slice([0, 2 * e], [e, e]);
6
+ gradFunc: (x, K) => {
7
+ const [f, h, M] = x, [p, l] = K, [t, n, e] = p.shape, i = f.transpose([0, 2, 1, 3]).reshape([t * n, e]), u = h.transpose([0, 2, 1, 3]).reshape([t * n, e]), k = M.transpose([0, 2, 1, 3]).reshape([t * n, e]);
8
8
  return {
9
9
  x: () => {
10
- const n = d.matMul(i, !1, !0), a = l.matMul(h, !1, !0), o = u.matMul(M, !1, !0);
11
- return n.add(a).add(o).reshape([t, s, e]);
10
+ const s = l.slice([0, 0], [e, e]), o = i.matMul(s, !1, !0);
11
+ s.dispose();
12
+ const d = l.slice([0, e], [e, e]), r = u.matMul(d, !1, !0);
13
+ d.dispose();
14
+ const a = o.add(r);
15
+ o.dispose(), r.dispose();
16
+ const c = l.slice([0, 2 * e], [e, e]), m = k.matMul(c, !1, !0);
17
+ c.dispose();
18
+ const v = a.add(m).reshape([t, n, e]);
19
+ return a.dispose(), m.dispose(), v;
12
20
  },
13
21
  kernel: () => {
14
- const n = c.reshape([t * s, e]), a = n.matMul(d, !0, !1), o = n.matMul(l, !0, !1), p = n.matMul(u, !0, !1);
15
- return a.concat(o, 1).concat(p, 1);
22
+ const s = p.reshape([t * n, e]), o = s.matMul(i, !0, !1), d = s.matMul(u, !0, !1), r = o.concat(d, 1);
23
+ o.dispose(), d.dispose();
24
+ const a = s.matMul(k, !0, !1), c = r.concat(a, 1);
25
+ return r.dispose(), a.dispose(), s.dispose(), c;
16
26
  }
17
27
  };
18
28
  }
19
29
  };
20
- v(g);
30
+ Q(V);
@@ -1,4 +1,4 @@
1
- import { g as a, e as i } from "../../index--6vO-cOz.js";
1
+ import { h as a, e as i } from "../../index-iNhkcAEQ.js";
2
2
  function p(n, e, s, o) {
3
3
  return i().runKernel("Rope", { x: n, sin: e, cos: s }, { pastLen: o });
4
4
  }
@@ -1,4 +1,4 @@
1
- import { e as u } from "../index--6vO-cOz.js";
1
+ import { e as u } from "../index-iNhkcAEQ.js";
2
2
  import "./cpu/matMulGelu.js";
3
3
  import "./webgl/matMulGelu.js";
4
4
  import "./grads/matMulGelu.js";
@@ -0,0 +1,2 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ export declare function matMulMul(x: Tensor, kernel: Tensor, y: Tensor, transposeA?: boolean, transposeB?: boolean): Tensor;
@@ -0,0 +1,9 @@
1
+ import { e as u } from "../index-iNhkcAEQ.js";
2
+ import "./cpu/matMulMul.js";
3
+ import "./webgl/matMulMul.js";
4
+ function m(e, r, t, l = !1, n = !1) {
5
+ return u().runKernel("MatMulMul", { x: e, kernel: r, y: t }, { transposeA: l, transposeB: n });
6
+ }
7
+ export {
8
+ m as matMulMul
9
+ };
@@ -1,4 +1,4 @@
1
- import { e as t } from "../index--6vO-cOz.js";
1
+ import { e as t } from "../index-iNhkcAEQ.js";
2
2
  import "./cpu/mulDropout.js";
3
3
  import "./webgl/mulDropout.js";
4
4
  function m(r, o, e, n) {
@@ -1,4 +1,4 @@
1
- import { r as o } from "../../index--6vO-cOz.js";
1
+ import { r as o } from "../../index-iNhkcAEQ.js";
2
2
  function r(e) {
3
3
  const { logits: t, labels: n } = e.inputs;
4
4
  return e.backend.executeMultipleOutputs("SparseSoftmaxCrossEntropyWithLogits", [], [t, n], 2);
@@ -0,0 +1,2 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ export declare function normRMS(x: Tensor, gamma: Tensor): Tensor;
@@ -0,0 +1,10 @@
1
+ import { e as n } from "../index-iNhkcAEQ.js";
2
+ import "./cpu/normRMS.js";
3
+ import "./webgl/normRMS.js";
4
+ import "./grads/normRMS.js";
5
+ function p(r, o) {
6
+ return n().runKernel("RMSNorm", { x: r, gamma: o });
7
+ }
8
+ export {
9
+ p as normRMS
10
+ };
package/dist/ops/qkv.js CHANGED
@@ -1,4 +1,4 @@
1
- import { e as o } from "../index--6vO-cOz.js";
1
+ import { e as o } from "../index-iNhkcAEQ.js";
2
2
  import "./cpu/qkv.js";
3
3
  import "./webgl/qkv.js";
4
4
  import "./grads/qkv.js";
@@ -1,4 +1,4 @@
1
- import { e as i } from "../index--6vO-cOz.js";
1
+ import { e as i } from "../index-iNhkcAEQ.js";
2
2
  import "./cpu/scatterSub.js";
3
3
  import "./webgl/scatterSub.js";
4
4
  function c(t, r, e) {
@@ -1,4 +1,4 @@
1
- import { r as p } from "../../index--6vO-cOz.js";
1
+ import { r as p } from "../../index-iNhkcAEQ.js";
2
2
  class m {
3
3
  variableNames = ["cache", "item"];
4
4
  outputShape;
@@ -1,14 +1,15 @@
1
- import { r as h } from "../../index--6vO-cOz.js";
2
- class l {
1
+ import { r as m } from "../../index-iNhkcAEQ.js";
2
+ class h {
3
3
  variableNames = ["q", "k"];
4
4
  outputShape;
5
5
  userCode;
6
6
  customUniforms = [
7
7
  { name: "divisor", type: "float" },
8
- { name: "pastLen", type: "int" }
8
+ { name: "pastLen", type: "int" },
9
+ { name: "inf", type: "float" }
9
10
  ];
10
- constructor(t, s, e, n, a) {
11
- this.outputShape = [t, s, e, n], this.userCode = `
11
+ constructor(t, e, s, n, a) {
12
+ this.outputShape = [t, e, s, n], this.userCode = `
12
13
  void main() {
13
14
  ivec4 coords = getOutputCoords(); // [batch, nh, t1, t2]
14
15
  int b = coords.x;
@@ -27,18 +28,18 @@ class l {
27
28
  float scaled = sum * divisor;
28
29
 
29
30
  // Mask out future positions
30
- setOutput((t2 > t1 + pastLen) ? -1.0/0.0 : scaled);
31
+ setOutput((t2 > t1 + pastLen) ? inf : scaled);
31
32
  }
32
33
  `;
33
34
  }
34
35
  }
35
- function m(o) {
36
- const { q: t, k: s } = o.inputs, { divisor: e, pastLen: n } = o.attrs, a = o.backend, i = t.shape[0], r = t.shape[2], c = s.shape[2], u = t.shape[1], p = t.shape[3], d = new l(i, u, r, c, p);
37
- return a.runWebGLProgram(d, [t, s], "float32", [[e], [n]]);
36
+ function l(o) {
37
+ const { q: t, k: e } = o.inputs, { divisor: s, pastLen: n } = o.attrs, a = o.backend, i = t.shape[0], r = t.shape[2], c = e.shape[2], u = t.shape[1], p = t.shape[3], d = new h(i, u, r, c, p);
38
+ return a.runWebGLProgram(d, [t, e], "float32", [[s], [n], [Number.NEGATIVE_INFINITY]]);
38
39
  }
39
- const k = {
40
+ const f = {
40
41
  kernelName: "AttentionMask",
41
42
  backendName: "webgl",
42
- kernelFunc: m
43
+ kernelFunc: l
43
44
  };
44
- h(k);
45
+ m(f);
@@ -1,10 +1,11 @@
1
- import { a$ as qt, n as y, O as Ct, ac as J, b0 as nt, b1 as It, b2 as pt, b3 as ft, b4 as kt, Z as ot, ag as B, b5 as q, b6 as Ut, Q as Gt, j as jt, t as Zt, b7 as dt, ao as ct, b8 as tt, $ as ut, b9 as Bt, L as Ht, ba as Kt, r as Xt } from "../../index--6vO-cOz.js";
2
- import { f as $t, a as Qt, g as yt, b as Yt, c as Jt } from "../../kernel_funcs_utils-C6YBCuOt.js";
3
- import { c as mt, g as Tt, a as Vt, b as Mt, e as wt } from "../../axis_util-QP0LdI1v.js";
4
- import { b as te, i as ee, c as ne } from "../../slice_util-BdhYwFY_.js";
5
- import { r as oe } from "../../reshape-z51Eu-re.js";
1
+ import { a$ as qt, p as y, N as Ct, ad as Q, b0 as nt, b1 as It, b2 as pt, b3 as ft, b4 as kt, Z as ot, ah as B, b5 as q, b6 as Ut, O as Gt, k as jt, t as Zt, b7 as dt, ao as ct, b8 as tt, a1 as ut, b9 as Bt, K as Ht, ba as Kt, r as Xt } from "../../index-iNhkcAEQ.js";
2
+ import { f as $t, a as Yt, g as yt, b as Jt, c as Qt } from "../../kernel_funcs_utils-C4eIk4fE.js";
3
+ import { c as mt, g as Tt, a as Vt, b as Mt, e as wt } from "../../axis_util-97KkkyRQ.js";
4
+ import { b as te } from "../../broadcast_to-CMlkG8NS.js";
5
+ import { r as ee } from "../../reshape-DxTPgnwL.js";
6
+ import { i as ne, c as oe } from "../../slice_util-D-kaD4ZV.js";
6
7
  import { g as se } from "../../_commonjsHelpers-ByX85dGu.js";
7
- import { r as st } from "../../Reshape-CiAY8ltP.js";
8
+ import { r as st } from "../../Reshape-BE5rA4rT.js";
8
9
  function re(t, e) {
9
10
  for (var n = 0; n < e.length; n++) {
10
11
  const o = e[n];
@@ -654,20 +655,20 @@ const _t = /* @__PURE__ */ se(Lt), ae = /* @__PURE__ */ re({
654
655
  * limitations under the License.
655
656
  * =============================================================================
656
657
  */
657
- const Y = (
658
+ const J = (
658
659
  // tslint:disable-next-line
659
660
  _t || ae
660
661
  );
661
662
  function lt(t) {
662
- return Y.fromString(t, !0, 16);
663
+ return J.fromString(t, !0, 16);
663
664
  }
664
- const Nt = lt("c3a5c85c97cb3127"), Q = lt("b492b66fbe98f273"), W = lt("9ae16a3b2f90404f");
665
+ const Nt = lt("c3a5c85c97cb3127"), Y = lt("b492b66fbe98f273"), W = lt("9ae16a3b2f90404f");
665
666
  function gt(t) {
666
667
  return t.xor(t.shru(47));
667
668
  }
668
669
  function At(t, e, n) {
669
670
  const o = t.slice(e, e + n);
670
- return Y.fromBytes(Array.from(o), !0, !0);
671
+ return J.fromBytes(Array.from(o), !0, !0);
671
672
  }
672
673
  function R(t, e) {
673
674
  return At(t, e, 8);
@@ -708,7 +709,7 @@ function le(t, e = t.length) {
708
709
  return W;
709
710
  }
710
711
  function ce(t, e = t.length) {
711
- const n = W.add(e * 2), o = R(t, 0).mul(Q), s = R(t, 8), r = R(t, e - 8).mul(n), a = R(t, e - 16).mul(W);
712
+ const n = W.add(e * 2), o = R(t, 0).mul(Y), s = R(t, 8), r = R(t, e - 8).mul(n), a = R(t, e - 16).mul(W);
712
713
  return X(A(o.add(s), 43).add(A(r, 30)).add(a), o.add(A(s.add(W), 18)).add(r), n);
713
714
  }
714
715
  function he(t, e = t.length) {
@@ -716,19 +717,19 @@ function he(t, e = t.length) {
716
717
  return X(A(c.add(h), 43).add(A(f, 30)).add(w), c.add(A(h.add(o), 18)).add(f), n);
717
718
  }
718
719
  function fe(t, e = t.length) {
719
- const n = Y.fromNumber(81, !0);
720
+ const n = J.fromNumber(81, !0);
720
721
  if (e <= 32)
721
722
  return e <= 16 ? le(t, e) : ce(t, e);
722
723
  if (e <= 64)
723
724
  return he(t, e);
724
- let o = n, s = n.mul(Q).add(113), r = gt(s.mul(W).add(113)).mul(W), a = [Y.UZERO, Y.UZERO], i = [Y.UZERO, Y.UZERO];
725
+ let o = n, s = n.mul(Y).add(113), r = gt(s.mul(W).add(113)).mul(W), a = [J.UZERO, J.UZERO], i = [J.UZERO, J.UZERO];
725
726
  o = o.mul(W).add(R(t, 0));
726
727
  let u = 0;
727
728
  const c = (e - 1 >> 6) * 64, h = c + (e - 1 & 63) - 63;
728
729
  do
729
- o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(Q), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(Q), o = o.xor(i[1]), s = s.add(a[0]).add(R(t, u + 40)), r = A(r.add(i[0]), 33).mul(Q), a = it(t, u, a[1].mul(Q), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], u += 64;
730
+ o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(Y), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(Y), o = o.xor(i[1]), s = s.add(a[0]).add(R(t, u + 40)), r = A(r.add(i[0]), 33).mul(Y), a = it(t, u, a[1].mul(Y), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], u += 64;
730
731
  while (u !== c);
731
- const f = Q.add(r.and(255).shl(1));
732
+ const f = Y.add(r.and(255).shl(1));
732
733
  return u = h, i[0] = i[0].add(e - 1 & 63), a[0] = a[0].add(i[0]), i[0] = i[0].add(a[0]), o = A(o.add(s).add(a[0]).add(R(t, u + 8)), 37).mul(f), s = A(s.add(a[1]).add(R(t, u + 48)), 42).mul(f), o = o.xor(i[1].mul(9)), s = s.add(a[0].mul(9).add(R(t, u + 40))), r = A(r.add(i[0]), 33).mul(f), a = it(t, u, a[1].mul(f), o.add(i[0])), i = it(t, u + 32, r.add(i[1]), s.add(R(t, u + 16))), [r, o] = [o, r], X(X(a[0], i[0], f).add(gt(s).mul(Nt)).add(r), X(a[1], i[1], f).add(o), f);
733
734
  }
734
735
  /**
@@ -954,7 +955,7 @@ function Ve(t) {
954
955
  */
955
956
  function C(t) {
956
957
  return (e, n, o, s, r) => {
957
- const a = Ct(e, n), i = a.length, u = J(a), c = y(a), h = nt(r, c), f = e.length, w = n.length, p = J(e), m = J(n), b = It(e, a), d = It(n, a);
958
+ const a = Ct(e, n), i = a.length, u = Q(a), c = y(a), h = nt(r, c), f = e.length, w = n.length, p = Q(e), m = Q(n), b = It(e, a), d = It(n, a);
958
959
  if (b.length + d.length === 0)
959
960
  for (let g = 0; g < h.length; ++g)
960
961
  h[g] = t(o[g % o.length], s[g % s.length]);
@@ -1415,7 +1416,7 @@ const Xe = K((t) => Math.log(t));
1415
1416
  * limitations under the License.
1416
1417
  * =============================================================================
1417
1418
  */
1418
- function Qe(t, e, n, o) {
1419
+ function Ye(t, e, n, o) {
1419
1420
  const s = nt(o, y(n));
1420
1421
  for (let r = 0; r < s.length; ++r) {
1421
1422
  const a = r * e;
@@ -1444,7 +1445,7 @@ function Qe(t, e, n, o) {
1444
1445
  * limitations under the License.
1445
1446
  * =============================================================================
1446
1447
  */
1447
- const Ye = C((t, e) => Math.max(t, e));
1448
+ const Je = C((t, e) => Math.max(t, e));
1448
1449
  /**
1449
1450
  * @license
1450
1451
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1461,7 +1462,7 @@ const Ye = C((t, e) => Math.max(t, e));
1461
1462
  * limitations under the License.
1462
1463
  * =============================================================================
1463
1464
  */
1464
- const Je = C((t, e) => Math.min(t, e));
1465
+ const Qe = C((t, e) => Math.min(t, e));
1465
1466
  /**
1466
1467
  * @license
1467
1468
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1533,7 +1534,7 @@ const en = C((t, e) => t !== e ? 1 : 0);
1533
1534
  * =============================================================================
1534
1535
  */
1535
1536
  function nn(t, e, n, o, s) {
1536
- const r = e.length, a = y(e), i = J(e), u = J(s), c = nt(n, y(s));
1537
+ const r = e.length, a = y(e), i = Q(e), u = Q(s), c = nt(n, y(s));
1537
1538
  for (let h = 0; h < a; ++h) {
1538
1539
  const f = pt(h, r, i), w = new Array(f.length);
1539
1540
  for (let m = 0; m < w.length; m++)
@@ -1589,7 +1590,7 @@ function on(t, e, n, o) {
1589
1590
  function sn(t, e, n) {
1590
1591
  t.forEach((o, s) => {
1591
1592
  if (o < 0 || o >= n) {
1592
- const r = pt(s, e.length, J(e)).join(",");
1593
+ const r = pt(s, e.length, Q(e)).join(",");
1593
1594
  throw new Error(`indices[${r}] = ${o} is not in [0, ${n})`);
1594
1595
  }
1595
1596
  });
@@ -1944,7 +1945,7 @@ class at {
1944
1945
  if (h.length !== u && h.length !== 1) {
1945
1946
  const m = this.defaultValueShape;
1946
1947
  Zt(() => {
1947
- const b = oe(h, m);
1948
+ const b = ee(h, m);
1948
1949
  h = te(b, i).dataSync();
1949
1950
  });
1950
1951
  }
@@ -2109,9 +2110,9 @@ const wn = K((t) => 1 / (1 + Math.exp(-t)));
2109
2110
  * =============================================================================
2110
2111
  */
2111
2112
  function In(t, e, n, o, s) {
2112
- const r = ee(o, e, n), a = y(n), i = J(o);
2113
+ const r = ne(o, e, n), a = y(n), i = Q(o);
2113
2114
  if (r) {
2114
- const f = ne(e, i);
2115
+ const f = oe(e, i);
2115
2116
  return s === "string" ? t.slice(f, f + a) : t.subarray(f, f + a);
2116
2117
  }
2117
2118
  const u = s === "string" ? $t(t) : t, c = B(o, s, u), h = B(n, s);
@@ -2119,7 +2120,7 @@ function In(t, e, n, o, s) {
2119
2120
  const w = h.indexToLoc(f), p = w.map((m, b) => m + e[b]);
2120
2121
  h.set(c.get(...p), ...w);
2121
2122
  }
2122
- return s === "string" ? Qt(h.values) : h.values;
2123
+ return s === "string" ? Yt(h.values) : h.values;
2123
2124
  }
2124
2125
  /**
2125
2126
  * @license
@@ -2790,9 +2791,9 @@ const An = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
2790
2791
  lessImpl: Be,
2791
2792
  linSpaceImpl: Ke,
2792
2793
  logImpl: Xe,
2793
- maxImpl: Qe,
2794
- maximumImpl: Ye,
2795
- minimumImpl: Je,
2794
+ maxImpl: Ye,
2795
+ maximumImpl: Je,
2796
+ minimumImpl: Qe,
2796
2797
  multiplyImpl: Pt,
2797
2798
  negImpl: tn,
2798
2799
  notEqualImpl: en,
@@ -3170,7 +3171,7 @@ class Un {
3170
3171
  o[h] = e[n[h]];
3171
3172
  if (this.outputShape = o, this.rank = o.length, this.rank > 6)
3172
3173
  throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
3173
- const s = yt(this.rank), r = Yt("rc", this.rank), a = new Array(this.rank);
3174
+ const s = yt(this.rank), r = Jt("rc", this.rank), a = new Array(this.rank);
3174
3175
  for (let h = 0; h < n.length; h++)
3175
3176
  a[n[h]] = r[h];
3176
3177
  const i = `vec2(${a.slice(-2).join()})`, u = `++${r[this.rank - 1]} < ${o[this.rank - 1]}`, c = `getChannel(getA(${a.join()}), ${i})`;
@@ -3348,7 +3349,7 @@ return a / b;`, Hn = `
3348
3349
  }
3349
3350
 
3350
3351
  return result;
3351
- `, Kn = Jt({ opSnippet: Bn, packedOpSnippet: Hn, checkOutOfBounds: !0 });
3352
+ `, Kn = Qt({ opSnippet: Bn, packedOpSnippet: Hn, checkOutOfBounds: !0 });
3352
3353
  class Xn {
3353
3354
  variableNames = ["logits", "maxLogits"];
3354
3355
  outputShape;
@@ -3368,7 +3369,7 @@ class Xn {
3368
3369
  `;
3369
3370
  }
3370
3371
  }
3371
- class Qn {
3372
+ class Yn {
3372
3373
  variableNames = ["exp", "sum"];
3373
3374
  outputShape;
3374
3375
  userCode;
@@ -3395,7 +3396,7 @@ class Qn {
3395
3396
  `;
3396
3397
  }
3397
3398
  }
3398
- function Yn(t) {
3399
+ function Jn(t) {
3399
3400
  const { inputs: e, attrs: n } = t, { logits: o } = e, { dim: s, dropoutRate: r, seed: a } = n, i = t.backend;
3400
3401
  if (!o)
3401
3402
  throw new Error("Error in softmax: input logits is null");
@@ -3403,23 +3404,25 @@ function Yn(t) {
3403
3404
  inputs: { x: o },
3404
3405
  backend: i,
3405
3406
  attrs: { reductionIndices: u, keepDims: !1 }
3406
- }), h = wt(c.shape, u), f = new Xn(o.shape), w = i.runWebGLProgram(f, [o, c], "float32"), p = Zn({ inputs: { x: w }, backend: i, attrs: { axis: u, keepDims: !1 } }), m = st({ inputs: { x: p }, backend: i, attrs: { shape: h } });
3407
+ }), h = wt(c.shape, u), f = new Xn(o.shape), w = i.runWebGLProgram(f, [o, c], "float32");
3408
+ i.disposeIntermediateTensorInfo(c);
3409
+ const p = Zn({ inputs: { x: w }, backend: i, attrs: { axis: u, keepDims: !1 } }), m = st({ inputs: { x: p }, backend: i, attrs: { shape: h } });
3407
3410
  if (r !== void 0 && r > 0) {
3408
- const d = new Qn(o.shape), g = i.runWebGLProgram(d, [w, m], "float32", [
3411
+ const d = new Yn(o.shape), g = i.runWebGLProgram(d, [w, m], "float32", [
3409
3412
  [r],
3410
3413
  [a ?? Math.random() * 1e4]
3411
3414
  ]);
3412
- return i.disposeIntermediateTensorInfo(c), i.disposeIntermediateTensorInfo(w), i.disposeIntermediateTensorInfo(p), i.disposeIntermediateTensorInfo(m), g;
3415
+ return i.disposeIntermediateTensorInfo(w), i.disposeIntermediateTensorInfo(p), i.disposeIntermediateTensorInfo(m), g;
3413
3416
  }
3414
3417
  const b = Kn({ inputs: { a: w, b: m }, backend: i });
3415
- return i.disposeIntermediateTensorInfo(c), i.disposeIntermediateTensorInfo(w), i.disposeIntermediateTensorInfo(p), i.disposeIntermediateTensorInfo(m), b;
3418
+ return i.disposeIntermediateTensorInfo(w), i.disposeIntermediateTensorInfo(p), i.disposeIntermediateTensorInfo(m), b;
3416
3419
  }
3417
- const Jn = {
3420
+ const Qn = {
3418
3421
  kernelName: "FusedSoftmax",
3419
3422
  backendName: "webgl",
3420
- kernelFunc: Yn
3423
+ kernelFunc: Jn
3421
3424
  };
3422
- Xt(Jn);
3425
+ Xt(Qn);
3423
3426
  export {
3424
- Yn as softmax
3427
+ Jn as softmax
3425
3428
  };
@@ -1,4 +1,4 @@
1
- import { r as l } from "../../index--6vO-cOz.js";
1
+ import { r as l } from "../../index-iNhkcAEQ.js";
2
2
  class u {
3
3
  variableNames = ["labels", "logits", "values"];
4
4
  outputShape;
@@ -1,5 +1,5 @@
1
- import { r as a } from "../../index--6vO-cOz.js";
2
- import { u as s, C as x } from "../../kernel_funcs_utils-C6YBCuOt.js";
1
+ import { r as a } from "../../index-iNhkcAEQ.js";
2
+ import { u as s, C as x } from "../../kernel_funcs_utils-C4eIk4fE.js";
3
3
  const t = 0.7978845608028654, r = 0.044715, c = x + `
4
4
  float x3 = x * x * x;
5
5
  float inner = x + ${r} * x3;
@@ -7,9 +7,10 @@ type BatchMatMulConfig = {
7
7
  transposeA: boolean;
8
8
  transposeB: boolean;
9
9
  backend: MathBackendWebGL;
10
- activationSnippet: string;
10
+ activationSnippet?: string;
11
+ multiplier?: TensorInfo;
11
12
  };
12
- export declare function batchMatMulGeluImpl({ a, b, transposeA, transposeB, backend, activationSnippet, }: BatchMatMulConfig): TensorInfo;
13
+ export declare function batchMatMulGeluImpl({ a, b, transposeA, transposeB, backend, activationSnippet, multiplier, }: BatchMatMulConfig): TensorInfo;
13
14
  export declare function batchMatMulKernel(args: {
14
15
  inputs: {
15
16
  x: TensorInfo;