@genai-fi/nanogpt 0.7.2 → 0.7.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +11 -2
- package/dist/Generator.js +76 -63
- package/dist/TeachableLLM.js +28 -27
- package/dist/Trainer.d.ts +6 -1
- package/dist/Trainer.js +53 -19
- package/dist/training/FullTrainer.d.ts +15 -2
- package/dist/training/FullTrainer.js +97 -51
- package/dist/training/Trainer.d.ts +10 -0
- package/package.json +1 -1
package/dist/Generator.d.ts
CHANGED
|
@@ -9,11 +9,20 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
|
|
|
9
9
|
private readonly model;
|
|
10
10
|
private readonly tokeniser;
|
|
11
11
|
private active;
|
|
12
|
+
private cache;
|
|
13
|
+
private initialPrompt;
|
|
14
|
+
private outputText;
|
|
15
|
+
private actualTokeniser;
|
|
16
|
+
private lastToken;
|
|
12
17
|
constructor(model: NanoGPT, tokeniser: ITokeniser);
|
|
13
18
|
private tokenisePrompt;
|
|
14
|
-
private generateNoCache;
|
|
15
19
|
private processResponse;
|
|
16
|
-
private
|
|
20
|
+
private _generate;
|
|
21
|
+
reset(): void;
|
|
22
|
+
dispose(): void;
|
|
23
|
+
private initialise;
|
|
24
|
+
step(prompt?: string, options?: IGenerateOptions): Promise<string>;
|
|
17
25
|
generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
|
|
18
26
|
stop(): void;
|
|
27
|
+
getText(): string;
|
|
19
28
|
}
|
package/dist/Generator.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { E as
|
|
1
|
+
import { E as l } from "./index-Dwqa6Zy2.js";
|
|
2
2
|
import "./index-BoWRt-10.js";
|
|
3
3
|
import "./ops/cpu/attentionMask.js";
|
|
4
4
|
import "./ops/webgl/attentionMask.js";
|
|
@@ -29,7 +29,7 @@ import "./ops/webgl/gatherSub.js";
|
|
|
29
29
|
import "./ops/cpu/scatterSub.js";
|
|
30
30
|
import "./ops/webgl/scatterSub.js";
|
|
31
31
|
import "./jszip.min-CjP2V1VV.js";
|
|
32
|
-
import
|
|
32
|
+
import u from "./tokeniser/CharTokeniser.js";
|
|
33
33
|
import "./ops/cpu/adamAdjust.js";
|
|
34
34
|
import "./ops/webgl/adamAdjust.js";
|
|
35
35
|
import "./ops/cpu/adamMoments.js";
|
|
@@ -39,10 +39,10 @@ import "./ops/cpu/gelu.js";
|
|
|
39
39
|
import "./ops/webgl/gelu.js";
|
|
40
40
|
import "./gelu-C-dPj6Ku.js";
|
|
41
41
|
import "./ops/webgl/log.js";
|
|
42
|
-
import { t as
|
|
43
|
-
import { c as
|
|
42
|
+
import { t as p } from "./tensor2d-wxPAnDQy.js";
|
|
43
|
+
import { c as f } from "./concat-CsxrgovM.js";
|
|
44
44
|
const k = [
|
|
45
|
-
...Array.from({ length: 95 }, (
|
|
45
|
+
...Array.from({ length: 95 }, (r, t) => String.fromCharCode(t + 32)),
|
|
46
46
|
// ASCII
|
|
47
47
|
// Spanish accented letters and punctuation
|
|
48
48
|
..."áéíóúüñ¿¡",
|
|
@@ -53,80 +53,93 @@ const k = [
|
|
|
53
53
|
// Cyrillic letters
|
|
54
54
|
..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
|
|
55
55
|
];
|
|
56
|
-
function
|
|
57
|
-
return
|
|
56
|
+
function d(r, t) {
|
|
57
|
+
return r.length === t ? r : r.length > t ? r.slice(0, t) : r.concat(Array(t - r.length).fill(""));
|
|
58
58
|
}
|
|
59
|
-
class
|
|
60
|
-
constructor(t,
|
|
61
|
-
super(), this.model = t, this.tokeniser =
|
|
59
|
+
class nt extends l {
|
|
60
|
+
constructor(t, i) {
|
|
61
|
+
super(), this.model = t, this.tokeniser = i, this.actualTokeniser = i;
|
|
62
62
|
}
|
|
63
63
|
active = !1;
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
64
|
+
cache = null;
|
|
65
|
+
initialPrompt = null;
|
|
66
|
+
outputText = "";
|
|
67
|
+
actualTokeniser;
|
|
68
|
+
lastToken = -1;
|
|
69
|
+
async tokenisePrompt(t, i) {
|
|
70
|
+
const e = i ? await t.tokenise([i], !0) : [[t.eosToken]];
|
|
71
|
+
return p(e, [1, e[0].length], "int32");
|
|
67
72
|
}
|
|
68
|
-
async
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
for (let m = 0; m < n && this.active; m++) {
|
|
72
|
-
const {
|
|
73
|
-
output: e,
|
|
74
|
-
attention: p,
|
|
75
|
-
probabilities: c
|
|
76
|
-
} = await this.model.generate(i, void 0, r), h = i;
|
|
77
|
-
i = g([i, e], 1), h.dispose();
|
|
78
|
-
const l = await this.processResponse(t, e, p, c);
|
|
79
|
-
if (e.dispose(), l === null)
|
|
80
|
-
break;
|
|
81
|
-
s += l;
|
|
82
|
-
}
|
|
83
|
-
return i.dispose(), s;
|
|
84
|
-
}
|
|
85
|
-
async processResponse(t, o, r, i) {
|
|
86
|
-
const s = (await o.array())[0][0];
|
|
87
|
-
if (s === this.tokeniser.eosToken)
|
|
73
|
+
async processResponse(t, i, e, o) {
|
|
74
|
+
const s = (await i.array())[0][0];
|
|
75
|
+
if (this.lastToken = s, s === this.tokeniser.eosToken)
|
|
88
76
|
return null;
|
|
89
77
|
const n = await t.decode([s]);
|
|
90
|
-
let
|
|
91
|
-
|
|
92
|
-
let
|
|
93
|
-
return
|
|
78
|
+
let c;
|
|
79
|
+
e && (c = await Promise.all(e.map((h) => h.array().then((m) => m))), e.forEach((h) => h.dispose()));
|
|
80
|
+
let a;
|
|
81
|
+
return o && (a = await o.array(), o.dispose()), this.emit("tokens", [s], n, c, a), n;
|
|
94
82
|
}
|
|
95
|
-
async
|
|
96
|
-
let i =
|
|
97
|
-
const
|
|
98
|
-
for (let
|
|
99
|
-
n[e] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
|
|
100
|
-
const m = r?.maxLength ?? 1e3;
|
|
101
|
-
for (let e = 0; e < m && this.active; e++) {
|
|
83
|
+
async _generate(t) {
|
|
84
|
+
let i = this.lastToken >= 0 && this.cache ? p([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
|
|
85
|
+
const e = t?.maxLength ?? 1e3;
|
|
86
|
+
for (let o = 0; o < e && this.active; o++) {
|
|
102
87
|
const {
|
|
103
|
-
output:
|
|
104
|
-
probabilities:
|
|
105
|
-
attention:
|
|
106
|
-
} = await this.model.generate(i,
|
|
107
|
-
...
|
|
108
|
-
usePadding: !
|
|
88
|
+
output: s,
|
|
89
|
+
probabilities: n,
|
|
90
|
+
attention: c
|
|
91
|
+
} = await this.model.generate(i, this.cache ? this.cache : void 0, {
|
|
92
|
+
...t,
|
|
93
|
+
usePadding: !this.cache
|
|
109
94
|
});
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
95
|
+
if (this.cache)
|
|
96
|
+
i.dispose(), i = s;
|
|
97
|
+
else {
|
|
98
|
+
const h = i;
|
|
99
|
+
i = f([i, s], 1), h.dispose();
|
|
100
|
+
}
|
|
101
|
+
const a = await this.processResponse(this.actualTokeniser, s, c, n);
|
|
102
|
+
if (this.cache || s.dispose(), a === null)
|
|
113
103
|
break;
|
|
114
|
-
|
|
104
|
+
this.outputText += a;
|
|
105
|
+
}
|
|
106
|
+
return i.dispose(), this.outputText;
|
|
107
|
+
}
|
|
108
|
+
reset() {
|
|
109
|
+
this.cache && (this.cache.forEach((t) => {
|
|
110
|
+
t && (t.k && t.k.dispose(), t.v && t.v.dispose());
|
|
111
|
+
}), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1;
|
|
112
|
+
}
|
|
113
|
+
dispose() {
|
|
114
|
+
this.reset();
|
|
115
|
+
}
|
|
116
|
+
initialise(t, i) {
|
|
117
|
+
const e = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t ?? null;
|
|
118
|
+
if (this.cache && i?.noCache && this.reset(), this.initialPrompt = e || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !i?.noCache && this.model.config.gpt.useRope) {
|
|
119
|
+
const s = new Array(this.model.config.gpt.nLayer);
|
|
120
|
+
for (let n = 0; n < this.model.config.gpt.nLayer; n++)
|
|
121
|
+
s[n] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
|
|
122
|
+
this.cache = s, this.lastToken = -1;
|
|
115
123
|
}
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
}), i.dispose(), s;
|
|
124
|
+
const o = this.tokeniser.trained ? this.tokeniser : new u(d(k, this.tokeniser.vocabSize));
|
|
125
|
+
this.actualTokeniser = o;
|
|
119
126
|
}
|
|
120
|
-
async
|
|
121
|
-
const
|
|
122
|
-
this.
|
|
123
|
-
|
|
124
|
-
|
|
127
|
+
async step(t, i) {
|
|
128
|
+
const e = { ...i, maxLength: 1 };
|
|
129
|
+
return this.generate(t, e);
|
|
130
|
+
}
|
|
131
|
+
async generate(t, i) {
|
|
132
|
+
this.initialise(t, i), this.active = !0, this.emit("start");
|
|
133
|
+
const o = await this._generate(i);
|
|
134
|
+
return this.active = !1, this.emit("stop"), o;
|
|
125
135
|
}
|
|
126
136
|
stop() {
|
|
127
137
|
this.active = !1;
|
|
128
138
|
}
|
|
139
|
+
getText() {
|
|
140
|
+
return this.outputText;
|
|
141
|
+
}
|
|
129
142
|
}
|
|
130
143
|
export {
|
|
131
|
-
|
|
144
|
+
nt as default
|
|
132
145
|
};
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import { defaultConfig as _ } from "./config.js";
|
|
2
2
|
import f from "./NanoGPTModel.js";
|
|
3
|
-
import { saveModel as
|
|
4
|
-
import { loadModel as
|
|
5
|
-
import
|
|
3
|
+
import { saveModel as d } from "./utilities/save.js";
|
|
4
|
+
import { loadModel as l } from "./loader/load.js";
|
|
5
|
+
import u from "./Generator.js";
|
|
6
6
|
import p from "./Trainer.js";
|
|
7
|
-
import { E as
|
|
7
|
+
import { E as c } from "./index-Dwqa6Zy2.js";
|
|
8
8
|
import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
|
|
9
|
-
import
|
|
9
|
+
import g from "./tokeniser/CharTokeniser.js";
|
|
10
10
|
import k from "./tokeniser/bpe.js";
|
|
11
11
|
import "./papaparse.min-C8l2Kvo1.js";
|
|
12
12
|
import "./index-Tf7vU29b.js";
|
|
@@ -49,7 +49,7 @@ import "./ops/cpu/adamAdjust.js";
|
|
|
49
49
|
import "./ops/webgl/adamAdjust.js";
|
|
50
50
|
import w from "./utilities/profile.js";
|
|
51
51
|
class a {
|
|
52
|
-
ee = new
|
|
52
|
+
ee = new c();
|
|
53
53
|
_config;
|
|
54
54
|
_model;
|
|
55
55
|
_tokeniser;
|
|
@@ -92,8 +92,8 @@ class a {
|
|
|
92
92
|
return this._status === "busy" || this._status === "training";
|
|
93
93
|
}
|
|
94
94
|
estimateTrainingMemoryUsage(t) {
|
|
95
|
-
const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 },
|
|
96
|
-
return
|
|
95
|
+
const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 }, r = e.perBatch * t, o = e.gradients;
|
|
96
|
+
return r * 0.66 + o * 4;
|
|
97
97
|
}
|
|
98
98
|
setStatus(t) {
|
|
99
99
|
this._status !== t && (this._status = t, this.ee.emit("status", t));
|
|
@@ -101,32 +101,32 @@ class a {
|
|
|
101
101
|
saveModel(t) {
|
|
102
102
|
if (!this._model || !this._tokeniser)
|
|
103
103
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
104
|
-
return
|
|
104
|
+
return d(this._model, this._tokeniser, {
|
|
105
105
|
...t,
|
|
106
106
|
name: t?.name || this.meta.name
|
|
107
107
|
});
|
|
108
108
|
}
|
|
109
109
|
static loadModel(t) {
|
|
110
110
|
const e = new a();
|
|
111
|
-
return
|
|
112
|
-
e._model =
|
|
113
|
-
e._memoryRequirements =
|
|
114
|
-
}).catch((
|
|
115
|
-
e.setStatus("error"), e.ee.emit("error",
|
|
111
|
+
return l(t).then(({ model: r, tokeniser: o, name: s }) => {
|
|
112
|
+
e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
|
|
113
|
+
e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
|
|
114
|
+
}).catch((i) => {
|
|
115
|
+
e.setStatus("error"), e.ee.emit("error", i);
|
|
116
116
|
});
|
|
117
|
-
}).catch((
|
|
118
|
-
e.setStatus("error"), e.ee.emit("error",
|
|
117
|
+
}).catch((r) => {
|
|
118
|
+
e.setStatus("error"), e.ee.emit("error", r);
|
|
119
119
|
}), e;
|
|
120
120
|
}
|
|
121
121
|
static create(t, e = {}) {
|
|
122
|
-
const
|
|
123
|
-
return
|
|
124
|
-
|
|
125
|
-
h === "trained" &&
|
|
122
|
+
const r = { ..._, ...e }, o = t === "char" ? new g(r.vocabSize) : new k(r.vocabSize), s = new f(r), i = new a(o, s);
|
|
123
|
+
return i.setStatus("warmup"), m(s).then((n) => {
|
|
124
|
+
i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
|
|
125
|
+
h === "trained" && i.setStatus("ready");
|
|
126
126
|
}));
|
|
127
127
|
}).catch((n) => {
|
|
128
|
-
|
|
129
|
-
}),
|
|
128
|
+
i.setStatus("error"), i.ee.emit("error", n);
|
|
129
|
+
}), i;
|
|
130
130
|
}
|
|
131
131
|
getProfiler() {
|
|
132
132
|
return this._model?.getProfiler();
|
|
@@ -149,14 +149,15 @@ class a {
|
|
|
149
149
|
if (!this._model || !this._tokeniser)
|
|
150
150
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
151
151
|
const t = new p(this._model, this._tokeniser);
|
|
152
|
-
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e,
|
|
152
|
+
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
|
|
153
153
|
const o = this.ee.listeners("trainStep");
|
|
154
154
|
for (const s of o)
|
|
155
|
-
await s(e,
|
|
155
|
+
await s(e, r);
|
|
156
156
|
}), t;
|
|
157
157
|
}
|
|
158
|
-
train(t, e) {
|
|
159
|
-
|
|
158
|
+
async train(t, e) {
|
|
159
|
+
const r = this.trainer();
|
|
160
|
+
await r.prepare(t, e), await r.train(e);
|
|
160
161
|
}
|
|
161
162
|
async trainTokeniser(t) {
|
|
162
163
|
if (!this._tokeniser)
|
|
@@ -167,7 +168,7 @@ class a {
|
|
|
167
168
|
generator() {
|
|
168
169
|
if (!this._model || !this._tokeniser)
|
|
169
170
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
170
|
-
const t = new
|
|
171
|
+
const t = new u(this._model, this._tokeniser);
|
|
171
172
|
return t.on("start", () => {
|
|
172
173
|
this.status === "ready" && this.setStatus("busy");
|
|
173
174
|
}), t.on("stop", () => {
|
package/dist/Trainer.d.ts
CHANGED
|
@@ -14,8 +14,13 @@ export interface ITrainerOptions {
|
|
|
14
14
|
export default class Trainer extends EE<'start' | 'stop' | 'log'> {
|
|
15
15
|
private trainer;
|
|
16
16
|
private hasTrained;
|
|
17
|
+
private trainDataset?;
|
|
18
|
+
private validationDataset?;
|
|
19
|
+
private totalSamples;
|
|
17
20
|
constructor(model: NanoGPT, tokeniser: ITokeniser);
|
|
18
21
|
stop(): void;
|
|
19
22
|
reset(): void;
|
|
20
|
-
|
|
23
|
+
prepare(text: string[], options?: ITrainerOptions): Promise<void>;
|
|
24
|
+
train(options?: ITrainerOptions): Promise<void>;
|
|
25
|
+
step(options?: ITrainerOptions): Promise<void>;
|
|
21
26
|
}
|
package/dist/Trainer.js
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
-
import { E as
|
|
2
|
-
import
|
|
3
|
-
class p extends
|
|
1
|
+
import { E as l } from "./index-Dwqa6Zy2.js";
|
|
2
|
+
import h from "./training/FullTrainer.js";
|
|
3
|
+
class p extends l {
|
|
4
4
|
trainer;
|
|
5
5
|
hasTrained = !1;
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
trainDataset;
|
|
7
|
+
validationDataset;
|
|
8
|
+
totalSamples = 0;
|
|
9
|
+
constructor(t, e) {
|
|
10
|
+
super(), this.trainer = new h(t, e, 1e-3);
|
|
8
11
|
}
|
|
9
12
|
stop() {
|
|
10
13
|
this.trainer.stop();
|
|
@@ -12,36 +15,67 @@ class p extends h {
|
|
|
12
15
|
reset() {
|
|
13
16
|
this.hasTrained = !1, this.trainer.reset();
|
|
14
17
|
}
|
|
15
|
-
async
|
|
16
|
-
const { trainDataset:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
),
|
|
18
|
+
async prepare(t, e) {
|
|
19
|
+
const { trainDataset: a, validationDataset: s } = await this.trainer.createTrainValidationSplit(
|
|
20
|
+
t,
|
|
21
|
+
e?.batchSize || 32,
|
|
22
|
+
e?.validationSplit || 0.1
|
|
23
|
+
), i = t.reduce((r, n) => r + n.length, 0) * (1 - (e?.validationSplit || 0));
|
|
24
|
+
this.trainDataset = a, this.validationDataset = s, this.totalSamples = i;
|
|
25
|
+
}
|
|
26
|
+
async train(t) {
|
|
27
|
+
if (!this.trainDataset || !this.validationDataset)
|
|
28
|
+
throw new Error("Datasets not prepared");
|
|
21
29
|
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), await this.trainer.trainOnDataset(
|
|
22
|
-
|
|
30
|
+
this.trainDataset,
|
|
23
31
|
{
|
|
24
32
|
prompt: t?.prompt,
|
|
25
33
|
logInterval: t?.logInterval || 10,
|
|
26
34
|
desiredLoss: t?.desiredLoss || 0.01,
|
|
27
35
|
maxSteps: t?.maxSteps || 1e3,
|
|
28
36
|
advancedMetrics: t?.advancedMetrics || !1,
|
|
29
|
-
onStep: async (
|
|
30
|
-
const
|
|
31
|
-
for (const
|
|
32
|
-
await
|
|
37
|
+
onStep: async (e, a) => {
|
|
38
|
+
const s = this.listeners("log");
|
|
39
|
+
for (const i of s)
|
|
40
|
+
await i(e, {
|
|
33
41
|
...a,
|
|
34
|
-
progress: a.totalSamples /
|
|
42
|
+
progress: a.totalSamples / this.totalSamples,
|
|
35
43
|
remaining: Math.max(
|
|
36
44
|
0,
|
|
37
|
-
(
|
|
45
|
+
(this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
|
|
38
46
|
)
|
|
39
47
|
});
|
|
40
48
|
}
|
|
41
49
|
},
|
|
42
|
-
|
|
50
|
+
this.validationDataset
|
|
43
51
|
), this.emit("stop");
|
|
44
52
|
}
|
|
53
|
+
async step(t) {
|
|
54
|
+
if (!this.trainDataset || !this.validationDataset)
|
|
55
|
+
throw new Error("Datasets not prepared");
|
|
56
|
+
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
|
|
57
|
+
const { log: e, progress: a } = await this.trainer.stepDataset(
|
|
58
|
+
this.trainDataset,
|
|
59
|
+
{
|
|
60
|
+
prompt: t?.prompt,
|
|
61
|
+
logInterval: t?.logInterval || 10,
|
|
62
|
+
desiredLoss: t?.desiredLoss || 0.01,
|
|
63
|
+
maxSteps: t?.maxSteps || 1e3,
|
|
64
|
+
advancedMetrics: t?.advancedMetrics || !1
|
|
65
|
+
},
|
|
66
|
+
this.validationDataset
|
|
67
|
+
), s = this.listeners("log");
|
|
68
|
+
for (const i of s)
|
|
69
|
+
await i(e, {
|
|
70
|
+
...a,
|
|
71
|
+
progress: a.totalSamples / this.totalSamples,
|
|
72
|
+
remaining: Math.max(
|
|
73
|
+
0,
|
|
74
|
+
(this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
|
|
75
|
+
)
|
|
76
|
+
});
|
|
77
|
+
this.emit("stop");
|
|
78
|
+
}
|
|
45
79
|
}
|
|
46
80
|
export {
|
|
47
81
|
p as default
|
|
@@ -1,10 +1,23 @@
|
|
|
1
1
|
import { ITokeniser } from '../tokeniser/type';
|
|
2
|
-
import { default as NanoGPT } from '../NanoGPTModel';
|
|
3
|
-
import { default as GPTTrainer, TrainingOptions } from './Trainer';
|
|
2
|
+
import { default as NanoGPT, TrainingLogEntry } from '../NanoGPTModel';
|
|
3
|
+
import { default as GPTTrainer, TrainingOptions, TrainingProgress } from './Trainer';
|
|
4
4
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
5
5
|
import { Dataset } from '@tensorflow/tfjs-data';
|
|
6
6
|
export default class FullTrainer extends GPTTrainer {
|
|
7
7
|
constructor(model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
|
|
8
|
+
private createEmptyState;
|
|
9
|
+
private createLogEntry;
|
|
10
|
+
private createProgress;
|
|
11
|
+
stepDataset(dataset: Dataset<{
|
|
12
|
+
xs: Tensor;
|
|
13
|
+
ys: Tensor;
|
|
14
|
+
}>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
|
|
15
|
+
xs: Tensor;
|
|
16
|
+
ys: Tensor;
|
|
17
|
+
}>): Promise<{
|
|
18
|
+
log: TrainingLogEntry;
|
|
19
|
+
progress: TrainingProgress;
|
|
20
|
+
}>;
|
|
8
21
|
trainOnDataset(dataset: Dataset<{
|
|
9
22
|
xs: Tensor;
|
|
10
23
|
ys: Tensor;
|
|
@@ -1,81 +1,127 @@
|
|
|
1
|
-
import { generateText as
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
import { d as
|
|
5
|
-
import
|
|
6
|
-
const
|
|
1
|
+
import { generateText as v } from "../utilities/generate.js";
|
|
2
|
+
import x from "./Trainer.js";
|
|
3
|
+
import S from "./Evaluator.js";
|
|
4
|
+
import { d as w } from "../index-BoWRt-10.js";
|
|
5
|
+
import y from "../utilities/profile.js";
|
|
6
|
+
const T = {
|
|
7
7
|
desiredLoss: 0.01,
|
|
8
8
|
logInterval: 1,
|
|
9
9
|
maxSteps: 1e3
|
|
10
10
|
};
|
|
11
|
-
class
|
|
12
|
-
constructor(
|
|
13
|
-
super(
|
|
11
|
+
class z extends x {
|
|
12
|
+
constructor(r, t, s = 3e-4) {
|
|
13
|
+
super(r, t, s);
|
|
14
14
|
}
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
const { logInterval: g, onStep: l, prompt: c, maxSteps: u } = {
|
|
18
|
-
...y,
|
|
19
|
-
...e
|
|
20
|
-
}, n = Date.now(), t = {
|
|
15
|
+
createEmptyState() {
|
|
16
|
+
return {
|
|
21
17
|
step: 0,
|
|
22
18
|
lastLoss: 1e6,
|
|
23
19
|
totalSteps: 0,
|
|
24
20
|
losses: [],
|
|
25
21
|
validationLosses: [],
|
|
26
|
-
logStartTime:
|
|
22
|
+
logStartTime: 0,
|
|
27
23
|
trainingDuration: 0,
|
|
28
24
|
...this.lastState || {}
|
|
29
25
|
};
|
|
30
|
-
|
|
31
|
-
|
|
26
|
+
}
|
|
27
|
+
createLogEntry(r, t, s, h) {
|
|
28
|
+
return {
|
|
29
|
+
loss: r.lastLoss,
|
|
30
|
+
step: r.step,
|
|
31
|
+
time: Date.now() - t,
|
|
32
|
+
batchSize: s,
|
|
33
|
+
learningRate: h ? this.optimizer.lr : void 0
|
|
34
|
+
};
|
|
35
|
+
}
|
|
36
|
+
createProgress(r, t, s) {
|
|
37
|
+
return {
|
|
38
|
+
duration: r.trainingDuration,
|
|
39
|
+
totalSamples: r.totalSteps * t.batchSize,
|
|
40
|
+
samplesPerSecond: r.totalSteps * t.batchSize / (r.trainingDuration / 1e3),
|
|
41
|
+
memory: s ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
42
|
+
};
|
|
43
|
+
}
|
|
44
|
+
async stepDataset(r, t, s) {
|
|
45
|
+
const { logInterval: h, prompt: m } = {
|
|
46
|
+
...T,
|
|
47
|
+
...t
|
|
48
|
+
}, g = Date.now(), a = this.createEmptyState();
|
|
49
|
+
this.lastState = a, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, a.logStartTime = g;
|
|
50
|
+
const p = s ? new S(this.model, s) : void 0, e = await r.iterator();
|
|
51
|
+
try {
|
|
52
|
+
for (; this.running; ) {
|
|
53
|
+
const i = await e.next();
|
|
54
|
+
if (i.done) break;
|
|
55
|
+
const u = i.value, o = this.trainBatch(a, u), n = this.createLogEntry(a, g, u.xs.shape[0], t?.advancedMetrics);
|
|
56
|
+
if (this.model.log.push(n), a.step % h === 0) {
|
|
57
|
+
await o.data();
|
|
58
|
+
const f = Date.now();
|
|
59
|
+
if (a.trainingDuration += f - a.logStartTime, p)
|
|
60
|
+
try {
|
|
61
|
+
const l = await p.evaluate(5);
|
|
62
|
+
a.validationLosses.push(l), n.valLoss = l;
|
|
63
|
+
} catch (l) {
|
|
64
|
+
console.error("Validation error:", l);
|
|
65
|
+
}
|
|
66
|
+
if (m) {
|
|
67
|
+
const l = await v(this.tokenizer, this.model, m, 100, {
|
|
68
|
+
temperature: 0.8
|
|
69
|
+
});
|
|
70
|
+
n.example = l;
|
|
71
|
+
}
|
|
72
|
+
const c = this.createProgress(a, n, t?.advancedMetrics);
|
|
73
|
+
return o.dispose(), this.stop(), { log: n, progress: c };
|
|
74
|
+
}
|
|
75
|
+
o.dispose();
|
|
76
|
+
}
|
|
77
|
+
} catch (i) {
|
|
78
|
+
throw console.error("Training error:", i), w(), i;
|
|
79
|
+
}
|
|
80
|
+
throw w(), this.running = !1, new Error("No log returned before training stopped.");
|
|
81
|
+
}
|
|
82
|
+
// Train for multiple epochs using Dataset API - FIXED memory leaks
|
|
83
|
+
async trainOnDataset(r, t, s) {
|
|
84
|
+
const { logInterval: h, onStep: m, prompt: g, maxSteps: a } = {
|
|
85
|
+
...T,
|
|
86
|
+
...t
|
|
87
|
+
}, p = Date.now(), e = this.createEmptyState();
|
|
88
|
+
this.lastState = e, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, e.logStartTime = p;
|
|
89
|
+
const i = s ? new S(this.model, s) : void 0, u = await r.iterator();
|
|
32
90
|
try {
|
|
33
91
|
for (; this.running; ) {
|
|
34
|
-
const o = await
|
|
92
|
+
const o = await u.next();
|
|
35
93
|
if (o.done) break;
|
|
36
|
-
const
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
learningRate: e?.advancedMetrics ? this.optimizer.lr : void 0
|
|
42
|
-
//gradientNorm: options?.advancedMetrics ? await state.gradientNorm : undefined,
|
|
43
|
-
};
|
|
44
|
-
if (this.model.log.push(s), t.step % g === 0) {
|
|
45
|
-
await p.data();
|
|
46
|
-
const S = Date.now();
|
|
47
|
-
if (t.trainingDuration += S - t.logStartTime, m)
|
|
94
|
+
const n = o.value, f = this.trainBatch(e, n), c = this.createLogEntry(e, p, n.xs.shape[0], t?.advancedMetrics);
|
|
95
|
+
if (this.model.log.push(c), e.step % h === 0) {
|
|
96
|
+
await f.data();
|
|
97
|
+
const l = Date.now();
|
|
98
|
+
if (e.trainingDuration += l - e.logStartTime, i)
|
|
48
99
|
try {
|
|
49
|
-
const
|
|
50
|
-
|
|
51
|
-
} catch (
|
|
52
|
-
console.error("Validation error:",
|
|
100
|
+
const d = await i.evaluate(5);
|
|
101
|
+
e.validationLosses.push(d), c.valLoss = d;
|
|
102
|
+
} catch (d) {
|
|
103
|
+
console.error("Validation error:", d);
|
|
53
104
|
}
|
|
54
|
-
if (
|
|
55
|
-
if (
|
|
56
|
-
const
|
|
105
|
+
if (m) {
|
|
106
|
+
if (g) {
|
|
107
|
+
const L = await v(this.tokenizer, this.model, g, 100, {
|
|
57
108
|
temperature: 0.8
|
|
58
109
|
});
|
|
59
|
-
|
|
110
|
+
c.example = L;
|
|
60
111
|
}
|
|
61
|
-
const
|
|
62
|
-
|
|
63
|
-
totalSamples: t.totalSteps * s.batchSize,
|
|
64
|
-
samplesPerSecond: t.totalSteps * s.batchSize / (t.trainingDuration / 1e3),
|
|
65
|
-
memory: e.advancedMetrics ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
66
|
-
};
|
|
67
|
-
await l(s, a);
|
|
112
|
+
const d = this.createProgress(e, c, t?.advancedMetrics);
|
|
113
|
+
await m(c, d);
|
|
68
114
|
}
|
|
69
|
-
|
|
115
|
+
e.logStartTime = Date.now();
|
|
70
116
|
}
|
|
71
|
-
|
|
117
|
+
f.dispose(), e.step >= a && this.stop();
|
|
72
118
|
}
|
|
73
119
|
} catch (o) {
|
|
74
|
-
throw console.error("Training error:", o),
|
|
120
|
+
throw console.error("Training error:", o), w(), o;
|
|
75
121
|
}
|
|
76
|
-
return
|
|
122
|
+
return w(), this.running = !1, { losses: e.losses, validationLosses: e.validationLosses };
|
|
77
123
|
}
|
|
78
124
|
}
|
|
79
125
|
export {
|
|
80
|
-
|
|
126
|
+
z as default
|
|
81
127
|
};
|
|
@@ -66,6 +66,16 @@ export default abstract class GPTTrainer {
|
|
|
66
66
|
losses: number[];
|
|
67
67
|
validationLosses: number[];
|
|
68
68
|
}>;
|
|
69
|
+
abstract stepDataset(dataset: Dataset<{
|
|
70
|
+
xs: Tensor;
|
|
71
|
+
ys: Tensor;
|
|
72
|
+
}>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
|
|
73
|
+
xs: Tensor;
|
|
74
|
+
ys: Tensor;
|
|
75
|
+
}>): Promise<{
|
|
76
|
+
log: TrainingLogEntry;
|
|
77
|
+
progress: TrainingProgress;
|
|
78
|
+
}>;
|
|
69
79
|
createTrainValidationSplit(textData: string[], batchSize?: number, validationSplit?: number): Promise<{
|
|
70
80
|
trainDataset: Dataset<{
|
|
71
81
|
xs: Tensor;
|