@genai-fi/nanogpt 0.1.8 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,16 @@
1
- import { default as NanoGPT } from './NanoGPTModel';
1
+ import { default as NanoGPT, GenerateOptions } from './NanoGPTModel';
2
2
  import { ITokeniser } from './tokeniser/type';
3
3
  import { default as EE } from 'eventemitter3';
4
- export interface IGenerateOptions {
4
+ export interface IGenerateOptions extends GenerateOptions {
5
5
  maxLength?: number;
6
- temperature?: number;
7
- topK?: number;
8
- usePadding?: boolean;
9
- includeAttention?: boolean;
10
- includeProbabilities?: boolean;
11
6
  }
12
7
  export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
13
8
  private readonly model;
14
9
  private readonly tokeniser;
15
10
  constructor(model: NanoGPT, tokeniser: ITokeniser);
16
- private generateBlockOfTokens;
11
+ private tokenisePrompt;
12
+ private generateNoCache;
13
+ private processResponse;
14
+ private generateCache;
17
15
  generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
18
16
  }
package/dist/Generator.js CHANGED
@@ -1,62 +1,65 @@
1
- import { E as m } from "./index-SOhdqzHq.js";
2
- const b = 4;
3
- class x extends m {
4
- constructor(a, t) {
5
- super(), this.model = a, this.tokeniser = t;
1
+ import { E as u } from "./index-SOhdqzHq.js";
2
+ class k extends u {
3
+ constructor(s, e) {
4
+ super(), this.model = s, this.tokeniser = e;
6
5
  }
7
- generateBlockOfTokens(a, t) {
8
- const g = t?.temperature ?? 1, c = t?.topK, d = t?.usePadding ?? t?.includeAttention ?? !1, k = t?.includeAttention ?? !1, h = t?.includeProbabilities ?? !1;
9
- let i = a, n, s;
10
- for (let e = 0; e < b; e++) {
6
+ async tokenisePrompt(s) {
7
+ const e = s ? await this.tokeniser.tokenise([s], !0) : [[this.tokeniser.eosToken]];
8
+ return this.model.tf.tensor2d(e, [1, e[0].length], "int32");
9
+ }
10
+ async generateNoCache(s, e) {
11
+ let t = await this.tokenisePrompt(s), n = s || "";
12
+ const a = e?.maxLength ?? 1e3;
13
+ for (let i = 0; i < a; i++) {
11
14
  const {
12
- output: u,
13
- attention: l,
14
- probabilities: r
15
- } = this.model.generate(i, {
16
- temperature: g,
17
- topK: c,
18
- usePadding: d,
19
- includeAttention: k,
20
- includeProbabilities: h
21
- }), p = i;
22
- if (i = this.model.tf.concat([i, u], 1), n && l) {
23
- const o = n;
24
- n = this.model.tf.concat([n, l], 0), o.dispose();
25
- } else l && (n = l);
26
- if (s && r) {
27
- const o = s;
28
- s = this.model.tf.concat([s, r], 0), o.dispose();
29
- } else r && (s = r);
30
- p.dispose(), u.dispose();
15
+ output: o,
16
+ attention: c,
17
+ probabilities: l
18
+ } = this.model.generate(t, void 0, e), h = t;
19
+ t = this.model.tf.concat([t, o], 1), h.dispose();
20
+ const r = await this.processResponse(o, c, l);
21
+ if (o.dispose(), r === null)
22
+ break;
23
+ n += r;
31
24
  }
32
- return { output: i, attention: n, probabilities: s };
25
+ return t.dispose(), n;
33
26
  }
34
- async generate(a, t) {
35
- const g = a ? await this.tokeniser.tokenise([a], !0) : [[this.tokeniser.eosToken]];
36
- let c = this.model.tf.tensor2d(g, [1, g[0].length], "int32");
37
- this.emit("start");
38
- let d = a || "";
39
- for (; ; ) {
40
- const { output: k, attention: h, probabilities: i } = this.generateBlockOfTokens(c, t), n = c;
41
- c = k;
42
- const s = k.slice([0, n.shape[1]], [1, b]), e = (await s.array())[0];
43
- n.dispose(), s.dispose();
44
- let u = !1, l = !1;
45
- const r = e.indexOf(this.tokeniser.eosToken);
46
- r !== -1 && (u = !0, e.splice(r)), e.length + d.length >= (t?.maxLength ?? 1e3) && (l = !0, e.splice(
47
- t?.maxLength ? t.maxLength - d.length : e.length
48
- ));
49
- const p = await this.tokeniser.decode(e);
50
- d += p;
51
- let o;
52
- h && (o = await h.array(), h.dispose(), o.length > e.length && (o = o.slice(0, e.length)));
53
- let f;
54
- if (i && (f = await i.array(), i.dispose(), f.length > e.length && (f = f.slice(0, e.length))), this.emit("tokens", e, p, o, f), u || l)
27
+ async processResponse(s, e, t) {
28
+ const n = (await s.array())[0][0];
29
+ if (n === this.tokeniser.eosToken)
30
+ return null;
31
+ const a = await this.tokeniser.decode([n]);
32
+ let i;
33
+ e && (i = await e.array(), e.dispose());
34
+ let o;
35
+ return t && (o = await t.array(), t.dispose()), this.emit("tokens", [n], a, i, o), a;
36
+ }
37
+ async generateCache(s, e) {
38
+ let t = await this.tokenisePrompt(s), n = s || "";
39
+ const a = new Array(this.model.config.nLayer).fill(void 0), i = e?.maxLength ?? 1e3;
40
+ for (let o = 0; o < i; o++) {
41
+ const {
42
+ output: c,
43
+ attention: l,
44
+ probabilities: h
45
+ } = this.model.generate(t, a, {
46
+ ...e,
47
+ usePadding: !1
48
+ });
49
+ t.dispose(), t = c;
50
+ const r = await this.processResponse(c, l, h);
51
+ if (r === null)
55
52
  break;
53
+ n += r;
56
54
  }
57
- return c.dispose(), this.emit("stop"), d;
55
+ return t.dispose(), n;
56
+ }
57
+ async generate(s, e) {
58
+ this.emit("start");
59
+ const t = this.model.config.useRope ? this.generateCache(s, e) : this.generateNoCache(s, e);
60
+ return this.emit("stop"), t;
58
61
  }
59
62
  }
60
63
  export {
61
- x as default
64
+ k as default
62
65
  };
@@ -1,5 +1,6 @@
1
1
  import { default as TF } from '@tensorflow/tfjs';
2
2
  import { GPTConfig } from './config';
3
+ import { KVCache } from './layers/CausalSelfAttention';
3
4
  export interface TrainingLogEntry {
4
5
  loss: number;
5
6
  valLoss?: number;
@@ -18,10 +19,11 @@ export interface GenerateOptions {
18
19
  export default class NanoGPT {
19
20
  readonly config: GPTConfig;
20
21
  private wte;
21
- private wpe;
22
+ private wpe?;
22
23
  private drop;
23
24
  private blocks;
24
25
  private lnF;
26
+ private ropeCache?;
25
27
  readonly tf: typeof TF;
26
28
  log: TrainingLogEntry[];
27
29
  constructor(tf: typeof TF, config?: Partial<GPTConfig>);
@@ -35,12 +37,12 @@ export default class NanoGPT {
35
37
  private validateInput;
36
38
  private calculateLoss;
37
39
  private computeAttentionRollout;
38
- forward(idx: TF.Tensor, targets?: TF.Tensor, training?: boolean, includeAttention?: boolean): {
40
+ forward(idx: TF.Tensor, targets?: TF.Tensor, training?: boolean, includeAttention?: boolean, cache?: (KVCache | undefined)[]): {
39
41
  logits: TF.Tensor;
40
42
  loss?: TF.Tensor;
41
43
  attention?: TF.Tensor;
42
44
  };
43
- generate(idx: TF.Tensor, options?: GenerateOptions): {
45
+ generate(idx: TF.Tensor, cache?: (KVCache | undefined)[], options?: GenerateOptions): {
44
46
  output: TF.Tensor;
45
47
  attention?: TF.Tensor;
46
48
  probabilities?: TF.Tensor;
@@ -1,8 +1,9 @@
1
- import { defaultConfig as z } from "./config.js";
2
- import v from "./layers/TransformerBlock.js";
3
- import S from "./layers/TiedEmbedding.js";
4
- import _ from "./layers/LayerNorm.js";
5
- class $ {
1
+ import { defaultConfig as v } from "./config.js";
2
+ import S from "./layers/TransformerBlock.js";
3
+ import _ from "./layers/TiedEmbedding.js";
4
+ import L from "./layers/RoPECache.js";
5
+ import I from "./layers/RMSNorm.js";
6
+ class F {
6
7
  config;
7
8
  wte;
8
9
  // Token embeddings
@@ -13,27 +14,28 @@ class $ {
13
14
  blocks;
14
15
  lnF;
15
16
  // Final layer norm
17
+ ropeCache;
16
18
  tf;
17
19
  log = [];
18
20
  // Training log
19
21
  constructor(t, e = {}) {
20
- this.tf = t, this.config = { ...z, ...e }, this.wte = new S(t, {
22
+ this.tf = t, this.config = { ...v, ...e }, this.wte = new _(t, {
21
23
  vocabSize: this.config.vocabSize,
22
24
  embedDim: this.config.nEmbed,
23
25
  name: "token_embedding"
24
- }), this.wpe = this.tf.layers.embedding({
26
+ }), this.config.useRope === !1 ? this.wpe = this.tf.layers.embedding({
25
27
  inputDim: this.config.blockSize,
26
28
  outputDim: this.config.nEmbed,
27
29
  name: "positional_embedding",
28
30
  embeddingsInitializer: this.tf.initializers.randomNormal({ mean: 0, stddev: 0.02 })
29
- }), this.drop = this.tf.layers.dropout({ rate: this.config.dropout }), this.blocks = [];
31
+ }) : this.ropeCache = new L(t, this.config), this.drop = this.tf.layers.dropout({ rate: this.config.dropout }), this.blocks = [];
30
32
  for (let s = 0; s < this.config.nLayer; s++)
31
- this.blocks.push(new v(this.tf, s, this.config));
32
- this.lnF = new _(t, [this.config.nEmbed], 1e-5, "final_layer_norm");
33
+ this.blocks.push(new S(this.tf, s, this.config, this.ropeCache));
34
+ this.lnF = new I(t, [this.config.nEmbed], 1e-8, "final_rms_norm");
33
35
  }
34
36
  get variables() {
35
37
  return [
36
- ...this.wpe.trainableWeights.map((t) => t.read()),
38
+ //...this.wpe.trainableWeights.map((v) => v.read() as TF.Variable),
37
39
  ...this.blocks.flatMap((t) => t.variables),
38
40
  ...this.lnF.trainableWeights.map((t) => t),
39
41
  ...this.wte.variables
@@ -41,21 +43,28 @@ class $ {
41
43
  }
42
44
  saveWeights() {
43
45
  const t = /* @__PURE__ */ new Map();
44
- t.set("token_embedding", this.wte.getWeights()), t.set("positional_embedding", this.wpe.getWeights());
46
+ t.set("token_embedding", this.wte.getWeights()), this.wpe && t.set("positional_embedding", this.wpe.getWeights());
45
47
  for (let e = 0; e < this.blocks.length; e++)
46
48
  this.blocks[e].saveWeights(t);
47
- return t.set("final_layer_norm", this.lnF.getWeights()), t;
49
+ return t.set("final_rms_norm", this.lnF.getWeights()), t;
48
50
  }
49
51
  loadWeights(t) {
50
- this.wte.setWeights(t.get("token_embedding") || []), this.wpe.setWeights(t.get("positional_embedding") || []);
52
+ this.wte.setWeights(t.get("token_embedding") || []), this.wpe && this.wpe.setWeights(t.get("positional_embedding") || []);
51
53
  for (let e = 0; e < this.blocks.length; e++)
52
54
  this.blocks[e].loadWeights(t);
53
- this.lnF.setWeights(t.get("final_layer_norm") || []);
55
+ this.lnF.setWeights(t.get("final_rms_norm") || []);
54
56
  }
55
- inputPhase(t, e = !1) {
57
+ inputPhase(t, e, s = !1) {
56
58
  return this.tf.tidy(() => {
57
- const [, s] = t.shape, i = this.wte.embed(t), n = this.tf.range(0, s, 1, "int32"), h = this.wpe.apply(n), o = i.add(h);
58
- return this.drop.apply(o, { training: e });
59
+ const o = this.wte.embed(t);
60
+ if (this.config.useRope === !1) {
61
+ const [, i] = t.shape, a = this.config.blockSize, n = this.tf.range(0, i, 1, "int32"), h = this.tf.mod(
62
+ this.tf.add(n, this.tf.scalar(e, "int32")),
63
+ this.tf.scalar(a, "int32")
64
+ ), c = this.wpe.apply(h), r = o.add(c);
65
+ return this.drop.apply(r, { training: s });
66
+ } else
67
+ return this.drop.apply(o, { training: s });
59
68
  });
60
69
  }
61
70
  setSkipMask(t) {
@@ -73,7 +82,7 @@ class $ {
73
82
  set trainable(t) {
74
83
  for (const e of this.blocks)
75
84
  e.trainable = t;
76
- this.wpe.trainable = t, this.lnF.trainable = t;
85
+ this.lnF.trainable = t;
77
86
  }
78
87
  validateInput(t) {
79
88
  if (t.shape.length !== 2)
@@ -96,60 +105,67 @@ class $ {
96
105
  return this.tf.tidy(() => {
97
106
  if (t.length === 0)
98
107
  throw new Error("No attentions for rollout");
99
- const e = t[0].shape[0], s = t[0].shape[1], i = this.tf.eye(s, s).expandDims(0);
100
- let n = i.tile([e, 1, 1]);
101
- for (const h of t) {
102
- let o = h.add(i);
103
- o = o.div(o.sum(-1, !0)), n = o.matMul(n);
108
+ const e = t[0].shape[0], s = t[0].shape[1], o = this.tf.eye(s, s).expandDims(0);
109
+ let i = o.tile([e, 1, 1]);
110
+ for (const a of t) {
111
+ let n = a.add(o);
112
+ n = n.div(n.sum(-1, !0)), i = n.matMul(i);
104
113
  }
105
- return n;
114
+ return i;
106
115
  });
107
116
  }
108
- forward(t, e, s = !1, i = !1) {
117
+ forward(t, e, s = !1, o = !1, i) {
109
118
  return this.validateInput(t), this.tf.tidy(() => {
110
- let n = this.inputPhase(t, s);
119
+ const a = i?.[0]?.length ?? 0;
120
+ let n = this.inputPhase(t, a, s);
111
121
  const h = [];
112
- for (const c of this.blocks) {
113
- const { output: d, attention: l } = c.call(n, s, i);
114
- n = d, i && l && h.push(l);
122
+ if (i && i.length !== this.blocks.length)
123
+ throw console.error("Cache", i), new Error(`Cache length ${i.length} does not match number of blocks ${this.blocks.length}`);
124
+ for (let l = 0; l < this.blocks.length; l++) {
125
+ const d = this.blocks[l], {
126
+ output: g,
127
+ attention: b,
128
+ cache: p
129
+ } = d.call(n, s, o, i ? i[l] : void 0);
130
+ n = g, o && b && h.push(b), i && p ? (i[l]?.k.dispose(), i[l]?.v.dispose(), i[l] = p) : p && (p.k.dispose(), p.v.dispose());
115
131
  }
116
- let o;
117
- i && h.length > 0 && (o = this.computeAttentionRollout(h)), n = this.lnF.apply(n);
118
- const a = this.wte.project(n);
119
- let r;
120
- return e && (r = this.calculateLoss(a, e)), { logits: a, loss: r, attention: i ? o : void 0 };
132
+ let c;
133
+ o && h.length > 0 && (c = this.computeAttentionRollout(h)), n = this.lnF.apply(n);
134
+ const r = this.wte.project(n);
135
+ let f;
136
+ return e && (f = this.calculateLoss(r, e)), { logits: r, loss: f, attention: o ? c : void 0 };
121
137
  });
122
138
  }
123
- generate(t, e) {
124
- const s = e?.temperature ?? 1, i = e?.topK, n = e?.usePadding ?? !1, h = e?.includeAttention ?? !1;
139
+ generate(t, e, s) {
140
+ const o = s?.temperature ?? 1, i = s?.topK, a = s?.usePadding ?? !1, n = s?.includeAttention ?? !1;
125
141
  return this.tf.tidy(() => {
126
- const o = t, a = o.shape[1], r = a <= this.config.blockSize ? o : o.slice(
127
- [0, a - this.config.blockSize],
128
- [o.shape[0], this.config.blockSize]
129
- ), c = n ? this.config.blockSize - r.shape[1] : 0, d = c > 0 ? this.tf.pad(r, [
142
+ const h = t, c = h.shape[1], r = c <= this.config.blockSize ? h : h.slice(
143
+ [0, c - this.config.blockSize],
144
+ [h.shape[0], this.config.blockSize]
145
+ ), f = a ? this.config.blockSize - r.shape[1] : 0, l = f > 0 ? this.tf.pad(r, [
130
146
  [0, 0],
131
- [0, c]
132
- ]) : r, { logits: l, attention: p } = this.forward(d, void 0, !1, h), b = l.shape[1] - 1 - c, u = l.slice([0, b, 0], [l.shape[0], 1, l.shape[2]]), k = p ? p.slice([0, b, 0], [p.shape[0], 1, p.shape[2]]) : void 0, g = u.div(s);
133
- let f;
147
+ [0, f]
148
+ ]) : r, { logits: d, attention: g } = this.forward(l, void 0, !1, n, e), b = d.shape[1] - 1 - f, p = d.slice([0, b, 0], [d.shape[0], 1, d.shape[2]]), w = g ? g.slice([0, b, 0], [g.shape[0], 1, g.shape[2]]) : void 0, u = p.div(o);
149
+ let m;
134
150
  if (i) {
135
- const { values: w, indices: E } = this.tf.topk(g, i), y = this.tf.multinomial(w.squeeze([1]), 1);
136
- f = this.tf.gather(E.squeeze([1]), y, 1);
151
+ const { values: E, indices: y } = this.tf.topk(u, i), z = this.tf.multinomial(E.squeeze([1]), 1);
152
+ m = this.tf.gather(y.squeeze([1]), z, 1);
137
153
  } else
138
- f = this.tf.multinomial(g.squeeze([1]), 1);
139
- let m;
140
- return e?.includeProbabilities && (m = this.tf.softmax(g.squeeze([1]))), f = f.reshape([1, 1]), { output: f, attention: k?.squeeze([1]), probabilities: m };
154
+ m = this.tf.multinomial(u.squeeze([1]), 1);
155
+ let k;
156
+ return s?.includeProbabilities && (k = this.tf.softmax(u.squeeze([1]))), m = m.reshape([1, 1]), { output: m, attention: w?.squeeze([1]), probabilities: k };
141
157
  });
142
158
  }
143
159
  getNumParams() {
144
160
  const t = this.config.vocabSize * this.config.nEmbed + this.config.blockSize * this.config.nEmbed, e = this.config.nLayer * (4 * this.config.nEmbed * this.config.nEmbed + // qkv + proj
145
161
  2 * this.config.nEmbed), s = this.config.nLayer * (4 * this.config.nEmbed * this.config.nEmbed + // fc
146
- this.config.nEmbed * 4 * this.config.nEmbed), i = this.config.nEmbed + this.config.vocabSize * this.config.nEmbed;
147
- return t + e + s + i;
162
+ this.config.nEmbed * 4 * this.config.nEmbed), o = this.config.nEmbed + this.config.vocabSize * this.config.nEmbed;
163
+ return t + e + s + o;
148
164
  }
149
165
  dispose() {
150
- this.wte.dispose(), this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
166
+ this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
151
167
  }
152
168
  }
153
169
  export {
154
- $ as default
170
+ F as default
155
171
  };
@@ -47,25 +47,25 @@ class a extends c {
47
47
  }
48
48
  static loadModel(t, r) {
49
49
  const e = new a(t);
50
- return l(t, r).then(({ model: i, tokeniser: o }) => {
51
- e._model = i, e._tokeniser = o, e._config = i.config, e.setStatus("warmup"), h(i).then(() => {
50
+ return l(t, r).then(({ model: s, tokeniser: o }) => {
51
+ e._model = s, e._tokeniser = o, e._config = s.config, e.setStatus("warmup"), h(s).then(() => {
52
52
  e.setStatus("ready");
53
- }).catch((s) => {
54
- e.setStatus("error"), e.emit("error", s);
53
+ }).catch((i) => {
54
+ e.setStatus("error"), e.emit("error", i);
55
55
  });
56
- }).catch((i) => {
57
- e.setStatus("error"), e.emit("error", i);
56
+ }).catch((s) => {
57
+ e.setStatus("error"), e.emit("error", s);
58
58
  }), e;
59
59
  }
60
60
  static create(t, r = {}) {
61
- const e = { ...u, ...r }, i = new g(e.vocabSize), o = new d(t, e), s = new a(t, i, o);
62
- return s.setStatus("warmup"), h(o).then(() => {
63
- s.setStatus("awaitingTokens"), s.tokeniser.once("trainStatus", (n) => {
64
- n === "trained" && s.setStatus("ready");
65
- });
61
+ const e = { ...u, ...r }, s = new g(e.vocabSize), o = new d(t, e), i = new a(t, s, o);
62
+ return i.setStatus("warmup"), h(o).then(() => {
63
+ i.tokeniser.trained ? i.setStatus("ready") : (i.setStatus("awaitingTokens"), i.tokeniser.once("trainStatus", (n) => {
64
+ n === "trained" && i.setStatus("ready");
65
+ }));
66
66
  }).catch((n) => {
67
- s.setStatus("error"), s.emit("error", n);
68
- }), s;
67
+ i.setStatus("error"), i.emit("error", n);
68
+ }), i;
69
69
  }
70
70
  getNumParams() {
71
71
  if (!this._model)
@@ -78,8 +78,8 @@ class a extends c {
78
78
  const t = new _(this._model, this._tokeniser);
79
79
  return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (r) => {
80
80
  const e = this.listeners("trainStep");
81
- for (const i of e)
82
- await i(r);
81
+ for (const s of e)
82
+ await s(r);
83
83
  }), t;
84
84
  }
85
85
  train(t, r) {
package/dist/Trainer.d.ts CHANGED
@@ -12,7 +12,9 @@ export interface ITrainerOptions {
12
12
  }
13
13
  export default class Trainer extends EE<'start' | 'stop' | 'log'> {
14
14
  private trainer;
15
+ private hasTrained;
15
16
  constructor(model: NanoGPT, tokeniser: ITokeniser);
16
17
  stop(): void;
18
+ reset(): void;
17
19
  train(text: string[], options?: ITrainerOptions): Promise<void>;
18
20
  }
package/dist/Trainer.js CHANGED
@@ -1,11 +1,16 @@
1
1
  import { E as l } from "./index-SOhdqzHq.js";
2
- import o from "./training/FullTrainer.js";
3
- class d extends l {
2
+ import h from "./training/FullTrainer.js";
3
+ class m extends l {
4
4
  trainer;
5
+ hasTrained = !1;
5
6
  constructor(a, t) {
6
- super(), this.trainer = new o(a.tf, a, t, 1e-3);
7
+ super(), this.trainer = new h(a.tf, a, t, 1e-3);
7
8
  }
8
9
  stop() {
10
+ this.trainer.stop();
11
+ }
12
+ reset() {
13
+ this.hasTrained = !1, this.trainer.reset();
9
14
  }
10
15
  async train(a, t) {
11
16
  const { trainDataset: e, validationDataset: r } = await this.trainer.createTrainValidationSplit(
@@ -13,7 +18,7 @@ class d extends l {
13
18
  t?.batchSize || 32,
14
19
  t?.validationSplit || 0.1
15
20
  );
16
- this.trainer.setLearningRate(t?.learningRate || 1e-3), this.emit("start"), await this.trainer.trainOnDataset(
21
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), await this.trainer.trainOnDataset(
17
22
  e,
18
23
  {
19
24
  prompt: t?.prompt,
@@ -31,5 +36,5 @@ class d extends l {
31
36
  }
32
37
  }
33
38
  export {
34
- d as default
39
+ m as default
35
40
  };
package/dist/config.d.ts CHANGED
@@ -8,5 +8,6 @@ export interface GPTConfig {
8
8
  biasInLinear: boolean;
9
9
  biasInLayerNorm: boolean;
10
10
  mlpFactor: number;
11
+ useRope: boolean;
11
12
  }
12
13
  export declare const defaultConfig: GPTConfig;
package/dist/config.js CHANGED
@@ -1,4 +1,4 @@
1
- const a = {
1
+ const e = {
2
2
  vocabSize: 50304,
3
3
  // GPT-2 vocab size
4
4
  blockSize: 1024,
@@ -13,8 +13,10 @@ const a = {
13
13
  // Dropout probability
14
14
  biasInLinear: !1,
15
15
  biasInLayerNorm: !1,
16
- mlpFactor: 4
16
+ mlpFactor: 4,
17
+ useRope: !1
18
+ // Use Rotary Position Embeddings
17
19
  };
18
20
  export {
19
- a as defaultConfig
21
+ e as defaultConfig
20
22
  };
@@ -1,6 +1,14 @@
1
1
  import { default as TF } from '@tensorflow/tfjs';
2
2
  import { GPTConfig } from '../config';
3
+ import { default as RoPECache } from './RoPECache';
4
+ export type KVCache = {
5
+ k: TF.Tensor;
6
+ v: TF.Tensor;
7
+ length: number;
8
+ cumulativeLength: number;
9
+ };
3
10
  export default class CausalSelfAttention {
11
+ private readonly ropeCache?;
4
12
  private config;
5
13
  private cAttn;
6
14
  private cProj;
@@ -12,18 +20,20 @@ export default class CausalSelfAttention {
12
20
  private divisor;
13
21
  private index;
14
22
  private _trainable;
15
- constructor(tf: typeof TF, index: number, config: GPTConfig);
23
+ constructor(tf: typeof TF, index: number, config: GPTConfig, ropeCache?: RoPECache | undefined);
16
24
  get variables(): TF.Variable[];
17
25
  get trainable(): boolean;
18
26
  set trainable(value: boolean);
19
27
  saveWeights(map: Map<string, TF.Tensor[]>): void;
20
28
  loadWeights(weights: Map<string, TF.Tensor[]>): void;
21
29
  private getAttentionScores;
30
+ private getAttentionScoresWithPast;
22
31
  private getQKV;
23
32
  private getOutputProjection;
24
- call(x: TF.Tensor, training?: boolean, includeAttention?: boolean): {
33
+ call(x: TF.Tensor, training?: boolean, includeAttention?: boolean, pastKV?: KVCache): {
25
34
  output: TF.Tensor;
26
35
  attention?: TF.Tensor;
36
+ presentKV?: KVCache;
27
37
  };
28
38
  dispose(): void;
29
39
  }
@@ -1,20 +1,9 @@
1
- class m {
2
- config;
3
- cAttn;
4
- cProj;
5
- attnDropout;
6
- residDropout;
7
- bias;
8
- maskInf;
9
- tf;
10
- divisor;
11
- index;
12
- _trainable = !0;
13
- constructor(t, e, s) {
14
- this.config = s, this.tf = t, this.index = e, this.cAttn = this.tf.layers.dense({
1
+ class S {
2
+ constructor(t, i, s, e) {
3
+ this.ropeCache = e, this.config = s, this.tf = t, this.index = i, this.cAttn = this.tf.layers.dense({
15
4
  units: 3 * s.nEmbed,
16
5
  useBias: s.biasInLinear,
17
- name: `block_${e}_attn_cAttn`,
6
+ name: `block_${i}_attn_cAttn`,
18
7
  kernelInitializer: this.tf.initializers.randomNormal({
19
8
  mean: 0,
20
9
  stddev: 0.02
@@ -23,14 +12,27 @@ class m {
23
12
  }), this.cProj = this.tf.layers.dense({
24
13
  units: s.nEmbed,
25
14
  useBias: s.biasInLinear,
26
- name: `block_${e}_attn_cProj`,
15
+ name: `block_${i}_attn_cProj`,
27
16
  kernelInitializer: this.tf.initializers.randomNormal({
28
17
  mean: 0,
29
18
  stddev: 0.02 / Math.sqrt(2 * s.nLayer)
30
19
  }),
31
20
  biasInitializer: "zeros"
32
- }), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = this.tf.scalar(1 / Math.sqrt(s.nEmbed / s.nHead)), this.maskInf = this.tf.zeros([s.blockSize, s.blockSize]).where(this.bias, -1 / 0);
21
+ }), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = this.tf.scalar(1 / Math.sqrt(s.nEmbed / s.nHead));
22
+ const a = this.tf.zeros([s.blockSize, s.blockSize]), h = this.tf.fill([s.blockSize, s.blockSize], Number.NEGATIVE_INFINITY);
23
+ this.maskInf = this.tf.where(this.bias, a, h);
33
24
  }
25
+ config;
26
+ cAttn;
27
+ cProj;
28
+ attnDropout;
29
+ residDropout;
30
+ bias;
31
+ maskInf;
32
+ tf;
33
+ divisor;
34
+ index;
35
+ _trainable = !0;
34
36
  get variables() {
35
37
  return [
36
38
  ...this.cAttn.trainableWeights.map((t) => t.read()),
@@ -49,34 +51,65 @@ class m {
49
51
  loadWeights(t) {
50
52
  this.cAttn.setWeights(t.get(`block_${this.index}_cAttn`) || []), this.cProj.setWeights(t.get(`block_${this.index}_cProj`) || []);
51
53
  }
52
- getAttentionScores(t, e, s) {
53
- const a = t.shape[2], o = this.tf.matMul(t, e, !1, !0).mul(this.divisor), i = this.maskInf.slice([0, 0], [a, a]), n = o.add(i), h = this.tf.softmax(n, -1);
54
- return this.attnDropout.apply(h, { training: s });
54
+ getAttentionScores(t, i, s) {
55
+ const e = t.shape[2], h = this.tf.matMul(t, i, !1, !0).mul(this.divisor), n = this.maskInf.slice([0, 0], [e, e]).expandDims(0).expandDims(0), r = h.add(n), o = this.tf.softmax(r, -1);
56
+ return this.attnDropout.apply(o, { training: s });
57
+ }
58
+ // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
59
+ getAttentionScoresWithPast(t, i, s, e) {
60
+ const a = t.shape[2];
61
+ let n = this.tf.matMul(t, i, !1, !0).mul(this.divisor);
62
+ if (a > 1 && e > 0)
63
+ throw new Error("Cannot use past with T_cur > 1");
64
+ if (a > 1) {
65
+ const o = this.maskInf.slice([0, 0], [a, a]).expandDims(0).expandDims(0);
66
+ n = n.add(o);
67
+ }
68
+ const r = this.tf.softmax(n, -1);
69
+ return this.attnDropout.apply(r, { training: s });
55
70
  }
56
71
  getQKV(t) {
57
- const [e, s, a] = t.shape, r = this.cAttn.apply(t), [o, i, n] = this.tf.split(r, 3, -1);
58
- r.dispose();
59
- const h = a / this.config.nHead, c = this.tf.reshape(o, [e, s, this.config.nHead, h]);
60
- o.dispose();
61
- const l = c.transpose([0, 2, 1, 3]);
62
- c.dispose();
63
- const d = this.tf.reshape(i, [e, s, this.config.nHead, h]);
64
- i.dispose();
65
- const u = d.transpose([0, 2, 1, 3]);
66
- d.dispose();
67
- const p = this.tf.reshape(n, [e, s, this.config.nHead, h]);
72
+ const [i, s, e] = t.shape, a = this.cAttn.apply(t), [h, n, r] = this.tf.split(a, 3, -1);
73
+ a.dispose();
74
+ const o = e / this.config.nHead, u = this.tf.reshape(h, [i, s, this.config.nHead, o]);
75
+ h.dispose();
76
+ const f = u.transpose([0, 2, 1, 3]);
77
+ u.dispose();
78
+ const d = this.tf.reshape(n, [i, s, this.config.nHead, o]);
68
79
  n.dispose();
69
- const b = p.transpose([0, 2, 1, 3]);
70
- return p.dispose(), [l, u, b];
80
+ const c = d.transpose([0, 2, 1, 3]);
81
+ d.dispose();
82
+ const l = this.tf.reshape(r, [i, s, this.config.nHead, o]);
83
+ r.dispose();
84
+ const p = l.transpose([0, 2, 1, 3]);
85
+ return l.dispose(), [f, c, p];
71
86
  }
72
- getOutputProjection(t, e) {
73
- const s = t.shape[0], a = t.shape[2], r = this.config.nEmbed, o = t.transpose([0, 2, 1, 3]), i = this.tf.reshape(o, [s, a, r]), n = this.cProj.apply(i);
74
- return this.residDropout.apply(n, { training: e });
87
+ getOutputProjection(t, i) {
88
+ const s = t.shape[0], e = t.shape[2], a = this.config.nEmbed, h = t.transpose([0, 2, 1, 3]), n = this.tf.reshape(h, [s, e, a]), r = this.cProj.apply(n);
89
+ return this.residDropout.apply(r, { training: i });
75
90
  }
76
- call(t, e = !1, s = !1) {
91
+ // Added optional KV cache support (pastKV). Returns presentKV for chaining.
92
+ call(t, i = !1, s = !1, e) {
93
+ if (e && !this.config.useRope)
94
+ throw new Error("Cannot use pastKV without RoPE enabled");
77
95
  return this.tf.tidy(() => {
78
- const [a, r, o] = this.getQKV(t), i = this.getAttentionScores(a, r, e), n = this.tf.matMul(i, o);
79
- return { output: this.getOutputProjection(n, e), attention: s ? i.mean(1) : void 0 };
96
+ const [a, h, n] = this.getQKV(t), r = a.shape[2], o = this.config.blockSize, u = e ? e.cumulativeLength : 0, [f, d] = this.ropeCache ? this.ropeCache.applyRoPE(a, h, u) : [a, h];
97
+ let c = d, l = n, p = 0;
98
+ e && (p = e.length, c = this.tf.concat([e.k, d], 2), l = this.tf.concat([e.v, n], 2));
99
+ const b = c.shape[2];
100
+ if (b > o) {
101
+ const k = b - o, g = c.shape[0], v = c.shape[1], I = c.shape[3];
102
+ c = c.slice([0, 0, k, 0], [g, v, o, I]), l = l.slice([0, 0, k, 0], [g, v, o, I]), p = o - r;
103
+ }
104
+ let m;
105
+ p > 0 ? m = this.getAttentionScoresWithPast(f, c, i, p) : m = this.getAttentionScores(f, c, i);
106
+ const _ = this.tf.matMul(m, l), A = this.getOutputProjection(_, i), P = {
107
+ k: this.tf.keep(c),
108
+ v: this.tf.keep(l),
109
+ length: p + r,
110
+ cumulativeLength: e ? e.cumulativeLength + r : r
111
+ };
112
+ return { output: A, attention: s ? m.mean(1) : void 0, presentKV: P };
80
113
  });
81
114
  }
82
115
  dispose() {
@@ -84,5 +117,5 @@ class m {
84
117
  }
85
118
  }
86
119
  export {
87
- m as default
120
+ S as default
88
121
  };
@@ -0,0 +1,13 @@
1
+ import { default as TF } from '@tensorflow/tfjs';
2
+ export default class RMSNorm {
3
+ private gamma;
4
+ private epsilon;
5
+ private tf;
6
+ constructor(tf: typeof TF, shape: number[], epsilon?: number, name?: string);
7
+ get trainableWeights(): TF.Variable[];
8
+ set trainable(value: boolean);
9
+ getWeights(): TF.Tensor[];
10
+ setWeights(weights: TF.Tensor[]): void;
11
+ apply(x: TF.Tensor): TF.Tensor;
12
+ dispose(): void;
13
+ }
@@ -0,0 +1,32 @@
1
+ class m {
2
+ gamma;
3
+ epsilon;
4
+ tf;
5
+ constructor(a, s, t = 1e-8, e = "") {
6
+ this.tf = a, this.epsilon = t, this.gamma = a.variable(a.ones(s), !0, `${e}_gamma`, "float32");
7
+ }
8
+ get trainableWeights() {
9
+ return [this.gamma];
10
+ }
11
+ set trainable(a) {
12
+ this.gamma.trainable = a;
13
+ }
14
+ getWeights() {
15
+ return [this.gamma];
16
+ }
17
+ setWeights(a) {
18
+ this.gamma.assign(a[0]);
19
+ }
20
+ apply(a) {
21
+ return this.tf.tidy(() => {
22
+ const t = a.square().mean(-1, !0).add(this.epsilon).rsqrt();
23
+ return a.mul(t).mul(this.gamma);
24
+ });
25
+ }
26
+ dispose() {
27
+ this.gamma.dispose();
28
+ }
29
+ }
30
+ export {
31
+ m as default
32
+ };
@@ -0,0 +1,16 @@
1
+ import { default as TF } from '@tensorflow/tfjs';
2
+ import { GPTConfig } from '../config';
3
+ export default class RoPECache {
4
+ private readonly tf;
5
+ private readonly config;
6
+ private rotaryDim;
7
+ private ropeBase;
8
+ private ropeInvFreq;
9
+ private ropeCos;
10
+ private ropeSin;
11
+ private ropeCacheLen;
12
+ constructor(tf: typeof TF, config: GPTConfig);
13
+ private ensureRopeCache;
14
+ applyRoPE(q: TF.Tensor, k: TF.Tensor, pastLen: number): [TF.Tensor, TF.Tensor];
15
+ dispose(): void;
16
+ }
@@ -0,0 +1,39 @@
1
+ class E {
2
+ constructor(t, c) {
3
+ this.tf = t, this.config = c;
4
+ const e = this.config.nEmbed / this.config.nHead;
5
+ if (this.rotaryDim = e, this.rotaryDim % 2 !== 0)
6
+ throw new Error("rotaryDim must be even");
7
+ this.ropeBase = 1e4;
8
+ const o = this.tf.range(0, this.rotaryDim, 2, "float32").div(this.tf.scalar(this.rotaryDim, "float32")), s = this.tf.pow(this.tf.scalar(this.ropeBase, "float32"), o);
9
+ this.ropeInvFreq = this.tf.reciprocal(s), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : this.ensureRopeCache(this.config.blockSize * 4);
10
+ }
11
+ rotaryDim;
12
+ ropeBase;
13
+ ropeInvFreq;
14
+ ropeCos = null;
15
+ // [cacheLen, rotaryDim/2]
16
+ ropeSin = null;
17
+ // [cacheLen, rotaryDim/2]
18
+ ropeCacheLen = 0;
19
+ ensureRopeCache(t) {
20
+ if (t <= this.ropeCacheLen) return;
21
+ this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose();
22
+ const e = this.tf.range(0, t, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
23
+ this.ropeCos = this.tf.keep(this.tf.cos(e).expandDims(-1)), this.ropeSin = this.tf.keep(this.tf.sin(e).expandDims(-1)), this.ropeCacheLen = t;
24
+ }
25
+ applyRoPE(t, c, e) {
26
+ const h = t.shape[3], o = this.rotaryDim;
27
+ if (o > h) return [t, c];
28
+ const s = t.shape[2], S = e + s;
29
+ this.ensureRopeCache(S);
30
+ const n = o / 2, g = this.ropeCos.slice([e, 0, 0], [s, n, 1]), v = this.ropeSin.slice([e, 0, 0], [s, n, 1]), l = g.reshape([1, 1, s, n, 1]), f = v.reshape([1, 1, s, n, 1]), p = this.tf.concat([t, c], 0), r = p.shape[0], i = p.shape[1], y = p.slice([0, 0, 0, 0], [r, i, s, o]), u = o < h ? p.slice([0, 0, 0, o], [r, i, s, h - o]) : null, d = y.reshape([r, i, s, n, 2]), m = d.slice([0, 0, 0, 0, 0], [r, i, s, n, 1]), C = d.slice([0, 0, 0, 0, 1], [r, i, s, n, 1]), B = m.mul(l).sub(C.mul(f)), b = C.mul(l).add(m.mul(f)), D = this.tf.concat([B, b], -1).reshape([r, i, s, o]), R = u ? this.tf.concat([D, u], 3) : D, a = r / 2, x = R.slice([0, 0, 0, 0], [a, i, s, h]), P = R.slice([a, 0, 0, 0], [a, i, s, h]);
31
+ return [x, P];
32
+ }
33
+ dispose() {
34
+ this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose(), this.ropeInvFreq.dispose();
35
+ }
36
+ }
37
+ export {
38
+ E as default
39
+ };
@@ -1,5 +1,7 @@
1
1
  import { default as TF } from '@tensorflow/tfjs';
2
2
  import { GPTConfig } from '../config';
3
+ import { KVCache } from './CausalSelfAttention';
4
+ import { default as RoPECache } from './RoPECache';
3
5
  export default class Block {
4
6
  private ln1;
5
7
  private attn;
@@ -9,16 +11,17 @@ export default class Block {
9
11
  private index;
10
12
  private _trainable;
11
13
  skipped: boolean;
12
- constructor(tf: typeof TF, index: number, config: GPTConfig);
14
+ constructor(tf: typeof TF, index: number, config: GPTConfig, ropeCache?: RoPECache);
13
15
  get variables(): TF.Variable[];
14
16
  get trainable(): boolean;
15
17
  set trainable(value: boolean);
16
18
  saveWeights(map: Map<string, TF.Tensor[]>): void;
17
19
  loadWeights(weights: Map<string, TF.Tensor[]>): void;
18
20
  private getMLPOutput;
19
- call(x: TF.Tensor, training?: boolean, includeAttention?: boolean): {
21
+ call(x: TF.Tensor, training?: boolean, includeAttention?: boolean, cache?: KVCache): {
20
22
  output: TF.Tensor;
21
23
  attention?: TF.Tensor;
24
+ cache?: KVCache;
22
25
  };
23
26
  dispose(): void;
24
27
  }
@@ -1,6 +1,6 @@
1
- import h from "./CausalSelfAttention.js";
2
- import r from "./MLP.js";
3
- import l from "./LayerNorm.js";
1
+ import r from "./CausalSelfAttention.js";
2
+ import o from "./MLP.js";
3
+ import a from "./RMSNorm.js";
4
4
  class u {
5
5
  ln1;
6
6
  attn;
@@ -10,8 +10,8 @@ class u {
10
10
  index;
11
11
  _trainable = !0;
12
12
  skipped = !1;
13
- constructor(t, i, s) {
14
- this.tf = t, this.index = i, this.ln1 = new l(t, [s.nEmbed], 1e-5, `block_${this.index}_ln1`), this.attn = new h(this.tf, this.index, s), this.ln2 = new l(t, [s.nEmbed], 1e-5, `block_${this.index}_ln2`), this.mlp = new r(this.tf, this.index, s);
13
+ constructor(t, i, s, e) {
14
+ this.tf = t, this.index = i, this.ln1 = new a(t, [s.nEmbed], 1e-8, `block_${this.index}_rms1`), this.attn = new r(this.tf, this.index, s, e), this.ln2 = new a(t, [s.nEmbed], 1e-8, `block_${this.index}_rms2`), this.mlp = new o(this.tf, this.index, s);
15
15
  }
16
16
  get variables() {
17
17
  return [
@@ -28,21 +28,25 @@ class u {
28
28
  this._trainable = t, this.ln1.trainable = t, this.ln2.trainable = t, this.attn.trainable = t, this.mlp.trainable = t;
29
29
  }
30
30
  saveWeights(t) {
31
- this.attn.saveWeights(t), this.mlp.saveWeights(t), t.set(`block_${this.index}_ln1`, this.ln1.getWeights()), t.set(`block_${this.index}_ln2`, this.ln2.getWeights());
31
+ this.attn.saveWeights(t), this.mlp.saveWeights(t), t.set(`block_${this.index}_rms1`, this.ln1.getWeights()), t.set(`block_${this.index}_rms2`, this.ln2.getWeights());
32
32
  }
33
33
  loadWeights(t) {
34
- this.attn.loadWeights(t), this.mlp.loadWeights(t), this.ln1.setWeights(t.get(`block_${this.index}_ln1`) || []), this.ln2.setWeights(t.get(`block_${this.index}_ln2`) || []);
34
+ this.attn.loadWeights(t), this.mlp.loadWeights(t), this.ln1.setWeights(t.get(`block_${this.index}_rms1`) || []), this.ln2.setWeights(t.get(`block_${this.index}_rms2`) || []);
35
35
  }
36
36
  getMLPOutput(t, i) {
37
37
  const s = this.ln2.apply(t), e = this.mlp.call(s, i);
38
38
  return t.add(e);
39
39
  }
40
- call(t, i = !1, s = !1) {
40
+ call(t, i = !1, s = !1, e) {
41
41
  return this.tf.tidy(() => {
42
42
  if (this.skipped)
43
43
  return { output: t };
44
- const e = this.ln1.apply(t), n = this.attn.call(e, i, s), a = t.add(n.output);
45
- return { output: this.getMLPOutput(a, i), attention: n.attention };
44
+ const l = this.ln1.apply(t), n = this.attn.call(l, i, s, e), h = t.add(n.output);
45
+ return {
46
+ output: this.getMLPOutput(h, i),
47
+ attention: n.attention,
48
+ cache: n.presentKV
49
+ };
46
50
  });
47
51
  }
48
52
  dispose() {
@@ -1,70 +1,68 @@
1
1
  import { generateText as L } from "../utilities/generate.js";
2
2
  import w from "./Trainer.js";
3
- import g from "./Evaluator.js";
4
- const x = {
3
+ import x from "./Evaluator.js";
4
+ const g = {
5
5
  desiredLoss: 0.01,
6
6
  logInterval: 1,
7
7
  maxSteps: 1e3
8
8
  };
9
- class D extends w {
9
+ class P extends w {
10
10
  constructor(r, i, o, n = 3e-4) {
11
11
  super(r, i, o, n);
12
12
  }
13
13
  // Train for multiple epochs using Dataset API - FIXED memory leaks
14
14
  async trainOnDataset(r, i, o) {
15
- const { desiredLoss: n, logInterval: d, onStep: l, prompt: p, maxSteps: m } = {
16
- ...x,
15
+ const { desiredLoss: n, logInterval: m, onStep: l, prompt: c, maxSteps: d } = {
16
+ ...g,
17
17
  ...i
18
- }, s = {
19
- pass: 0,
20
- depth: 1,
18
+ }, t = {
21
19
  step: 0,
22
- stepSinceDepthChange: 0,
23
20
  lastLoss: 1e6,
24
21
  totalSteps: 0,
25
22
  losses: [],
26
- validationLosses: []
23
+ validationLosses: [],
24
+ ...this.lastState || {}
27
25
  };
28
- this.dummyPass(), this.model.trainable = !0;
26
+ this.lastState = t, this.dummyPass(), this.model.trainable = !0;
29
27
  const u = Date.now();
30
28
  this.running = !0;
31
- const c = o ? new g(this.model, o) : void 0, f = await r.iterator();
29
+ const h = o ? new x(this.model, o) : void 0, f = await r.iterator();
32
30
  try {
33
- for (; this.running && !(s.lastLoss < n); ) {
31
+ for (; this.running && !(t.lastLoss < n); ) {
34
32
  const e = await f.next();
35
33
  if (e.done) break;
36
- const h = e.value, v = this.trainBatch(s, h), a = {
37
- loss: s.lastLoss,
38
- step: s.step,
34
+ const p = e.value, v = this.trainBatch(t, p), a = {
35
+ loss: t.lastLoss,
36
+ step: t.step,
39
37
  time: Date.now() - u,
40
- batchSize: h.xs.shape[0]
38
+ batchSize: p.xs.shape[0]
41
39
  };
42
- if (this.model.log.push(a), s.step % d === 0) {
43
- if (await v, c)
40
+ if (this.model.log.push(a), t.step % m === 0) {
41
+ if (await v, h)
44
42
  try {
45
- const t = await c.evaluate(5);
46
- s.validationLosses.push(t), a.valLoss = t;
47
- } catch (t) {
48
- console.error("Validation error:", t);
43
+ const s = await h.evaluate(5);
44
+ t.validationLosses.push(s), a.valLoss = s;
45
+ } catch (s) {
46
+ console.error("Validation error:", s);
49
47
  }
50
48
  if (l) {
51
- if (p) {
52
- const t = await L(this.tokenizer, this.model, p, 100, {
49
+ if (c) {
50
+ const s = await L(this.tokenizer, this.model, c, 100, {
53
51
  temperature: 0.8
54
52
  });
55
- a.example = t;
53
+ a.example = s;
56
54
  }
57
55
  await l(a);
58
56
  }
59
57
  }
60
- s.step >= m && this.stop();
58
+ t.step >= d && this.stop();
61
59
  }
62
60
  } catch (e) {
63
61
  throw console.error("Training error:", e), this.tf.dispose(), e;
64
62
  }
65
- return this.tf.dispose(), this.running = !1, { losses: s.losses, validationLosses: s.validationLosses };
63
+ return this.tf.dispose(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
66
64
  }
67
65
  }
68
66
  export {
69
- D as default
67
+ P as default
70
68
  };
@@ -31,8 +31,10 @@ export default abstract class GPTTrainer {
31
31
  protected tf: typeof TF;
32
32
  protected learningRate: number;
33
33
  protected running: boolean;
34
+ protected lastState?: TrainingState;
34
35
  constructor(tf: typeof TF, model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
35
36
  setLearningRate(learningRate: number): void;
37
+ reset(): void;
36
38
  stop(): void;
37
39
  getOptimizer(): AdamExt;
38
40
  resetOptimizer(config?: AdamConfig): void;
@@ -1,8 +1,8 @@
1
1
  import { DatasetBuilder as d } from "./DatasetBuilder.js";
2
- import p from "./AdamExt.js";
2
+ import h from "./AdamExt.js";
3
3
  class u {
4
- constructor(t, s, e, i = 1e-3) {
5
- this.tokenizer = e, this.tf = t, this.model = s, this.learningRate = i, this.resetOptimizer(), this.datasetBuilder = new d(this.tf, e, s.config.blockSize);
4
+ constructor(t, e, s, i = 1e-3) {
5
+ this.tokenizer = s, this.tf = t, this.model = e, this.learningRate = i, this.resetOptimizer(), this.datasetBuilder = new d(this.tf, s, e.config.blockSize);
6
6
  }
7
7
  model;
8
8
  optimizer;
@@ -10,9 +10,13 @@ class u {
10
10
  tf;
11
11
  learningRate;
12
12
  running = !1;
13
+ lastState;
13
14
  setLearningRate(t) {
14
15
  this.learningRate = t, this.resetOptimizer({ learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 });
15
16
  }
17
+ reset() {
18
+ this.lastState = void 0, this.running = !1;
19
+ }
16
20
  stop() {
17
21
  this.running = !1;
18
22
  }
@@ -21,7 +25,7 @@ class u {
21
25
  }
22
26
  resetOptimizer(t = { learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 }) {
23
27
  this.optimizer && this.optimizer.dispose();
24
- const s = new p(
28
+ const e = new h(
25
29
  t.learningRateFactor * this.learningRate,
26
30
  t.beta1,
27
31
  t.beta2,
@@ -33,53 +37,53 @@ class u {
33
37
  weightDecay: 0
34
38
  }
35
39
  );
36
- this.optimizer = s;
40
+ this.optimizer = e;
37
41
  }
38
42
  printGradients(t) {
39
- Object.keys(t).forEach((s) => {
40
- const e = t[s];
41
- console.log(`${s}:`), console.log(` Shape: ${e.shape}`), console.log(` Mean: ${this.tf.mean(e).dataSync()[0]}`), console.log(` Std: ${this.tf.moments(e).variance.sqrt().dataSync()[0]}`), console.log(` Min: ${this.tf.min(e).dataSync()[0]}`), console.log(` Max: ${this.tf.max(e).dataSync()[0]}`), console.log(` Norm: ${this.tf.norm(e).dataSync()[0]}`);
43
+ Object.keys(t).forEach((e) => {
44
+ const s = t[e];
45
+ console.log(`${e}:`), console.log(` Shape: ${s.shape}`), console.log(` Mean: ${this.tf.mean(s).dataSync()[0]}`), console.log(` Std: ${this.tf.moments(s).variance.sqrt().dataSync()[0]}`), console.log(` Min: ${this.tf.min(s).dataSync()[0]}`), console.log(` Max: ${this.tf.max(s).dataSync()[0]}`), console.log(` Norm: ${this.tf.norm(s).dataSync()[0]}`);
42
46
  });
43
47
  }
44
- trainStep(t, s = !1, e = !1) {
48
+ trainStep(t, e = !1, s = !1) {
45
49
  return this.tf.tidy(() => {
46
50
  const { xs: i, ys: a } = t, o = () => {
47
51
  const { loss: l, logits: c } = this.model.forward(i, a, !0);
48
52
  return c.dispose(), l;
49
53
  }, { value: n, grads: r } = this.tf.variableGrads(o);
50
- return s || (e && (console.log("-------"), this.printGradients(r), console.log("-------")), this.optimizer.applyGradients(r), this.tf.dispose(r)), n;
54
+ return e || (s && (console.log("-------"), this.printGradients(r), console.log("-------")), this.optimizer.applyGradients(r), this.tf.dispose(r)), n;
51
55
  });
52
56
  }
53
57
  dummyPass() {
54
- const t = this.tf.zeros([1, this.model.config.blockSize], "int32"), s = this.tf.zeros([1, this.model.config.blockSize, this.model.config.vocabSize]);
58
+ const t = this.tf.zeros([1, this.model.config.blockSize], "int32"), e = this.tf.zeros([1, this.model.config.blockSize, this.model.config.vocabSize]);
55
59
  try {
56
- const e = this.trainStep({ xs: t, ys: s }, !0);
57
- e.dataSync(), e.dispose();
58
- } catch (e) {
59
- console.error("Error during dummy pass:", e);
60
+ const s = this.trainStep({ xs: t, ys: e }, !0);
61
+ s.dataSync(), s.dispose();
62
+ } catch (s) {
63
+ console.error("Error during dummy pass:", s);
60
64
  } finally {
61
- t.dispose(), s.dispose();
65
+ t.dispose(), e.dispose();
62
66
  }
63
67
  }
64
- async trainBatch(t, s) {
68
+ async trainBatch(t, e) {
65
69
  try {
66
- const e = this.trainStep(s, !1, !1);
67
- return s.xs.dispose(), s.ys.dispose(), t.step++, t.totalSteps++, e.array().then((i) => (t.lastLoss = i, t.losses.push(t.lastLoss), e.dispose(), t.lastLoss));
68
- } catch (e) {
69
- throw console.error(`Error processing batch at step ${t.step}:`, e), this.tf.dispose(), e;
70
+ const s = this.trainStep(e, !1, !1);
71
+ return e.xs.dispose(), e.ys.dispose(), t.step++, t.totalSteps++, s.array().then((i) => (t.lastLoss = i, t.losses.push(t.lastLoss), s.dispose(), t.lastLoss));
72
+ } catch (s) {
73
+ throw console.error(`Error processing batch at step ${t.step}:`, s), this.tf.dispose(), s;
70
74
  }
71
75
  }
72
- async createTrainValidationSplit(t, s = 32, e = 0.1) {
73
- const i = await this.datasetBuilder.createTextDataset(t, s, 0, 1 - e), a = await this.datasetBuilder.createTextDataset(
76
+ async createTrainValidationSplit(t, e = 32, s = 0.1) {
77
+ const i = await this.datasetBuilder.createTextDataset(t, e, 0, 1 - s), a = await this.datasetBuilder.createTextDataset(
74
78
  t,
75
- s,
76
- 1 - e,
79
+ e,
80
+ 1 - s,
77
81
  1
78
82
  );
79
83
  return { trainDataset: i, validationDataset: a };
80
84
  }
81
- async createDataset(t, s = 32) {
82
- return await this.datasetBuilder.createTextDataset(t, s);
85
+ async createDataset(t, e = 32) {
86
+ return await this.datasetBuilder.createTextDataset(t, e);
83
87
  }
84
88
  dispose() {
85
89
  this.optimizer && this.optimizer.dispose();
@@ -1,20 +1,20 @@
1
- async function w(n, t, r, s, g) {
2
- if (s <= 0)
1
+ async function h(r, t, a, c, g) {
2
+ if (c <= 0)
3
3
  throw new Error("Length must be a positive integer");
4
- if (r.length === 0)
4
+ if (a.length === 0)
5
5
  throw new Error("Prompt cannot be an empty string");
6
- const i = await n.tokenise([r], !0), a = t.tf.tidy(() => {
7
- let e = t.tf.tensor2d(i, [1, i[0].length], "int32");
8
- for (let d = 0; d < s; d++) {
9
- const { output: p } = t.generate(e, g), f = e;
10
- e = t.tf.concat([e, p], 1), f.dispose(), p.dispose();
6
+ const p = await r.tokenise([a], !0), s = t.config.useRope ? new Array(t.config.nLayer).fill(void 0) : void 0, u = t.tf.tidy(() => {
7
+ let e = t.tf.tensor2d(p, [1, p[0].length], "int32"), n = e;
8
+ for (let f = 0; f < c; f++) {
9
+ const { output: o } = t.generate(e, s, g), w = e, y = n;
10
+ n = t.tf.concat([n, o], 1), e = s ? o : t.tf.concat([e, o], 1), w.dispose(), y.dispose(), s || o.dispose();
11
11
  }
12
- return e;
13
- }), u = await a.array();
14
- a.dispose();
15
- const o = u[0], c = o.indexOf(n.eosToken);
16
- return c !== -1 && o.splice(c), await n.decode(o);
12
+ return n;
13
+ }), T = await u.array();
14
+ u.dispose();
15
+ const i = T[0], d = i.indexOf(r.eosToken);
16
+ return d !== -1 && i.splice(d), await r.decode(i);
17
17
  }
18
18
  export {
19
- w as generateText
19
+ h as generateText
20
20
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.1.8",
3
+ "version": "0.2.0",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",