@genai-fi/nanogpt 0.7.2 → 0.7.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,11 +9,20 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
9
9
  private readonly model;
10
10
  private readonly tokeniser;
11
11
  private active;
12
+ private cache;
13
+ private initialPrompt;
14
+ private outputText;
15
+ private actualTokeniser;
16
+ private lastToken;
12
17
  constructor(model: NanoGPT, tokeniser: ITokeniser);
13
18
  private tokenisePrompt;
14
- private generateNoCache;
15
19
  private processResponse;
16
- private generateCache;
20
+ private _generate;
21
+ reset(): void;
22
+ dispose(): void;
23
+ private initialise;
24
+ step(prompt?: string, options?: IGenerateOptions): Promise<string>;
17
25
  generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
18
26
  stop(): void;
27
+ getText(): string;
19
28
  }
package/dist/Generator.js CHANGED
@@ -1,4 +1,4 @@
1
- import { E as u } from "./index-Dwqa6Zy2.js";
1
+ import { E as l } from "./index-Dwqa6Zy2.js";
2
2
  import "./index-BoWRt-10.js";
3
3
  import "./ops/cpu/attentionMask.js";
4
4
  import "./ops/webgl/attentionMask.js";
@@ -29,7 +29,7 @@ import "./ops/webgl/gatherSub.js";
29
29
  import "./ops/cpu/scatterSub.js";
30
30
  import "./ops/webgl/scatterSub.js";
31
31
  import "./jszip.min-CjP2V1VV.js";
32
- import f from "./tokeniser/CharTokeniser.js";
32
+ import u from "./tokeniser/CharTokeniser.js";
33
33
  import "./ops/cpu/adamAdjust.js";
34
34
  import "./ops/webgl/adamAdjust.js";
35
35
  import "./ops/cpu/adamMoments.js";
@@ -39,10 +39,10 @@ import "./ops/cpu/gelu.js";
39
39
  import "./ops/webgl/gelu.js";
40
40
  import "./gelu-C-dPj6Ku.js";
41
41
  import "./ops/webgl/log.js";
42
- import { t as d } from "./tensor2d-wxPAnDQy.js";
43
- import { c as g } from "./concat-CsxrgovM.js";
42
+ import { t as p } from "./tensor2d-wxPAnDQy.js";
43
+ import { c as f } from "./concat-CsxrgovM.js";
44
44
  const k = [
45
- ...Array.from({ length: 95 }, (a, t) => String.fromCharCode(t + 32)),
45
+ ...Array.from({ length: 95 }, (r, t) => String.fromCharCode(t + 32)),
46
46
  // ASCII
47
47
  // Spanish accented letters and punctuation
48
48
  ..."áéíóúüñ¿¡",
@@ -53,80 +53,93 @@ const k = [
53
53
  // Cyrillic letters
54
54
  ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
55
55
  ];
56
- function w(a, t) {
57
- return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
56
+ function d(r, t) {
57
+ return r.length === t ? r : r.length > t ? r.slice(0, t) : r.concat(Array(t - r.length).fill(""));
58
58
  }
59
- class pt extends u {
60
- constructor(t, o) {
61
- super(), this.model = t, this.tokeniser = o;
59
+ class nt extends l {
60
+ constructor(t, i) {
61
+ super(), this.model = t, this.tokeniser = i, this.actualTokeniser = i;
62
62
  }
63
63
  active = !1;
64
- async tokenisePrompt(t, o) {
65
- const r = o ? await t.tokenise([o], !0) : [[t.eosToken]];
66
- return d(r, [1, r[0].length], "int32");
64
+ cache = null;
65
+ initialPrompt = null;
66
+ outputText = "";
67
+ actualTokeniser;
68
+ lastToken = -1;
69
+ async tokenisePrompt(t, i) {
70
+ const e = i ? await t.tokenise([i], !0) : [[t.eosToken]];
71
+ return p(e, [1, e[0].length], "int32");
67
72
  }
68
- async generateNoCache(t, o, r) {
69
- let i = await this.tokenisePrompt(t, o), s = o || "";
70
- const n = r?.maxLength ?? 1e3;
71
- for (let m = 0; m < n && this.active; m++) {
72
- const {
73
- output: e,
74
- attention: p,
75
- probabilities: c
76
- } = await this.model.generate(i, void 0, r), h = i;
77
- i = g([i, e], 1), h.dispose();
78
- const l = await this.processResponse(t, e, p, c);
79
- if (e.dispose(), l === null)
80
- break;
81
- s += l;
82
- }
83
- return i.dispose(), s;
84
- }
85
- async processResponse(t, o, r, i) {
86
- const s = (await o.array())[0][0];
87
- if (s === this.tokeniser.eosToken)
73
+ async processResponse(t, i, e, o) {
74
+ const s = (await i.array())[0][0];
75
+ if (this.lastToken = s, s === this.tokeniser.eosToken)
88
76
  return null;
89
77
  const n = await t.decode([s]);
90
- let m;
91
- r && (m = await Promise.all(r.map((p) => p.array().then((c) => c))), r.forEach((p) => p.dispose()));
92
- let e;
93
- return i && (e = await i.array(), i.dispose()), this.emit("tokens", [s], n, m, e), n;
78
+ let c;
79
+ e && (c = await Promise.all(e.map((h) => h.array().then((m) => m))), e.forEach((h) => h.dispose()));
80
+ let a;
81
+ return o && (a = await o.array(), o.dispose()), this.emit("tokens", [s], n, c, a), n;
94
82
  }
95
- async generateCache(t, o, r) {
96
- let i = await this.tokenisePrompt(t, o), s = o || "";
97
- const n = new Array(this.model.config.gpt.nLayer);
98
- for (let e = 0; e < this.model.config.gpt.nLayer; e++)
99
- n[e] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
100
- const m = r?.maxLength ?? 1e3;
101
- for (let e = 0; e < m && this.active; e++) {
83
+ async _generate(t) {
84
+ let i = this.lastToken >= 0 && this.cache ? p([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
85
+ const e = t?.maxLength ?? 1e3;
86
+ for (let o = 0; o < e && this.active; o++) {
102
87
  const {
103
- output: p,
104
- probabilities: c,
105
- attention: h
106
- } = await this.model.generate(i, n, {
107
- ...r,
108
- usePadding: !1
88
+ output: s,
89
+ probabilities: n,
90
+ attention: c
91
+ } = await this.model.generate(i, this.cache ? this.cache : void 0, {
92
+ ...t,
93
+ usePadding: !this.cache
109
94
  });
110
- i.dispose(), i = p;
111
- const l = await this.processResponse(t, p, h, c);
112
- if (l === null)
95
+ if (this.cache)
96
+ i.dispose(), i = s;
97
+ else {
98
+ const h = i;
99
+ i = f([i, s], 1), h.dispose();
100
+ }
101
+ const a = await this.processResponse(this.actualTokeniser, s, c, n);
102
+ if (this.cache || s.dispose(), a === null)
113
103
  break;
114
- s += l;
104
+ this.outputText += a;
105
+ }
106
+ return i.dispose(), this.outputText;
107
+ }
108
+ reset() {
109
+ this.cache && (this.cache.forEach((t) => {
110
+ t && (t.k && t.k.dispose(), t.v && t.v.dispose());
111
+ }), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1;
112
+ }
113
+ dispose() {
114
+ this.reset();
115
+ }
116
+ initialise(t, i) {
117
+ const e = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t ?? null;
118
+ if (this.cache && i?.noCache && this.reset(), this.initialPrompt = e || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !i?.noCache && this.model.config.gpt.useRope) {
119
+ const s = new Array(this.model.config.gpt.nLayer);
120
+ for (let n = 0; n < this.model.config.gpt.nLayer; n++)
121
+ s[n] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
122
+ this.cache = s, this.lastToken = -1;
115
123
  }
116
- return n.forEach((e) => {
117
- e && (e.k && e.k.dispose(), e.v && e.v.dispose());
118
- }), i.dispose(), s;
124
+ const o = this.tokeniser.trained ? this.tokeniser : new u(d(k, this.tokeniser.vocabSize));
125
+ this.actualTokeniser = o;
119
126
  }
120
- async generate(t, o) {
121
- const r = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t;
122
- this.active = !0, this.emit("start");
123
- const i = this.tokeniser.trained ? this.tokeniser : new f(w(k, this.tokeniser.vocabSize)), n = await (this.model.config.gpt.useRope && !o?.noCache ? this.generateCache(i, r, o) : this.generateNoCache(i, r, o));
124
- return this.active = !1, this.emit("stop"), n;
127
+ async step(t, i) {
128
+ const e = { ...i, maxLength: 1 };
129
+ return this.generate(t, e);
130
+ }
131
+ async generate(t, i) {
132
+ this.initialise(t, i), this.active = !0, this.emit("start");
133
+ const o = await this._generate(i);
134
+ return this.active = !1, this.emit("stop"), o;
125
135
  }
126
136
  stop() {
127
137
  this.active = !1;
128
138
  }
139
+ getText() {
140
+ return this.outputText;
141
+ }
129
142
  }
130
143
  export {
131
- pt as default
144
+ nt as default
132
145
  };
@@ -1,12 +1,12 @@
1
1
  import { defaultConfig as _ } from "./config.js";
2
2
  import f from "./NanoGPTModel.js";
3
- import { saveModel as u } from "./utilities/save.js";
4
- import { loadModel as d } from "./loader/load.js";
5
- import l from "./Generator.js";
3
+ import { saveModel as d } from "./utilities/save.js";
4
+ import { loadModel as l } from "./loader/load.js";
5
+ import u from "./Generator.js";
6
6
  import p from "./Trainer.js";
7
- import { E as g } from "./index-Dwqa6Zy2.js";
7
+ import { E as c } from "./index-Dwqa6Zy2.js";
8
8
  import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
9
- import c from "./tokeniser/CharTokeniser.js";
9
+ import g from "./tokeniser/CharTokeniser.js";
10
10
  import k from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
@@ -49,7 +49,7 @@ import "./ops/cpu/adamAdjust.js";
49
49
  import "./ops/webgl/adamAdjust.js";
50
50
  import w from "./utilities/profile.js";
51
51
  class a {
52
- ee = new g();
52
+ ee = new c();
53
53
  _config;
54
54
  _model;
55
55
  _tokeniser;
@@ -92,8 +92,8 @@ class a {
92
92
  return this._status === "busy" || this._status === "training";
93
93
  }
94
94
  estimateTrainingMemoryUsage(t) {
95
- const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 }, i = e.perBatch * t, o = e.gradients;
96
- return i * 0.66 + o * 4;
95
+ const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 }, r = e.perBatch * t, o = e.gradients;
96
+ return r * 0.66 + o * 4;
97
97
  }
98
98
  setStatus(t) {
99
99
  this._status !== t && (this._status = t, this.ee.emit("status", t));
@@ -101,32 +101,32 @@ class a {
101
101
  saveModel(t) {
102
102
  if (!this._model || !this._tokeniser)
103
103
  throw new Error("model_or_tokeniser_not_initialized.");
104
- return u(this._model, this._tokeniser, {
104
+ return d(this._model, this._tokeniser, {
105
105
  ...t,
106
106
  name: t?.name || this.meta.name
107
107
  });
108
108
  }
109
109
  static loadModel(t) {
110
110
  const e = new a();
111
- return d(t).then(({ model: i, tokeniser: o, name: s }) => {
112
- e._model = i, e._tokeniser = o, e._config = i.config, s && (e.meta.name = s), e.setStatus("warmup"), m(i).then((r) => {
113
- e._memoryRequirements = r, e.setStatus("ready"), e.ee.emit("loaded");
114
- }).catch((r) => {
115
- e.setStatus("error"), e.ee.emit("error", r);
111
+ return l(t).then(({ model: r, tokeniser: o, name: s }) => {
112
+ e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
113
+ e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
114
+ }).catch((i) => {
115
+ e.setStatus("error"), e.ee.emit("error", i);
116
116
  });
117
- }).catch((i) => {
118
- e.setStatus("error"), e.ee.emit("error", i);
117
+ }).catch((r) => {
118
+ e.setStatus("error"), e.ee.emit("error", r);
119
119
  }), e;
120
120
  }
121
121
  static create(t, e = {}) {
122
- const i = { ..._, ...e }, o = t === "char" ? new c(i.vocabSize) : new k(i.vocabSize), s = new f(i), r = new a(o, s);
123
- return r.setStatus("warmup"), m(s).then((n) => {
124
- r._memoryRequirements = n, r.tokeniser.trained ? (r.setStatus("ready"), r.ee.emit("loaded")) : (r.setStatus("awaitingTokens"), r.ee.emit("loaded"), r.tokeniser.once("trainStatus", (h) => {
125
- h === "trained" && r.setStatus("ready");
122
+ const r = { ..._, ...e }, o = t === "char" ? new g(r.vocabSize) : new k(r.vocabSize), s = new f(r), i = new a(o, s);
123
+ return i.setStatus("warmup"), m(s).then((n) => {
124
+ i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
125
+ h === "trained" && i.setStatus("ready");
126
126
  }));
127
127
  }).catch((n) => {
128
- r.setStatus("error"), r.ee.emit("error", n);
129
- }), r;
128
+ i.setStatus("error"), i.ee.emit("error", n);
129
+ }), i;
130
130
  }
131
131
  getProfiler() {
132
132
  return this._model?.getProfiler();
@@ -149,14 +149,15 @@ class a {
149
149
  if (!this._model || !this._tokeniser)
150
150
  throw new Error("model_or_tokeniser_not_initialized.");
151
151
  const t = new p(this._model, this._tokeniser);
152
- return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, i) => {
152
+ return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
153
153
  const o = this.ee.listeners("trainStep");
154
154
  for (const s of o)
155
- await s(e, i);
155
+ await s(e, r);
156
156
  }), t;
157
157
  }
158
- train(t, e) {
159
- return this.trainer().train(t, e);
158
+ async train(t, e) {
159
+ const r = this.trainer();
160
+ await r.prepare(t, e), await r.train(e);
160
161
  }
161
162
  async trainTokeniser(t) {
162
163
  if (!this._tokeniser)
@@ -167,7 +168,7 @@ class a {
167
168
  generator() {
168
169
  if (!this._model || !this._tokeniser)
169
170
  throw new Error("model_or_tokeniser_not_initialized.");
170
- const t = new l(this._model, this._tokeniser);
171
+ const t = new u(this._model, this._tokeniser);
171
172
  return t.on("start", () => {
172
173
  this.status === "ready" && this.setStatus("busy");
173
174
  }), t.on("stop", () => {
package/dist/Trainer.d.ts CHANGED
@@ -14,8 +14,13 @@ export interface ITrainerOptions {
14
14
  export default class Trainer extends EE<'start' | 'stop' | 'log'> {
15
15
  private trainer;
16
16
  private hasTrained;
17
+ private trainDataset?;
18
+ private validationDataset?;
19
+ private totalSamples;
17
20
  constructor(model: NanoGPT, tokeniser: ITokeniser);
18
21
  stop(): void;
19
22
  reset(): void;
20
- train(text: string[], options?: ITrainerOptions): Promise<void>;
23
+ prepare(text: string[], options?: ITrainerOptions): Promise<void>;
24
+ train(options?: ITrainerOptions): Promise<void>;
25
+ step(options?: ITrainerOptions): Promise<void>;
21
26
  }
package/dist/Trainer.js CHANGED
@@ -1,10 +1,13 @@
1
- import { E as h } from "./index-Dwqa6Zy2.js";
2
- import m from "./training/FullTrainer.js";
3
- class p extends h {
1
+ import { E as l } from "./index-Dwqa6Zy2.js";
2
+ import h from "./training/FullTrainer.js";
3
+ class p extends l {
4
4
  trainer;
5
5
  hasTrained = !1;
6
- constructor(e, t) {
7
- super(), this.trainer = new m(e, t, 1e-3);
6
+ trainDataset;
7
+ validationDataset;
8
+ totalSamples = 0;
9
+ constructor(t, e) {
10
+ super(), this.trainer = new h(t, e, 1e-3);
8
11
  }
9
12
  stop() {
10
13
  this.trainer.stop();
@@ -12,36 +15,67 @@ class p extends h {
12
15
  reset() {
13
16
  this.hasTrained = !1, this.trainer.reset();
14
17
  }
15
- async train(e, t) {
16
- const { trainDataset: s, validationDataset: n } = await this.trainer.createTrainValidationSplit(
17
- e,
18
- t?.batchSize || 32,
19
- t?.validationSplit || 0.1
20
- ), r = e.reduce((i, a) => i + a.length, 0) * (1 - (t?.validationSplit || 0));
18
+ async prepare(t, e) {
19
+ const { trainDataset: a, validationDataset: s } = await this.trainer.createTrainValidationSplit(
20
+ t,
21
+ e?.batchSize || 32,
22
+ e?.validationSplit || 0.1
23
+ ), i = t.reduce((r, n) => r + n.length, 0) * (1 - (e?.validationSplit || 0));
24
+ this.trainDataset = a, this.validationDataset = s, this.totalSamples = i;
25
+ }
26
+ async train(t) {
27
+ if (!this.trainDataset || !this.validationDataset)
28
+ throw new Error("Datasets not prepared");
21
29
  this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), await this.trainer.trainOnDataset(
22
- s,
30
+ this.trainDataset,
23
31
  {
24
32
  prompt: t?.prompt,
25
33
  logInterval: t?.logInterval || 10,
26
34
  desiredLoss: t?.desiredLoss || 0.01,
27
35
  maxSteps: t?.maxSteps || 1e3,
28
36
  advancedMetrics: t?.advancedMetrics || !1,
29
- onStep: async (i, a) => {
30
- const l = this.listeners("log");
31
- for (const d of l)
32
- await d(i, {
37
+ onStep: async (e, a) => {
38
+ const s = this.listeners("log");
39
+ for (const i of s)
40
+ await i(e, {
33
41
  ...a,
34
- progress: a.totalSamples / r,
42
+ progress: a.totalSamples / this.totalSamples,
35
43
  remaining: Math.max(
36
44
  0,
37
- (r - a.totalSamples) / a.totalSamples * a.duration
45
+ (this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
38
46
  )
39
47
  });
40
48
  }
41
49
  },
42
- n
50
+ this.validationDataset
43
51
  ), this.emit("stop");
44
52
  }
53
+ async step(t) {
54
+ if (!this.trainDataset || !this.validationDataset)
55
+ throw new Error("Datasets not prepared");
56
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
57
+ const { log: e, progress: a } = await this.trainer.stepDataset(
58
+ this.trainDataset,
59
+ {
60
+ prompt: t?.prompt,
61
+ logInterval: t?.logInterval || 10,
62
+ desiredLoss: t?.desiredLoss || 0.01,
63
+ maxSteps: t?.maxSteps || 1e3,
64
+ advancedMetrics: t?.advancedMetrics || !1
65
+ },
66
+ this.validationDataset
67
+ ), s = this.listeners("log");
68
+ for (const i of s)
69
+ await i(e, {
70
+ ...a,
71
+ progress: a.totalSamples / this.totalSamples,
72
+ remaining: Math.max(
73
+ 0,
74
+ (this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
75
+ )
76
+ });
77
+ this.emit("stop");
78
+ }
45
79
  }
46
80
  export {
47
81
  p as default
@@ -1,10 +1,23 @@
1
1
  import { ITokeniser } from '../tokeniser/type';
2
- import { default as NanoGPT } from '../NanoGPTModel';
3
- import { default as GPTTrainer, TrainingOptions } from './Trainer';
2
+ import { default as NanoGPT, TrainingLogEntry } from '../NanoGPTModel';
3
+ import { default as GPTTrainer, TrainingOptions, TrainingProgress } from './Trainer';
4
4
  import { Tensor } from '@tensorflow/tfjs-core';
5
5
  import { Dataset } from '@tensorflow/tfjs-data';
6
6
  export default class FullTrainer extends GPTTrainer {
7
7
  constructor(model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
8
+ private createEmptyState;
9
+ private createLogEntry;
10
+ private createProgress;
11
+ stepDataset(dataset: Dataset<{
12
+ xs: Tensor;
13
+ ys: Tensor;
14
+ }>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
15
+ xs: Tensor;
16
+ ys: Tensor;
17
+ }>): Promise<{
18
+ log: TrainingLogEntry;
19
+ progress: TrainingProgress;
20
+ }>;
8
21
  trainOnDataset(dataset: Dataset<{
9
22
  xs: Tensor;
10
23
  ys: Tensor;
@@ -1,81 +1,127 @@
1
- import { generateText as w } from "../utilities/generate.js";
2
- import T from "./Trainer.js";
3
- import L from "./Evaluator.js";
4
- import { d as h } from "../index-BoWRt-10.js";
5
- import x from "../utilities/profile.js";
6
- const y = {
1
+ import { generateText as v } from "../utilities/generate.js";
2
+ import x from "./Trainer.js";
3
+ import S from "./Evaluator.js";
4
+ import { d as w } from "../index-BoWRt-10.js";
5
+ import y from "../utilities/profile.js";
6
+ const T = {
7
7
  desiredLoss: 0.01,
8
8
  logInterval: 1,
9
9
  maxSteps: 1e3
10
10
  };
11
- class E extends T {
12
- constructor(i, e, r = 3e-4) {
13
- super(i, e, r);
11
+ class z extends x {
12
+ constructor(r, t, s = 3e-4) {
13
+ super(r, t, s);
14
14
  }
15
- // Train for multiple epochs using Dataset API - FIXED memory leaks
16
- async trainOnDataset(i, e, r) {
17
- const { logInterval: g, onStep: l, prompt: c, maxSteps: u } = {
18
- ...y,
19
- ...e
20
- }, n = Date.now(), t = {
15
+ createEmptyState() {
16
+ return {
21
17
  step: 0,
22
18
  lastLoss: 1e6,
23
19
  totalSteps: 0,
24
20
  losses: [],
25
21
  validationLosses: [],
26
- logStartTime: n,
22
+ logStartTime: 0,
27
23
  trainingDuration: 0,
28
24
  ...this.lastState || {}
29
25
  };
30
- this.lastState = t, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new x())), this.running = !0, t.logStartTime = n;
31
- const m = r ? new L(this.model, r) : void 0, f = await i.iterator();
26
+ }
27
+ createLogEntry(r, t, s, h) {
28
+ return {
29
+ loss: r.lastLoss,
30
+ step: r.step,
31
+ time: Date.now() - t,
32
+ batchSize: s,
33
+ learningRate: h ? this.optimizer.lr : void 0
34
+ };
35
+ }
36
+ createProgress(r, t, s) {
37
+ return {
38
+ duration: r.trainingDuration,
39
+ totalSamples: r.totalSteps * t.batchSize,
40
+ samplesPerSecond: r.totalSteps * t.batchSize / (r.trainingDuration / 1e3),
41
+ memory: s ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
42
+ };
43
+ }
44
+ async stepDataset(r, t, s) {
45
+ const { logInterval: h, prompt: m } = {
46
+ ...T,
47
+ ...t
48
+ }, g = Date.now(), a = this.createEmptyState();
49
+ this.lastState = a, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, a.logStartTime = g;
50
+ const p = s ? new S(this.model, s) : void 0, e = await r.iterator();
51
+ try {
52
+ for (; this.running; ) {
53
+ const i = await e.next();
54
+ if (i.done) break;
55
+ const u = i.value, o = this.trainBatch(a, u), n = this.createLogEntry(a, g, u.xs.shape[0], t?.advancedMetrics);
56
+ if (this.model.log.push(n), a.step % h === 0) {
57
+ await o.data();
58
+ const f = Date.now();
59
+ if (a.trainingDuration += f - a.logStartTime, p)
60
+ try {
61
+ const l = await p.evaluate(5);
62
+ a.validationLosses.push(l), n.valLoss = l;
63
+ } catch (l) {
64
+ console.error("Validation error:", l);
65
+ }
66
+ if (m) {
67
+ const l = await v(this.tokenizer, this.model, m, 100, {
68
+ temperature: 0.8
69
+ });
70
+ n.example = l;
71
+ }
72
+ const c = this.createProgress(a, n, t?.advancedMetrics);
73
+ return o.dispose(), this.stop(), { log: n, progress: c };
74
+ }
75
+ o.dispose();
76
+ }
77
+ } catch (i) {
78
+ throw console.error("Training error:", i), w(), i;
79
+ }
80
+ throw w(), this.running = !1, new Error("No log returned before training stopped.");
81
+ }
82
+ // Train for multiple epochs using Dataset API - FIXED memory leaks
83
+ async trainOnDataset(r, t, s) {
84
+ const { logInterval: h, onStep: m, prompt: g, maxSteps: a } = {
85
+ ...T,
86
+ ...t
87
+ }, p = Date.now(), e = this.createEmptyState();
88
+ this.lastState = e, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, e.logStartTime = p;
89
+ const i = s ? new S(this.model, s) : void 0, u = await r.iterator();
32
90
  try {
33
91
  for (; this.running; ) {
34
- const o = await f.next();
92
+ const o = await u.next();
35
93
  if (o.done) break;
36
- const d = o.value, p = this.trainBatch(t, d), s = {
37
- loss: t.lastLoss,
38
- step: t.step,
39
- time: Date.now() - n,
40
- batchSize: d.xs.shape[0],
41
- learningRate: e?.advancedMetrics ? this.optimizer.lr : void 0
42
- //gradientNorm: options?.advancedMetrics ? await state.gradientNorm : undefined,
43
- };
44
- if (this.model.log.push(s), t.step % g === 0) {
45
- await p.data();
46
- const S = Date.now();
47
- if (t.trainingDuration += S - t.logStartTime, m)
94
+ const n = o.value, f = this.trainBatch(e, n), c = this.createLogEntry(e, p, n.xs.shape[0], t?.advancedMetrics);
95
+ if (this.model.log.push(c), e.step % h === 0) {
96
+ await f.data();
97
+ const l = Date.now();
98
+ if (e.trainingDuration += l - e.logStartTime, i)
48
99
  try {
49
- const a = await m.evaluate(5);
50
- t.validationLosses.push(a), s.valLoss = a;
51
- } catch (a) {
52
- console.error("Validation error:", a);
100
+ const d = await i.evaluate(5);
101
+ e.validationLosses.push(d), c.valLoss = d;
102
+ } catch (d) {
103
+ console.error("Validation error:", d);
53
104
  }
54
- if (l) {
55
- if (c) {
56
- const v = await w(this.tokenizer, this.model, c, 100, {
105
+ if (m) {
106
+ if (g) {
107
+ const L = await v(this.tokenizer, this.model, g, 100, {
57
108
  temperature: 0.8
58
109
  });
59
- s.example = v;
110
+ c.example = L;
60
111
  }
61
- const a = {
62
- duration: t.trainingDuration,
63
- totalSamples: t.totalSteps * s.batchSize,
64
- samplesPerSecond: t.totalSteps * s.batchSize / (t.trainingDuration / 1e3),
65
- memory: e.advancedMetrics ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
66
- };
67
- await l(s, a);
112
+ const d = this.createProgress(e, c, t?.advancedMetrics);
113
+ await m(c, d);
68
114
  }
69
- t.logStartTime = Date.now();
115
+ e.logStartTime = Date.now();
70
116
  }
71
- p.dispose(), t.step >= u && this.stop();
117
+ f.dispose(), e.step >= a && this.stop();
72
118
  }
73
119
  } catch (o) {
74
- throw console.error("Training error:", o), h(), o;
120
+ throw console.error("Training error:", o), w(), o;
75
121
  }
76
- return h(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
122
+ return w(), this.running = !1, { losses: e.losses, validationLosses: e.validationLosses };
77
123
  }
78
124
  }
79
125
  export {
80
- E as default
126
+ z as default
81
127
  };
@@ -66,6 +66,16 @@ export default abstract class GPTTrainer {
66
66
  losses: number[];
67
67
  validationLosses: number[];
68
68
  }>;
69
+ abstract stepDataset(dataset: Dataset<{
70
+ xs: Tensor;
71
+ ys: Tensor;
72
+ }>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
73
+ xs: Tensor;
74
+ ys: Tensor;
75
+ }>): Promise<{
76
+ log: TrainingLogEntry;
77
+ progress: TrainingProgress;
78
+ }>;
69
79
  createTrainValidationSplit(textData: string[], batchSize?: number, validationSplit?: number): Promise<{
70
80
  trainDataset: Dataset<{
71
81
  xs: Tensor;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.7.2",
3
+ "version": "0.7.3",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",