@genai-fi/nanogpt 0.1.2 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,6 +33,7 @@ export default class NanoGPT {
33
33
  set trainable(value: boolean);
34
34
  private validateInput;
35
35
  private calculateLoss;
36
+ private computeAttentionRollout;
36
37
  forward(idx: TF.Tensor, targets?: TF.Tensor, training?: boolean, includeAttention?: boolean): {
37
38
  logits: TF.Tensor;
38
39
  loss?: TF.Tensor;
@@ -54,8 +54,8 @@ class $ {
54
54
  }
55
55
  inputPhase(t, e = !1) {
56
56
  return this.tf.tidy(() => {
57
- const [, s] = t.shape, n = this.wte.embed(t), i = this.tf.range(0, s, 1, "int32"), o = this.wpe.apply(i), h = n.add(o);
58
- return this.drop.apply(h, { training: e });
57
+ const [, s] = t.shape, i = this.wte.embed(t), n = this.tf.range(0, s, 1, "int32"), h = this.wpe.apply(n), o = i.add(h);
58
+ return this.drop.apply(o, { training: e });
59
59
  });
60
60
  }
61
61
  setSkipMask(t) {
@@ -90,44 +90,60 @@ class $ {
90
90
  throw console.error("Error computing loss:", s), new Error(`Loss computation failed: ${s}`);
91
91
  }
92
92
  }
93
- forward(t, e, s = !1, n = !1) {
93
+ // Attention rollout per Abnar & Zuidema (2020)
94
+ // Expects list of (B, T, T) attention matrices already averaged over heads.
95
+ computeAttentionRollout(t) {
96
+ return this.tf.tidy(() => {
97
+ if (t.length === 0)
98
+ throw new Error("No attentions for rollout");
99
+ const e = t[0].shape[0], s = t[0].shape[1], i = this.tf.eye(s, s).expandDims(0);
100
+ let n = i.tile([e, 1, 1]);
101
+ for (const h of t) {
102
+ let o = h.add(i);
103
+ o = o.div(o.sum(-1, !0)), n = o.matMul(n);
104
+ }
105
+ return n;
106
+ });
107
+ }
108
+ forward(t, e, s = !1, i = !1) {
94
109
  return this.validateInput(t), this.tf.tidy(() => {
95
- let i = this.inputPhase(t, s), o;
96
- n && (o = this.tf.zeros([i.shape[0], i.shape[1], i.shape[1]]));
97
- for (const l of this.blocks) {
98
- const { output: r, attention: f } = l.call(i, s, n);
99
- i = r, f && o && (o = o.add(f));
110
+ let n = this.inputPhase(t, s);
111
+ const h = [];
112
+ for (const c of this.blocks) {
113
+ const { output: p, attention: a } = c.call(n, s, i);
114
+ n = p, i && a && h.push(a);
100
115
  }
101
- o && (o = o.div(this.blocks.length)), i = this.lnF.apply(i);
102
- const h = this.wte.project(i);
103
- let a;
104
- return e && (a = this.calculateLoss(h, e)), { logits: h, loss: a, attention: n ? o : void 0 };
116
+ let o;
117
+ i && h.length > 0 && (o = this.computeAttentionRollout(h)), n = this.lnF.apply(n);
118
+ const l = this.wte.project(n);
119
+ let r;
120
+ return e && (r = this.calculateLoss(l, e)), { logits: l, loss: r, attention: i ? o : void 0 };
105
121
  });
106
122
  }
107
123
  generate(t, e) {
108
- const s = e?.temperature ?? 1, n = e?.topK, i = e?.usePadding ?? !1, o = e?.includeAttention ?? !1;
124
+ const s = e?.temperature ?? 1, i = e?.topK, n = e?.usePadding ?? !1, h = e?.includeAttention ?? !1;
109
125
  return this.tf.tidy(() => {
110
- const h = t, a = h.shape[1], l = a <= this.config.blockSize ? h : h.slice(
111
- [0, a - this.config.blockSize],
112
- [h.shape[0], this.config.blockSize]
113
- ), r = i ? this.config.blockSize - l.shape[1] : 0, f = r > 0 ? this.tf.pad(l, [
126
+ const o = t, l = o.shape[1], r = l <= this.config.blockSize ? o : o.slice(
127
+ [0, l - this.config.blockSize],
128
+ [o.shape[0], this.config.blockSize]
129
+ ), c = n ? this.config.blockSize - r.shape[1] : 0, p = c > 0 ? this.tf.pad(r, [
114
130
  [0, 0],
115
- [0, r]
116
- ]) : l, { logits: g, attention: p } = this.forward(f, void 0, !1, o), d = g.shape[1] - 1 - r, m = g.slice([0, d, 0], [g.shape[0], 1, g.shape[2]]), u = p ? p.slice([0, d, 0], [p.shape[0], 1, p.shape[2]]) : void 0, b = m.div(s);
117
- let c;
118
- if (n) {
119
- const { values: k, indices: w } = this.tf.topk(b, n), E = this.tf.multinomial(k.squeeze([1]), 1);
120
- c = this.tf.gather(w.squeeze([1]), E, 1);
131
+ [0, c]
132
+ ]) : r, { logits: a, attention: g } = this.forward(p, void 0, !1, h), d = a.shape[1] - 1 - c, m = a.slice([0, d, 0], [a.shape[0], 1, a.shape[2]]), u = g ? g.slice([0, d, 0], [g.shape[0], 1, g.shape[2]]) : void 0, b = m.div(s);
133
+ let f;
134
+ if (i) {
135
+ const { values: k, indices: w } = this.tf.topk(b, i), E = this.tf.multinomial(k.squeeze([1]), 1);
136
+ f = this.tf.gather(w.squeeze([1]), E, 1);
121
137
  } else
122
- c = this.tf.multinomial(b.squeeze([1]), 1);
123
- return c = c.reshape([1, 1]), { output: c, attention: u?.squeeze([1]) };
138
+ f = this.tf.multinomial(b.squeeze([1]), 1);
139
+ return f = f.reshape([1, 1]), { output: f, attention: u?.squeeze([1]) };
124
140
  });
125
141
  }
126
142
  getNumParams() {
127
143
  const t = this.config.vocabSize * this.config.nEmbed + this.config.blockSize * this.config.nEmbed, e = this.config.nLayer * (4 * this.config.nEmbed * this.config.nEmbed + // qkv + proj
128
144
  2 * this.config.nEmbed), s = this.config.nLayer * (4 * this.config.nEmbed * this.config.nEmbed + // fc
129
- this.config.nEmbed * 4 * this.config.nEmbed), n = this.config.nEmbed + this.config.vocabSize * this.config.nEmbed;
130
- return t + e + s + n;
145
+ this.config.nEmbed * 4 * this.config.nEmbed), i = this.config.nEmbed + this.config.vocabSize * this.config.nEmbed;
146
+ return t + e + s + i;
131
147
  }
132
148
  }
133
149
  export {
@@ -1,13 +1,13 @@
1
- import m from "./NanoGPTModel.js";
2
- import { defaultConfig as d } from "./config.js";
3
- import { saveModel as l } from "./utilities/save.js";
4
- import { loadModel as u } from "./utilities/load.js";
5
- import _ from "./Generator.js";
6
- import c from "./Trainer.js";
7
- import { E as f } from "./index-SOhdqzHq.js";
1
+ import d from "./NanoGPTModel.js";
2
+ import { defaultConfig as m } from "./config.js";
3
+ import { saveModel as u } from "./utilities/save.js";
4
+ import { loadModel as l } from "./utilities/load.js";
5
+ import f from "./Generator.js";
6
+ import _ from "./Trainer.js";
7
+ import { E as c } from "./index-SOhdqzHq.js";
8
8
  import { dummyPassAsync as a } from "./utilities/dummy.js";
9
9
  import g from "./tokeniser/CharTokeniser.js";
10
- class n extends f {
10
+ class n extends c {
11
11
  _config;
12
12
  _model;
13
13
  tf;
@@ -43,27 +43,27 @@ class n extends f {
43
43
  saveModel() {
44
44
  if (!this._model || !this._tokeniser)
45
45
  throw new Error("Model or tokeniser is not initialized.");
46
- return l(this._model, this._tokeniser);
46
+ return u(this._model, this._tokeniser);
47
47
  }
48
48
  static loadModel(t, r) {
49
49
  const e = new n(t);
50
- return u(t, r).then(({ model: i, tokeniser: s }) => {
51
- e._model = i, e._tokeniser = s, e._config = i.config, e.setStatus("warmup"), a(i).then(() => {
50
+ return l(t, r).then(({ model: i, tokeniser: o }) => {
51
+ e._model = i, e._tokeniser = o, e._config = i.config, e.setStatus("warmup"), a(i).then(() => {
52
52
  e.setStatus("ready");
53
- }).catch((o) => {
54
- e.setStatus("error"), e.emit("error", o);
53
+ }).catch((s) => {
54
+ e.setStatus("error"), e.emit("error", s);
55
55
  });
56
56
  }).catch((i) => {
57
57
  e.setStatus("error"), e.emit("error", i);
58
58
  }), e;
59
59
  }
60
60
  static create(t, r = {}) {
61
- const e = { ...d, ...r }, i = new g(e.vocabSize), s = new m(t, e), o = new n(t, i, s);
62
- return o.setStatus("warmup"), a(s).then(() => {
63
- o.setStatus("ready");
61
+ const e = { ...m, ...r }, i = new g(e.vocabSize), o = new d(t, e), s = new n(t, i, o);
62
+ return s.setStatus("warmup"), a(o).then(() => {
63
+ s.setStatus("ready");
64
64
  }).catch((h) => {
65
- o.setStatus("error"), o.emit("error", h);
66
- }), o;
65
+ s.setStatus("error"), s.emit("error", h);
66
+ }), s;
67
67
  }
68
68
  getNumParams() {
69
69
  if (!this._model)
@@ -73,7 +73,7 @@ class n extends f {
73
73
  trainer() {
74
74
  if (!this._model || !this._tokeniser)
75
75
  throw new Error("Model or tokeniser is not initialized.");
76
- const t = new c(this._model, this._tokeniser);
76
+ const t = new _(this._model, this._tokeniser);
77
77
  return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (r) => {
78
78
  const e = this.listeners("trainStep");
79
79
  for (const i of e)
@@ -86,8 +86,12 @@ class n extends f {
86
86
  generator() {
87
87
  if (!this._model || !this._tokeniser)
88
88
  throw new Error("Model or tokeniser is not initialized.");
89
- const t = new _(this._model, this._tokeniser);
90
- return t.on("start", () => this.setStatus("busy")), t.on("stop", () => this.setStatus("ready")), t;
89
+ const t = new f(this._model, this._tokeniser);
90
+ return t.on("start", () => {
91
+ this.status === "ready" && this.setStatus("busy");
92
+ }), t.on("stop", () => {
93
+ this.status === "busy" && this.setStatus("ready");
94
+ }), t;
91
95
  }
92
96
  generateText(t, r) {
93
97
  return this.generator().generate(t, r);
package/dist/Trainer.js CHANGED
@@ -1,6 +1,6 @@
1
1
  import { E as l } from "./index-SOhdqzHq.js";
2
2
  import o from "./training/FullTrainer.js";
3
- class m extends l {
3
+ class d extends l {
4
4
  trainer;
5
5
  constructor(a, t) {
6
6
  super(), this.trainer = new o(a.tf, a, t, 1e-3);
@@ -8,27 +8,28 @@ class m extends l {
8
8
  stop() {
9
9
  }
10
10
  async train(a, t) {
11
- const { trainDataset: r, validationDataset: e } = await this.trainer.createTrainValidationSplit(
11
+ const { trainDataset: e, validationDataset: r } = await this.trainer.createTrainValidationSplit(
12
12
  a,
13
13
  t?.batchSize || 32,
14
14
  t?.validationSplit || 0.1
15
15
  );
16
- this.emit("start"), await this.trainer.trainOnDataset(
17
- r,
16
+ this.trainer.setLearningRate(t?.learningRate || 1e-3), this.emit("start"), await this.trainer.trainOnDataset(
17
+ e,
18
18
  {
19
19
  prompt: t?.prompt,
20
20
  logInterval: t?.logInterval || 10,
21
21
  desiredLoss: t?.desiredLoss || 0.01,
22
+ maxSteps: t?.maxSteps || 1e3,
22
23
  onStep: async (i) => {
23
24
  const s = this.listeners("log");
24
25
  for (const n of s)
25
26
  await n(i);
26
27
  }
27
28
  },
28
- e
29
+ r
29
30
  ), this.emit("stop");
30
31
  }
31
32
  }
32
33
  export {
33
- m as default
34
+ d as default
34
35
  };
@@ -1,17 +1,18 @@
1
1
  import { generateText as L } from "../utilities/generate.js";
2
- import f from "./Trainer.js";
3
- const w = {
2
+ import w from "./Trainer.js";
3
+ const g = {
4
4
  desiredLoss: 0.01,
5
- logInterval: 1
5
+ logInterval: 1,
6
+ maxSteps: 1e3
6
7
  };
7
- class g extends f {
8
+ class S extends w {
8
9
  constructor(r, i, o, n = 3e-4) {
9
10
  super(r, i, o, n);
10
11
  }
11
12
  // Train for multiple epochs using Dataset API - FIXED memory leaks
12
13
  async trainOnDataset(r, i, o) {
13
- const { desiredLoss: n, logInterval: h, onStep: l, prompt: c } = {
14
- ...w,
14
+ const { desiredLoss: n, logInterval: c, onStep: l, prompt: p, maxSteps: d } = {
15
+ ...g,
15
16
  ...i
16
17
  }, s = {
17
18
  pass: 0,
@@ -24,19 +25,21 @@ class g extends f {
24
25
  validationLosses: []
25
26
  };
26
27
  this.dummyPass(), this.model.trainable = !0;
27
- const d = Date.now(), m = await r.iterator();
28
+ const m = Date.now();
29
+ this.running = !0;
30
+ const u = await r.iterator();
28
31
  try {
29
- for (; !(s.lastLoss < n); ) {
30
- const e = await m.next();
32
+ for (; this.running && !(s.lastLoss < n); ) {
33
+ const e = await u.next();
31
34
  if (e.done) break;
32
- const p = e.value, u = this.trainBatch(s, p), a = {
35
+ const h = e.value, f = this.trainBatch(s, h), a = {
33
36
  loss: s.lastLoss,
34
37
  step: s.step,
35
- time: Date.now() - d,
36
- batchSize: p.xs.shape[0]
38
+ time: Date.now() - m,
39
+ batchSize: h.xs.shape[0]
37
40
  };
38
- if (this.model.log.push(a), s.step % h === 0) {
39
- if (await u, o)
41
+ if (this.model.log.push(a), s.step % c === 0) {
42
+ if (await f, o)
40
43
  try {
41
44
  const t = await this.evaluateOnDataset(o, 5);
42
45
  s.validationLosses.push(t), a.valLoss = t;
@@ -44,8 +47,8 @@ class g extends f {
44
47
  console.error("Validation error:", t);
45
48
  }
46
49
  if (l) {
47
- if (c) {
48
- const t = await L(this.tokenizer, this.model, c, 100, {
50
+ if (p) {
51
+ const t = await L(this.tokenizer, this.model, p, 100, {
49
52
  temperature: 0.8
50
53
  });
51
54
  a.example = t;
@@ -53,13 +56,14 @@ class g extends f {
53
56
  await l(a);
54
57
  }
55
58
  }
59
+ s.step >= d && this.stop();
56
60
  }
57
61
  } catch (e) {
58
62
  throw console.error("Training error:", e), this.tf.dispose(), e;
59
63
  }
60
- return this.tf.dispose(), { losses: s.losses, validationLosses: s.validationLosses };
64
+ return this.tf.dispose(), this.running = !1, { losses: s.losses, validationLosses: s.validationLosses };
61
65
  }
62
66
  }
63
67
  export {
64
- g as default
68
+ S as default
65
69
  };
@@ -5,7 +5,8 @@ const w = {
5
5
  desiredLoss: 0.01,
6
6
  logInterval: 1,
7
7
  stepsPerLayer: 400,
8
- maxPasses: 3
8
+ maxPasses: 3,
9
+ maxSteps: 1e3
9
10
  };
10
11
  class b extends S {
11
12
  trainingPattern = [];
@@ -37,20 +38,20 @@ class b extends S {
37
38
  validationLosses: []
38
39
  };
39
40
  this.dummyPass();
40
- const f = Date.now();
41
+ const m = Date.now();
41
42
  this.startPass = 0, this.startLayer = 0;
42
- const m = await r.iterator();
43
+ const f = await r.iterator();
43
44
  this.applyTrainingPattern(t.layerStep % this.trainingPattern.length);
44
45
  try {
45
46
  for (; !(t.lastLoss < p); ) {
46
- const n = await m.next();
47
+ const n = await f.next();
47
48
  if (n.done) break;
48
49
  const y = n.value, P = this.trainBatch(t, y);
49
50
  t.stepSinceLayerChange++;
50
51
  const o = {
51
52
  loss: t.lastLoss,
52
53
  step: t.step,
53
- time: Date.now() - f,
54
+ time: Date.now() - m,
54
55
  batchSize: y.xs.shape[0],
55
56
  pass: t.pass,
56
57
  layer: t.layerStep % this.model.config.nLayer
@@ -20,6 +20,7 @@ export interface TrainingOptions {
20
20
  desiredLoss: number;
21
21
  logInterval: number;
22
22
  prompt?: string;
23
+ maxSteps: number;
23
24
  onStep?: (log: TrainingLogEntry) => Promise<void> | void;
24
25
  }
25
26
  export default abstract class GPTTrainer {
@@ -29,7 +30,10 @@ export default abstract class GPTTrainer {
29
30
  protected datasetBuilder: DatasetBuilder;
30
31
  protected tf: typeof TF;
31
32
  protected learningRate: number;
33
+ protected running: boolean;
32
34
  constructor(tf: typeof TF, model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
35
+ setLearningRate(learningRate: number): void;
36
+ stop(): void;
33
37
  getOptimizer(): AdamExt;
34
38
  resetOptimizer(config?: AdamConfig): void;
35
39
  private printGradients;
@@ -9,6 +9,13 @@ class y {
9
9
  datasetBuilder;
10
10
  tf;
11
11
  learningRate;
12
+ running = !1;
13
+ setLearningRate(t) {
14
+ this.learningRate = t, this.resetOptimizer({ learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 });
15
+ }
16
+ stop() {
17
+ this.running = !1;
18
+ }
12
19
  getOptimizer() {
13
20
  return this.optimizer;
14
21
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.1.2",
3
+ "version": "0.1.4",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",