@genai-fi/nanogpt 0.5.2 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
2
|
import { ITokeniser } from '../tokeniser/type';
|
|
3
3
|
import { Dataset } from '@tensorflow/tfjs-data';
|
|
4
|
+
export declare const PAGE_FACTOR = 8;
|
|
5
|
+
export declare function flattenTokens(textData: string[], tokenizer: ITokeniser): Promise<number[]>;
|
|
4
6
|
export declare class DatasetBuilder {
|
|
5
7
|
tokenizer: ITokeniser;
|
|
6
8
|
blockSize: number;
|
|
9
|
+
private pageSize;
|
|
7
10
|
constructor(tokenizer: ITokeniser, blockSize?: number);
|
|
8
|
-
createTextDataset(
|
|
11
|
+
createTextDataset(flatTokens: number[], batchSize?: number, masked?: Set<number>, invertMask?: boolean): Promise<Dataset<{
|
|
9
12
|
xs: Tensor;
|
|
10
13
|
ys: Tensor;
|
|
11
14
|
}>>;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { t as
|
|
2
|
-
import { d as
|
|
1
|
+
import { t as u } from "../index-CnHyhpKc.js";
|
|
2
|
+
import { d as z, i as f } from "../dataset-ZHEPJmED.js";
|
|
3
3
|
import "../index-Tf7vU29b.js";
|
|
4
4
|
/**
|
|
5
5
|
* @license
|
|
@@ -18,39 +18,64 @@ import "../index-Tf7vU29b.js";
|
|
|
18
18
|
*
|
|
19
19
|
* =============================================================================
|
|
20
20
|
*/
|
|
21
|
-
function
|
|
22
|
-
return
|
|
23
|
-
const
|
|
24
|
-
return
|
|
21
|
+
function S(c) {
|
|
22
|
+
return z(async () => {
|
|
23
|
+
const t = await c();
|
|
24
|
+
return f(() => t.next());
|
|
25
25
|
});
|
|
26
26
|
}
|
|
27
|
-
|
|
27
|
+
const p = 8;
|
|
28
|
+
async function y(c, t) {
|
|
29
|
+
const s = await Promise.all(c.map((n) => t.encode(n))), i = t.eosToken >= 0;
|
|
30
|
+
return s.map((n) => i ? [...n, t.eosToken] : n).flat();
|
|
31
|
+
}
|
|
32
|
+
class w {
|
|
28
33
|
tokenizer;
|
|
29
34
|
blockSize;
|
|
30
|
-
|
|
31
|
-
|
|
35
|
+
pageSize;
|
|
36
|
+
constructor(t, s = 128) {
|
|
37
|
+
this.tokenizer = t, this.blockSize = s, this.pageSize = s * p;
|
|
32
38
|
}
|
|
33
39
|
// Create dataset from text files
|
|
34
|
-
async createTextDataset(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
async createTextDataset(t, s = 32, i, r) {
|
|
41
|
+
if (t.length < this.blockSize + 1)
|
|
42
|
+
throw new Error(`Not enough tokens (${t.length}) for block size ${this.blockSize}`);
|
|
43
|
+
if (i && i.size > t.length / this.pageSize / 2)
|
|
44
|
+
throw new Error("Too many masked pages - would leave insufficient training data");
|
|
45
|
+
const n = (function* () {
|
|
46
|
+
if (i && r) {
|
|
47
|
+
const e = Array.from(i);
|
|
48
|
+
for (; ; ) {
|
|
49
|
+
const a = Math.floor(Math.random() * e.length), l = Math.floor(Math.random() * this.pageSize), o = e[a] * this.pageSize + l;
|
|
50
|
+
if (o + this.blockSize + 1 > t.length)
|
|
51
|
+
continue;
|
|
52
|
+
const h = t.slice(o, o + this.blockSize), g = t.slice(o + 1, o + this.blockSize + 1);
|
|
53
|
+
yield { xs: h, ys: g };
|
|
54
|
+
}
|
|
55
|
+
} else
|
|
56
|
+
for (; ; ) {
|
|
57
|
+
const e = Math.floor(Math.random() * (t.length - this.blockSize - 1));
|
|
58
|
+
if (i) {
|
|
59
|
+
const o = Math.floor(e / this.pageSize), h = i.has(o);
|
|
60
|
+
if (h && !r || !h && r)
|
|
61
|
+
continue;
|
|
62
|
+
}
|
|
63
|
+
const a = t.slice(e, e + this.blockSize), l = t.slice(e + 1, e + this.blockSize + 1);
|
|
64
|
+
yield { xs: a, ys: l };
|
|
65
|
+
}
|
|
43
66
|
}).bind(this);
|
|
44
|
-
return
|
|
45
|
-
const
|
|
46
|
-
return
|
|
47
|
-
xs:
|
|
48
|
-
ys:
|
|
67
|
+
return S(n).batch(s).map((e) => {
|
|
68
|
+
const a = e;
|
|
69
|
+
return u(() => ({
|
|
70
|
+
xs: a.xs.cast("int32"),
|
|
71
|
+
ys: a.ys.cast("int32")
|
|
49
72
|
// this.tf.oneHot(batchData.ys.cast('int32'), this.tokenizer.vocabSize),
|
|
50
73
|
}));
|
|
51
74
|
}).prefetch(2);
|
|
52
75
|
}
|
|
53
76
|
}
|
|
54
77
|
export {
|
|
55
|
-
|
|
78
|
+
w as DatasetBuilder,
|
|
79
|
+
p as PAGE_FACTOR,
|
|
80
|
+
y as flattenTokens
|
|
56
81
|
};
|
package/dist/training/Trainer.js
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
import { DatasetBuilder as d } from "./DatasetBuilder.js";
|
|
2
|
-
import
|
|
3
|
-
import { t as
|
|
4
|
-
import { m as
|
|
5
|
-
import { m as
|
|
6
|
-
import { m as
|
|
7
|
-
import { z as
|
|
1
|
+
import { DatasetBuilder as h, flattenTokens as d, PAGE_FACTOR as g } from "./DatasetBuilder.js";
|
|
2
|
+
import u from "./AdamExt.js";
|
|
3
|
+
import { t as f, v as y, a as m } from "../index-CnHyhpKc.js";
|
|
4
|
+
import { m as S, n as z } from "../norm-BpWsOapl.js";
|
|
5
|
+
import { m as w, a as T } from "../moments-DLTE6-1p.js";
|
|
6
|
+
import { m as x } from "../max-CcnEArWK.js";
|
|
7
|
+
import { z as p } from "../zeros-CYMicyqz.js";
|
|
8
8
|
class G {
|
|
9
9
|
constructor(t, s, e = 1e-3) {
|
|
10
|
-
this.tokenizer = s, this.model = t, this.learningRate = e, this.resetOptimizer(), this.datasetBuilder = new
|
|
10
|
+
this.tokenizer = s, this.model = t, this.learningRate = e, this.resetOptimizer(), this.datasetBuilder = new h(s, t.config.gpt.blockSize);
|
|
11
11
|
}
|
|
12
12
|
model;
|
|
13
13
|
optimizer;
|
|
@@ -29,7 +29,7 @@ class G {
|
|
|
29
29
|
}
|
|
30
30
|
resetOptimizer(t = { learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 }) {
|
|
31
31
|
this.optimizer && this.optimizer.dispose();
|
|
32
|
-
const s = new
|
|
32
|
+
const s = new u(
|
|
33
33
|
t.learningRateFactor * this.learningRate,
|
|
34
34
|
t.beta1,
|
|
35
35
|
t.beta2,
|
|
@@ -46,21 +46,21 @@ class G {
|
|
|
46
46
|
printGradients(t) {
|
|
47
47
|
Object.keys(t).forEach((s) => {
|
|
48
48
|
const e = t[s];
|
|
49
|
-
console.log(`${s}:`), console.log(` Shape: ${e.shape}`), console.log(` Mean: ${
|
|
49
|
+
console.log(`${s}:`), console.log(` Shape: ${e.shape}`), console.log(` Mean: ${w(e).dataSync()[0]}`), console.log(` Std: ${T(e).variance.sqrt().dataSync()[0]}`), console.log(` Min: ${S(e).dataSync()[0]}`), console.log(` Max: ${x(e).dataSync()[0]}`), console.log(` Norm: ${z(e).dataSync()[0]}`);
|
|
50
50
|
});
|
|
51
51
|
}
|
|
52
52
|
trainStep(t, s = !1, e = !1) {
|
|
53
|
-
return
|
|
53
|
+
return f(() => {
|
|
54
54
|
this.model.getProfiler()?.startMemory();
|
|
55
|
-
const { xs: a, ys:
|
|
56
|
-
const [
|
|
57
|
-
return
|
|
58
|
-
}, { value:
|
|
59
|
-
return s ? this.model.getProfiler()?.endMemory("Training") : (e && (console.log("-------"), this.printGradients(
|
|
55
|
+
const { xs: a, ys: i } = t, o = () => {
|
|
56
|
+
const [l, c] = this.model.forward({ training: !0 }, a, i);
|
|
57
|
+
return l.dispose(), c;
|
|
58
|
+
}, { value: n, grads: r } = y(o);
|
|
59
|
+
return s ? this.model.getProfiler()?.endMemory("Training") : (e && (console.log("-------"), this.printGradients(r), console.log("-------")), this.optimizer.applyGradients(r), this.model.getProfiler()?.endMemory("Training"), m(r)), n;
|
|
60
60
|
});
|
|
61
61
|
}
|
|
62
62
|
dummyPass() {
|
|
63
|
-
const t =
|
|
63
|
+
const t = p([1, this.model.config.gpt.blockSize], "int32"), s = p([1, this.model.config.gpt.blockSize], "int32");
|
|
64
64
|
try {
|
|
65
65
|
const e = this.trainStep({ xs: t, ys: s }, !0);
|
|
66
66
|
e.dataSync(), e.dispose();
|
|
@@ -75,20 +75,29 @@ class G {
|
|
|
75
75
|
const e = this.trainStep(s, !1, !1);
|
|
76
76
|
return s.xs.dispose(), s.ys.dispose(), t.step++, t.totalSteps++, e.array().then((a) => (t.lastLoss = a, t.losses.push(t.lastLoss), e.dispose(), t.lastLoss));
|
|
77
77
|
} catch (e) {
|
|
78
|
-
throw console.error(`Error processing batch at step ${t.step}:`, e),
|
|
78
|
+
throw console.error(`Error processing batch at step ${t.step}:`, e), m(), e;
|
|
79
79
|
}
|
|
80
80
|
}
|
|
81
81
|
async createTrainValidationSplit(t, s = 32, e = 0.1) {
|
|
82
|
-
const a = await
|
|
83
|
-
|
|
82
|
+
const a = await d(t, this.tokenizer), i = /* @__PURE__ */ new Set();
|
|
83
|
+
if (e > 0) {
|
|
84
|
+
const r = Math.floor(a.length / (this.datasetBuilder.blockSize * g)), l = Math.max(1, Math.floor(r * e));
|
|
85
|
+
for (; i.size < l; ) {
|
|
86
|
+
const c = Math.floor(Math.random() * r);
|
|
87
|
+
i.add(c);
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
const o = await this.datasetBuilder.createTextDataset(a, s, i, !1), n = await this.datasetBuilder.createTextDataset(
|
|
91
|
+
a,
|
|
84
92
|
s,
|
|
85
|
-
|
|
86
|
-
|
|
93
|
+
i,
|
|
94
|
+
!0
|
|
87
95
|
);
|
|
88
|
-
return { trainDataset:
|
|
96
|
+
return { trainDataset: o, validationDataset: n };
|
|
89
97
|
}
|
|
90
98
|
async createDataset(t, s = 32) {
|
|
91
|
-
|
|
99
|
+
const e = await d(t, this.tokenizer);
|
|
100
|
+
return await this.datasetBuilder.createTextDataset(e, s);
|
|
92
101
|
}
|
|
93
102
|
dispose() {
|
|
94
103
|
this.optimizer && this.optimizer.dispose();
|
|
@@ -1,12 +1,21 @@
|
|
|
1
|
-
function
|
|
2
|
-
const r = Array.from(
|
|
3
|
-
let
|
|
4
|
-
for (let
|
|
5
|
-
const
|
|
6
|
-
|
|
1
|
+
function c(l) {
|
|
2
|
+
const r = Array.from(l), s = [], o = new RegExp("(\\p{P}|\\p{S}|\\s)", "gu");
|
|
3
|
+
let t = "";
|
|
4
|
+
for (let e = 0; e < r.length; e++) {
|
|
5
|
+
const n = r[e];
|
|
6
|
+
if (n === " ")
|
|
7
|
+
(r[e + 1] ?? "") !== " " ? (s.push(t), t = n) : t += n;
|
|
8
|
+
else if (n.match(o)) {
|
|
9
|
+
s.push(t);
|
|
10
|
+
let h = n;
|
|
11
|
+
for (; e + 1 < r.length && r[e + 1] === n; )
|
|
12
|
+
h += r[e + 1], e++;
|
|
13
|
+
s.push(h), t = "";
|
|
14
|
+
} else
|
|
15
|
+
t += n;
|
|
7
16
|
}
|
|
8
|
-
return
|
|
17
|
+
return t.length > 0 && s.push(t), s.filter((e) => e.length > 0);
|
|
9
18
|
}
|
|
10
19
|
export {
|
|
11
|
-
|
|
20
|
+
c as default
|
|
12
21
|
};
|