@genai-fi/nanogpt 0.15.13 → 0.15.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Trainer.js +10 -5
- package/dist/data/textLoader.js +47 -41
- package/dist/training/BasicTrainer.js +62 -62
- package/dist/training/Evaluator.d.ts +2 -1
- package/dist/training/Evaluator.js +19 -18
- package/dist/training/SFTDatasetBuilder.js +43 -36
- package/dist/training/tasks/ConversationTask.d.ts +2 -2
- package/dist/training/tasks/ConversationTask.js +13 -11
- package/dist/training/tasks/PretrainingTask.d.ts +1 -2
- package/dist/training/tasks/PretrainingTask.js +4 -14
- package/dist/training/tasks/StartSentenceTask.d.ts +1 -2
- package/dist/training/tasks/StartSentenceTask.js +2 -7
- package/dist/training/tasks/Task.d.ts +1 -2
- package/dist/training/tasks/splitter.d.ts +5 -0
- package/dist/training/tasks/splitter.js +21 -0
- package/dist/training/validation.js +1 -1
- package/package.json +1 -1
package/dist/Trainer.js
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import { E as g } from "./index-DvYrXKkX.js";
|
|
2
2
|
import o from "./training/PreTrainer.js";
|
|
3
|
-
import { createTrainValidationSplit as
|
|
3
|
+
import { createTrainValidationSplit as d } from "./training/validation.js";
|
|
4
4
|
import h from "./training/SFTTrainer.js";
|
|
5
|
+
import p from "./training/tasks/splitter.js";
|
|
5
6
|
class l extends g {
|
|
6
7
|
trainer;
|
|
7
8
|
trainingType = "pretraining";
|
|
@@ -81,7 +82,7 @@ class l extends g {
|
|
|
81
82
|
async prepare(t = []) {
|
|
82
83
|
const i = this.options;
|
|
83
84
|
if (this.trainingType === "pretraining" && this.trainer instanceof o) {
|
|
84
|
-
const { trainDataset: e, validationDataset: a, size: r, trainState: n } = await
|
|
85
|
+
const { trainDataset: e, validationDataset: a, size: r, trainState: n } = await d(
|
|
85
86
|
t,
|
|
86
87
|
this.trainer.tokenizer,
|
|
87
88
|
this.trainer.datasetBuilder,
|
|
@@ -92,12 +93,16 @@ class l extends g {
|
|
|
92
93
|
} else if (this.trainingType === "sft" && this.trainer instanceof h) {
|
|
93
94
|
if (t instanceof Uint16Array)
|
|
94
95
|
throw new Error("SFT training requires Task[] input");
|
|
95
|
-
const e = await this.trainer.datasetBuilder.createSFTDataset(
|
|
96
|
-
|
|
96
|
+
const e = p(t, i?.validationSplit || 0.1), a = await this.trainer.datasetBuilder.createSFTDataset(
|
|
97
|
+
[e.training],
|
|
98
|
+
i?.batchSize || 32,
|
|
99
|
+
-100
|
|
100
|
+
), r = await this.trainer.datasetBuilder.createSFTDataset(
|
|
101
|
+
[e.validation],
|
|
97
102
|
i?.batchSize || 32,
|
|
98
103
|
-100
|
|
99
104
|
);
|
|
100
|
-
this.trainDataset =
|
|
105
|
+
this.validationDataset = r, this.trainDataset = a, this.totalSamples = t.reduce((n, s) => n + s.length, 0), this.options.epochSteps = Math.ceil(this.totalSamples / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
|
|
101
106
|
}
|
|
102
107
|
}
|
|
103
108
|
configureModel(t) {
|
package/dist/data/textLoader.js
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import { p as u } from "../papaparse.min-C0cScC2i.js";
|
|
2
|
-
import { loadParquet as
|
|
3
|
-
import { loadPDF as
|
|
2
|
+
import { loadParquet as f } from "./parquet.js";
|
|
3
|
+
import { loadPDF as d } from "./pdf.js";
|
|
4
4
|
import { loadDOCX as m } from "./docx.js";
|
|
5
5
|
import { z as x } from "../jszip.min-BZhlzntC.js";
|
|
6
|
-
function
|
|
7
|
-
const
|
|
8
|
-
return
|
|
6
|
+
function y(t, r) {
|
|
7
|
+
const a = t.findIndex((i) => i.toLowerCase() === r.toLowerCase());
|
|
8
|
+
return a === -1 ? 0 : a;
|
|
9
9
|
}
|
|
10
|
-
function
|
|
11
|
-
return t.every((
|
|
10
|
+
function w(t) {
|
|
11
|
+
return t.every((r) => r.length < 64);
|
|
12
12
|
}
|
|
13
13
|
function h(t) {
|
|
14
14
|
return t.split(".").pop() || "";
|
|
@@ -35,66 +35,72 @@ function g(t) {
|
|
|
35
35
|
return "unknown";
|
|
36
36
|
}
|
|
37
37
|
}
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
38
|
+
function j(t) {
|
|
39
|
+
if (!Array.isArray(t)) return !1;
|
|
40
|
+
const r = t[0];
|
|
41
|
+
return typeof r == "object" && r !== null && "role" in r && "content" in r && typeof r.role == "string" && typeof r.content == "string";
|
|
42
|
+
}
|
|
43
|
+
async function z(t, r) {
|
|
44
|
+
const a = t.type !== "" ? t.type : g(t.name);
|
|
45
|
+
if (a === "application/parquet")
|
|
46
|
+
return f(t, r?.maxSize, r?.column);
|
|
47
|
+
if (a === "application/pdf")
|
|
48
|
+
return d(t, r?.maxSize);
|
|
49
|
+
if (a === "application/vnd.openxmlformats-officedocument.wordprocessingml.document")
|
|
45
50
|
return m(t);
|
|
46
|
-
if (
|
|
47
|
-
const i = await t.text(),
|
|
48
|
-
if (Array.isArray(
|
|
49
|
-
return
|
|
51
|
+
if (a === "application/json") {
|
|
52
|
+
const i = await t.text(), o = JSON.parse(i);
|
|
53
|
+
if (Array.isArray(o))
|
|
54
|
+
return o.map(
|
|
50
55
|
(e) => typeof e == "string" ? e : "text" in e ? e.text : JSON.stringify(e)
|
|
51
56
|
);
|
|
52
57
|
throw new Error("Expected JSON array");
|
|
53
58
|
}
|
|
54
|
-
if (
|
|
59
|
+
if (a === "application/jsonl")
|
|
55
60
|
return (await t.text()).split(`
|
|
56
|
-
`).filter((
|
|
61
|
+
`).filter((o) => o.trim() !== "").map((o) => {
|
|
57
62
|
try {
|
|
58
|
-
const e = JSON.parse(
|
|
59
|
-
return
|
|
63
|
+
const e = JSON.parse(o);
|
|
64
|
+
return j(e) ? e.map((n) => `${n.content}`).join(`
|
|
65
|
+
`) : typeof e == "string" ? e : "text" in e ? e.text : JSON.stringify(e);
|
|
60
66
|
} catch {
|
|
61
|
-
return
|
|
67
|
+
return o;
|
|
62
68
|
}
|
|
63
69
|
});
|
|
64
|
-
if (
|
|
65
|
-
const i = await x.loadAsync(t),
|
|
70
|
+
if (a === "application/zip") {
|
|
71
|
+
const i = await x.loadAsync(t), o = [];
|
|
66
72
|
for (const e of Object.keys(i.files)) {
|
|
67
|
-
const
|
|
68
|
-
if (
|
|
69
|
-
const
|
|
70
|
-
|
|
73
|
+
const n = i.file(e);
|
|
74
|
+
if (n) {
|
|
75
|
+
const s = await n.async("blob"), c = await z(new File([s], e), r);
|
|
76
|
+
o.push(...c);
|
|
71
77
|
}
|
|
72
78
|
}
|
|
73
|
-
return
|
|
79
|
+
return o;
|
|
74
80
|
}
|
|
75
|
-
if (
|
|
81
|
+
if (a === "text/csv") {
|
|
76
82
|
const i = await t.text();
|
|
77
|
-
return new Promise((
|
|
83
|
+
return new Promise((o, e) => {
|
|
78
84
|
u.parse(i, {
|
|
79
85
|
header: !1,
|
|
80
86
|
skipEmptyLines: !0,
|
|
81
87
|
delimiter: ",",
|
|
82
|
-
complete: (
|
|
83
|
-
if (
|
|
84
|
-
console.error(
|
|
88
|
+
complete: (n) => {
|
|
89
|
+
if (n.errors.length > 0)
|
|
90
|
+
console.error(n.errors), e(new Error("Error parsing file"));
|
|
85
91
|
else {
|
|
86
|
-
const
|
|
87
|
-
|
|
92
|
+
const s = y(n.data[0], r?.column || "text"), p = r?.hasHeader ?? w(n.data[0]) ? n.data.slice(1) : n.data;
|
|
93
|
+
o(p.map((l) => l[s]));
|
|
88
94
|
}
|
|
89
95
|
},
|
|
90
|
-
error: (
|
|
91
|
-
e(
|
|
96
|
+
error: (n) => {
|
|
97
|
+
e(n);
|
|
92
98
|
}
|
|
93
99
|
});
|
|
94
100
|
});
|
|
95
|
-
} else if (
|
|
101
|
+
} else if (a === "text/plain")
|
|
96
102
|
return [await t.text()];
|
|
97
|
-
throw new Error(`Unsupported file type: ${
|
|
103
|
+
throw new Error(`Unsupported file type: ${a}`);
|
|
98
104
|
}
|
|
99
105
|
export {
|
|
100
106
|
z as default
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import u from "./Evaluator.js";
|
|
2
|
-
import { t as
|
|
2
|
+
import { t as z, v as P, k as g, d as p, a as y } from "../index-CUXkjxiT.js";
|
|
3
3
|
import S from "../utilities/profile.js";
|
|
4
|
-
import { createTensorStatistics as
|
|
5
|
-
import { calculateLoss as
|
|
4
|
+
import { createTensorStatistics as k } from "../checks/weights.js";
|
|
5
|
+
import { calculateLoss as x, calculateAccuracy as T } from "./loss.js";
|
|
6
6
|
import { AdamWOptimizer as N } from "./AdamW.js";
|
|
7
7
|
import { z as w } from "../zeros-DvZpK8s6.js";
|
|
8
8
|
const v = {
|
|
@@ -23,11 +23,11 @@ const v = {
|
|
|
23
23
|
lossScaling: 1
|
|
24
24
|
};
|
|
25
25
|
class G {
|
|
26
|
-
constructor(
|
|
27
|
-
this.tokenizer = i, this.model =
|
|
26
|
+
constructor(s, i, o, c) {
|
|
27
|
+
this.tokenizer = i, this.model = s, this.optimizerConfig = {
|
|
28
28
|
...b,
|
|
29
29
|
...o,
|
|
30
|
-
lossScaling:
|
|
30
|
+
lossScaling: s.lossScaling
|
|
31
31
|
};
|
|
32
32
|
const l = c || new N(this.optimizerConfig);
|
|
33
33
|
c && c.updateConfig(this.optimizerConfig), this.optimizer = l;
|
|
@@ -44,26 +44,26 @@ class G {
|
|
|
44
44
|
_labelSmoothing = 0;
|
|
45
45
|
_layerDrop = 0;
|
|
46
46
|
_dropout = 0;
|
|
47
|
-
setGradientCheckpointing(
|
|
48
|
-
this._gradientCheckpointing =
|
|
47
|
+
setGradientCheckpointing(s) {
|
|
48
|
+
this._gradientCheckpointing = s;
|
|
49
49
|
}
|
|
50
|
-
setMixedPrecision(
|
|
51
|
-
this._mixedPrecision =
|
|
50
|
+
setMixedPrecision(s) {
|
|
51
|
+
this._mixedPrecision = s;
|
|
52
52
|
}
|
|
53
|
-
setLabelSmoothing(
|
|
54
|
-
this._labelSmoothing =
|
|
53
|
+
setLabelSmoothing(s) {
|
|
54
|
+
this._labelSmoothing = s;
|
|
55
55
|
}
|
|
56
|
-
setDropout(
|
|
57
|
-
this._dropout =
|
|
56
|
+
setDropout(s) {
|
|
57
|
+
this._dropout = s;
|
|
58
58
|
}
|
|
59
|
-
setLayerDrop(
|
|
60
|
-
this._layerDrop =
|
|
59
|
+
setLayerDrop(s) {
|
|
60
|
+
this._layerDrop = s;
|
|
61
61
|
}
|
|
62
|
-
setLearningRate(
|
|
63
|
-
this.optimizerConfig.learningRate =
|
|
62
|
+
setLearningRate(s) {
|
|
63
|
+
this.optimizerConfig.learningRate = s, this.updateOptimizer();
|
|
64
64
|
}
|
|
65
|
-
setMetrics(
|
|
66
|
-
this.metrics = new Set(
|
|
65
|
+
setMetrics(s) {
|
|
66
|
+
this.metrics = new Set(s);
|
|
67
67
|
}
|
|
68
68
|
reset() {
|
|
69
69
|
this.lastState = void 0, this.running = !1;
|
|
@@ -77,12 +77,12 @@ class G {
|
|
|
77
77
|
getOptimizer() {
|
|
78
78
|
return this.optimizer;
|
|
79
79
|
}
|
|
80
|
-
updateOptimizer(
|
|
81
|
-
|
|
80
|
+
updateOptimizer(s) {
|
|
81
|
+
s && (this.optimizerConfig = { ...this.optimizerConfig, ...s }), this.optimizer.updateConfig(this.optimizerConfig);
|
|
82
82
|
}
|
|
83
83
|
// A single forward pass, backward pass, and optimizer step
|
|
84
|
-
trainStep(
|
|
85
|
-
return
|
|
84
|
+
trainStep(s, i, o = !1, c = !1) {
|
|
85
|
+
return z(() => {
|
|
86
86
|
this.model.getProfiler()?.startMemory();
|
|
87
87
|
const { xs: l, ys: a } = i, d = () => {
|
|
88
88
|
const r = this.model.forward(
|
|
@@ -94,31 +94,31 @@ class G {
|
|
|
94
94
|
layerDrop: this._layerDrop
|
|
95
95
|
},
|
|
96
96
|
l
|
|
97
|
-
),
|
|
98
|
-
this.metrics.has("accuracy") && (
|
|
99
|
-
const m =
|
|
100
|
-
return
|
|
97
|
+
), e = x(r, a, this.maskedLoss, !1, this._labelSmoothing);
|
|
98
|
+
this.metrics.has("accuracy") && (s.accuracy = T(r, a), g(s.accuracy)), r.dispose();
|
|
99
|
+
const m = e.mul(y(this.optimizerConfig.lossScaling));
|
|
100
|
+
return e.dispose(), m;
|
|
101
101
|
}, { value: t, grads: n } = P(d);
|
|
102
102
|
if (o)
|
|
103
103
|
this.model.getProfiler()?.endMemory("Training");
|
|
104
104
|
else {
|
|
105
105
|
const r = this.optimizer.applyGradients(n);
|
|
106
|
-
this.metrics.has("gradientNorm") ? (
|
|
107
|
-
const
|
|
108
|
-
this.model.weightStore.touchVariables(
|
|
106
|
+
this.metrics.has("gradientNorm") ? (s.gradientNorm = r, g(r)) : (s.gradientNorm = void 0, r.dispose());
|
|
107
|
+
const e = Object.keys(n);
|
|
108
|
+
this.model.weightStore.touchVariables(e), this.model.getProfiler()?.endMemory("Training"), c ? (s.gradients = n, Object.values(n).forEach((m) => g(m))) : p(n);
|
|
109
109
|
}
|
|
110
110
|
return t.mul(y(1 / this.optimizerConfig.lossScaling));
|
|
111
111
|
});
|
|
112
112
|
}
|
|
113
113
|
async dummyPass() {
|
|
114
|
-
const
|
|
114
|
+
const s = w([1, this.model.config.blockSize], "int32"), i = w([1, this.model.config.blockSize], "int32");
|
|
115
115
|
try {
|
|
116
|
-
const o = this.trainStep({}, { xs:
|
|
116
|
+
const o = this.trainStep({}, { xs: s, ys: i }, !0);
|
|
117
117
|
await o.data(), o.dispose();
|
|
118
118
|
} catch (o) {
|
|
119
119
|
console.error("Error during dummy pass:", o);
|
|
120
120
|
} finally {
|
|
121
|
-
|
|
121
|
+
s.dispose(), i.dispose();
|
|
122
122
|
}
|
|
123
123
|
}
|
|
124
124
|
dispose() {
|
|
@@ -136,7 +136,7 @@ class G {
|
|
|
136
136
|
...this.lastState || {}
|
|
137
137
|
};
|
|
138
138
|
}
|
|
139
|
-
async stepDataset(
|
|
139
|
+
async stepDataset(s, i, o) {
|
|
140
140
|
const { logInterval: c = 10 } = {
|
|
141
141
|
...v,
|
|
142
142
|
...i
|
|
@@ -144,21 +144,21 @@ class G {
|
|
|
144
144
|
i.metrics && this.setMetrics(i.metrics);
|
|
145
145
|
const l = Date.now(), a = this.createEmptyState();
|
|
146
146
|
this.lastState = a, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new S())), this.running = !0, a.logStartTime = l;
|
|
147
|
-
const d = o ? new u(this.model, o) : void 0, t = await
|
|
147
|
+
const d = o ? new u(this.model, o, void 0, this.maskedLoss) : void 0, t = await s.iterator();
|
|
148
148
|
try {
|
|
149
149
|
for (; this.running; ) {
|
|
150
150
|
const n = await t.next();
|
|
151
151
|
if (n.done) break;
|
|
152
|
-
const r = n.value,
|
|
153
|
-
r.xs.dispose(), r.ys.dispose(), a.step++, a.totalSteps++, a.step % c === 0 ? await this.performLogging(
|
|
152
|
+
const r = n.value, e = this.trainStep(a, r, !1);
|
|
153
|
+
r.xs.dispose(), r.ys.dispose(), a.step++, a.totalSteps++, a.step % c === 0 ? await this.performLogging(e, r.xs.shape[0], i, d) : (a.gradientNorm && (a.gradientNorm.dispose(), a.gradientNorm = void 0), a.accuracy && (a.accuracy.dispose(), a.accuracy = void 0)), e.dispose();
|
|
154
154
|
}
|
|
155
155
|
} catch (n) {
|
|
156
|
-
throw console.error("Training error:", n),
|
|
156
|
+
throw console.error("Training error:", n), p(), n;
|
|
157
157
|
}
|
|
158
|
-
throw
|
|
158
|
+
throw p(), this.running = !1, new Error("No log returned before training stopped.");
|
|
159
159
|
}
|
|
160
|
-
async performLogging(
|
|
161
|
-
const l = o?.onStep, a = this.metrics.has("gradientStatistics"), d = (await
|
|
160
|
+
async performLogging(s, i, o, c) {
|
|
161
|
+
const l = o?.onStep, a = this.metrics.has("gradientStatistics"), d = (await s.data())[0], t = this.lastState;
|
|
162
162
|
t.lastLoss = d;
|
|
163
163
|
const n = Date.now();
|
|
164
164
|
t.trainingDuration += n - t.logStartTime;
|
|
@@ -184,25 +184,25 @@ class G {
|
|
|
184
184
|
batchSize: i,
|
|
185
185
|
loss: t.lastLoss
|
|
186
186
|
}, a && t.gradients) {
|
|
187
|
-
const
|
|
188
|
-
for (const [m,
|
|
189
|
-
|
|
190
|
-
r.gradientMetrics =
|
|
187
|
+
const e = /* @__PURE__ */ new Map();
|
|
188
|
+
for (const [m, h] of Object.entries(t.gradients))
|
|
189
|
+
e.set(m, await k(h)), h.dispose();
|
|
190
|
+
r.gradientMetrics = e;
|
|
191
191
|
}
|
|
192
192
|
if (c)
|
|
193
193
|
try {
|
|
194
|
-
const
|
|
195
|
-
Array.isArray(
|
|
196
|
-
accuracy:
|
|
197
|
-
loss:
|
|
198
|
-
perplexity: this.metrics.has("perplexity") ? Math.exp(
|
|
194
|
+
const e = await c.evaluate(5);
|
|
195
|
+
Array.isArray(e) ? r.validationMetrics = { loss: e[0].loss, accuracy: e[0].accuracy } : (t.validationLosses.push(e.loss), r.validationMetrics = {
|
|
196
|
+
accuracy: e.accuracy,
|
|
197
|
+
loss: e.loss,
|
|
198
|
+
perplexity: this.metrics.has("perplexity") ? Math.exp(e.loss) : void 0
|
|
199
199
|
});
|
|
200
|
-
} catch (
|
|
201
|
-
console.error("Validation error:",
|
|
200
|
+
} catch (e) {
|
|
201
|
+
console.error("Validation error:", e);
|
|
202
202
|
}
|
|
203
203
|
l && await l(r), t.logStartTime = Date.now();
|
|
204
204
|
}
|
|
205
|
-
async trainOnDataset(
|
|
205
|
+
async trainOnDataset(s, i, o) {
|
|
206
206
|
const { logInterval: c = 10, maxEpochs: l = 1 / 0 } = {
|
|
207
207
|
...v,
|
|
208
208
|
...i
|
|
@@ -210,18 +210,18 @@ class G {
|
|
|
210
210
|
i.metrics && this.setMetrics(i.metrics);
|
|
211
211
|
const d = Date.now(), t = this.createEmptyState();
|
|
212
212
|
this.lastState = t, await this.dummyPass(), i?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new S())), this.running = !0, t.logStartTime = d;
|
|
213
|
-
const n = o ? new u(this.model, o) : void 0, r = await
|
|
213
|
+
const n = o ? new u(this.model, o, void 0, this.maskedLoss) : void 0, r = await s.iterator();
|
|
214
214
|
try {
|
|
215
215
|
for (; this.running; ) {
|
|
216
|
-
const
|
|
217
|
-
if (
|
|
218
|
-
const m =
|
|
219
|
-
m.xs.dispose(), m.ys.dispose(), t.step++, t.totalSteps++,
|
|
216
|
+
const e = await r.next();
|
|
217
|
+
if (e.done) break;
|
|
218
|
+
const m = e.value, h = t.step % c === 0, L = (i?.metrics?.includes("gradientStatistics") || !1) && h, f = this.trainStep(t, m, !1, L);
|
|
219
|
+
m.xs.dispose(), m.ys.dispose(), t.step++, t.totalSteps++, h ? await this.performLogging(f, m.xs.shape[0], i, n) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= a && this.stop();
|
|
220
220
|
}
|
|
221
|
-
} catch (
|
|
222
|
-
throw console.error("Training error:",
|
|
221
|
+
} catch (e) {
|
|
222
|
+
throw console.error("Training error:", e), p(), e;
|
|
223
223
|
}
|
|
224
|
-
return
|
|
224
|
+
return p(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
|
|
225
225
|
}
|
|
226
226
|
}
|
|
227
227
|
export {
|
|
@@ -11,7 +11,8 @@ export default class Evaluator {
|
|
|
11
11
|
private iterator?;
|
|
12
12
|
private xs?;
|
|
13
13
|
private ys?;
|
|
14
|
-
|
|
14
|
+
private masked;
|
|
15
|
+
constructor(model: Model<ModelForwardAttributes>, dataset: Dataset<TensorContainer> | Conversation[][], tokeniser?: ITokeniser, masked?: boolean);
|
|
15
16
|
dispose(): void;
|
|
16
17
|
private calculateBatchLoss;
|
|
17
18
|
evaluate(maxBatches?: number): Promise<Result | Result[]>;
|
|
@@ -2,12 +2,12 @@ import { t as p } from "../index-CUXkjxiT.js";
|
|
|
2
2
|
import { calculateLoss as d, calculateAccuracy as m } from "./loss.js";
|
|
3
3
|
import { buildSFTExample as x } from "./SFTDatasetBuilder.js";
|
|
4
4
|
import { t as h } from "../tensor-BWFldCso.js";
|
|
5
|
-
class
|
|
6
|
-
constructor(
|
|
7
|
-
if (this.model = c, Array.isArray(t)) {
|
|
5
|
+
class k {
|
|
6
|
+
constructor(i, t, o, c) {
|
|
7
|
+
if (this.model = i, this.masked = !!c, Array.isArray(t)) {
|
|
8
8
|
if (!o)
|
|
9
9
|
throw new Error("Tokeniser is required when dataset is an array of conversations");
|
|
10
|
-
const a = t.map((s) => x(s, -100, o,
|
|
10
|
+
const a = t.map((s) => x(s, -100, o, i.config.blockSize)).filter((s) => s !== null);
|
|
11
11
|
if (a.length === 0)
|
|
12
12
|
return;
|
|
13
13
|
this.xs = h(a.map((s) => s.xs)), this.ys = h(a.map((s) => s.ys));
|
|
@@ -17,32 +17,33 @@ class g {
|
|
|
17
17
|
iterator;
|
|
18
18
|
xs;
|
|
19
19
|
ys;
|
|
20
|
+
masked = !1;
|
|
20
21
|
dispose() {
|
|
21
22
|
this.xs && this.xs.dispose(), this.ys && this.ys.dispose();
|
|
22
23
|
}
|
|
23
|
-
async calculateBatchLoss(
|
|
24
|
-
const [
|
|
25
|
-
const r = this.model.forward({ training: !1 },
|
|
24
|
+
async calculateBatchLoss(i, t, o, c) {
|
|
25
|
+
const [a, s] = p(() => {
|
|
26
|
+
const r = this.model.forward({ training: !1 }, i), y = d(r, t, c, o), f = m(r, t);
|
|
26
27
|
return r.dispose(), [y, f];
|
|
27
|
-
}), n = await
|
|
28
|
-
return
|
|
28
|
+
}), n = await a.array(), u = await s.array(), e = n, l = u;
|
|
29
|
+
return s.dispose(), a.dispose(), Array.isArray(e) ? e.map((r) => ({ loss: r, accuracy: l })) : { loss: e, accuracy: l };
|
|
29
30
|
}
|
|
30
|
-
async evaluate(
|
|
31
|
-
let t = 0, o = 0,
|
|
31
|
+
async evaluate(i = 100) {
|
|
32
|
+
let t = 0, o = 0, c = 0;
|
|
32
33
|
if (this.iterator) {
|
|
33
|
-
const
|
|
34
|
-
for (let
|
|
35
|
-
const n = await
|
|
34
|
+
const a = await this.iterator;
|
|
35
|
+
for (let s = 0; s < i; s++) {
|
|
36
|
+
const n = await a.next();
|
|
36
37
|
if (n.done) break;
|
|
37
|
-
const u = n.value, { xs:
|
|
38
|
-
|
|
38
|
+
const u = n.value, { xs: e, ys: l } = u, r = await this.calculateBatchLoss(e, l, !1, this.masked);
|
|
39
|
+
e.dispose(), l.dispose(), t += r.loss, o += r.accuracy, c++;
|
|
39
40
|
}
|
|
40
|
-
return { loss: t /
|
|
41
|
+
return { loss: t / c, accuracy: o / c };
|
|
41
42
|
} else if (this.xs && this.ys)
|
|
42
43
|
return this.calculateBatchLoss(this.xs, this.ys, !0, !0);
|
|
43
44
|
throw new Error("No data available for evaluation");
|
|
44
45
|
}
|
|
45
46
|
}
|
|
46
47
|
export {
|
|
47
|
-
|
|
48
|
+
k as default
|
|
48
49
|
};
|
|
@@ -1,50 +1,50 @@
|
|
|
1
|
-
import { t as
|
|
1
|
+
import { t as y } from "../index-CUXkjxiT.js";
|
|
2
2
|
import "../dataset-CGGp1z9P.js";
|
|
3
3
|
import { g as I } from "../readers-iz5u3HBo.js";
|
|
4
4
|
import "../index-Cp39cXWe.js";
|
|
5
|
-
function w(
|
|
6
|
-
const s = [t.bosToken],
|
|
5
|
+
function w(p, o, t, l) {
|
|
6
|
+
const s = [t.bosToken], a = [!1], u = {
|
|
7
7
|
user: t.getSpecialTokenIndex("<|user_start|>"),
|
|
8
8
|
assistant: t.getSpecialTokenIndex("<|assistant_start|>"),
|
|
9
9
|
system: t.getSpecialTokenIndex("<|system_start|>")
|
|
10
|
-
},
|
|
10
|
+
}, c = {
|
|
11
11
|
user: t.getSpecialTokenIndex("<|user_end|>"),
|
|
12
12
|
assistant: t.getSpecialTokenIndex("<|assistant_end|>"),
|
|
13
13
|
system: t.getSpecialTokenIndex("<|system_end|>")
|
|
14
14
|
};
|
|
15
|
-
for (const e of
|
|
16
|
-
const
|
|
17
|
-
if (!
|
|
15
|
+
for (const e of p) {
|
|
16
|
+
const r = u[e.role], h = c[e.role];
|
|
17
|
+
if (!r || !h)
|
|
18
18
|
throw new Error(`Missing special tokens for role: ${e.role}`);
|
|
19
|
-
s.push(
|
|
19
|
+
s.push(r), a.push(!1);
|
|
20
20
|
const m = e.role === "assistant", S = t.encode(e.content);
|
|
21
21
|
for (const T of S) {
|
|
22
22
|
s.push(T);
|
|
23
|
-
const
|
|
24
|
-
|
|
23
|
+
const x = t.isSpecialToken(T);
|
|
24
|
+
a.push(m && !x);
|
|
25
25
|
}
|
|
26
|
-
s.push(h),
|
|
26
|
+
s.push(h), a.push(m);
|
|
27
27
|
}
|
|
28
|
-
s.push(t.eosToken),
|
|
29
|
-
const
|
|
30
|
-
if (s.length <
|
|
31
|
-
const e =
|
|
28
|
+
s.push(t.eosToken), a.push(!1);
|
|
29
|
+
const n = l + 1;
|
|
30
|
+
if (s.length < n) {
|
|
31
|
+
const e = n - s.length, r = t.getSpecialTokenIndex("<pad>");
|
|
32
32
|
for (let h = 0; h < e; h++)
|
|
33
|
-
s.push(
|
|
34
|
-
} else s.length >
|
|
35
|
-
const
|
|
33
|
+
s.push(r), a.push(!1);
|
|
34
|
+
} else s.length > n && (s.length = n, a.length = n);
|
|
35
|
+
const f = new Int32Array(s.slice(0, l)), i = s.slice(1, l + 1), k = a.slice(1, l + 1), d = new Int32Array(i.length);
|
|
36
36
|
let g = !1;
|
|
37
|
-
for (let e = 0; e <
|
|
38
|
-
const
|
|
39
|
-
d[e] =
|
|
37
|
+
for (let e = 0; e < i.length; e++) {
|
|
38
|
+
const r = k[e] ? i[e] : o;
|
|
39
|
+
d[e] = r, r !== o && (g = !0);
|
|
40
40
|
}
|
|
41
|
-
return g ? { xs:
|
|
41
|
+
return g ? { xs: f, ys: d } : null;
|
|
42
42
|
}
|
|
43
|
-
class
|
|
43
|
+
class D {
|
|
44
44
|
tokenizer;
|
|
45
45
|
blockSize;
|
|
46
|
-
constructor(
|
|
47
|
-
this.tokenizer =
|
|
46
|
+
constructor(o, t = 128) {
|
|
47
|
+
this.tokenizer = o, this.blockSize = t;
|
|
48
48
|
}
|
|
49
49
|
/**
|
|
50
50
|
* Create SFT dataset from structured conversations.
|
|
@@ -52,25 +52,32 @@ class A {
|
|
|
52
52
|
* - Pads with eosToken and masks padding.
|
|
53
53
|
* - Masks non-assistant tokens in labels with ignoreIndex (default -100).
|
|
54
54
|
*/
|
|
55
|
-
async createSFTDataset(
|
|
56
|
-
if (!
|
|
55
|
+
async createSFTDataset(o, t = 32, l = -100) {
|
|
56
|
+
if (!o.length)
|
|
57
57
|
throw new Error("No conversations provided.");
|
|
58
|
-
const s = this.tokenizer,
|
|
58
|
+
const s = this.tokenizer, a = this.blockSize;
|
|
59
|
+
for (const c of o)
|
|
60
|
+
c.shuffle();
|
|
59
61
|
return I(function* () {
|
|
60
62
|
for (; ; ) {
|
|
61
|
-
const
|
|
62
|
-
|
|
63
|
+
const c = Math.floor(Math.random() * o.length), n = o[c], f = n.nextConversation();
|
|
64
|
+
if (!f) {
|
|
65
|
+
n.shuffle();
|
|
66
|
+
continue;
|
|
67
|
+
}
|
|
68
|
+
const i = w(f, l, s, a);
|
|
69
|
+
i && (yield i);
|
|
63
70
|
}
|
|
64
|
-
}).batch(t).map((
|
|
65
|
-
const
|
|
66
|
-
return
|
|
67
|
-
xs:
|
|
68
|
-
ys:
|
|
71
|
+
}).batch(t).map((c) => {
|
|
72
|
+
const n = c;
|
|
73
|
+
return y(() => ({
|
|
74
|
+
xs: n.xs.cast("int32"),
|
|
75
|
+
ys: n.ys.cast("int32")
|
|
69
76
|
}));
|
|
70
77
|
}).prefetch(2);
|
|
71
78
|
}
|
|
72
79
|
}
|
|
73
80
|
export {
|
|
74
|
-
|
|
81
|
+
D as SFTDatasetBuilder,
|
|
75
82
|
w as buildSFTExample
|
|
76
83
|
};
|
|
@@ -2,13 +2,13 @@ import { Conversation, ITokeniser } from '../../main';
|
|
|
2
2
|
import { Task } from './Task';
|
|
3
3
|
export default class ConversationTask extends Task {
|
|
4
4
|
private rawConvo;
|
|
5
|
+
private shuffledIndices;
|
|
5
6
|
private index;
|
|
6
7
|
get length(): number;
|
|
7
8
|
constructor(conversations: Conversation[][]);
|
|
8
9
|
hasMoreConversations(): boolean;
|
|
9
10
|
nextConversation(): Conversation[] | null;
|
|
10
11
|
nextTokens(tokeniser: ITokeniser): number[] | null;
|
|
11
|
-
|
|
12
|
-
getRandomTokens(tokeniser: ITokeniser): number[];
|
|
12
|
+
shuffle(): void;
|
|
13
13
|
estimateTokens(tokeniser: ITokeniser): Promise<number>;
|
|
14
14
|
}
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import { Task as t } from "./Task.js";
|
|
2
|
+
import { shuffle as s } from "../DatasetBuilder.js";
|
|
2
3
|
class a extends t {
|
|
3
4
|
rawConvo;
|
|
5
|
+
shuffledIndices = null;
|
|
4
6
|
index = 0;
|
|
5
7
|
get length() {
|
|
6
8
|
return this.rawConvo.length;
|
|
@@ -14,20 +16,20 @@ class a extends t {
|
|
|
14
16
|
nextConversation() {
|
|
15
17
|
if (this.index >= this.rawConvo.length)
|
|
16
18
|
return null;
|
|
17
|
-
const n = this.rawConvo[this.index];
|
|
19
|
+
const n = this.rawConvo[this.shuffledIndices ? this.shuffledIndices[this.index] : this.index];
|
|
18
20
|
return this.index++, n;
|
|
19
21
|
}
|
|
20
22
|
nextTokens(n) {
|
|
21
|
-
const
|
|
22
|
-
return
|
|
23
|
-
}
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
23
|
+
const e = this.nextConversation();
|
|
24
|
+
return e ? n.encodeConversation(e) : null;
|
|
25
|
+
}
|
|
26
|
+
shuffle() {
|
|
27
|
+
if (!this.shuffledIndices) {
|
|
28
|
+
this.shuffledIndices = new Uint32Array(this.rawConvo.length);
|
|
29
|
+
for (let n = 0; n < this.rawConvo.length; n++)
|
|
30
|
+
this.shuffledIndices[n] = n;
|
|
31
|
+
}
|
|
32
|
+
s(this.shuffledIndices), this.index = 0;
|
|
31
33
|
}
|
|
32
34
|
async estimateTokens(n) {
|
|
33
35
|
return (await n.encodeConversation(this.rawConvo[0])).length * this.length;
|
|
@@ -8,7 +8,6 @@ export default class PretrainingTask extends Task {
|
|
|
8
8
|
hasMoreConversations(): boolean;
|
|
9
9
|
nextConversation(): Conversation[] | null;
|
|
10
10
|
nextTokens(tokeniser: ITokeniser): number[] | null;
|
|
11
|
-
|
|
12
|
-
getRandomTokens(tokeniser: ITokeniser): number[];
|
|
11
|
+
shuffle(): void;
|
|
13
12
|
estimateTokens(tokeniser: ITokeniser): Promise<number>;
|
|
14
13
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { Task as n } from "./Task.js";
|
|
2
|
-
class
|
|
2
|
+
class r extends n {
|
|
3
3
|
rawText;
|
|
4
4
|
index = 0;
|
|
5
5
|
get length() {
|
|
@@ -26,18 +26,8 @@ class i extends n {
|
|
|
26
26
|
const e = t.encodeSequence(this.rawText[this.index]);
|
|
27
27
|
return this.index++, e;
|
|
28
28
|
}
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
return [
|
|
32
|
-
{
|
|
33
|
-
role: "assistant",
|
|
34
|
-
content: this.rawText[t]
|
|
35
|
-
}
|
|
36
|
-
];
|
|
37
|
-
}
|
|
38
|
-
getRandomTokens(t) {
|
|
39
|
-
const e = Math.floor(Math.random() * this.rawText.length);
|
|
40
|
-
return t.encodeSequence(this.rawText[e]);
|
|
29
|
+
shuffle() {
|
|
30
|
+
this.index = 0;
|
|
41
31
|
}
|
|
42
32
|
async estimateTokens(t) {
|
|
43
33
|
return (await t.encodeConversation([
|
|
@@ -49,5 +39,5 @@ class i extends n {
|
|
|
49
39
|
}
|
|
50
40
|
}
|
|
51
41
|
export {
|
|
52
|
-
|
|
42
|
+
r as default
|
|
53
43
|
};
|
|
@@ -8,8 +8,7 @@ export default class StartSentenceTask extends Task {
|
|
|
8
8
|
hasMoreConversations(): boolean;
|
|
9
9
|
nextConversation(): Conversation[] | null;
|
|
10
10
|
nextTokens(tokeniser: ITokeniser): number[] | null;
|
|
11
|
-
|
|
12
|
-
getRandomTokens(tokeniser: ITokeniser): number[];
|
|
11
|
+
shuffle(): void;
|
|
13
12
|
private conversationFromString;
|
|
14
13
|
estimateTokens(tokeniser: ITokeniser): Promise<number>;
|
|
15
14
|
}
|
|
@@ -21,13 +21,8 @@ class a extends e {
|
|
|
21
21
|
const n = this.nextConversation();
|
|
22
22
|
return n ? t.encodeConversation(n) : null;
|
|
23
23
|
}
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
return this.conversationFromString(this.rawText[t]);
|
|
27
|
-
}
|
|
28
|
-
getRandomTokens(t) {
|
|
29
|
-
const n = this.getRandomConversation();
|
|
30
|
-
return t.encodeConversation(n);
|
|
24
|
+
shuffle() {
|
|
25
|
+
this.index = 0;
|
|
31
26
|
}
|
|
32
27
|
conversationFromString(t) {
|
|
33
28
|
const n = t.indexOf(".");
|
|
@@ -5,7 +5,6 @@ export declare abstract class Task {
|
|
|
5
5
|
abstract nextConversation(): Conversation[] | null;
|
|
6
6
|
abstract nextTokens(tokeniser: ITokeniser): number[] | null;
|
|
7
7
|
abstract estimateTokens(tokeniser: ITokeniser): Promise<number>;
|
|
8
|
-
abstract
|
|
9
|
-
abstract getRandomTokens(tokeniser: ITokeniser): number[];
|
|
8
|
+
abstract shuffle(): void;
|
|
10
9
|
}
|
|
11
10
|
export declare function tokensFromTasks(tasks: Task[], tokenizer: ITokeniser, cb?: (tokens: number) => void): Promise<Uint16Array>;
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import s from "./ConversationTask.js";
|
|
2
|
+
function f(e, o) {
|
|
3
|
+
if (o <= 0 || o >= 1)
|
|
4
|
+
throw new Error("validationSplit must be between 0 and 1");
|
|
5
|
+
e.forEach((n) => n.shuffle());
|
|
6
|
+
const r = [], a = [];
|
|
7
|
+
for (const n of e)
|
|
8
|
+
for (; n.hasMoreConversations(); ) {
|
|
9
|
+
const t = n.nextConversation();
|
|
10
|
+
if (!t)
|
|
11
|
+
break;
|
|
12
|
+
Math.random() < o ? a.push(t) : r.push(t);
|
|
13
|
+
}
|
|
14
|
+
return {
|
|
15
|
+
training: new s(r),
|
|
16
|
+
validation: new s(a)
|
|
17
|
+
};
|
|
18
|
+
}
|
|
19
|
+
export {
|
|
20
|
+
f as default
|
|
21
|
+
};
|
|
@@ -39,8 +39,8 @@ import "../ops/webgl/adamAdjust.js";
|
|
|
39
39
|
import "../ops/cpu/adamMoments.js";
|
|
40
40
|
import "../ops/webgl/adamMoments.js";
|
|
41
41
|
import { PAGE_FACTOR as m, shuffle as h } from "./DatasetBuilder.js";
|
|
42
|
-
import "../papaparse.min-C0cScC2i.js";
|
|
43
42
|
import { tokensFromTasks as k } from "./tasks/Task.js";
|
|
43
|
+
import "../papaparse.min-C0cScC2i.js";
|
|
44
44
|
import "../ops/cpu/matMulGelu.js";
|
|
45
45
|
import "../matMulGelu-JNLZqKQp.js";
|
|
46
46
|
import "../ops/grads/matMulGelu.js";
|