@genai-fi/nanogpt 0.4.3 → 0.4.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/dist/Generator.js +3 -3
  2. package/dist/NanoGPTModel.js +8 -8
  3. package/dist/Reshape-CiAY8ltP.js +212 -0
  4. package/dist/TeachableLLM.js +14 -5
  5. package/dist/{TiedEmbedding-CnJ1bx4q.js → TiedEmbedding-DznFwzcB.js} +244 -244
  6. package/dist/{axis_util-BgTGy5w8.js → axis_util-QP0LdI1v.js} +1 -1
  7. package/dist/{concat-CuRsVY-K.js → concat-DvWM7HGZ.js} +1 -1
  8. package/dist/data/parquet.js +9 -6
  9. package/dist/data/textLoader.js +6 -5
  10. package/dist/{dropout-DfDdklfL.js → dropout-DFEXTPV0.js} +4 -4
  11. package/dist/{gather-ZYRWhmXR.js → gather-C5D8PxwA.js} +1 -1
  12. package/dist/gpgpu_math-CUzjlO9A.js +23 -0
  13. package/dist/{index-C4JCoBvj.js → index--6vO-cOz.js} +87 -87
  14. package/dist/{kernel_funcs_utils-CAd1h9X1.js → kernel_funcs_utils-C6YBCuOt.js} +72 -91
  15. package/dist/layers/CausalSelfAttention.js +47 -46
  16. package/dist/layers/MLP.js +31 -33
  17. package/dist/layers/RMSNorm.d.ts +1 -2
  18. package/dist/layers/RMSNorm.js +10 -10
  19. package/dist/layers/RoPECache.js +3 -3
  20. package/dist/layers/TiedEmbedding.js +5 -5
  21. package/dist/layers/TransformerBlock.js +2 -2
  22. package/dist/{log_sum_exp-BswFnwOb.js → log_sum_exp-CiEy1aUe.js} +7 -7
  23. package/dist/main.js +28 -19
  24. package/dist/{mat_mul-415y5Qn2.js → mat_mul-BEHRPMh0.js} +1 -1
  25. package/dist/{max-CP_9O2Yd.js → max-BUShNgfh.js} +1 -1
  26. package/dist/{moments-CjeIaVdp.js → moments-DYOHXoRV.js} +5 -5
  27. package/dist/{norm-CZM380I3.js → norm-DSva3hI3.js} +13 -13
  28. package/dist/{ones-Bf3YR48P.js → ones-D6kB8bdY.js} +2 -2
  29. package/dist/ops/appendCache.js +3 -3
  30. package/dist/ops/attentionMask.js +1 -1
  31. package/dist/ops/cpu/appendCache.js +2 -2
  32. package/dist/ops/cpu/attentionMask.js +2 -2
  33. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  34. package/dist/ops/cpu/gatherSub.js +4 -4
  35. package/dist/ops/cpu/gelu.js +1 -1
  36. package/dist/ops/cpu/matMulGelu.d.ts +1 -0
  37. package/dist/ops/cpu/matMulGelu.js +40 -0
  38. package/dist/ops/cpu/mulDropout.js +1 -1
  39. package/dist/ops/cpu/normRMS.d.ts +1 -0
  40. package/dist/ops/cpu/normRMS.js +39 -0
  41. package/dist/ops/cpu/qkv.js +3 -3
  42. package/dist/ops/cpu/rope.js +5 -5
  43. package/dist/ops/cpu/scatterSub.js +4 -4
  44. package/dist/ops/fusedSoftmax.js +1 -1
  45. package/dist/ops/gatherSub.js +1 -1
  46. package/dist/ops/gelu.js +2 -2
  47. package/dist/ops/grads/attentionMask.js +1 -1
  48. package/dist/ops/grads/fusedSoftmax.js +2 -2
  49. package/dist/ops/grads/gelu.js +24 -3
  50. package/dist/ops/grads/matMulGelu.d.ts +1 -0
  51. package/dist/ops/grads/matMulGelu.js +17 -0
  52. package/dist/ops/grads/normRMS.d.ts +2 -0
  53. package/dist/ops/grads/normRMS.js +20 -0
  54. package/dist/ops/grads/qkv.js +1 -1
  55. package/dist/ops/grads/rope.js +1 -1
  56. package/dist/ops/matMulGelu.d.ts +3 -0
  57. package/dist/ops/matMulGelu.js +14 -0
  58. package/dist/ops/mulDrop.js +1 -1
  59. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  60. package/dist/ops/normRMS.d.ts +2 -0
  61. package/dist/ops/normRMS.js +10 -0
  62. package/dist/ops/qkv.js +1 -1
  63. package/dist/ops/scatterSub.js +1 -1
  64. package/dist/ops/webgl/appendCache.js +1 -1
  65. package/dist/ops/webgl/attentionMask.js +1 -1
  66. package/dist/ops/webgl/fusedSoftmax.js +689 -895
  67. package/dist/ops/webgl/gatherSub.js +1 -1
  68. package/dist/ops/webgl/gelu.js +2 -2
  69. package/dist/ops/webgl/matMulGelu.d.ts +21 -0
  70. package/dist/ops/webgl/matMulGelu.js +168 -0
  71. package/dist/ops/webgl/mulDropout.js +1 -1
  72. package/dist/ops/webgl/normRMS.d.ts +1 -0
  73. package/dist/ops/webgl/normRMS.js +78 -0
  74. package/dist/ops/webgl/qkv.js +1 -1
  75. package/dist/ops/webgl/rope.js +1 -1
  76. package/dist/ops/webgl/scatterSub.js +1 -1
  77. package/dist/{range-9AzeApCc.js → range-C_vpUjBu.js} +1 -1
  78. package/dist/{reshape-Boe4DuIO.js → reshape-z51Eu-re.js} +1 -1
  79. package/dist/{sin-KmhiDuMa.js → sin-H567uayl.js} +1 -1
  80. package/dist/{slice_util-19zDNNSn.js → slice_util-BdhYwFY_.js} +2 -2
  81. package/dist/{softmax-Cujsg4ay.js → softmax-Dsxflvdl.js} +1 -1
  82. package/dist/{split-DbcNm1-i.js → split-B_k_jwud.js} +1 -1
  83. package/dist/{stack-D1YjmgKN.js → stack-CmqSdsfs.js} +1 -1
  84. package/dist/{sum-R28pucR5.js → sum-DdkDf2MG.js} +1 -1
  85. package/dist/{tensor-BVeHdl7V.js → tensor-BGYi41cj.js} +1 -1
  86. package/dist/{tensor2d-DqFGNs_K.js → tensor2d-DUr_htjt.js} +1 -1
  87. package/dist/{tfjs_backend-Cug-PH75.js → tfjs_backend-DuKis_xG.js} +46 -46
  88. package/dist/training/AdamExt.js +1 -1
  89. package/dist/training/DatasetBuilder.js +18 -18
  90. package/dist/training/FullTrainer.js +1 -1
  91. package/dist/training/Trainer.js +5 -5
  92. package/dist/training/sparseCrossEntropy.js +4 -4
  93. package/dist/utilities/dummy.js +2 -2
  94. package/dist/utilities/generate.js +3 -3
  95. package/dist/utilities/load.js +1 -1
  96. package/dist/utilities/profile.js +1 -1
  97. package/dist/utilities/weights.js +2 -2
  98. package/dist/{variable-LJT9Ld63.js → variable-BJTZ3jOy.js} +1 -1
  99. package/dist/{zeros-dnQxFgAD.js → zeros-8xl-W2DC.js} +1 -1
  100. package/package.json +1 -1
  101. package/dist/gelu-CnCt17Lk.js +0 -26
@@ -3,15 +3,15 @@ import T from "./BaseLayer.js";
3
3
  import { qkv as y } from "../ops/qkv.js";
4
4
  import { rope as w } from "../ops/rope.js";
5
5
  import { appendCache as E } from "../ops/appendCache.js";
6
- import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index-C4JCoBvj.js";
6
+ import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index--6vO-cOz.js";
7
7
  import { fusedSoftmax as _ } from "../ops/fusedSoftmax.js";
8
- import { l as W, w as M, d as x } from "../tfjs_backend-Cug-PH75.js";
9
- import { o as N } from "../ones-Bf3YR48P.js";
10
- import { v as A } from "../variable-LJT9Ld63.js";
11
- import { z as q } from "../zeros-dnQxFgAD.js";
12
- import { r as C, d as I } from "../dropout-DfDdklfL.js";
13
- import { r as B } from "../reshape-Boe4DuIO.js";
14
- import { m as F } from "../mat_mul-415y5Qn2.js";
8
+ import { l as W, w as M, d as x } from "../tfjs_backend-DuKis_xG.js";
9
+ import { o as q } from "../ones-D6kB8bdY.js";
10
+ import { v as b } from "../variable-BJTZ3jOy.js";
11
+ import { z as B } from "../zeros-8xl-W2DC.js";
12
+ import { r as C, d as I } from "../dropout-DFEXTPV0.js";
13
+ import { r as F } from "../reshape-z51Eu-re.js";
14
+ import { m as H } from "../mat_mul-BEHRPMh0.js";
15
15
  class nt extends T {
16
16
  cAttn = null;
17
17
  cProj = null;
@@ -23,16 +23,16 @@ class nt extends T {
23
23
  units;
24
24
  projUnits;
25
25
  constructor(t, s) {
26
- super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(N([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
27
- const o = q([s.gpt.blockSize, s.gpt.blockSize]), e = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
28
- this.maskInf = M(this.bias, o, e);
26
+ super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(q([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
27
+ const e = B([s.gpt.blockSize, s.gpt.blockSize]), o = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
28
+ this.maskInf = M(this.bias, e, o);
29
29
  }
30
30
  build() {
31
- this.cAttn === null && (this.cAttn = A(
31
+ this.cAttn === null && (this.cAttn = b(
32
32
  C([this.config.gpt.nEmbed, this.units], 0, 0.02),
33
33
  !0
34
34
  //`block_${this.index}_attn_cAttn_kernel`
35
- )), this.cProj === null && (this.cProj = A(
35
+ )), this.cProj === null && (this.cProj = b(
36
36
  C([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
37
37
  !0
38
38
  //`block_${this.index}_attn_cProj_kernel`
@@ -53,57 +53,58 @@ class nt extends T {
53
53
  t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj ? [this.cProj.clone()] : []);
54
54
  }
55
55
  loadWeights(t) {
56
- const s = t.get(`block_${this.index}_cAttn`)?.[0], o = t.get(`block_${this.index}_cProj`)?.[0];
56
+ const s = t.get(`block_${this.index}_cAttn`)?.[0], e = t.get(`block_${this.index}_cProj`)?.[0];
57
57
  if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
58
- if (!o) throw new Error(`Weights for block_${this.index}_cProj not found`);
59
- this.cAttn ? this.cAttn.assign(s) : this.cAttn = A(s, !0), this.cProj ? this.cProj.assign(o) : this.cProj = A(o, !0);
58
+ if (!e) throw new Error(`Weights for block_${this.index}_cProj not found`);
59
+ this.cAttn ? this.cAttn.assign(s) : this.cAttn = b(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = b(e, !0);
60
60
  }
61
- getAttentionScores(t, s, o, e) {
61
+ getAttentionScores(t, s, e, o) {
62
62
  const i = P(t, s, this.divisor, this.maskInf);
63
- return _(i, o ? this.config.gpt.dropout : 0, e);
63
+ return _(i, e ? this.config.gpt.dropout : 0, o);
64
64
  }
65
65
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
66
- getAttentionScoresWithPast(t, s, o) {
67
- const e = P(t, s, this.divisor, void 0, o);
68
- return _(e, 0, 0);
66
+ getAttentionScoresWithPast(t, s, e) {
67
+ const o = P(t, s, this.divisor, void 0, e);
68
+ return _(o, 0, 0);
69
69
  }
70
70
  getQKV(t) {
71
71
  return y(t, this.cAttn, this.config.gpt.nHead);
72
72
  }
73
73
  getOutputProjection(t) {
74
- const s = t.shape[0], o = t.shape[2], e = this.config.gpt.nEmbed, i = t.transpose([0, 2, 1, 3]), n = B(i, [s, o, e]);
74
+ const s = t.shape[0], e = t.shape[2], o = this.config.gpt.nEmbed, i = t.transpose([0, 2, 1, 3]), n = F(i, [s, e, o]);
75
75
  return x(n, this.cProj);
76
76
  }
77
- updateCache(t, s, o, e) {
78
- const i = this.config.gpt.blockSize, n = t.shape[2], r = e?.length || 0, a = o ? t : E(t, i, r, e?.k), p = o ? s : E(s, i, r, e?.v);
79
- return {
77
+ updateCache(t, s, e, o) {
78
+ const i = this.config.gpt.blockSize, n = t.shape[2], r = o?.length || 0, a = e ? t : E(t, i, r, o?.k);
79
+ e || (t.dispose(), o?.k.dispose());
80
+ const p = e ? s : E(s, i, r, o?.v);
81
+ return e || (s.dispose(), o?.v.dispose()), {
80
82
  k: S(a),
81
83
  v: S(p),
82
84
  length: Math.min(r + n, i),
83
- cumulativeLength: e ? e.cumulativeLength + n : n
85
+ cumulativeLength: o ? o.cumulativeLength + n : n
84
86
  };
85
87
  }
86
- forward(t, s = !1, o, e = !1, i) {
88
+ forward(t, s = !1, e, o = !1, i) {
87
89
  return $(() => {
88
90
  this.startMemory();
89
- const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, f = c ? w(r, c, p) : r;
91
+ const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, A = c ? w(r, c, p) : r;
90
92
  c && (n.dispose(), r.dispose());
91
- const g = i ? i.length : 0, d = this.updateCache(f, a, s, i), l = d.k, m = d.v;
92
- i && (f.dispose(), a.dispose());
93
+ const f = i ? i.length : 0, d = this.updateCache(A, a, s, i), l = d.k, g = d.v;
93
94
  let h;
94
- g > 0 ? h = this.getAttentionScoresWithPast(u, l, g) : h = this.getAttentionScores(u, l, s, o), u.dispose(), s && l.dispose();
95
- const b = F(h, m);
96
- e || h.dispose(), s && m.dispose();
97
- const k = this.getOutputProjection(b);
98
- b.dispose();
99
- const v = e ? h.mean(1) : void 0;
95
+ f > 0 ? h = this.getAttentionScoresWithPast(u, l, f) : h = this.getAttentionScores(u, l, s, e), u.dispose(), s && l.dispose();
96
+ const m = H(h, g);
97
+ o || h.dispose(), s && g.dispose();
98
+ const k = this.getOutputProjection(m);
99
+ m.dispose();
100
+ const v = o ? h.mean(1) : void 0;
100
101
  return this.endMemory("CausalSelfAttention"), { output: k, attention: v, presentKV: s ? void 0 : d };
101
102
  });
102
103
  }
103
- call(t, s = !1, o = !1, e) {
104
- if (e && !this.config.gpt.useRope)
104
+ call(t, s = !1, e = !1, o) {
105
+ if (o && !this.config.gpt.useRope)
105
106
  throw new Error("Cannot use pastKV without RoPE enabled");
106
- if (s && e)
107
+ if (s && o)
107
108
  throw new Error("Cannot use pastKV during training");
108
109
  if (t.shape.length !== 3)
109
110
  throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
@@ -115,15 +116,15 @@ class nt extends T {
115
116
  const r = L(
116
117
  // @ts-expect-error Invalid params
117
118
  (a, p, c, u) => {
118
- const f = this.forward(a, !0, i);
119
+ const A = this.forward(a, !0, i);
119
120
  u([a]);
120
- const g = (d, l) => {
121
- const [m] = l, h = j().state.activeTape;
121
+ const f = (d, l) => {
122
+ const [g] = l, h = j().state.activeTape;
122
123
  j().state.activeTape = [];
123
- const b = O((k, v, R) => this.forward(k, !0, i).output)([m, p, c], d);
124
- return j().state.activeTape = h, b;
124
+ const m = O((k, v, R) => this.forward(k, !0, i).output)([g, p, c], d);
125
+ return j().state.activeTape = h, m;
125
126
  };
126
- return { value: f.output, gradFunc: g };
127
+ return { value: A.output, gradFunc: f };
127
128
  }
128
129
  )(t, this.cAttn, this.cProj);
129
130
  if (this.config.gpt.dropout > 0) {
@@ -132,7 +133,7 @@ class nt extends T {
132
133
  } else
133
134
  return { output: r };
134
135
  } else {
135
- const n = this.forward(t, s, i, o, e);
136
+ const n = this.forward(t, s, i, e, o);
136
137
  if (this.config.gpt.dropout > 0) {
137
138
  const r = I(n.output, this.config.gpt.dropout);
138
139
  return n.output.dispose(), { output: r, attention: n.attention, presentKV: n.presentKV };
@@ -1,11 +1,11 @@
1
- import { t as _, c as M, e as d, H as v } from "../index-C4JCoBvj.js";
2
- import x from "./BaseLayer.js";
3
- import { g as L } from "../gelu-CnCt17Lk.js";
4
- import { v as n } from "../variable-LJT9Ld63.js";
5
- import { r as p, d as u } from "../dropout-DfDdklfL.js";
6
- import { r as l } from "../reshape-Boe4DuIO.js";
7
- import { m as f } from "../mat_mul-415y5Qn2.js";
8
- class V extends x {
1
+ import { t as F, c as _, e as h, H as M } from "../index--6vO-cOz.js";
2
+ import v from "./BaseLayer.js";
3
+ import { matMulGelu as x } from "../ops/matMulGelu.js";
4
+ import { v as c } from "../variable-BJTZ3jOy.js";
5
+ import { r as d, d as u } from "../dropout-DFEXTPV0.js";
6
+ import { r as p } from "../reshape-z51Eu-re.js";
7
+ import { m as L } from "../mat_mul-BEHRPMh0.js";
8
+ class G extends v {
9
9
  cFc = null;
10
10
  cProj = null;
11
11
  index;
@@ -15,12 +15,12 @@ class V extends x {
15
15
  super(s), this.index = t, this.hiddenUnits = s.gpt.mlpFactor * s.gpt.nEmbed;
16
16
  }
17
17
  build() {
18
- this.cFc === null && (this.cFc = n(
19
- p([this.config.gpt.nEmbed, this.hiddenUnits], 0, 0.02),
18
+ this.cFc === null && (this.cFc = c(
19
+ d([this.config.gpt.nEmbed, this.hiddenUnits], 0, 0.02),
20
20
  !0
21
21
  //`block_${this.index}_attn_cAttn_kernel`
22
- )), this.cProj === null && (this.cProj = n(
23
- p(
22
+ )), this.cProj === null && (this.cProj = c(
23
+ d(
24
24
  [this.hiddenUnits, this.config.gpt.nEmbed],
25
25
  0,
26
26
  0.02 / Math.sqrt(2 * this.config.gpt.nLayer)
@@ -45,43 +45,41 @@ class V extends x {
45
45
  const s = t.get(`block_${this.index}_mlpOut`)?.[0], i = t.get(`block_${this.index}_mlpHidden`)?.[0];
46
46
  if (!s || !i)
47
47
  throw new Error(`Weights for block ${this.index} not found`);
48
- this.cFc ? this.cFc.assign(i) : this.cFc = n(i, !0), this.cProj ? this.cProj.assign(s) : this.cProj = n(s, !0);
48
+ this.cFc ? this.cFc.assign(i) : this.cFc = c(i, !0), this.cProj ? this.cProj.assign(s) : this.cProj = c(s, !0);
49
49
  }
50
50
  forward(t) {
51
- return _(() => {
51
+ return F(() => {
52
52
  this.startMemory();
53
- const [s, i, o] = t.shape, r = l(t, [s * i, o]), e = f(r, this.cFc), c = L(e);
53
+ const [s, i, r] = t.shape, o = p(t, [s * i, r]), e = x(o, this.cFc), n = L(e, this.cProj);
54
54
  e.dispose();
55
- const a = f(c, this.cProj);
56
- c.dispose();
57
- const h = l(a, [s, i, o]);
58
- return this.endMemory("MLP"), h;
55
+ const a = p(n, [s, i, r]);
56
+ return this.endMemory("MLP"), a;
59
57
  });
60
58
  }
61
59
  call(t, s = !1) {
62
60
  if (this.build(), s && this.config.layerConfig.checkpointMLP) {
63
- const o = M(
61
+ const r = _(
64
62
  // @ts-expect-error Invalid params
65
- (r, e, c, a) => {
66
- const h = this.forward(r);
67
- return a([r]), { value: h, gradFunc: (g, m) => {
68
- const [b] = m, P = d().state.activeTape;
69
- d().state.activeTape = [];
70
- const j = v((F, w, T) => this.forward(F))([b, e, c], g);
71
- return d().state.activeTape = P, j;
63
+ (o, e, n, a) => {
64
+ const l = this.forward(o);
65
+ return a([o]), { value: l, gradFunc: (f, g) => {
66
+ const [m] = g, b = h().state.activeTape;
67
+ h().state.activeTape = [];
68
+ const P = M((j, w, T) => this.forward(j))([m, e, n], f);
69
+ return h().state.activeTape = b, P;
72
70
  } };
73
71
  }
74
72
  )(t, this.cFc, this.cProj);
75
73
  if (this.config.gpt.dropout > 0) {
76
- const r = u(o, this.config.gpt.dropout);
77
- return o.dispose(), r;
74
+ const o = u(r, this.config.gpt.dropout);
75
+ return r.dispose(), o;
78
76
  }
79
- return o;
77
+ return r;
80
78
  } else {
81
79
  const i = this.forward(t);
82
80
  if (s && this.config.gpt.dropout > 0) {
83
- const o = u(i, this.config.gpt.dropout);
84
- return i.dispose(), o;
81
+ const r = u(i, this.config.gpt.dropout);
82
+ return i.dispose(), r;
85
83
  }
86
84
  return i;
87
85
  }
@@ -91,5 +89,5 @@ class V extends x {
91
89
  }
92
90
  }
93
91
  export {
94
- V as default
92
+ G as default
95
93
  };
@@ -2,8 +2,7 @@ import { Tensor, Variable } from '@tensorflow/tfjs-core';
2
2
  import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
3
3
  export default class RMSNorm extends BaseLayer {
4
4
  private gamma;
5
- private epsilon;
6
- constructor(config: GPTLayerConfig, epsilon?: number, name?: string);
5
+ constructor(config: GPTLayerConfig, name?: string);
7
6
  get trainableWeights(): Variable[];
8
7
  set trainable(value: boolean);
9
8
  getWeights(): Tensor[];
@@ -1,12 +1,12 @@
1
- import { t as r } from "../index-C4JCoBvj.js";
1
+ import { t as r } from "../index--6vO-cOz.js";
2
2
  import m from "./BaseLayer.js";
3
- import { v as i } from "../variable-LJT9Ld63.js";
4
- import { o } from "../ones-Bf3YR48P.js";
5
- class d extends m {
3
+ import { normRMS as s } from "../ops/normRMS.js";
4
+ import { v as e } from "../variable-BJTZ3jOy.js";
5
+ import { o as i } from "../ones-D6kB8bdY.js";
6
+ class u extends m {
6
7
  gamma;
7
- epsilon;
8
- constructor(t, s = 1e-8, a = "") {
9
- super(t), this.epsilon = s, this.gamma = i(o([t.gpt.nEmbed]), !0, `${a}_gamma`, "float32");
8
+ constructor(t, a = "") {
9
+ super(t), this.gamma = e(i([t.gpt.nEmbed]), !0, `${a}_gamma`, "float32");
10
10
  }
11
11
  get trainableWeights() {
12
12
  return [this.gamma];
@@ -23,8 +23,8 @@ class d extends m {
23
23
  apply(t) {
24
24
  return r(() => {
25
25
  this.startMemory();
26
- const a = t.square().mean(-1, !0).add(this.epsilon).rsqrt(), e = t.mul(a).mul(this.gamma);
27
- return this.endMemory("RMSNorm"), e;
26
+ const a = s(t, this.gamma);
27
+ return this.endMemory("RMSNorm"), a;
28
28
  });
29
29
  }
30
30
  dispose() {
@@ -32,5 +32,5 @@ class d extends m {
32
32
  }
33
33
  }
34
34
  export {
35
- d as default
35
+ u as default
36
36
  };
@@ -1,6 +1,6 @@
1
- import { o as h, h as c, E as f, N as l, f as n, O as m, t as u, F as p } from "../index-C4JCoBvj.js";
2
- import { c as d, s as C } from "../sin-KmhiDuMa.js";
3
- import { r as a } from "../range-9AzeApCc.js";
1
+ import { o as h, h as c, E as f, T as l, f as n, U as m, t as u, F as p } from "../index--6vO-cOz.js";
2
+ import { c as d, s as C } from "../sin-H567uayl.js";
3
+ import { r as a } from "../range-C_vpUjBu.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,8 +1,8 @@
1
- import { T as a } from "../TiedEmbedding-CnJ1bx4q.js";
2
- import "../index-C4JCoBvj.js";
3
- import "../tfjs_backend-Cug-PH75.js";
4
- import "../variable-LJT9Ld63.js";
5
- import "../gather-ZYRWhmXR.js";
1
+ import { T as a } from "../TiedEmbedding-DznFwzcB.js";
2
+ import "../index--6vO-cOz.js";
3
+ import "../tfjs_backend-DuKis_xG.js";
4
+ import "../variable-BJTZ3jOy.js";
5
+ import "../gather-C5D8PxwA.js";
6
6
  export {
7
7
  a as default
8
8
  };
@@ -2,7 +2,7 @@ import h from "./CausalSelfAttention.js";
2
2
  import o from "./MLP.js";
3
3
  import a from "./RMSNorm.js";
4
4
  import p from "./BaseLayer.js";
5
- import { t as d } from "../index-C4JCoBvj.js";
5
+ import { t as d } from "../index--6vO-cOz.js";
6
6
  class W extends p {
7
7
  ln1;
8
8
  attn;
@@ -12,7 +12,7 @@ class W extends p {
12
12
  _trainable = !0;
13
13
  skipped = !1;
14
14
  constructor(t, s) {
15
- super(s), this.index = t, this.ln1 = new a(s, 1e-8, `block_${this.index}_rms1`), this.attn = new h(this.index, s), this.ln2 = new a(s, 1e-8, `block_${this.index}_rms2`), this.mlp = new o(this.index, s);
15
+ super(s), this.index = t, this.ln1 = new a(s, `block_${this.index}_rms1`), this.attn = new h(this.index, s), this.ln2 = new a(s, `block_${this.index}_rms2`), this.mlp = new o(this.index, s);
16
16
  }
17
17
  get variables() {
18
18
  return [
@@ -1,8 +1,8 @@
1
- import { o as r, h as p, E as u, a3 as h, a4 as E, Y as S, s as $, a5 as d } from "./index-C4JCoBvj.js";
2
- import { e as K } from "./axis_util-BgTGy5w8.js";
3
- import { m as T } from "./max-CP_9O2Yd.js";
4
- import { r as m } from "./reshape-Boe4DuIO.js";
5
- import { s as _ } from "./sum-R28pucR5.js";
1
+ import { o as r, h as p, E as u, a6 as h, a7 as E, $, s as S, a8 as d } from "./index--6vO-cOz.js";
2
+ import { e as K } from "./axis_util-QP0LdI1v.js";
3
+ import { m as T } from "./max-BUShNgfh.js";
4
+ import { r as m } from "./reshape-z51Eu-re.js";
5
+ import { s as _ } from "./sum-DdkDf2MG.js";
6
6
  /**
7
7
  * @license
8
8
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -62,12 +62,12 @@ const w = /* @__PURE__ */ r({ log_: v });
62
62
  * =============================================================================
63
63
  */
64
64
  function A(s, n = null, o = !1) {
65
- const a = p(s, "x", "logSumExp"), t = S(n, a.shape), x = T(
65
+ const a = p(s, "x", "logSumExp"), t = $(n, a.shape), x = T(
66
66
  a,
67
67
  t,
68
68
  !0
69
69
  /* keepDims */
70
- ), i = $(a, x), l = N(i), f = _(l, t), c = w(f), e = d(m(x, c.shape), c);
70
+ ), i = S(a, x), l = N(i), f = _(l, t), c = w(f), e = d(m(x, c.shape), c);
71
71
  if (o) {
72
72
  const g = K(e.shape, t);
73
73
  return m(e, g);
package/dist/main.js CHANGED
@@ -1,11 +1,11 @@
1
- import { default as k } from "./NanoGPTModel.js";
2
- import { default as L } from "./TeachableLLM.js";
3
- import { default as b } from "./tokeniser/CharTokeniser.js";
4
- import { default as w } from "./tokeniser/bpe.js";
5
- import { default as D } from "./utilities/waitForModel.js";
6
- import { default as F } from "./data/textLoader.js";
7
- import { estimateMemoryUsage as N, estimateParameterCount as R, estimateResources as j, estimateTrainingMemoryUsage as q, validateConfig as z } from "./utilities/parameters.js";
8
- import "./index-C4JCoBvj.js";
1
+ import { default as E } from "./NanoGPTModel.js";
2
+ import { default as G } from "./TeachableLLM.js";
3
+ import { default as R } from "./tokeniser/CharTokeniser.js";
4
+ import { default as q } from "./tokeniser/bpe.js";
5
+ import { default as A } from "./utilities/waitForModel.js";
6
+ import { default as I } from "./data/textLoader.js";
7
+ import { estimateMemoryUsage as K, estimateParameterCount as O, estimateResources as Q, estimateTrainingMemoryUsage as S, validateConfig as V } from "./utilities/parameters.js";
8
+ import "./index--6vO-cOz.js";
9
9
  import "./ops/cpu/scatterSub.js";
10
10
  import "./ops/webgl/scatterSub.js";
11
11
  import "./ops/cpu/gatherSub.js";
@@ -25,16 +25,25 @@ import "./ops/webgl/appendCache.js";
25
25
  import "./ops/cpu/fusedSoftmax.js";
26
26
  import "./ops/webgl/fusedSoftmax.js";
27
27
  import "./ops/grads/fusedSoftmax.js";
28
+ import "./ops/cpu/matMulGelu.js";
29
+ import "./ops/webgl/matMulGelu.js";
30
+ import "./ops/grads/matMulGelu.js";
31
+ import "./ops/cpu/gelu.js";
32
+ import "./ops/webgl/gelu.js";
33
+ import "./ops/grads/gelu.js";
34
+ import "./ops/cpu/normRMS.js";
35
+ import "./ops/webgl/normRMS.js";
36
+ import "./ops/grads/normRMS.js";
28
37
  export {
29
- w as BPETokeniser,
30
- b as CharTokeniser,
31
- k as NanoGPT,
32
- L as TeachableLLM,
33
- N as estimateMemoryUsage,
34
- R as estimateParameterCount,
35
- j as estimateResources,
36
- q as estimateTrainingMemoryUsage,
37
- F as loadTextData,
38
- z as validateConfig,
39
- D as waitForModel
38
+ q as BPETokeniser,
39
+ R as CharTokeniser,
40
+ E as NanoGPT,
41
+ G as TeachableLLM,
42
+ K as estimateMemoryUsage,
43
+ O as estimateParameterCount,
44
+ Q as estimateResources,
45
+ S as estimateTrainingMemoryUsage,
46
+ I as loadTextData,
47
+ V as validateConfig,
48
+ A as waitForModel
40
49
  };
@@ -1,4 +1,4 @@
1
- import { o as m, h as s, p as c, E as M, B as p } from "./index-C4JCoBvj.js";
1
+ import { o as m, h as s, p as c, E as M, B as p } from "./index--6vO-cOz.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as r, h as e, E as x, M as c } from "./index-C4JCoBvj.js";
1
+ import { o as r, h as e, E as x, M as c } from "./index--6vO-cOz.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,6 +1,6 @@
1
- import { o as m, h as c, E as f, X as i, Y as l, Z as h, s as x, x as d } from "./index-C4JCoBvj.js";
2
- import { e as v } from "./axis_util-BgTGy5w8.js";
3
- import { r as E } from "./reshape-Boe4DuIO.js";
1
+ import { o as m, h as c, E as f, _ as i, $ as l, a0 as h, s as x, x as d } from "./index--6vO-cOz.js";
2
+ import { e as v } from "./axis_util-QP0LdI1v.js";
3
+ import { r as E } from "./reshape-z51Eu-re.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2020 Google Inc. All Rights Reserved.
@@ -46,8 +46,8 @@ function T(a, t = null, e = !1) {
46
46
  const p = h(x(d(a, "float32"), E(s, o))), u = r(p, n, e);
47
47
  return { mean: s, variance: u };
48
48
  }
49
- const N = /* @__PURE__ */ m({ moments_: T });
49
+ const K = /* @__PURE__ */ m({ moments_: T });
50
50
  export {
51
- N as a,
51
+ K as a,
52
52
  r as m
53
53
  };
@@ -1,8 +1,8 @@
1
- import { o as l, h as c, E as y, _ as E, Y as w, $ as o, a0 as u, O as v, f as I, Z as $ } from "./index-C4JCoBvj.js";
2
- import { e as A } from "./axis_util-BgTGy5w8.js";
3
- import { m as f } from "./max-CP_9O2Yd.js";
4
- import { r as h } from "./reshape-Boe4DuIO.js";
5
- import { s as t } from "./sum-R28pucR5.js";
1
+ import { o as l, h as c, E as y, a1 as E, $ as w, a2 as o, a3 as u, U as v, f as I, a0 as $ } from "./index--6vO-cOz.js";
2
+ import { e as A } from "./axis_util-QP0LdI1v.js";
3
+ import { m as f } from "./max-BUShNgfh.js";
4
+ import { r as h } from "./reshape-z51Eu-re.js";
5
+ import { s as t } from "./sum-DdkDf2MG.js";
6
6
  /**
7
7
  * @license
8
8
  * Copyright 2020 Google Inc. All Rights Reserved.
@@ -20,8 +20,8 @@ import { s as t } from "./sum-R28pucR5.js";
20
20
  * =============================================================================
21
21
  */
22
22
  function k(n, e = null, r = !1) {
23
- const i = { x: c(n, "x", "min") }, a = { axis: e, keepDims: r };
24
- return y.runKernel(E, i, a);
23
+ const a = { x: c(n, "x", "min") }, i = { axis: e, keepDims: r };
24
+ return y.runKernel(E, a, i);
25
25
  }
26
26
  const s = /* @__PURE__ */ l({ min_: k });
27
27
  /**
@@ -42,13 +42,13 @@ const s = /* @__PURE__ */ l({ min_: k });
42
42
  */
43
43
  function T(n, e = "euclidean", r = null, m = !1) {
44
44
  n = c(n, "x", "norm");
45
- const i = d(n, e, r);
46
- let a = i.shape;
45
+ const a = d(n, e, r);
46
+ let i = a.shape;
47
47
  if (m) {
48
48
  const p = w(r, n.shape);
49
- a = A(i.shape, p);
49
+ i = A(a.shape, p);
50
50
  }
51
- return h(i, a);
51
+ return h(a, i);
52
52
  }
53
53
  function d(n, e, r = null) {
54
54
  if (n.rank === 0)
@@ -79,8 +79,8 @@ function d(n, e, r = null) {
79
79
  }
80
80
  throw new Error(`Error in norm: invalid axis: ${r}`);
81
81
  }
82
- const K = /* @__PURE__ */ l({ norm_: T });
82
+ const N = /* @__PURE__ */ l({ norm_: T });
83
83
  export {
84
84
  s as m,
85
- K as n
85
+ N as n
86
86
  };
@@ -1,5 +1,5 @@
1
- import { k as n, l as t, n as m, E as i } from "./index-C4JCoBvj.js";
2
- import { z as l, c } from "./zeros-dnQxFgAD.js";
1
+ import { k as n, l as t, n as m, E as i } from "./index--6vO-cOz.js";
2
+ import { z as l, c } from "./zeros-8xl-W2DC.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,8 +1,8 @@
1
- import { e as a } from "../index-C4JCoBvj.js";
1
+ import { e as a } from "../index--6vO-cOz.js";
2
2
  import "./cpu/appendCache.js";
3
3
  import "./webgl/appendCache.js";
4
- import { z as s } from "../zeros-dnQxFgAD.js";
5
- import { c } from "../concat-CuRsVY-K.js";
4
+ import { z as s } from "../zeros-8xl-W2DC.js";
5
+ import { c } from "../concat-DvWM7HGZ.js";
6
6
  function i(r, p, n, o) {
7
7
  if (!o) {
8
8
  const e = r.shape[2];
@@ -1,4 +1,4 @@
1
- import { e as i } from "../index-C4JCoBvj.js";
1
+ import { e as i } from "../index--6vO-cOz.js";
2
2
  import "./cpu/attentionMask.js";
3
3
  import "./webgl/attentionMask.js";
4
4
  import "./grads/attentionMask.js";
@@ -1,5 +1,5 @@
1
- import { r as d } from "../../index-C4JCoBvj.js";
2
- import { c as h } from "../../concat-CuRsVY-K.js";
1
+ import { r as d } from "../../index--6vO-cOz.js";
2
+ import { c as h } from "../../concat-DvWM7HGZ.js";
3
3
  function u(p) {
4
4
  const { cache: n, item: s } = p.inputs, { maxSize: r, pastLen: c } = p.attrs, t = n.shape[0], o = n.shape[1], a = n.shape[3], e = s.shape[2];
5
5
  if (c + e <= r) {
@@ -1,5 +1,5 @@
1
- import { r as o, f as k } from "../../index-C4JCoBvj.js";
2
- import { m as d } from "../../mat_mul-415y5Qn2.js";
1
+ import { r as o, f as k } from "../../index--6vO-cOz.js";
2
+ import { m as d } from "../../mat_mul-BEHRPMh0.js";
3
3
  function r(t) {
4
4
  const { q: e, k: n, mask: s } = t.inputs, { divisor: c } = t.attrs, m = e.shape[2], i = n.shape[2], a = d(e, n, !1, !0).mul(k(c));
5
5
  if (s) {
@@ -1,5 +1,5 @@
1
- import { r as n } from "../../index-C4JCoBvj.js";
2
- import { s as f } from "../../softmax-Cujsg4ay.js";
1
+ import { r as n } from "../../index--6vO-cOz.js";
2
+ import { s as f } from "../../softmax-Dsxflvdl.js";
3
3
  function r(t) {
4
4
  const { inputs: s, attrs: i } = t, { logits: o } = s, { dim: a, dropoutRate: e } = i;
5
5
  if (!o)
@@ -1,6 +1,6 @@
1
- import { o as u, h as c, E as g, L as h, r as m, s as p } from "../../index-C4JCoBvj.js";
2
- import { r as l } from "../../range-9AzeApCc.js";
3
- import { s as N } from "../../stack-D1YjmgKN.js";
1
+ import { o as u, h as c, E as g, N as h, r as m, s as p } from "../../index--6vO-cOz.js";
2
+ import { r as N } from "../../range-C_vpUjBu.js";
3
+ import { s as l } from "../../stack-CmqSdsfs.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -23,7 +23,7 @@ function f(e, s) {
23
23
  }
24
24
  const b = /* @__PURE__ */ u({ gatherND_: f });
25
25
  function d(e) {
26
- const { values: s, labels: n, logits: r } = e.inputs, t = n.shape[0], o = l(0, t, 1, "int32"), a = N([o, n], 1), i = b(r, a);
26
+ const { values: s, labels: n, logits: r } = e.inputs, t = n.shape[0], o = N(0, t, 1, "int32"), a = l([o, n], 1), i = b(r, a);
27
27
  return p(s, i);
28
28
  }
29
29
  const k = {
@@ -1,4 +1,4 @@
1
- import { r as t, t as d } from "../../index-C4JCoBvj.js";
1
+ import { r as t, t as d } from "../../index--6vO-cOz.js";
2
2
  const o = 0.7978845608028654, c = 0.044715;
3
3
  function m(u) {
4
4
  const { inputs: l } = u, { x: e } = l, n = e;
@@ -0,0 +1 @@
1
+ export {};