@genai-fi/nanogpt 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/TeachableLLM.js
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
import
|
|
2
|
-
import { defaultConfig as
|
|
3
|
-
import { saveModel as
|
|
4
|
-
import { loadModel as
|
|
5
|
-
import
|
|
6
|
-
import
|
|
7
|
-
import { E as
|
|
1
|
+
import d from "./NanoGPTModel.js";
|
|
2
|
+
import { defaultConfig as m } from "./config.js";
|
|
3
|
+
import { saveModel as u } from "./utilities/save.js";
|
|
4
|
+
import { loadModel as l } from "./utilities/load.js";
|
|
5
|
+
import f from "./Generator.js";
|
|
6
|
+
import _ from "./Trainer.js";
|
|
7
|
+
import { E as c } from "./index-SOhdqzHq.js";
|
|
8
8
|
import { dummyPassAsync as a } from "./utilities/dummy.js";
|
|
9
9
|
import g from "./tokeniser/CharTokeniser.js";
|
|
10
|
-
class n extends
|
|
10
|
+
class n extends c {
|
|
11
11
|
_config;
|
|
12
12
|
_model;
|
|
13
13
|
tf;
|
|
@@ -43,27 +43,27 @@ class n extends f {
|
|
|
43
43
|
saveModel() {
|
|
44
44
|
if (!this._model || !this._tokeniser)
|
|
45
45
|
throw new Error("Model or tokeniser is not initialized.");
|
|
46
|
-
return
|
|
46
|
+
return u(this._model, this._tokeniser);
|
|
47
47
|
}
|
|
48
48
|
static loadModel(t, r) {
|
|
49
49
|
const e = new n(t);
|
|
50
|
-
return
|
|
51
|
-
e._model = i, e._tokeniser =
|
|
50
|
+
return l(t, r).then(({ model: i, tokeniser: o }) => {
|
|
51
|
+
e._model = i, e._tokeniser = o, e._config = i.config, e.setStatus("warmup"), a(i).then(() => {
|
|
52
52
|
e.setStatus("ready");
|
|
53
|
-
}).catch((
|
|
54
|
-
e.setStatus("error"), e.emit("error",
|
|
53
|
+
}).catch((s) => {
|
|
54
|
+
e.setStatus("error"), e.emit("error", s);
|
|
55
55
|
});
|
|
56
56
|
}).catch((i) => {
|
|
57
57
|
e.setStatus("error"), e.emit("error", i);
|
|
58
58
|
}), e;
|
|
59
59
|
}
|
|
60
60
|
static create(t, r = {}) {
|
|
61
|
-
const e = { ...
|
|
62
|
-
return
|
|
63
|
-
|
|
61
|
+
const e = { ...m, ...r }, i = new g(e.vocabSize), o = new d(t, e), s = new n(t, i, o);
|
|
62
|
+
return s.setStatus("warmup"), a(o).then(() => {
|
|
63
|
+
s.setStatus("ready");
|
|
64
64
|
}).catch((h) => {
|
|
65
|
-
|
|
66
|
-
}),
|
|
65
|
+
s.setStatus("error"), s.emit("error", h);
|
|
66
|
+
}), s;
|
|
67
67
|
}
|
|
68
68
|
getNumParams() {
|
|
69
69
|
if (!this._model)
|
|
@@ -73,7 +73,7 @@ class n extends f {
|
|
|
73
73
|
trainer() {
|
|
74
74
|
if (!this._model || !this._tokeniser)
|
|
75
75
|
throw new Error("Model or tokeniser is not initialized.");
|
|
76
|
-
const t = new
|
|
76
|
+
const t = new _(this._model, this._tokeniser);
|
|
77
77
|
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (r) => {
|
|
78
78
|
const e = this.listeners("trainStep");
|
|
79
79
|
for (const i of e)
|
|
@@ -86,8 +86,12 @@ class n extends f {
|
|
|
86
86
|
generator() {
|
|
87
87
|
if (!this._model || !this._tokeniser)
|
|
88
88
|
throw new Error("Model or tokeniser is not initialized.");
|
|
89
|
-
const t = new
|
|
90
|
-
return t.on("start", () =>
|
|
89
|
+
const t = new f(this._model, this._tokeniser);
|
|
90
|
+
return t.on("start", () => {
|
|
91
|
+
this.status === "ready" && this.setStatus("busy");
|
|
92
|
+
}), t.on("stop", () => {
|
|
93
|
+
this.status === "busy" && this.setStatus("ready");
|
|
94
|
+
}), t;
|
|
91
95
|
}
|
|
92
96
|
generateText(t, r) {
|
|
93
97
|
return this.generator().generate(t, r);
|
package/dist/Trainer.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { E as l } from "./index-SOhdqzHq.js";
|
|
2
2
|
import o from "./training/FullTrainer.js";
|
|
3
|
-
class
|
|
3
|
+
class d extends l {
|
|
4
4
|
trainer;
|
|
5
5
|
constructor(a, t) {
|
|
6
6
|
super(), this.trainer = new o(a.tf, a, t, 1e-3);
|
|
@@ -8,27 +8,28 @@ class m extends l {
|
|
|
8
8
|
stop() {
|
|
9
9
|
}
|
|
10
10
|
async train(a, t) {
|
|
11
|
-
const { trainDataset:
|
|
11
|
+
const { trainDataset: e, validationDataset: r } = await this.trainer.createTrainValidationSplit(
|
|
12
12
|
a,
|
|
13
13
|
t?.batchSize || 32,
|
|
14
14
|
t?.validationSplit || 0.1
|
|
15
15
|
);
|
|
16
|
-
this.emit("start"), await this.trainer.trainOnDataset(
|
|
17
|
-
|
|
16
|
+
this.trainer.setLearningRate(t?.learningRate || 1e-3), this.emit("start"), await this.trainer.trainOnDataset(
|
|
17
|
+
e,
|
|
18
18
|
{
|
|
19
19
|
prompt: t?.prompt,
|
|
20
20
|
logInterval: t?.logInterval || 10,
|
|
21
21
|
desiredLoss: t?.desiredLoss || 0.01,
|
|
22
|
+
maxSteps: t?.maxSteps || 1e3,
|
|
22
23
|
onStep: async (i) => {
|
|
23
24
|
const s = this.listeners("log");
|
|
24
25
|
for (const n of s)
|
|
25
26
|
await n(i);
|
|
26
27
|
}
|
|
27
28
|
},
|
|
28
|
-
|
|
29
|
+
r
|
|
29
30
|
), this.emit("stop");
|
|
30
31
|
}
|
|
31
32
|
}
|
|
32
33
|
export {
|
|
33
|
-
|
|
34
|
+
d as default
|
|
34
35
|
};
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
import { generateText as L } from "../utilities/generate.js";
|
|
2
|
-
import
|
|
3
|
-
const
|
|
2
|
+
import w from "./Trainer.js";
|
|
3
|
+
const g = {
|
|
4
4
|
desiredLoss: 0.01,
|
|
5
|
-
logInterval: 1
|
|
5
|
+
logInterval: 1,
|
|
6
|
+
maxSteps: 1e3
|
|
6
7
|
};
|
|
7
|
-
class
|
|
8
|
+
class S extends w {
|
|
8
9
|
constructor(r, i, o, n = 3e-4) {
|
|
9
10
|
super(r, i, o, n);
|
|
10
11
|
}
|
|
11
12
|
// Train for multiple epochs using Dataset API - FIXED memory leaks
|
|
12
13
|
async trainOnDataset(r, i, o) {
|
|
13
|
-
const { desiredLoss: n, logInterval:
|
|
14
|
-
...
|
|
14
|
+
const { desiredLoss: n, logInterval: c, onStep: l, prompt: p, maxSteps: d } = {
|
|
15
|
+
...g,
|
|
15
16
|
...i
|
|
16
17
|
}, s = {
|
|
17
18
|
pass: 0,
|
|
@@ -24,19 +25,21 @@ class g extends f {
|
|
|
24
25
|
validationLosses: []
|
|
25
26
|
};
|
|
26
27
|
this.dummyPass(), this.model.trainable = !0;
|
|
27
|
-
const
|
|
28
|
+
const m = Date.now();
|
|
29
|
+
this.running = !0;
|
|
30
|
+
const u = await r.iterator();
|
|
28
31
|
try {
|
|
29
|
-
for (; !(s.lastLoss < n); ) {
|
|
30
|
-
const e = await
|
|
32
|
+
for (; this.running && !(s.lastLoss < n); ) {
|
|
33
|
+
const e = await u.next();
|
|
31
34
|
if (e.done) break;
|
|
32
|
-
const
|
|
35
|
+
const h = e.value, f = this.trainBatch(s, h), a = {
|
|
33
36
|
loss: s.lastLoss,
|
|
34
37
|
step: s.step,
|
|
35
|
-
time: Date.now() -
|
|
36
|
-
batchSize:
|
|
38
|
+
time: Date.now() - m,
|
|
39
|
+
batchSize: h.xs.shape[0]
|
|
37
40
|
};
|
|
38
|
-
if (this.model.log.push(a), s.step %
|
|
39
|
-
if (await
|
|
41
|
+
if (this.model.log.push(a), s.step % c === 0) {
|
|
42
|
+
if (await f, o)
|
|
40
43
|
try {
|
|
41
44
|
const t = await this.evaluateOnDataset(o, 5);
|
|
42
45
|
s.validationLosses.push(t), a.valLoss = t;
|
|
@@ -44,8 +47,8 @@ class g extends f {
|
|
|
44
47
|
console.error("Validation error:", t);
|
|
45
48
|
}
|
|
46
49
|
if (l) {
|
|
47
|
-
if (
|
|
48
|
-
const t = await L(this.tokenizer, this.model,
|
|
50
|
+
if (p) {
|
|
51
|
+
const t = await L(this.tokenizer, this.model, p, 100, {
|
|
49
52
|
temperature: 0.8
|
|
50
53
|
});
|
|
51
54
|
a.example = t;
|
|
@@ -53,13 +56,14 @@ class g extends f {
|
|
|
53
56
|
await l(a);
|
|
54
57
|
}
|
|
55
58
|
}
|
|
59
|
+
s.step >= d && this.stop();
|
|
56
60
|
}
|
|
57
61
|
} catch (e) {
|
|
58
62
|
throw console.error("Training error:", e), this.tf.dispose(), e;
|
|
59
63
|
}
|
|
60
|
-
return this.tf.dispose(), { losses: s.losses, validationLosses: s.validationLosses };
|
|
64
|
+
return this.tf.dispose(), this.running = !1, { losses: s.losses, validationLosses: s.validationLosses };
|
|
61
65
|
}
|
|
62
66
|
}
|
|
63
67
|
export {
|
|
64
|
-
|
|
68
|
+
S as default
|
|
65
69
|
};
|
|
@@ -5,7 +5,8 @@ const w = {
|
|
|
5
5
|
desiredLoss: 0.01,
|
|
6
6
|
logInterval: 1,
|
|
7
7
|
stepsPerLayer: 400,
|
|
8
|
-
maxPasses: 3
|
|
8
|
+
maxPasses: 3,
|
|
9
|
+
maxSteps: 1e3
|
|
9
10
|
};
|
|
10
11
|
class b extends S {
|
|
11
12
|
trainingPattern = [];
|
|
@@ -37,20 +38,20 @@ class b extends S {
|
|
|
37
38
|
validationLosses: []
|
|
38
39
|
};
|
|
39
40
|
this.dummyPass();
|
|
40
|
-
const
|
|
41
|
+
const m = Date.now();
|
|
41
42
|
this.startPass = 0, this.startLayer = 0;
|
|
42
|
-
const
|
|
43
|
+
const f = await r.iterator();
|
|
43
44
|
this.applyTrainingPattern(t.layerStep % this.trainingPattern.length);
|
|
44
45
|
try {
|
|
45
46
|
for (; !(t.lastLoss < p); ) {
|
|
46
|
-
const n = await
|
|
47
|
+
const n = await f.next();
|
|
47
48
|
if (n.done) break;
|
|
48
49
|
const y = n.value, P = this.trainBatch(t, y);
|
|
49
50
|
t.stepSinceLayerChange++;
|
|
50
51
|
const o = {
|
|
51
52
|
loss: t.lastLoss,
|
|
52
53
|
step: t.step,
|
|
53
|
-
time: Date.now() -
|
|
54
|
+
time: Date.now() - m,
|
|
54
55
|
batchSize: y.xs.shape[0],
|
|
55
56
|
pass: t.pass,
|
|
56
57
|
layer: t.layerStep % this.model.config.nLayer
|
|
@@ -20,6 +20,7 @@ export interface TrainingOptions {
|
|
|
20
20
|
desiredLoss: number;
|
|
21
21
|
logInterval: number;
|
|
22
22
|
prompt?: string;
|
|
23
|
+
maxSteps: number;
|
|
23
24
|
onStep?: (log: TrainingLogEntry) => Promise<void> | void;
|
|
24
25
|
}
|
|
25
26
|
export default abstract class GPTTrainer {
|
|
@@ -29,7 +30,10 @@ export default abstract class GPTTrainer {
|
|
|
29
30
|
protected datasetBuilder: DatasetBuilder;
|
|
30
31
|
protected tf: typeof TF;
|
|
31
32
|
protected learningRate: number;
|
|
33
|
+
protected running: boolean;
|
|
32
34
|
constructor(tf: typeof TF, model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
|
|
35
|
+
setLearningRate(learningRate: number): void;
|
|
36
|
+
stop(): void;
|
|
33
37
|
getOptimizer(): AdamExt;
|
|
34
38
|
resetOptimizer(config?: AdamConfig): void;
|
|
35
39
|
private printGradients;
|
package/dist/training/Trainer.js
CHANGED
|
@@ -9,6 +9,13 @@ class y {
|
|
|
9
9
|
datasetBuilder;
|
|
10
10
|
tf;
|
|
11
11
|
learningRate;
|
|
12
|
+
running = !1;
|
|
13
|
+
setLearningRate(t) {
|
|
14
|
+
this.learningRate = t, this.resetOptimizer({ learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 });
|
|
15
|
+
}
|
|
16
|
+
stop() {
|
|
17
|
+
this.running = !1;
|
|
18
|
+
}
|
|
12
19
|
getOptimizer() {
|
|
13
20
|
return this.optimizer;
|
|
14
21
|
}
|