@genai-fi/nanogpt 0.9.0 → 0.9.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/training/FullTrainer.js +71 -60
- package/package.json +1 -1
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
1
|
+
import L from "./Trainer.js";
|
|
2
|
+
import w from "./Evaluator.js";
|
|
3
3
|
import { d as S } from "../index-BzFyqcy-.js";
|
|
4
|
-
import
|
|
5
|
-
const
|
|
4
|
+
import f from "../utilities/profile.js";
|
|
5
|
+
const y = {
|
|
6
6
|
desiredLoss: 0.01,
|
|
7
7
|
logInterval: 1,
|
|
8
8
|
maxSteps: 1e3
|
|
9
9
|
};
|
|
10
|
-
class
|
|
11
|
-
constructor(s,
|
|
12
|
-
super(s,
|
|
10
|
+
class x extends L {
|
|
11
|
+
constructor(s, e, a = 3e-4) {
|
|
12
|
+
super(s, e, a);
|
|
13
13
|
}
|
|
14
14
|
createEmptyState() {
|
|
15
15
|
return {
|
|
@@ -23,52 +23,55 @@ class b extends y {
|
|
|
23
23
|
...this.lastState || {}
|
|
24
24
|
};
|
|
25
25
|
}
|
|
26
|
-
createLogEntry(s,
|
|
26
|
+
createLogEntry(s, e, a, l) {
|
|
27
27
|
return {
|
|
28
28
|
loss: s.lastLoss,
|
|
29
29
|
step: s.step,
|
|
30
|
-
time: Date.now() -
|
|
30
|
+
time: Date.now() - e,
|
|
31
31
|
batchSize: a,
|
|
32
|
-
learningRate:
|
|
32
|
+
learningRate: l ? this.optimizer.lr : void 0
|
|
33
33
|
};
|
|
34
34
|
}
|
|
35
|
-
createProgress(s,
|
|
35
|
+
createProgress(s, e, a) {
|
|
36
36
|
return {
|
|
37
37
|
duration: s.trainingDuration,
|
|
38
|
-
totalSamples: s.totalSteps *
|
|
39
|
-
samplesPerSecond: s.totalSteps *
|
|
38
|
+
totalSamples: s.totalSteps * e.batchSize,
|
|
39
|
+
samplesPerSecond: s.totalSteps * e.batchSize / (s.trainingDuration / 1e3),
|
|
40
40
|
memory: a ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
41
41
|
};
|
|
42
42
|
}
|
|
43
|
-
async stepDataset(s,
|
|
44
|
-
const { logInterval:
|
|
45
|
-
...
|
|
46
|
-
...
|
|
47
|
-
},
|
|
48
|
-
this.lastState = r, await this.dummyPass(), this.model.trainable = !0,
|
|
49
|
-
const
|
|
43
|
+
async stepDataset(s, e, a) {
|
|
44
|
+
const { logInterval: l } = {
|
|
45
|
+
...y,
|
|
46
|
+
...e
|
|
47
|
+
}, c = Date.now(), r = this.createEmptyState();
|
|
48
|
+
this.lastState = r, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new f())), this.running = !0, r.logStartTime = c;
|
|
49
|
+
const d = a ? new w(this.model, a) : void 0, t = await s.iterator();
|
|
50
50
|
try {
|
|
51
51
|
for (; this.running; ) {
|
|
52
|
-
const i = await
|
|
52
|
+
const i = await t.next();
|
|
53
53
|
if (i.done) break;
|
|
54
|
-
const
|
|
55
|
-
if (
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
batchSize: g.xs.shape[0],
|
|
59
|
-
loss: r.lastLoss
|
|
60
|
-
}, r.step % n === 0) {
|
|
61
|
-
await o.data();
|
|
54
|
+
const m = i.value, o = this.trainBatch(r, m);
|
|
55
|
+
if (r.step % l === 0) {
|
|
56
|
+
const g = (await o.data())[0];
|
|
57
|
+
r.lastLoss = g;
|
|
62
58
|
const u = Date.now();
|
|
63
|
-
|
|
59
|
+
r.trainingDuration += u - r.logStartTime;
|
|
60
|
+
const p = this.createLogEntry(r, c, m.xs.shape[0], e?.advancedMetrics);
|
|
61
|
+
if (this.model.trainingState = {
|
|
62
|
+
steps: r.totalSteps,
|
|
63
|
+
learningRate: this.optimizer.lr,
|
|
64
|
+
batchSize: m.xs.shape[0],
|
|
65
|
+
loss: r.lastLoss
|
|
66
|
+
}, d)
|
|
64
67
|
try {
|
|
65
|
-
const
|
|
66
|
-
r.validationLosses.push(
|
|
67
|
-
} catch (
|
|
68
|
-
console.error("Validation error:",
|
|
68
|
+
const n = await d.evaluate(5);
|
|
69
|
+
r.validationLosses.push(n), p.valLoss = n;
|
|
70
|
+
} catch (n) {
|
|
71
|
+
console.error("Validation error:", n);
|
|
69
72
|
}
|
|
70
|
-
const
|
|
71
|
-
return o.dispose(), this.stop(), { log:
|
|
73
|
+
const v = this.createProgress(r, p, e?.advancedMetrics);
|
|
74
|
+
return o.dispose(), this.stop(), { log: p, progress: v };
|
|
72
75
|
}
|
|
73
76
|
o.dispose();
|
|
74
77
|
}
|
|
@@ -78,42 +81,50 @@ class b extends y {
|
|
|
78
81
|
throw S(), this.running = !1, new Error("No log returned before training stopped.");
|
|
79
82
|
}
|
|
80
83
|
// Train for multiple epochs using Dataset API - FIXED memory leaks
|
|
81
|
-
async trainOnDataset(s,
|
|
82
|
-
const { logInterval:
|
|
83
|
-
...
|
|
84
|
-
...
|
|
85
|
-
},
|
|
86
|
-
this.lastState =
|
|
87
|
-
const i = a ? new
|
|
84
|
+
async trainOnDataset(s, e, a) {
|
|
85
|
+
const { logInterval: l, onStep: c, maxSteps: r } = {
|
|
86
|
+
...y,
|
|
87
|
+
...e
|
|
88
|
+
}, d = Date.now(), t = this.createEmptyState();
|
|
89
|
+
this.lastState = t, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new f())), this.running = !0, t.logStartTime = d;
|
|
90
|
+
const i = a ? new w(this.model, a) : void 0, m = await s.iterator();
|
|
88
91
|
try {
|
|
89
92
|
for (; this.running; ) {
|
|
90
|
-
const o = await
|
|
93
|
+
const o = await m.next();
|
|
91
94
|
if (o.done) break;
|
|
92
|
-
const
|
|
93
|
-
if (
|
|
94
|
-
await u.data();
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
const g = o.value, u = this.trainBatch(t, g);
|
|
96
|
+
if (t.step % l === 0) {
|
|
97
|
+
const p = (await u.data())[0];
|
|
98
|
+
t.lastLoss = p;
|
|
99
|
+
const v = Date.now();
|
|
100
|
+
t.trainingDuration += v - t.logStartTime;
|
|
101
|
+
const n = this.createLogEntry(t, d, g.xs.shape[0], e?.advancedMetrics);
|
|
102
|
+
if (this.model.trainingState = {
|
|
103
|
+
steps: t.totalSteps,
|
|
104
|
+
learningRate: this.optimizer.lr,
|
|
105
|
+
batchSize: g.xs.shape[0],
|
|
106
|
+
loss: t.lastLoss
|
|
107
|
+
}, i)
|
|
97
108
|
try {
|
|
98
|
-
const
|
|
99
|
-
|
|
100
|
-
} catch (
|
|
101
|
-
console.error("Validation error:",
|
|
109
|
+
const h = await i.evaluate(5);
|
|
110
|
+
t.validationLosses.push(h), n.valLoss = h;
|
|
111
|
+
} catch (h) {
|
|
112
|
+
console.error("Validation error:", h);
|
|
102
113
|
}
|
|
103
|
-
if (
|
|
104
|
-
const
|
|
105
|
-
await
|
|
114
|
+
if (c) {
|
|
115
|
+
const h = this.createProgress(t, n, e?.advancedMetrics);
|
|
116
|
+
await c(n, h);
|
|
106
117
|
}
|
|
107
|
-
|
|
118
|
+
t.logStartTime = Date.now();
|
|
108
119
|
}
|
|
109
|
-
u.dispose(),
|
|
120
|
+
u.dispose(), t.step >= r && this.stop();
|
|
110
121
|
}
|
|
111
122
|
} catch (o) {
|
|
112
123
|
throw console.error("Training error:", o), S(), o;
|
|
113
124
|
}
|
|
114
|
-
return S(), this.running = !1, { losses:
|
|
125
|
+
return S(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
|
|
115
126
|
}
|
|
116
127
|
}
|
|
117
128
|
export {
|
|
118
|
-
|
|
129
|
+
x as default
|
|
119
130
|
};
|