@genai-fi/nanogpt 0.4.5 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/dist/BaseLayer-BhrMN8JO.js +135 -0
  2. package/dist/Generator.js +44 -41
  3. package/dist/NanoGPTModel.d.ts +12 -16
  4. package/dist/NanoGPTModel.js +128 -138
  5. package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
  6. package/dist/TeachableLLM.js +1 -1
  7. package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
  8. package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
  9. package/dist/broadcast_to-CMlkG8NS.js +44 -0
  10. package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
  11. package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
  12. package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
  13. package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
  14. package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
  15. package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
  16. package/dist/layers/BaseLayer.d.ts +28 -4
  17. package/dist/layers/BaseLayer.js +3 -16
  18. package/dist/layers/CausalSelfAttention.d.ts +22 -24
  19. package/dist/layers/CausalSelfAttention.js +73 -128
  20. package/dist/layers/MLP.d.ts +8 -15
  21. package/dist/layers/MLP.js +43 -81
  22. package/dist/layers/RMSNorm.d.ts +5 -10
  23. package/dist/layers/RMSNorm.js +13 -29
  24. package/dist/layers/RoPECache.js +14 -12
  25. package/dist/layers/TiedEmbedding.d.ts +6 -16
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.d.ts +12 -16
  28. package/dist/layers/TransformerBlock.js +20 -41
  29. package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
  30. package/dist/main.js +1 -1
  31. package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
  32. package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
  33. package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
  34. package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
  35. package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
  36. package/dist/ops/appendCache.js +4 -4
  37. package/dist/ops/attentionMask.d.ts +1 -1
  38. package/dist/ops/attentionMask.js +4 -4
  39. package/dist/ops/cpu/appendCache.js +2 -2
  40. package/dist/ops/cpu/attentionMask.js +14 -15
  41. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  42. package/dist/ops/cpu/gatherSub.js +5 -5
  43. package/dist/ops/cpu/gelu.js +1 -1
  44. package/dist/ops/cpu/matMulGelu.js +1 -1
  45. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  46. package/dist/ops/cpu/matMulMul.js +17 -0
  47. package/dist/ops/cpu/mulDropout.js +1 -1
  48. package/dist/ops/cpu/normRMS.js +1 -1
  49. package/dist/ops/cpu/qkv.js +3 -3
  50. package/dist/ops/cpu/rope.js +5 -5
  51. package/dist/ops/cpu/scatterSub.js +8 -8
  52. package/dist/ops/fusedSoftmax.js +1 -1
  53. package/dist/ops/gatherSub.js +1 -1
  54. package/dist/ops/gelu.js +1 -1
  55. package/dist/ops/grads/attentionMask.js +13 -9
  56. package/dist/ops/grads/fusedSoftmax.js +12 -9
  57. package/dist/ops/grads/gelu.js +1 -1
  58. package/dist/ops/grads/matMulGelu.js +1 -1
  59. package/dist/ops/grads/normRMS.js +1 -1
  60. package/dist/ops/grads/qkv.js +19 -9
  61. package/dist/ops/grads/rope.js +1 -1
  62. package/dist/ops/matMulGelu.js +1 -1
  63. package/dist/ops/matMulMul.d.ts +2 -0
  64. package/dist/ops/matMulMul.js +9 -0
  65. package/dist/ops/mulDrop.js +1 -1
  66. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  67. package/dist/ops/normRMS.js +1 -1
  68. package/dist/ops/qkv.js +1 -1
  69. package/dist/ops/scatterSub.js +1 -1
  70. package/dist/ops/webgl/appendCache.js +1 -1
  71. package/dist/ops/webgl/attentionMask.js +13 -12
  72. package/dist/ops/webgl/fusedSoftmax.js +43 -40
  73. package/dist/ops/webgl/gatherSub.js +1 -1
  74. package/dist/ops/webgl/gelu.js +2 -2
  75. package/dist/ops/webgl/matMulGelu.js +17 -17
  76. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  77. package/dist/ops/webgl/matMulMul.js +28 -0
  78. package/dist/ops/webgl/mulDropout.js +1 -1
  79. package/dist/ops/webgl/normRMS.js +29 -21
  80. package/dist/ops/webgl/qkv.js +1 -1
  81. package/dist/ops/webgl/rope.js +1 -1
  82. package/dist/ops/webgl/scatterSub.js +1 -1
  83. package/dist/ops-ObfXLHYQ.js +1269 -0
  84. package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
  85. package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
  86. package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
  87. package/dist/slice_util-D-kaD4ZV.js +49 -0
  88. package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
  89. package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
  90. package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
  91. package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
  92. package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
  93. package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
  94. package/dist/tfjs_backend-NucKez4s.js +1010 -0
  95. package/dist/training/AdamExt.js +1 -1
  96. package/dist/training/DatasetBuilder.js +44 -44
  97. package/dist/training/Evaluator.js +6 -6
  98. package/dist/training/FullTrainer.js +1 -1
  99. package/dist/training/Trainer.js +7 -7
  100. package/dist/training/sparseCrossEntropy.js +4 -4
  101. package/dist/utilities/dummy.js +10 -10
  102. package/dist/utilities/generate.js +3 -3
  103. package/dist/utilities/load.js +1 -1
  104. package/dist/utilities/profile.js +1 -1
  105. package/dist/utilities/save.js +10 -8
  106. package/dist/utilities/weights.js +2 -2
  107. package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
  108. package/package.json +1 -1
  109. package/dist/slice_util-BdhYwFY_.js +0 -90
  110. package/dist/tfjs_backend-DuKis_xG.js +0 -2271
  111. package/dist/variable-BJTZ3jOy.js +0 -23
@@ -1,53 +1,32 @@
1
- import h from "./CausalSelfAttention.js";
2
- import o from "./MLP.js";
3
- import a from "./RMSNorm.js";
4
- import p from "./BaseLayer.js";
5
- import { t as d } from "../index--6vO-cOz.js";
6
- class W extends p {
1
+ import l from "./CausalSelfAttention.js";
2
+ import r from "./MLP.js";
3
+ import o from "./RMSNorm.js";
4
+ import { B as d } from "../BaseLayer-BhrMN8JO.js";
5
+ import { t as p } from "../index-iNhkcAEQ.js";
6
+ class k extends d {
7
7
  ln1;
8
8
  attn;
9
9
  ln2;
10
10
  mlp;
11
11
  index;
12
- _trainable = !0;
13
12
  skipped = !1;
14
- constructor(t, s) {
15
- super(s), this.index = t, this.ln1 = new a(s, `block_${this.index}_rms1`), this.attn = new h(this.index, s), this.ln2 = new a(s, `block_${this.index}_rms2`), this.mlp = new o(this.index, s);
16
- }
17
- get variables() {
18
- return [
19
- ...this.ln1.trainableWeights.map((t) => t),
20
- ...this.attn.variables,
21
- ...this.ln2.trainableWeights.map((t) => t),
22
- ...this.mlp.variables
23
- ];
24
- }
25
- get trainable() {
26
- return this._trainable;
27
- }
28
- set trainable(t) {
29
- this._trainable = t, this.ln1.trainable = t, this.ln2.trainable = t, this.attn.trainable = t, this.mlp.trainable = t;
30
- }
31
- saveWeights(t) {
32
- this.attn.saveWeights(t), this.mlp.saveWeights(t), t.set(`block_${this.index}_rms1`, this.ln1.getWeights()), t.set(`block_${this.index}_rms2`, this.ln2.getWeights());
33
- }
34
- loadWeights(t) {
35
- this.attn.loadWeights(t), this.mlp.loadWeights(t), this.ln1.setWeights(t.get(`block_${this.index}_rms1`) || []), this.ln2.setWeights(t.get(`block_${this.index}_rms2`) || []);
13
+ constructor(t, s, i) {
14
+ super(s, i), this.index = t, this.ln1 = new o(s, `block_${this.index}_rms1`, this), this.attn = new l(this.index, s, this), this.ln2 = new o(s, `block_${this.index}_rms2`, this), this.mlp = new r(this.index, s, this);
36
15
  }
37
16
  getMLPOutput(t, s) {
38
- const i = this.ln2.apply(t), e = this.mlp.call(i, s);
39
- return t.add(e);
17
+ const i = this.ln2.call({ training: s }, t), e = this.mlp.call({ training: s }, i);
18
+ i.dispose();
19
+ const n = t.add(e);
20
+ return t.dispose(), e.dispose(), n;
40
21
  }
41
- call(t, s = !1, i = !1, e) {
42
- return d(() => {
22
+ forward(t, s) {
23
+ return p(() => {
43
24
  if (this.skipped)
44
- return { output: t };
45
- const l = this.ln1.apply(t), n = this.attn.call(l, s, i, e), r = t.add(n.output);
46
- return {
47
- output: this.getMLPOutput(r, s),
48
- attention: n.attention,
49
- cache: n.presentKV
50
- };
25
+ return s;
26
+ const i = this.ln1.call(t, s), e = this.attn.call(t, i);
27
+ i.dispose();
28
+ const n = s.add(e);
29
+ return e.dispose(), this.getMLPOutput(n, t.training);
51
30
  });
52
31
  }
53
32
  dispose() {
@@ -55,5 +34,5 @@ class W extends p {
55
34
  }
56
35
  }
57
36
  export {
58
- W as default
37
+ k as default
59
38
  };
@@ -1,8 +1,8 @@
1
- import { o as r, h as p, E as u, a6 as h, a7 as E, $, s as S, a8 as d } from "./index--6vO-cOz.js";
2
- import { e as K } from "./axis_util-QP0LdI1v.js";
3
- import { m as T } from "./max-BUShNgfh.js";
4
- import { r as m } from "./reshape-z51Eu-re.js";
5
- import { s as _ } from "./sum-DdkDf2MG.js";
1
+ import { o as r, i as p, E as u, a7 as E, a8 as h, a1 as S, s as $, a9 as d } from "./index-iNhkcAEQ.js";
2
+ import { e as K } from "./axis_util-97KkkyRQ.js";
3
+ import { m as T } from "./max-CYaAjEEp.js";
4
+ import { r as m } from "./reshape-DxTPgnwL.js";
5
+ import { s as _ } from "./sum-B_92TaHD.js";
6
6
  /**
7
7
  * @license
8
8
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -21,7 +21,7 @@ import { s as _ } from "./sum-DdkDf2MG.js";
21
21
  */
22
22
  function b(s) {
23
23
  const o = { x: p(s, "x", "exp") };
24
- return u.runKernel(h, o);
24
+ return u.runKernel(E, o);
25
25
  }
26
26
  const N = /* @__PURE__ */ r({ exp_: b });
27
27
  /**
@@ -42,7 +42,7 @@ const N = /* @__PURE__ */ r({ exp_: b });
42
42
  */
43
43
  function v(s) {
44
44
  const o = { x: p(s, "x", "log", "float32") };
45
- return u.runKernel(E, o);
45
+ return u.runKernel(h, o);
46
46
  }
47
47
  const w = /* @__PURE__ */ r({ log_: v });
48
48
  /**
@@ -61,13 +61,13 @@ const w = /* @__PURE__ */ r({ log_: v });
61
61
  * limitations under the License.
62
62
  * =============================================================================
63
63
  */
64
- function A(s, n = null, o = !1) {
65
- const a = p(s, "x", "logSumExp"), t = $(n, a.shape), x = T(
66
- a,
64
+ function A(s, a = null, o = !1) {
65
+ const n = p(s, "x", "logSumExp"), t = S(a, n.shape), x = T(
66
+ n,
67
67
  t,
68
68
  !0
69
69
  /* keepDims */
70
- ), i = S(a, x), l = N(i), f = _(l, t), c = w(f), e = d(m(x, c.shape), c);
70
+ ), i = $(n, x), l = N(i), f = _(l, t), c = w(f), e = d(m(x, c.shape), c);
71
71
  if (o) {
72
72
  const g = K(e.shape, t);
73
73
  return m(e, g);
package/dist/main.js CHANGED
@@ -5,7 +5,7 @@ import { default as q } from "./tokeniser/bpe.js";
5
5
  import { default as A } from "./utilities/waitForModel.js";
6
6
  import { default as I } from "./data/textLoader.js";
7
7
  import { estimateMemoryUsage as K, estimateParameterCount as O, estimateResources as Q, estimateTrainingMemoryUsage as S, validateConfig as V } from "./utilities/parameters.js";
8
- import "./index--6vO-cOz.js";
8
+ import "./index-iNhkcAEQ.js";
9
9
  import "./ops/cpu/scatterSub.js";
10
10
  import "./ops/webgl/scatterSub.js";
11
11
  import "./ops/cpu/gatherSub.js";
@@ -1,4 +1,4 @@
1
- import { o as m, h as s, p as c, E as M, B as p } from "./index--6vO-cOz.js";
1
+ import { o as m, i as s, q as c, E as M, B as p } from "./index-iNhkcAEQ.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -21,7 +21,7 @@ function f(e, o, n = !1, l = !1) {
21
21
  const r = { a, b: t }, u = { transposeA: n, transposeB: l };
22
22
  return M.runKernel(p, r, u);
23
23
  }
24
- const h = /* @__PURE__ */ m({ matMul_: f });
24
+ const b = /* @__PURE__ */ m({ matMul_: f });
25
25
  export {
26
- h as m
26
+ b as m
27
27
  };
@@ -1,4 +1,4 @@
1
- import { o as r, h as e, E as x, M as c } from "./index--6vO-cOz.js";
1
+ import { o as r, i as e, E as x, M as c } from "./index-iNhkcAEQ.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -15,11 +15,11 @@ import { o as r, h as e, E as x, M as c } from "./index--6vO-cOz.js";
15
15
  * limitations under the License.
16
16
  * =============================================================================
17
17
  */
18
- function m(n, o = null, s = !1) {
18
+ function i(n, o = null, s = !1) {
19
19
  const t = { x: e(n, "x", "max") }, a = { reductionIndices: o, keepDims: s };
20
20
  return x.runKernel(c, t, a);
21
21
  }
22
- const l = /* @__PURE__ */ r({ max_: m });
22
+ const l = /* @__PURE__ */ r({ max_: i });
23
23
  export {
24
24
  l as m
25
25
  };
@@ -1,6 +1,6 @@
1
- import { o as m, h as c, E as f, _ as i, $ as l, a0 as h, s as x, x as d } from "./index--6vO-cOz.js";
2
- import { e as v } from "./axis_util-QP0LdI1v.js";
3
- import { r as E } from "./reshape-z51Eu-re.js";
1
+ import { o as m, i as c, E as i, a0 as f, a1 as l, a2 as h, s as x, y as d } from "./index-iNhkcAEQ.js";
2
+ import { e as v } from "./axis_util-97KkkyRQ.js";
3
+ import { r as E } from "./reshape-DxTPgnwL.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2020 Google Inc. All Rights Reserved.
@@ -19,7 +19,7 @@ import { r as E } from "./reshape-z51Eu-re.js";
19
19
  */
20
20
  function S(a, t = null, e = !1) {
21
21
  const s = { x: c(a, "x", "mean") }, o = { axis: t, keepDims: e };
22
- return f.runKernel(i, s, o);
22
+ return i.runKernel(f, s, o);
23
23
  }
24
24
  const r = /* @__PURE__ */ m({ mean_: S });
25
25
  /**
@@ -46,8 +46,8 @@ function T(a, t = null, e = !1) {
46
46
  const p = h(x(d(a, "float32"), E(s, o))), u = r(p, n, e);
47
47
  return { mean: s, variance: u };
48
48
  }
49
- const K = /* @__PURE__ */ m({ moments_: T });
49
+ const N = /* @__PURE__ */ m({ moments_: T });
50
50
  export {
51
- K as a,
51
+ N as a,
52
52
  r as m
53
53
  };
@@ -1,8 +1,8 @@
1
- import { o as l, h as c, E as y, a1 as E, $ as w, a2 as o, a3 as u, U as v, f as I, a0 as $ } from "./index--6vO-cOz.js";
2
- import { e as A } from "./axis_util-QP0LdI1v.js";
3
- import { m as f } from "./max-BUShNgfh.js";
4
- import { r as h } from "./reshape-z51Eu-re.js";
5
- import { s as t } from "./sum-DdkDf2MG.js";
1
+ import { o as l, i as c, E as y, a3 as E, a1 as w, a4 as o, a5 as u, U as v, f as I, a2 as A } from "./index-iNhkcAEQ.js";
2
+ import { e as $ } from "./axis_util-97KkkyRQ.js";
3
+ import { m as f } from "./max-CYaAjEEp.js";
4
+ import { r as h } from "./reshape-DxTPgnwL.js";
5
+ import { s as t } from "./sum-B_92TaHD.js";
6
6
  /**
7
7
  * @license
8
8
  * Copyright 2020 Google Inc. All Rights Reserved.
@@ -46,7 +46,7 @@ function T(n, e = "euclidean", r = null, m = !1) {
46
46
  let i = a.shape;
47
47
  if (m) {
48
48
  const p = w(r, n.shape);
49
- i = A(a.shape, p);
49
+ i = $(a.shape, p);
50
50
  }
51
51
  return h(a, i);
52
52
  }
@@ -74,7 +74,7 @@ function d(n, e, r = null) {
74
74
  if (e === -1 / 0)
75
75
  return s(t(o(n), r[1]), r[0]);
76
76
  if (e === "fro" || e === "euclidean")
77
- return u(t($(n), r));
77
+ return u(t(A(n), r));
78
78
  throw new Error(`Error in norm: invalid ord value: ${e}`);
79
79
  }
80
80
  throw new Error(`Error in norm: invalid axis: ${r}`);
@@ -1,5 +1,5 @@
1
- import { k as n, l as t, n as m, E as i } from "./index--6vO-cOz.js";
2
- import { z as l, c } from "./zeros-8xl-W2DC.js";
1
+ import { l as n, n as t, p as m, E as i } from "./index-iNhkcAEQ.js";
2
+ import { z as l, c } from "./zeros-NMYTayy7.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,12 +1,12 @@
1
- import { e as a } from "../index--6vO-cOz.js";
1
+ import { e as a } from "../index-iNhkcAEQ.js";
2
2
  import "./cpu/appendCache.js";
3
3
  import "./webgl/appendCache.js";
4
- import { z as s } from "../zeros-8xl-W2DC.js";
5
- import { c } from "../concat-DvWM7HGZ.js";
4
+ import { c as s } from "../concat-Cxbo2sOz.js";
5
+ import { z as c } from "../zeros-NMYTayy7.js";
6
6
  function i(r, p, n, o) {
7
7
  if (!o) {
8
8
  const e = r.shape[2];
9
- return c([r, s([r.shape[0], r.shape[1], p - e, r.shape[3]])], 2);
9
+ return s([r, c([r.shape[0], r.shape[1], p - e, r.shape[3]])], 2);
10
10
  }
11
11
  return a().runKernel("AppendCache", { cache: o, item: r }, { maxSize: p, pastLen: n });
12
12
  }
@@ -1,2 +1,2 @@
1
1
  import { Tensor } from '@tensorflow/tfjs-core';
2
- export declare function attentionMask(q: Tensor, k: Tensor, divisor: number, mask?: Tensor, pastLen?: number): Tensor;
2
+ export declare function attentionMask(q: Tensor, k: Tensor, divisor: number, pastLen?: number): Tensor;
@@ -1,10 +1,10 @@
1
- import { e as i } from "../index--6vO-cOz.js";
1
+ import { e as o } from "../index-iNhkcAEQ.js";
2
2
  import "./cpu/attentionMask.js";
3
3
  import "./webgl/attentionMask.js";
4
4
  import "./grads/attentionMask.js";
5
- function f(t, n, e, r, o) {
6
- return r ? i().runKernel("AttentionMask", { q: t, k: n, mask: r }, { divisor: e, pastLen: o || 0 }) : i().runKernel("AttentionMask", { q: t, k: n }, { divisor: e, pastLen: o || 0 });
5
+ function s(t, n, e, r) {
6
+ return o().runKernel("AttentionMask", { q: t, k: n }, { divisor: e, pastLen: r || 0 });
7
7
  }
8
8
  export {
9
- f as attentionMask
9
+ s as attentionMask
10
10
  };
@@ -1,5 +1,5 @@
1
- import { r as d } from "../../index--6vO-cOz.js";
2
- import { c as h } from "../../concat-DvWM7HGZ.js";
1
+ import { r as d } from "../../index-iNhkcAEQ.js";
2
+ import { c as h } from "../../concat-Cxbo2sOz.js";
3
3
  function u(p) {
4
4
  const { cache: n, item: s } = p.inputs, { maxSize: r, pastLen: c } = p.attrs, t = n.shape[0], o = n.shape[1], a = n.shape[3], e = s.shape[2];
5
5
  if (c + e <= r) {
@@ -1,22 +1,21 @@
1
- import { r as o, f as k } from "../../index--6vO-cOz.js";
2
- import { m as d } from "../../mat_mul-BEHRPMh0.js";
3
- function r(t) {
4
- const { q: e, k: n, mask: s } = t.inputs, { divisor: c } = t.attrs, m = e.shape[2], i = n.shape[2], a = d(e, n, !1, !0).mul(k(c));
5
- if (s) {
6
- const l = s.slice([0, 0], [m, i]).expandDims(0).expandDims(0);
7
- return a.add(l);
8
- }
9
- return a;
1
+ import { r as a, g as p, f as u } from "../../index-iNhkcAEQ.js";
2
+ import { l as N, w as b } from "../../ops-ObfXLHYQ.js";
3
+ import { o as g } from "../../ones-BIeFnPHR.js";
4
+ import { z as A } from "../../zeros-NMYTayy7.js";
5
+ import { m as I } from "../../mat_mul-D0SifYfJ.js";
6
+ function o(n) {
7
+ const { q: s, k: e } = n.inputs, { divisor: r } = n.attrs, c = s.shape[2], t = e.shape[2], m = N.bandPart(g([t, t]), -1, 0).cast("bool"), l = A([t, t]), i = p([t, t], Number.NEGATIVE_INFINITY), f = b(m, l, i), k = I(s, e, !1, !0).mul(u(r)), d = f.slice([0, 0], [c, t]).expandDims(0).expandDims(0);
8
+ return k.add(d);
10
9
  }
11
- const u = {
10
+ const w = {
12
11
  kernelName: "AttentionMask",
13
12
  backendName: "cpu",
14
- kernelFunc: r
13
+ kernelFunc: o
15
14
  };
16
- o(u);
17
- const f = {
15
+ a(w);
16
+ const M = {
18
17
  kernelName: "AttentionMask",
19
18
  backendName: "tensorflow",
20
- kernelFunc: r
19
+ kernelFunc: o
21
20
  };
22
- o(f);
21
+ a(M);
@@ -1,5 +1,5 @@
1
- import { r as n } from "../../index--6vO-cOz.js";
2
- import { s as f } from "../../softmax-Dsxflvdl.js";
1
+ import { r as n } from "../../index-iNhkcAEQ.js";
2
+ import { s as f } from "../../softmax-BjsptB07.js";
3
3
  function r(t) {
4
4
  const { inputs: s, attrs: i } = t, { logits: o } = s, { dim: a, dropoutRate: e } = i;
5
5
  if (!o)
@@ -1,6 +1,6 @@
1
- import { o as u, h as c, E as g, N as h, r as m, s as p } from "../../index--6vO-cOz.js";
2
- import { r as N } from "../../range-C_vpUjBu.js";
3
- import { s as l } from "../../stack-CmqSdsfs.js";
1
+ import { o as u, i as c, E as g, L as h, r as m, s as p } from "../../index-iNhkcAEQ.js";
2
+ import { r as l } from "../../range-BsFU-SNG.js";
3
+ import { s as N } from "../../stack--cqr9Dgc.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -23,8 +23,8 @@ function f(e, s) {
23
23
  }
24
24
  const b = /* @__PURE__ */ u({ gatherND_: f });
25
25
  function d(e) {
26
- const { values: s, labels: n, logits: r } = e.inputs, t = n.shape[0], o = N(0, t, 1, "int32"), a = l([o, n], 1), i = b(r, a);
27
- return p(s, i);
26
+ const { values: s, labels: n, logits: r } = e.inputs, t = n.shape[0], i = l(0, t, 1, "int32"), o = N([i, n], 1), a = b(r, o);
27
+ return p(s, a);
28
28
  }
29
29
  const k = {
30
30
  kernelName: "EfficientGatherSub",
@@ -1,4 +1,4 @@
1
- import { r as t, t as d } from "../../index--6vO-cOz.js";
1
+ import { r as t, t as d } from "../../index-iNhkcAEQ.js";
2
2
  const o = 0.7978845608028654, c = 0.044715;
3
3
  function m(u) {
4
4
  const { inputs: l } = u, { x: e } = l, n = e;
@@ -1,4 +1,4 @@
1
- import { r as a, t as i } from "../../index--6vO-cOz.js";
1
+ import { r as a, t as i } from "../../index-iNhkcAEQ.js";
2
2
  const c = 0.7978845608028654, m = 0.044715;
3
3
  function M(o) {
4
4
  const { inputs: s } = o, { x: t, kernel: l } = s, e = t, u = l;
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,17 @@
1
+ import { r as n, t as M } from "../../index-iNhkcAEQ.js";
2
+ function e(t) {
3
+ const { inputs: r, attrs: o } = t, { transposeA: s, transposeB: l } = o, { x: c, kernel: u, y: a } = r, m = c, i = u, k = a;
4
+ return M(() => m.matMul(i, s, l).mul(k));
5
+ }
6
+ const f = {
7
+ kernelName: "MatMulMul",
8
+ backendName: "cpu",
9
+ kernelFunc: e
10
+ };
11
+ n(f);
12
+ const p = {
13
+ kernelName: "MatMulMul",
14
+ backendName: "tensorflow",
15
+ kernelFunc: e
16
+ };
17
+ n(p);
@@ -1,4 +1,4 @@
1
- import { r as e, b as u } from "../../index--6vO-cOz.js";
1
+ import { r as e, b as u } from "../../index-iNhkcAEQ.js";
2
2
  function n(o) {
3
3
  const { inputs: r } = o, { a: l, b: t } = r;
4
4
  return console.warn("Using fallback mulDrop implementation without dropout."), u(l, t);
@@ -1,4 +1,4 @@
1
- import { r as o, t as d } from "../../index--6vO-cOz.js";
1
+ import { r as o, t as d } from "../../index-iNhkcAEQ.js";
2
2
  function i(t) {
3
3
  const { inputs: e } = t, { x: n, gamma: s } = e, r = n, a = s;
4
4
  return d(() => {
@@ -1,6 +1,6 @@
1
- import { r as q } from "../../index--6vO-cOz.js";
2
- import { r as o } from "../../reshape-z51Eu-re.js";
3
- import { s as x } from "../../split-B_k_jwud.js";
1
+ import { r as q } from "../../index-iNhkcAEQ.js";
2
+ import { r as o } from "../../reshape-DxTPgnwL.js";
3
+ import { s as x } from "../../split-BCbrzthj.js";
4
4
  function v(p) {
5
5
  const { x: c, kernel: K } = p.inputs, { heads: n } = p.attrs, [s, e, t] = c.shape, a = o(c, [s * e, t]), i = a.dot(K);
6
6
  a.dispose();
@@ -1,8 +1,8 @@
1
- import { r as S } from "../../index--6vO-cOz.js";
2
- import { r as F } from "../../range-C_vpUjBu.js";
3
- import { g as I } from "../../gather-C5D8PxwA.js";
4
- import { s as E } from "../../stack-CmqSdsfs.js";
5
- import { c as T } from "../../concat-DvWM7HGZ.js";
1
+ import { r as S } from "../../index-iNhkcAEQ.js";
2
+ import { r as F } from "../../range-BsFU-SNG.js";
3
+ import { g as I } from "../../gather-Bxe1Qip8.js";
4
+ import { s as E } from "../../stack--cqr9Dgc.js";
5
+ import { c as T } from "../../concat-Cxbo2sOz.js";
6
6
  function U(t, c, p, o, r) {
7
7
  const n = o.shape[3], s = p;
8
8
  if (s > n) return o;
@@ -1,7 +1,7 @@
1
- import { o as l, k, h, E as g, a5 as w, r as $, s as d, b as m } from "../../index--6vO-cOz.js";
2
- import { r as b } from "../../range-C_vpUjBu.js";
3
- import { s as E } from "../../stack-CmqSdsfs.js";
4
- import { o as D } from "../../ones-D6kB8bdY.js";
1
+ import { o as l, l as g, i, E as k, a6 as w, r as $, s as d, b as m } from "../../index-iNhkcAEQ.js";
2
+ import { r as b } from "../../range-BsFU-SNG.js";
3
+ import { s as E } from "../../stack--cqr9Dgc.js";
4
+ import { o as D } from "../../ones-BIeFnPHR.js";
5
5
  function N(n, r, t) {
6
6
  const s = r.rank > 1 ? r.shape[r.rank - 1] : 1, e = r.rank > 1 ? r.rank - 1 : 1, o = `Must have updates.shape = indices.shape[:batchDim] + shape[sliceDim:], got updates.shape: ${t.shape}, indices.shape: ${r.shape}, shape: ${n}, sliceDim: ${s}, and batchDim: ${e}.`;
7
7
  if (t.rank < e)
@@ -51,15 +51,15 @@ function S(n, r, t) {
51
51
  * =============================================================================
52
52
  */
53
53
  function y(n, r, t) {
54
- k(t);
55
- const s = h(n, "indices", "scatterND", "int32"), e = h(r, "updates", "scatterND");
54
+ g(t);
55
+ const s = i(n, "indices", "scatterND", "int32"), e = i(r, "updates", "scatterND");
56
56
  S(e, s, t);
57
57
  const o = { indices: s, updates: e }, a = { shape: t };
58
- return g.runKernel(w, o, a);
58
+ return k.runKernel(w, o, a);
59
59
  }
60
60
  const v = /* @__PURE__ */ l({ scatterND_: y });
61
61
  function I(n) {
62
- const { logits: r, labels: t, dy: s } = n.inputs, e = t.shape[0], o = r.shape[1], a = b(0, e, 1, "int32"), i = E([a, t], 1), c = D([e]), p = v(i, c, [e, o]), f = d(r, p), u = s.reshape([e, 1]);
62
+ const { logits: r, labels: t, dy: s } = n.inputs, e = t.shape[0], o = r.shape[1], a = b(0, e, 1, "int32"), h = E([a, t], 1), c = D([e]), p = v(h, c, [e, o]), f = d(r, p), u = s.reshape([e, 1]);
63
63
  return m(f, u);
64
64
  }
65
65
  const T = {
@@ -1,4 +1,4 @@
1
- import { e as t } from "../index--6vO-cOz.js";
1
+ import { e as t } from "../index-iNhkcAEQ.js";
2
2
  import "./cpu/fusedSoftmax.js";
3
3
  import "./webgl/fusedSoftmax.js";
4
4
  import "./grads/fusedSoftmax.js";
@@ -1,4 +1,4 @@
1
- import { e as n } from "../index--6vO-cOz.js";
1
+ import { e as n } from "../index-iNhkcAEQ.js";
2
2
  import "./cpu/gatherSub.js";
3
3
  import "./webgl/gatherSub.js";
4
4
  function f(r, e, t) {
package/dist/ops/gelu.js CHANGED
@@ -1,4 +1,4 @@
1
- import "../index--6vO-cOz.js";
1
+ import "../index-iNhkcAEQ.js";
2
2
  import "./cpu/gelu.js";
3
3
  import "./webgl/gelu.js";
4
4
  import { d as e, g as i } from "./grads/gelu.js";
@@ -1,21 +1,25 @@
1
- import { g as i } from "../../index--6vO-cOz.js";
2
- const u = {
1
+ import { h as l, f as a } from "../../index-iNhkcAEQ.js";
2
+ import { matMulMul as i } from "../matMulMul.js";
3
+ const m = {
3
4
  kernelName: "AttentionMask",
4
5
  inputsToSave: ["q", "k"],
5
6
  outputsToSave: [],
6
- gradFunc: (t, a, n) => {
7
+ gradFunc: (t, u, c) => {
7
8
  if (Array.isArray(t))
8
9
  throw new Error("Expected dy to be a single Tensor");
9
- const [r, e] = a, { divisor: s } = n;
10
+ const [e, o] = u, { divisor: n } = c;
10
11
  return {
11
- q: () => t.matMul(e).mul(s),
12
- k: () => r.transpose([0, 1, 3, 2]).matMul(t).mul(s).transpose([0, 1, 3, 2]),
12
+ q: () => i(t, o, a(n)),
13
+ k: () => {
14
+ const s = e.transpose([0, 1, 3, 2]), r = i(s, t, a(n));
15
+ return s.dispose(), r.transpose([0, 1, 3, 2]);
16
+ },
13
17
  mask: () => t,
14
18
  divisor: () => {
15
- const o = r.matMul(e, !1, !0);
16
- return t.mul(o).sum();
19
+ const s = e.matMul(o, !1, !0), r = t.mul(s);
20
+ return s.dispose(), r.sum();
17
21
  }
18
22
  };
19
23
  }
20
24
  };
21
- i(u);
25
+ l(m);
@@ -1,17 +1,20 @@
1
- import { g as p, b as m, s as d } from "../../index--6vO-cOz.js";
2
- import { mulDrop as c } from "../mulDrop.js";
3
- import { s as f } from "../../sum-DdkDf2MG.js";
4
- const g = {
1
+ import { h as d, b as u, s as f } from "../../index-iNhkcAEQ.js";
2
+ import { mulDrop as l } from "../mulDrop.js";
3
+ import { s as g } from "../../sum-B_92TaHD.js";
4
+ const T = {
5
5
  kernelName: "FusedSoftmax",
6
6
  outputsToSave: [!0],
7
- gradFunc: (s, a, u) => {
8
- const [o] = a, { dim: i, dropoutRate: t, seed: r } = u, n = !0, e = t && r ? c(s, o, t, r) : m(s, o);
7
+ gradFunc: (o, i, n) => {
8
+ const [s] = i, { dim: a, dropoutRate: t, seed: e } = n, p = !0, r = t && e ? l(o, s, t, e) : u(o, s);
9
9
  return {
10
- logits: () => d(e, m(f(e, [i], n), o))
10
+ logits: () => {
11
+ const m = g(r, [a], p), c = u(m, s);
12
+ return m.dispose(), f(r, c);
13
+ }
11
14
  };
12
15
  }
13
16
  };
14
- p(g);
17
+ d(T);
15
18
  export {
16
- g as softmaxGradConfig
19
+ T as softmaxGradConfig
17
20
  };
@@ -1,4 +1,4 @@
1
- import { g as t, e as n } from "../../index--6vO-cOz.js";
1
+ import { h as t, e as n } from "../../index-iNhkcAEQ.js";
2
2
  import "../cpu/gelu.js";
3
3
  import "../webgl/gelu.js";
4
4
  const o = {
@@ -1,4 +1,4 @@
1
- import { g as a, e as o } from "../../index--6vO-cOz.js";
1
+ import { h as a, e as o } from "../../index-iNhkcAEQ.js";
2
2
  function s(e, n, r) {
3
3
  return o().runKernel("MatMulGeluGrad", { dy: e, x: n, kernel: r });
4
4
  }
@@ -1,4 +1,4 @@
1
- import { g as t, e as g } from "../../index--6vO-cOz.js";
1
+ import { h as t, e as g } from "../../index-iNhkcAEQ.js";
2
2
  function s(r, a, n) {
3
3
  return g().runKernel("RMSNormGrad", { dy: r, x: a, gamma: n });
4
4
  }
@@ -1,20 +1,30 @@
1
- import { g as v } from "../../index--6vO-cOz.js";
2
- const g = {
1
+ import { h as Q } from "../../index-iNhkcAEQ.js";
2
+ const V = {
3
3
  kernelName: "QKV",
4
4
  inputsToSave: ["x", "kernel"],
5
5
  outputsToSave: [],
6
- gradFunc: (k, m) => {
7
- const [x, K, f] = k, [c, r] = m, [t, s, e] = c.shape, d = x.transpose([0, 2, 1, 3]).reshape([t * s, e]), l = K.transpose([0, 2, 1, 3]).reshape([t * s, e]), u = f.transpose([0, 2, 1, 3]).reshape([t * s, e]), i = r.slice([0, 0], [e, e]), h = r.slice([0, e], [e, e]), M = r.slice([0, 2 * e], [e, e]);
6
+ gradFunc: (x, K) => {
7
+ const [f, h, M] = x, [p, l] = K, [t, n, e] = p.shape, i = f.transpose([0, 2, 1, 3]).reshape([t * n, e]), u = h.transpose([0, 2, 1, 3]).reshape([t * n, e]), k = M.transpose([0, 2, 1, 3]).reshape([t * n, e]);
8
8
  return {
9
9
  x: () => {
10
- const n = d.matMul(i, !1, !0), a = l.matMul(h, !1, !0), o = u.matMul(M, !1, !0);
11
- return n.add(a).add(o).reshape([t, s, e]);
10
+ const s = l.slice([0, 0], [e, e]), o = i.matMul(s, !1, !0);
11
+ s.dispose();
12
+ const d = l.slice([0, e], [e, e]), r = u.matMul(d, !1, !0);
13
+ d.dispose();
14
+ const a = o.add(r);
15
+ o.dispose(), r.dispose();
16
+ const c = l.slice([0, 2 * e], [e, e]), m = k.matMul(c, !1, !0);
17
+ c.dispose();
18
+ const v = a.add(m).reshape([t, n, e]);
19
+ return a.dispose(), m.dispose(), v;
12
20
  },
13
21
  kernel: () => {
14
- const n = c.reshape([t * s, e]), a = n.matMul(d, !0, !1), o = n.matMul(l, !0, !1), p = n.matMul(u, !0, !1);
15
- return a.concat(o, 1).concat(p, 1);
22
+ const s = p.reshape([t * n, e]), o = s.matMul(i, !0, !1), d = s.matMul(u, !0, !1), r = o.concat(d, 1);
23
+ o.dispose(), d.dispose();
24
+ const a = s.matMul(k, !0, !1), c = r.concat(a, 1);
25
+ return r.dispose(), a.dispose(), s.dispose(), c;
16
26
  }
17
27
  };
18
28
  }
19
29
  };
20
- v(g);
30
+ Q(V);