@genai-fi/nanogpt 0.4.5 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/dist/BaseLayer-BhrMN8JO.js +135 -0
  2. package/dist/Generator.js +52 -49
  3. package/dist/NanoGPTModel.d.ts +13 -17
  4. package/dist/NanoGPTModel.js +128 -136
  5. package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
  6. package/dist/TeachableLLM.js +1 -1
  7. package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
  8. package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
  9. package/dist/broadcast_to-CMlkG8NS.js +44 -0
  10. package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
  11. package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
  12. package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
  13. package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
  14. package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
  15. package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
  16. package/dist/layers/BaseLayer.d.ts +28 -4
  17. package/dist/layers/BaseLayer.js +3 -16
  18. package/dist/layers/CausalSelfAttention.d.ts +21 -24
  19. package/dist/layers/CausalSelfAttention.js +73 -128
  20. package/dist/layers/MLP.d.ts +8 -15
  21. package/dist/layers/MLP.js +43 -81
  22. package/dist/layers/RMSNorm.d.ts +5 -10
  23. package/dist/layers/RMSNorm.js +13 -29
  24. package/dist/layers/RoPECache.js +14 -12
  25. package/dist/layers/TiedEmbedding.d.ts +6 -16
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.d.ts +12 -16
  28. package/dist/layers/TransformerBlock.js +20 -41
  29. package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
  30. package/dist/main.js +1 -1
  31. package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
  32. package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
  33. package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
  34. package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
  35. package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
  36. package/dist/ops/appendCache.js +4 -4
  37. package/dist/ops/attentionMask.d.ts +1 -1
  38. package/dist/ops/attentionMask.js +4 -4
  39. package/dist/ops/cpu/appendCache.js +2 -2
  40. package/dist/ops/cpu/attentionMask.js +14 -15
  41. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  42. package/dist/ops/cpu/gatherSub.js +5 -5
  43. package/dist/ops/cpu/gelu.js +1 -1
  44. package/dist/ops/cpu/matMulGelu.js +1 -1
  45. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  46. package/dist/ops/cpu/matMulMul.js +17 -0
  47. package/dist/ops/cpu/mulDropout.js +1 -1
  48. package/dist/ops/cpu/normRMS.js +1 -1
  49. package/dist/ops/cpu/qkv.js +3 -3
  50. package/dist/ops/cpu/rope.js +5 -5
  51. package/dist/ops/cpu/scatterSub.js +8 -8
  52. package/dist/ops/fusedSoftmax.js +1 -1
  53. package/dist/ops/gatherSub.js +1 -1
  54. package/dist/ops/gelu.js +1 -1
  55. package/dist/ops/grads/attentionMask.js +13 -9
  56. package/dist/ops/grads/fusedSoftmax.js +12 -9
  57. package/dist/ops/grads/gelu.js +1 -1
  58. package/dist/ops/grads/matMulGelu.js +1 -1
  59. package/dist/ops/grads/normRMS.js +1 -1
  60. package/dist/ops/grads/qkv.js +19 -9
  61. package/dist/ops/grads/rope.js +1 -1
  62. package/dist/ops/matMulGelu.js +1 -1
  63. package/dist/ops/matMulMul.d.ts +2 -0
  64. package/dist/ops/matMulMul.js +9 -0
  65. package/dist/ops/mulDrop.js +1 -1
  66. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  67. package/dist/ops/normRMS.js +1 -1
  68. package/dist/ops/qkv.js +1 -1
  69. package/dist/ops/scatterSub.js +1 -1
  70. package/dist/ops/webgl/appendCache.js +1 -1
  71. package/dist/ops/webgl/attentionMask.js +13 -12
  72. package/dist/ops/webgl/fusedSoftmax.js +43 -40
  73. package/dist/ops/webgl/gatherSub.js +1 -1
  74. package/dist/ops/webgl/gelu.js +2 -2
  75. package/dist/ops/webgl/matMulGelu.js +17 -17
  76. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  77. package/dist/ops/webgl/matMulMul.js +28 -0
  78. package/dist/ops/webgl/mulDropout.js +1 -1
  79. package/dist/ops/webgl/normRMS.js +29 -21
  80. package/dist/ops/webgl/qkv.js +1 -1
  81. package/dist/ops/webgl/rope.js +1 -1
  82. package/dist/ops/webgl/scatterSub.js +1 -1
  83. package/dist/ops-ObfXLHYQ.js +1269 -0
  84. package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
  85. package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
  86. package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
  87. package/dist/slice_util-D-kaD4ZV.js +49 -0
  88. package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
  89. package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
  90. package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
  91. package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
  92. package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
  93. package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
  94. package/dist/tfjs_backend-NucKez4s.js +1010 -0
  95. package/dist/training/AdamExt.js +1 -1
  96. package/dist/training/DatasetBuilder.js +44 -44
  97. package/dist/training/Evaluator.js +6 -6
  98. package/dist/training/FullTrainer.js +1 -1
  99. package/dist/training/Trainer.js +7 -7
  100. package/dist/training/sparseCrossEntropy.js +4 -4
  101. package/dist/utilities/dummy.js +10 -10
  102. package/dist/utilities/generate.js +3 -3
  103. package/dist/utilities/load.js +1 -1
  104. package/dist/utilities/profile.js +1 -1
  105. package/dist/utilities/save.js +13 -11
  106. package/dist/utilities/weights.js +2 -2
  107. package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
  108. package/package.json +1 -1
  109. package/dist/slice_util-BdhYwFY_.js +0 -90
  110. package/dist/tfjs_backend-DuKis_xG.js +0 -2271
  111. package/dist/variable-BJTZ3jOy.js +0 -23
@@ -1,150 +1,95 @@
1
- import { attentionMask as P } from "../ops/attentionMask.js";
2
- import T from "./BaseLayer.js";
3
- import { qkv as y } from "../ops/qkv.js";
4
- import { rope as w } from "../ops/rope.js";
5
- import { appendCache as E } from "../ops/appendCache.js";
6
- import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index--6vO-cOz.js";
7
- import { fusedSoftmax as _ } from "../ops/fusedSoftmax.js";
8
- import { l as W, w as M, d as x } from "../tfjs_backend-DuKis_xG.js";
9
- import { o as q } from "../ones-D6kB8bdY.js";
10
- import { v as b } from "../variable-BJTZ3jOy.js";
11
- import { z as B } from "../zeros-8xl-W2DC.js";
12
- import { r as C, d as I } from "../dropout-DFEXTPV0.js";
13
- import { r as F } from "../reshape-z51Eu-re.js";
14
- import { m as H } from "../mat_mul-BEHRPMh0.js";
15
- class nt extends T {
16
- cAttn = null;
17
- cProj = null;
18
- bias;
19
- maskInf;
1
+ import { attentionMask as g } from "../ops/attentionMask.js";
2
+ import { B as O, v } from "../BaseLayer-BhrMN8JO.js";
3
+ import { qkv as P } from "../ops/qkv.js";
4
+ import { rope as V } from "../ops/rope.js";
5
+ import { appendCache as T } from "../ops/appendCache.js";
6
+ import { F as c, t as C } from "../index-iNhkcAEQ.js";
7
+ import { fusedSoftmax as b } from "../ops/fusedSoftmax.js";
8
+ import { d as y } from "../tfjs_backend-NucKez4s.js";
9
+ import { r as k, d as L } from "../dropout-kbDY39Ci.js";
10
+ import { r as N } from "../reshape-DxTPgnwL.js";
11
+ import { m as R } from "../mat_mul-D0SifYfJ.js";
12
+ class W extends O {
20
13
  divisor;
21
14
  index;
22
- _trainable = !0;
23
15
  units;
24
16
  projUnits;
25
- constructor(t, s) {
26
- super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(q([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
27
- const e = B([s.gpt.blockSize, s.gpt.blockSize]), o = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
28
- this.maskInf = M(this.bias, e, o);
17
+ ATTN;
18
+ PROJ;
19
+ constructor(t, i, s) {
20
+ super(i, s), this.index = t, this.units = i.gpt.nEmbed * 3, this.projUnits = i.gpt.nEmbed, this.ATTN = `block_${this.index}_cAttn`, this.PROJ = `block_${this.index}_cProj`, this.addVariable(this.ATTN), this.addVariable(this.PROJ), this.divisor = 1 / Math.sqrt(i.gpt.nEmbed / i.gpt.nHead);
29
21
  }
30
22
  build() {
31
- this.cAttn === null && (this.cAttn = b(
32
- C([this.config.gpt.nEmbed, this.units], 0, 0.02),
33
- !0
34
- //`block_${this.index}_attn_cAttn_kernel`
35
- )), this.cProj === null && (this.cProj = b(
36
- C([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
37
- !0
38
- //`block_${this.index}_attn_cProj_kernel`
39
- ));
23
+ this.hasVariable(this.ATTN) === !1 && this.setVariable(
24
+ this.ATTN,
25
+ v(
26
+ k([this.config.gpt.nEmbed, this.units], 0, 0.02),
27
+ !0
28
+ //`block_${this.index}_attn_cAttn_kernel`
29
+ )
30
+ ), this.hasVariable(this.PROJ) === !1 && this.setVariable(
31
+ this.PROJ,
32
+ v(
33
+ k([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
34
+ !0
35
+ //`block_${this.index}_attn_cProj_kernel`
36
+ )
37
+ );
40
38
  }
41
- get variables() {
42
- if (this.cAttn === null)
43
- throw new Error("Layer not built yet");
44
- return [this.cAttn, this.cProj];
45
- }
46
- get trainable() {
47
- return this._trainable;
48
- }
49
- set trainable(t) {
50
- this._trainable = t, this.cAttn && (this.cAttn.trainable = t), this.cProj && (this.cProj.trainable = t);
51
- }
52
- saveWeights(t) {
53
- t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj ? [this.cProj.clone()] : []);
54
- }
55
- loadWeights(t) {
56
- const s = t.get(`block_${this.index}_cAttn`)?.[0], e = t.get(`block_${this.index}_cProj`)?.[0];
57
- if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
58
- if (!e) throw new Error(`Weights for block_${this.index}_cProj not found`);
59
- this.cAttn ? this.cAttn.assign(s) : this.cAttn = b(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = b(e, !0);
60
- }
61
- getAttentionScores(t, s, e, o) {
62
- const i = P(t, s, this.divisor, this.maskInf);
63
- return _(i, e ? this.config.gpt.dropout : 0, o);
39
+ getAttentionScores(t, i, s, o) {
40
+ const e = g(t, i, this.divisor), n = b(e, s ? this.config.gpt.dropout : 0, o);
41
+ return e.dispose(), n;
64
42
  }
65
43
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
66
- getAttentionScoresWithPast(t, s, e) {
67
- const o = P(t, s, this.divisor, void 0, e);
68
- return _(o, 0, 0);
44
+ getAttentionScoresWithPast(t, i, s) {
45
+ const o = g(t, i, this.divisor, s), e = b(o, 0, 0);
46
+ return o.dispose(), e;
69
47
  }
70
48
  getQKV(t) {
71
- return y(t, this.cAttn, this.config.gpt.nHead);
49
+ return P(t, this.getVariable(this.ATTN), this.config.gpt.nHead);
72
50
  }
73
51
  getOutputProjection(t) {
74
- const s = t.shape[0], e = t.shape[2], o = this.config.gpt.nEmbed, i = t.transpose([0, 2, 1, 3]), n = F(i, [s, e, o]);
75
- return x(n, this.cProj);
52
+ const i = t.shape[0], s = t.shape[2], o = this.config.gpt.nEmbed, e = t.transpose([0, 2, 1, 3]), n = N(e, [i, s, o]), p = y(n, this.getVariable(this.PROJ));
53
+ return n.dispose(), e.dispose(), p;
76
54
  }
77
- updateCache(t, s, e, o) {
78
- const i = this.config.gpt.blockSize, n = t.shape[2], r = o?.length || 0, a = e ? t : E(t, i, r, o?.k);
79
- e || (t.dispose(), o?.k.dispose());
80
- const p = e ? s : E(s, i, r, o?.v);
81
- return e || (s.dispose(), o?.v.dispose()), {
82
- k: S(a),
83
- v: S(p),
84
- length: Math.min(r + n, i),
85
- cumulativeLength: o ? o.cumulativeLength + n : n
86
- };
55
+ updateCache(t, i, s) {
56
+ const o = this.config.gpt.blockSize, e = t.shape[2], n = s.length || 0, p = T(t, o, n, s.k);
57
+ t.dispose(), s.k && s.k.dispose();
58
+ const r = T(i, o, n, s.v);
59
+ i.dispose(), s.v && s.v.dispose();
60
+ const d = Math.min(n + e, o), h = s.cumulativeLength + e;
61
+ s.length = d, s.cumulativeLength = h, s.k = c(p), s.v = c(r);
87
62
  }
88
- forward(t, s = !1, e, o = !1, i) {
89
- return $(() => {
63
+ forward(t, i) {
64
+ return C(() => {
90
65
  this.startMemory();
91
- const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, A = c ? w(r, c, p) : r;
92
- c && (n.dispose(), r.dispose());
93
- const f = i ? i.length : 0, d = this.updateCache(A, a, s, i), l = d.k, g = d.v;
94
- let h;
95
- f > 0 ? h = this.getAttentionScoresWithPast(u, l, f) : h = this.getAttentionScores(u, l, s, e), u.dispose(), s && l.dispose();
96
- const m = H(h, g);
97
- o || h.dispose(), s && g.dispose();
98
- const k = this.getOutputProjection(m);
99
- m.dispose();
100
- const v = o ? h.mean(1) : void 0;
101
- return this.endMemory("CausalSelfAttention"), { output: k, attention: v, presentKV: s ? void 0 : d };
66
+ const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0, p = this.config.layerConfig.ropeCache, r = p ? V(s, p, n) : s, d = p ? V(o, p, n) : o;
67
+ p && (s.dispose(), o.dispose());
68
+ const h = t.pastKV ? t.pastKV.length : 0;
69
+ t.pastKV && !t.training && this.updateCache(d, e, t.pastKV);
70
+ const u = t.pastKV?.k ? t.pastKV.k : d, l = t.pastKV?.v ? t.pastKV.v : e;
71
+ let a;
72
+ h > 0 ? a = this.getAttentionScoresWithPast(r, u, h) : a = this.getAttentionScores(r, u, t.training, t.seed || 0), r.dispose(), t.pastKV || u.dispose();
73
+ const m = R(a, l), f = t.attentionScores !== void 0 && t.attentionScores.attentionOut !== void 0;
74
+ f || a.dispose(), t.pastKV || l.dispose();
75
+ const A = this.getOutputProjection(m);
76
+ if (m.dispose(), f && t.attentionScores && t.attentionScores.attentionOut !== void 0) {
77
+ const K = a.shape[1], S = a.shape[2];
78
+ t.attentionScores.attentionOut?.push(
79
+ c(a.slice([0, 0, 0, 0], [1, -1, -1, -1]).reshape([K, S, -1]))
80
+ );
81
+ }
82
+ return this.endMemory("CausalSelfAttention"), A;
102
83
  });
103
84
  }
104
- call(t, s = !1, e = !1, o) {
105
- if (o && !this.config.gpt.useRope)
106
- throw new Error("Cannot use pastKV without RoPE enabled");
107
- if (s && o)
108
- throw new Error("Cannot use pastKV during training");
109
- if (t.shape.length !== 3)
110
- throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
111
- if (t.shape[2] !== this.config.gpt.nEmbed)
112
- throw new Error(`Input tensor last dimension must be ${this.config.gpt.nEmbed}, got ${t.shape[2]}`);
113
- this.build();
114
- const i = Math.random() * 1e9;
115
- if (s && this.config.layerConfig.checkpointAttention) {
116
- const r = L(
117
- // @ts-expect-error Invalid params
118
- (a, p, c, u) => {
119
- const A = this.forward(a, !0, i);
120
- u([a]);
121
- const f = (d, l) => {
122
- const [g] = l, h = j().state.activeTape;
123
- j().state.activeTape = [];
124
- const m = O((k, v, R) => this.forward(k, !0, i).output)([g, p, c], d);
125
- return j().state.activeTape = h, m;
126
- };
127
- return { value: A.output, gradFunc: f };
128
- }
129
- )(t, this.cAttn, this.cProj);
130
- if (this.config.gpt.dropout > 0) {
131
- const a = I(r, this.config.gpt.dropout);
132
- return r.dispose(), { output: a };
133
- } else
134
- return { output: r };
135
- } else {
136
- const n = this.forward(t, s, i, e, o);
137
- if (this.config.gpt.dropout > 0) {
138
- const r = I(n.output, this.config.gpt.dropout);
139
- return n.output.dispose(), { output: r, attention: n.attention, presentKV: n.presentKV };
140
- } else
141
- return n;
142
- }
143
- }
144
- dispose() {
145
- this.cAttn?.dispose(), this.cProj?.dispose(), this.bias.dispose(), this.maskInf.dispose();
85
+ dropout(t) {
86
+ if (this.config.gpt.dropout > 0) {
87
+ const i = L(t, this.config.gpt.dropout);
88
+ return t.dispose(), i;
89
+ } else
90
+ return t;
146
91
  }
147
92
  }
148
93
  export {
149
- nt as default
94
+ W as default
150
95
  };
@@ -1,19 +1,12 @@
1
- import { Tensor, Variable } from '@tensorflow/tfjs-core';
2
- import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
3
3
  export default class MLP extends BaseLayer {
4
- private cFc;
5
- private cProj;
6
4
  private index;
7
- private _trainable;
8
5
  private hiddenUnits;
9
- constructor(index: number, config: GPTLayerConfig);
10
- private build;
11
- get variables(): Variable[];
12
- get trainable(): boolean;
13
- set trainable(value: boolean);
14
- saveWeights(map: Map<string, Tensor[]>): void;
15
- loadWeights(weights: Map<string, Tensor[]>): void;
16
- forward(x: Tensor): Tensor;
17
- call(x: Tensor, training?: boolean): Tensor;
18
- dispose(): void;
6
+ private MLPHIDDEN;
7
+ private MLPOUT;
8
+ constructor(index: number, config: GPTLayerConfig, parent?: BaseLayer);
9
+ protected build(): void;
10
+ forward(_: ForwardAttributes, x: Tensor): Tensor;
11
+ protected dropout(x: Tensor): Tensor;
19
12
  }
@@ -1,93 +1,55 @@
1
- import { t as F, c as _, e as h, H as M } from "../index--6vO-cOz.js";
2
- import v from "./BaseLayer.js";
3
- import { matMulGelu as x } from "../ops/matMulGelu.js";
4
- import { v as c } from "../variable-BJTZ3jOy.js";
5
- import { r as d, d as u } from "../dropout-DFEXTPV0.js";
6
- import { r as p } from "../reshape-z51Eu-re.js";
7
- import { m as L } from "../mat_mul-BEHRPMh0.js";
8
- class G extends v {
9
- cFc = null;
10
- cProj = null;
1
+ import { t as l } from "../index-iNhkcAEQ.js";
2
+ import { B as u, v as o } from "../BaseLayer-BhrMN8JO.js";
3
+ import { matMulGelu as M } from "../ops/matMulGelu.js";
4
+ import { r as h, d as c } from "../dropout-kbDY39Ci.js";
5
+ import { r as d } from "../reshape-DxTPgnwL.js";
6
+ import { m as f } from "../mat_mul-D0SifYfJ.js";
7
+ class O extends u {
11
8
  index;
12
- _trainable = !0;
13
9
  hiddenUnits;
14
- constructor(t, s) {
15
- super(s), this.index = t, this.hiddenUnits = s.gpt.mlpFactor * s.gpt.nEmbed;
10
+ MLPHIDDEN;
11
+ MLPOUT;
12
+ constructor(i, t, s) {
13
+ super(t, s), this.index = i, this.hiddenUnits = t.gpt.mlpFactor * t.gpt.nEmbed, this.MLPHIDDEN = `block_${this.index}_mlpHidden`, this.MLPOUT = `block_${this.index}_mlpOut`, this.addVariable(this.MLPHIDDEN), this.addVariable(this.MLPOUT);
16
14
  }
17
15
  build() {
18
- this.cFc === null && (this.cFc = c(
19
- d([this.config.gpt.nEmbed, this.hiddenUnits], 0, 0.02),
20
- !0
21
- //`block_${this.index}_attn_cAttn_kernel`
22
- )), this.cProj === null && (this.cProj = c(
23
- d(
24
- [this.hiddenUnits, this.config.gpt.nEmbed],
25
- 0,
26
- 0.02 / Math.sqrt(2 * this.config.gpt.nLayer)
27
- ),
28
- !0
29
- //`block_${this.index}_attn_cProj_kernel`
30
- ));
31
- }
32
- get variables() {
33
- return [this.cFc, this.cProj];
34
- }
35
- get trainable() {
36
- return this._trainable;
37
- }
38
- set trainable(t) {
39
- this._trainable = t, this.cFc && (this.cFc.trainable = t), this.cProj && (this.cProj.trainable = t);
40
- }
41
- saveWeights(t) {
42
- t.set(`block_${this.index}_mlpHidden`, this.cFc ? [this.cFc.clone()] : []), t.set(`block_${this.index}_mlpOut`, this.cProj ? [this.cProj.clone()] : []);
43
- }
44
- loadWeights(t) {
45
- const s = t.get(`block_${this.index}_mlpOut`)?.[0], i = t.get(`block_${this.index}_mlpHidden`)?.[0];
46
- if (!s || !i)
47
- throw new Error(`Weights for block ${this.index} not found`);
48
- this.cFc ? this.cFc.assign(i) : this.cFc = c(i, !0), this.cProj ? this.cProj.assign(s) : this.cProj = c(s, !0);
49
- }
50
- forward(t) {
51
- return F(() => {
16
+ this.hasVariable(this.MLPHIDDEN) === !1 && this.setVariable(
17
+ this.MLPHIDDEN,
18
+ o(
19
+ h([this.config.gpt.nEmbed, this.hiddenUnits], 0, 0.02),
20
+ !0
21
+ //`block_${this.index}_attn_cAttn_kernel`
22
+ )
23
+ ), this.hasVariable(this.MLPOUT) === !1 && this.setVariable(
24
+ this.MLPOUT,
25
+ o(
26
+ h(
27
+ [this.hiddenUnits, this.config.gpt.nEmbed],
28
+ 0,
29
+ 0.02 / Math.sqrt(2 * this.config.gpt.nLayer)
30
+ ),
31
+ !0
32
+ //`block_${this.index}_attn_cProj_kernel`
33
+ )
34
+ );
35
+ }
36
+ forward(i, t) {
37
+ return l(() => {
52
38
  this.startMemory();
53
- const [s, i, r] = t.shape, o = p(t, [s * i, r]), e = x(o, this.cFc), n = L(e, this.cProj);
54
- e.dispose();
55
- const a = p(n, [s, i, r]);
56
- return this.endMemory("MLP"), a;
39
+ const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)), p = f(a, this.getVariable(this.MLPOUT));
40
+ a.dispose();
41
+ const m = d(p, [s, r, e]);
42
+ return this.endMemory("MLP"), m;
57
43
  });
58
44
  }
59
- call(t, s = !1) {
60
- if (this.build(), s && this.config.layerConfig.checkpointMLP) {
61
- const r = _(
62
- // @ts-expect-error Invalid params
63
- (o, e, n, a) => {
64
- const l = this.forward(o);
65
- return a([o]), { value: l, gradFunc: (f, g) => {
66
- const [m] = g, b = h().state.activeTape;
67
- h().state.activeTape = [];
68
- const P = M((j, w, T) => this.forward(j))([m, e, n], f);
69
- return h().state.activeTape = b, P;
70
- } };
71
- }
72
- )(t, this.cFc, this.cProj);
73
- if (this.config.gpt.dropout > 0) {
74
- const o = u(r, this.config.gpt.dropout);
75
- return r.dispose(), o;
76
- }
77
- return r;
78
- } else {
79
- const i = this.forward(t);
80
- if (s && this.config.gpt.dropout > 0) {
81
- const r = u(i, this.config.gpt.dropout);
82
- return i.dispose(), r;
83
- }
84
- return i;
45
+ dropout(i) {
46
+ if (this.config.gpt.dropout > 0) {
47
+ const t = c(i, this.config.gpt.dropout);
48
+ return i.dispose(), t;
85
49
  }
86
- }
87
- dispose() {
88
- this.cFc?.dispose(), this.cProj?.dispose();
50
+ return i;
89
51
  }
90
52
  }
91
53
  export {
92
- G as default
54
+ O as default
93
55
  };
@@ -1,12 +1,7 @@
1
- import { Tensor, Variable } from '@tensorflow/tfjs-core';
2
- import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
3
3
  export default class RMSNorm extends BaseLayer {
4
- private gamma;
5
- constructor(config: GPTLayerConfig, name?: string);
6
- get trainableWeights(): Variable[];
7
- set trainable(value: boolean);
8
- getWeights(): Tensor[];
9
- setWeights(weights: Tensor[]): void;
10
- apply(x: Tensor): Tensor;
11
- dispose(): void;
4
+ private GAMMA;
5
+ constructor(config: GPTLayerConfig, name?: string, parent?: BaseLayer);
6
+ forward(_: ForwardAttributes, x: Tensor): Tensor;
12
7
  }
@@ -1,36 +1,20 @@
1
- import { t as r } from "../index--6vO-cOz.js";
2
- import m from "./BaseLayer.js";
3
- import { normRMS as s } from "../ops/normRMS.js";
4
- import { v as e } from "../variable-BJTZ3jOy.js";
5
- import { o as i } from "../ones-D6kB8bdY.js";
6
- class u extends m {
7
- gamma;
8
- constructor(t, a = "") {
9
- super(t), this.gamma = e(i([t.gpt.nEmbed]), !0, `${a}_gamma`, "float32");
1
+ import { t as e } from "../index-iNhkcAEQ.js";
2
+ import { B as o, v as a } from "../BaseLayer-BhrMN8JO.js";
3
+ import { normRMS as i } from "../ops/normRMS.js";
4
+ import { o as M } from "../ones-BIeFnPHR.js";
5
+ class l extends o {
6
+ GAMMA;
7
+ constructor(r, t = "", s) {
8
+ super(r, s), this.GAMMA = t, this.addVariable(this.GAMMA, a(M([r.gpt.nEmbed]), !0, this.GAMMA, "float32"));
10
9
  }
11
- get trainableWeights() {
12
- return [this.gamma];
13
- }
14
- set trainable(t) {
15
- this.gamma.trainable = t;
16
- }
17
- getWeights() {
18
- return [this.gamma];
19
- }
20
- setWeights(t) {
21
- this.gamma.assign(t[0]);
22
- }
23
- apply(t) {
24
- return r(() => {
10
+ forward(r, t) {
11
+ return e(() => {
25
12
  this.startMemory();
26
- const a = s(t, this.gamma);
27
- return this.endMemory("RMSNorm"), a;
13
+ const s = i(t, this.getVariable(this.GAMMA));
14
+ return this.endMemory("RMSNorm"), s;
28
15
  });
29
16
  }
30
- dispose() {
31
- this.gamma.dispose();
32
- }
33
17
  }
34
18
  export {
35
- u as default
19
+ l as default
36
20
  };
@@ -1,6 +1,6 @@
1
- import { o as h, h as c, E as f, T as l, f as n, U as m, t as u, F as p } from "../index--6vO-cOz.js";
2
- import { c as d, s as C } from "../sin-H567uayl.js";
3
- import { r as a } from "../range-C_vpUjBu.js";
1
+ import { o as c, i as f, E as l, Q as m, f as n, U as u, t as p, F as a } from "../index-iNhkcAEQ.js";
2
+ import { c as d, s as C } from "../sin-BOX-JVAj.js";
3
+ import { r as h } from "../range-BsFU-SNG.js";
4
4
  /**
5
5
  * @license
6
6
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -18,10 +18,10 @@ import { r as a } from "../range-C_vpUjBu.js";
18
18
  * =============================================================================
19
19
  */
20
20
  function x(r) {
21
- const s = { x: c(r, "x", "reciprocal") };
22
- return f.runKernel(l, s);
21
+ const s = { x: f(r, "x", "reciprocal") };
22
+ return l.runKernel(m, s);
23
23
  }
24
- const S = /* @__PURE__ */ h({ reciprocal_: x });
24
+ const S = /* @__PURE__ */ c({ reciprocal_: x });
25
25
  class y {
26
26
  constructor(o) {
27
27
  this.config = o;
@@ -29,8 +29,8 @@ class y {
29
29
  if (this.rotaryDim = s, this.rotaryDim % 2 !== 0)
30
30
  throw new Error("rotaryDim must be even");
31
31
  this.ropeBase = 1e4;
32
- const i = a(0, this.rotaryDim, 2, "float32"), e = i.div(n(this.rotaryDim, "float32")), t = m(n(this.ropeBase, "float32"), e);
33
- this.ropeInvFreq = S(t), e.dispose(), t.dispose(), i.dispose(), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : u(() => {
32
+ const i = h(0, this.rotaryDim, 2, "float32"), e = i.div(n(this.rotaryDim, "float32")), t = u(n(this.ropeBase, "float32"), e);
33
+ this.ropeInvFreq = S(t), e.dispose(), t.dispose(), i.dispose(), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : p(() => {
34
34
  this.ensureRopeCache(this.config.blockSize * 4);
35
35
  });
36
36
  }
@@ -43,10 +43,12 @@ class y {
43
43
  // [cacheLen, rotaryDim/2]
44
44
  ropeCacheLen = 0;
45
45
  ensureRopeCache(o) {
46
- if (o <= this.ropeCacheLen) return;
47
- this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose();
48
- const s = Math.max(o, this.ropeCacheLen + this.config.blockSize * 4), e = a(0, s, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
49
- this.ropeCos = p(d(e).expandDims(-1)), this.ropeSin = p(C(e).expandDims(-1)), this.ropeCacheLen = s;
46
+ p(() => {
47
+ if (o <= this.ropeCacheLen) return;
48
+ this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose();
49
+ const s = Math.max(o, this.ropeCacheLen + this.config.blockSize * 4), e = h(0, s, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
50
+ this.ropeCos = a(d(e).expandDims(-1)), this.ropeSin = a(C(e).expandDims(-1)), this.ropeCacheLen = s;
51
+ });
50
52
  }
51
53
  getCos() {
52
54
  return this.ropeCos;
@@ -1,22 +1,12 @@
1
- import { Tensor, Variable } from '@tensorflow/tfjs-core';
2
- export default class TiedEmbeddingOutputLayer {
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
3
+ export default class TiedEmbeddingOutputLayer extends BaseLayer {
3
4
  private vocabSize;
4
5
  private embedDim;
5
- private tiedWeights;
6
6
  private initializer;
7
- constructor(config: {
8
- vocabSize: number;
9
- embedDim: number;
10
- name?: string;
11
- }, name?: string);
12
- get variables(): Variable[];
7
+ private WEIGHTS;
8
+ constructor(config: GPTLayerConfig, name: string, parent?: BaseLayer);
13
9
  embed(inputs: Tensor): Tensor;
14
10
  project(inputs: Tensor): Tensor;
15
- getWeights(): Tensor[];
16
- setWeights(weights: Tensor[]): void;
17
- getConfig(): {
18
- vocabSize: number;
19
- embedDim: number;
20
- };
21
- dispose(): void;
11
+ forward(_: ForwardAttributes, x: Tensor): Tensor;
22
12
  }
@@ -1,8 +1,8 @@
1
- import { T as a } from "../TiedEmbedding-DznFwzcB.js";
2
- import "../index--6vO-cOz.js";
3
- import "../tfjs_backend-DuKis_xG.js";
4
- import "../variable-BJTZ3jOy.js";
5
- import "../gather-C5D8PxwA.js";
1
+ import { T as a } from "../TiedEmbedding-DsDRvLB0.js";
2
+ import "../index-iNhkcAEQ.js";
3
+ import "../tfjs_backend-NucKez4s.js";
4
+ import "../BaseLayer-BhrMN8JO.js";
5
+ import "../gather-Bxe1Qip8.js";
6
6
  export {
7
7
  a as default
8
8
  };
@@ -1,25 +1,21 @@
1
- import { KVCache } from './CausalSelfAttention';
2
- import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
3
- import { Tensor, Variable } from '@tensorflow/tfjs-core';
4
- export default class Block extends BaseLayer {
1
+ import { AttentionScores, KVCache } from './CausalSelfAttention';
2
+ import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
3
+ import { Tensor } from '@tensorflow/tfjs-core';
4
+ interface BlockAttributes extends ForwardAttributes {
5
+ pastKV?: KVCache;
6
+ seed?: number;
7
+ attentionScores?: AttentionScores;
8
+ }
9
+ export default class Block extends BaseLayer<BlockAttributes> {
5
10
  private ln1;
6
11
  private attn;
7
12
  private ln2;
8
13
  private mlp;
9
14
  private index;
10
- private _trainable;
11
15
  skipped: boolean;
12
- constructor(index: number, config: GPTLayerConfig);
13
- get variables(): Variable[];
14
- get trainable(): boolean;
15
- set trainable(value: boolean);
16
- saveWeights(map: Map<string, Tensor[]>): void;
17
- loadWeights(weights: Map<string, Tensor[]>): void;
16
+ constructor(index: number, config: GPTLayerConfig, parent?: BaseLayer);
18
17
  private getMLPOutput;
19
- call(x: Tensor, training?: boolean, includeAttention?: boolean, cache?: KVCache): {
20
- output: Tensor;
21
- attention?: Tensor;
22
- cache?: KVCache;
23
- };
18
+ forward(attrs: BlockAttributes, x: Tensor): Tensor;
24
19
  dispose(): void;
25
20
  }
21
+ export {};