@genai-fi/nanogpt 0.1.8 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +6 -8
- package/dist/Generator.js +55 -52
- package/dist/NanoGPTModel.d.ts +5 -3
- package/dist/NanoGPTModel.js +69 -53
- package/dist/TeachableLLM.js +15 -15
- package/dist/Trainer.d.ts +2 -0
- package/dist/Trainer.js +10 -5
- package/dist/config.d.ts +1 -0
- package/dist/config.js +5 -3
- package/dist/layers/CausalSelfAttention.d.ts +12 -2
- package/dist/layers/CausalSelfAttention.js +73 -40
- package/dist/layers/RMSNorm.d.ts +13 -0
- package/dist/layers/RMSNorm.js +32 -0
- package/dist/layers/RoPECache.d.ts +16 -0
- package/dist/layers/RoPECache.js +39 -0
- package/dist/layers/TransformerBlock.d.ts +5 -2
- package/dist/layers/TransformerBlock.js +14 -10
- package/dist/training/FullTrainer.js +27 -29
- package/dist/training/Trainer.d.ts +2 -0
- package/dist/training/Trainer.js +31 -27
- package/dist/utilities/generate.js +14 -14
- package/package.json +1 -1
package/dist/Generator.d.ts
CHANGED
|
@@ -1,18 +1,16 @@
|
|
|
1
|
-
import { default as NanoGPT } from './NanoGPTModel';
|
|
1
|
+
import { default as NanoGPT, GenerateOptions } from './NanoGPTModel';
|
|
2
2
|
import { ITokeniser } from './tokeniser/type';
|
|
3
3
|
import { default as EE } from 'eventemitter3';
|
|
4
|
-
export interface IGenerateOptions {
|
|
4
|
+
export interface IGenerateOptions extends GenerateOptions {
|
|
5
5
|
maxLength?: number;
|
|
6
|
-
temperature?: number;
|
|
7
|
-
topK?: number;
|
|
8
|
-
usePadding?: boolean;
|
|
9
|
-
includeAttention?: boolean;
|
|
10
|
-
includeProbabilities?: boolean;
|
|
11
6
|
}
|
|
12
7
|
export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
|
|
13
8
|
private readonly model;
|
|
14
9
|
private readonly tokeniser;
|
|
15
10
|
constructor(model: NanoGPT, tokeniser: ITokeniser);
|
|
16
|
-
private
|
|
11
|
+
private tokenisePrompt;
|
|
12
|
+
private generateNoCache;
|
|
13
|
+
private processResponse;
|
|
14
|
+
private generateCache;
|
|
17
15
|
generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
|
|
18
16
|
}
|
package/dist/Generator.js
CHANGED
|
@@ -1,62 +1,65 @@
|
|
|
1
|
-
import { E as
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
super(), this.model = a, this.tokeniser = t;
|
|
1
|
+
import { E as u } from "./index-SOhdqzHq.js";
|
|
2
|
+
class k extends u {
|
|
3
|
+
constructor(s, e) {
|
|
4
|
+
super(), this.model = s, this.tokeniser = e;
|
|
6
5
|
}
|
|
7
|
-
|
|
8
|
-
const
|
|
9
|
-
|
|
10
|
-
|
|
6
|
+
async tokenisePrompt(s) {
|
|
7
|
+
const e = s ? await this.tokeniser.tokenise([s], !0) : [[this.tokeniser.eosToken]];
|
|
8
|
+
return this.model.tf.tensor2d(e, [1, e[0].length], "int32");
|
|
9
|
+
}
|
|
10
|
+
async generateNoCache(s, e) {
|
|
11
|
+
let t = await this.tokenisePrompt(s), n = s || "";
|
|
12
|
+
const a = e?.maxLength ?? 1e3;
|
|
13
|
+
for (let i = 0; i < a; i++) {
|
|
11
14
|
const {
|
|
12
|
-
output:
|
|
13
|
-
attention:
|
|
14
|
-
probabilities:
|
|
15
|
-
} = this.model.generate(
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
}), p = i;
|
|
22
|
-
if (i = this.model.tf.concat([i, u], 1), n && l) {
|
|
23
|
-
const o = n;
|
|
24
|
-
n = this.model.tf.concat([n, l], 0), o.dispose();
|
|
25
|
-
} else l && (n = l);
|
|
26
|
-
if (s && r) {
|
|
27
|
-
const o = s;
|
|
28
|
-
s = this.model.tf.concat([s, r], 0), o.dispose();
|
|
29
|
-
} else r && (s = r);
|
|
30
|
-
p.dispose(), u.dispose();
|
|
15
|
+
output: o,
|
|
16
|
+
attention: c,
|
|
17
|
+
probabilities: l
|
|
18
|
+
} = this.model.generate(t, void 0, e), h = t;
|
|
19
|
+
t = this.model.tf.concat([t, o], 1), h.dispose();
|
|
20
|
+
const r = await this.processResponse(o, c, l);
|
|
21
|
+
if (o.dispose(), r === null)
|
|
22
|
+
break;
|
|
23
|
+
n += r;
|
|
31
24
|
}
|
|
32
|
-
return
|
|
25
|
+
return t.dispose(), n;
|
|
33
26
|
}
|
|
34
|
-
async
|
|
35
|
-
const
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
27
|
+
async processResponse(s, e, t) {
|
|
28
|
+
const n = (await s.array())[0][0];
|
|
29
|
+
if (n === this.tokeniser.eosToken)
|
|
30
|
+
return null;
|
|
31
|
+
const a = await this.tokeniser.decode([n]);
|
|
32
|
+
let i;
|
|
33
|
+
e && (i = await e.array(), e.dispose());
|
|
34
|
+
let o;
|
|
35
|
+
return t && (o = await t.array(), t.dispose()), this.emit("tokens", [n], a, i, o), a;
|
|
36
|
+
}
|
|
37
|
+
async generateCache(s, e) {
|
|
38
|
+
let t = await this.tokenisePrompt(s), n = s || "";
|
|
39
|
+
const a = new Array(this.model.config.nLayer).fill(void 0), i = e?.maxLength ?? 1e3;
|
|
40
|
+
for (let o = 0; o < i; o++) {
|
|
41
|
+
const {
|
|
42
|
+
output: c,
|
|
43
|
+
attention: l,
|
|
44
|
+
probabilities: h
|
|
45
|
+
} = this.model.generate(t, a, {
|
|
46
|
+
...e,
|
|
47
|
+
usePadding: !1
|
|
48
|
+
});
|
|
49
|
+
t.dispose(), t = c;
|
|
50
|
+
const r = await this.processResponse(c, l, h);
|
|
51
|
+
if (r === null)
|
|
55
52
|
break;
|
|
53
|
+
n += r;
|
|
56
54
|
}
|
|
57
|
-
return
|
|
55
|
+
return t.dispose(), n;
|
|
56
|
+
}
|
|
57
|
+
async generate(s, e) {
|
|
58
|
+
this.emit("start");
|
|
59
|
+
const t = this.model.config.useRope ? this.generateCache(s, e) : this.generateNoCache(s, e);
|
|
60
|
+
return this.emit("stop"), t;
|
|
58
61
|
}
|
|
59
62
|
}
|
|
60
63
|
export {
|
|
61
|
-
|
|
64
|
+
k as default
|
|
62
65
|
};
|
package/dist/NanoGPTModel.d.ts
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import { default as TF } from '@tensorflow/tfjs';
|
|
2
2
|
import { GPTConfig } from './config';
|
|
3
|
+
import { KVCache } from './layers/CausalSelfAttention';
|
|
3
4
|
export interface TrainingLogEntry {
|
|
4
5
|
loss: number;
|
|
5
6
|
valLoss?: number;
|
|
@@ -18,10 +19,11 @@ export interface GenerateOptions {
|
|
|
18
19
|
export default class NanoGPT {
|
|
19
20
|
readonly config: GPTConfig;
|
|
20
21
|
private wte;
|
|
21
|
-
private wpe
|
|
22
|
+
private wpe?;
|
|
22
23
|
private drop;
|
|
23
24
|
private blocks;
|
|
24
25
|
private lnF;
|
|
26
|
+
private ropeCache?;
|
|
25
27
|
readonly tf: typeof TF;
|
|
26
28
|
log: TrainingLogEntry[];
|
|
27
29
|
constructor(tf: typeof TF, config?: Partial<GPTConfig>);
|
|
@@ -35,12 +37,12 @@ export default class NanoGPT {
|
|
|
35
37
|
private validateInput;
|
|
36
38
|
private calculateLoss;
|
|
37
39
|
private computeAttentionRollout;
|
|
38
|
-
forward(idx: TF.Tensor, targets?: TF.Tensor, training?: boolean, includeAttention?: boolean): {
|
|
40
|
+
forward(idx: TF.Tensor, targets?: TF.Tensor, training?: boolean, includeAttention?: boolean, cache?: (KVCache | undefined)[]): {
|
|
39
41
|
logits: TF.Tensor;
|
|
40
42
|
loss?: TF.Tensor;
|
|
41
43
|
attention?: TF.Tensor;
|
|
42
44
|
};
|
|
43
|
-
generate(idx: TF.Tensor, options?: GenerateOptions): {
|
|
45
|
+
generate(idx: TF.Tensor, cache?: (KVCache | undefined)[], options?: GenerateOptions): {
|
|
44
46
|
output: TF.Tensor;
|
|
45
47
|
attention?: TF.Tensor;
|
|
46
48
|
probabilities?: TF.Tensor;
|
package/dist/NanoGPTModel.js
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
import { defaultConfig as
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
|
|
1
|
+
import { defaultConfig as v } from "./config.js";
|
|
2
|
+
import S from "./layers/TransformerBlock.js";
|
|
3
|
+
import _ from "./layers/TiedEmbedding.js";
|
|
4
|
+
import L from "./layers/RoPECache.js";
|
|
5
|
+
import I from "./layers/RMSNorm.js";
|
|
6
|
+
class F {
|
|
6
7
|
config;
|
|
7
8
|
wte;
|
|
8
9
|
// Token embeddings
|
|
@@ -13,27 +14,28 @@ class $ {
|
|
|
13
14
|
blocks;
|
|
14
15
|
lnF;
|
|
15
16
|
// Final layer norm
|
|
17
|
+
ropeCache;
|
|
16
18
|
tf;
|
|
17
19
|
log = [];
|
|
18
20
|
// Training log
|
|
19
21
|
constructor(t, e = {}) {
|
|
20
|
-
this.tf = t, this.config = { ...
|
|
22
|
+
this.tf = t, this.config = { ...v, ...e }, this.wte = new _(t, {
|
|
21
23
|
vocabSize: this.config.vocabSize,
|
|
22
24
|
embedDim: this.config.nEmbed,
|
|
23
25
|
name: "token_embedding"
|
|
24
|
-
}), this.wpe = this.tf.layers.embedding({
|
|
26
|
+
}), this.config.useRope === !1 ? this.wpe = this.tf.layers.embedding({
|
|
25
27
|
inputDim: this.config.blockSize,
|
|
26
28
|
outputDim: this.config.nEmbed,
|
|
27
29
|
name: "positional_embedding",
|
|
28
30
|
embeddingsInitializer: this.tf.initializers.randomNormal({ mean: 0, stddev: 0.02 })
|
|
29
|
-
}), this.drop = this.tf.layers.dropout({ rate: this.config.dropout }), this.blocks = [];
|
|
31
|
+
}) : this.ropeCache = new L(t, this.config), this.drop = this.tf.layers.dropout({ rate: this.config.dropout }), this.blocks = [];
|
|
30
32
|
for (let s = 0; s < this.config.nLayer; s++)
|
|
31
|
-
this.blocks.push(new
|
|
32
|
-
this.lnF = new
|
|
33
|
+
this.blocks.push(new S(this.tf, s, this.config, this.ropeCache));
|
|
34
|
+
this.lnF = new I(t, [this.config.nEmbed], 1e-8, "final_rms_norm");
|
|
33
35
|
}
|
|
34
36
|
get variables() {
|
|
35
37
|
return [
|
|
36
|
-
|
|
38
|
+
//...this.wpe.trainableWeights.map((v) => v.read() as TF.Variable),
|
|
37
39
|
...this.blocks.flatMap((t) => t.variables),
|
|
38
40
|
...this.lnF.trainableWeights.map((t) => t),
|
|
39
41
|
...this.wte.variables
|
|
@@ -41,21 +43,28 @@ class $ {
|
|
|
41
43
|
}
|
|
42
44
|
saveWeights() {
|
|
43
45
|
const t = /* @__PURE__ */ new Map();
|
|
44
|
-
t.set("token_embedding", this.wte.getWeights()), t.set("positional_embedding", this.wpe.getWeights());
|
|
46
|
+
t.set("token_embedding", this.wte.getWeights()), this.wpe && t.set("positional_embedding", this.wpe.getWeights());
|
|
45
47
|
for (let e = 0; e < this.blocks.length; e++)
|
|
46
48
|
this.blocks[e].saveWeights(t);
|
|
47
|
-
return t.set("
|
|
49
|
+
return t.set("final_rms_norm", this.lnF.getWeights()), t;
|
|
48
50
|
}
|
|
49
51
|
loadWeights(t) {
|
|
50
|
-
this.wte.setWeights(t.get("token_embedding") || []), this.wpe.setWeights(t.get("positional_embedding") || []);
|
|
52
|
+
this.wte.setWeights(t.get("token_embedding") || []), this.wpe && this.wpe.setWeights(t.get("positional_embedding") || []);
|
|
51
53
|
for (let e = 0; e < this.blocks.length; e++)
|
|
52
54
|
this.blocks[e].loadWeights(t);
|
|
53
|
-
this.lnF.setWeights(t.get("
|
|
55
|
+
this.lnF.setWeights(t.get("final_rms_norm") || []);
|
|
54
56
|
}
|
|
55
|
-
inputPhase(t, e = !1) {
|
|
57
|
+
inputPhase(t, e, s = !1) {
|
|
56
58
|
return this.tf.tidy(() => {
|
|
57
|
-
const
|
|
58
|
-
|
|
59
|
+
const o = this.wte.embed(t);
|
|
60
|
+
if (this.config.useRope === !1) {
|
|
61
|
+
const [, i] = t.shape, a = this.config.blockSize, n = this.tf.range(0, i, 1, "int32"), h = this.tf.mod(
|
|
62
|
+
this.tf.add(n, this.tf.scalar(e, "int32")),
|
|
63
|
+
this.tf.scalar(a, "int32")
|
|
64
|
+
), c = this.wpe.apply(h), r = o.add(c);
|
|
65
|
+
return this.drop.apply(r, { training: s });
|
|
66
|
+
} else
|
|
67
|
+
return this.drop.apply(o, { training: s });
|
|
59
68
|
});
|
|
60
69
|
}
|
|
61
70
|
setSkipMask(t) {
|
|
@@ -73,7 +82,7 @@ class $ {
|
|
|
73
82
|
set trainable(t) {
|
|
74
83
|
for (const e of this.blocks)
|
|
75
84
|
e.trainable = t;
|
|
76
|
-
this.
|
|
85
|
+
this.lnF.trainable = t;
|
|
77
86
|
}
|
|
78
87
|
validateInput(t) {
|
|
79
88
|
if (t.shape.length !== 2)
|
|
@@ -96,60 +105,67 @@ class $ {
|
|
|
96
105
|
return this.tf.tidy(() => {
|
|
97
106
|
if (t.length === 0)
|
|
98
107
|
throw new Error("No attentions for rollout");
|
|
99
|
-
const e = t[0].shape[0], s = t[0].shape[1],
|
|
100
|
-
let
|
|
101
|
-
for (const
|
|
102
|
-
let
|
|
103
|
-
|
|
108
|
+
const e = t[0].shape[0], s = t[0].shape[1], o = this.tf.eye(s, s).expandDims(0);
|
|
109
|
+
let i = o.tile([e, 1, 1]);
|
|
110
|
+
for (const a of t) {
|
|
111
|
+
let n = a.add(o);
|
|
112
|
+
n = n.div(n.sum(-1, !0)), i = n.matMul(i);
|
|
104
113
|
}
|
|
105
|
-
return
|
|
114
|
+
return i;
|
|
106
115
|
});
|
|
107
116
|
}
|
|
108
|
-
forward(t, e, s = !1,
|
|
117
|
+
forward(t, e, s = !1, o = !1, i) {
|
|
109
118
|
return this.validateInput(t), this.tf.tidy(() => {
|
|
110
|
-
|
|
119
|
+
const a = i?.[0]?.length ?? 0;
|
|
120
|
+
let n = this.inputPhase(t, a, s);
|
|
111
121
|
const h = [];
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
122
|
+
if (i && i.length !== this.blocks.length)
|
|
123
|
+
throw console.error("Cache", i), new Error(`Cache length ${i.length} does not match number of blocks ${this.blocks.length}`);
|
|
124
|
+
for (let l = 0; l < this.blocks.length; l++) {
|
|
125
|
+
const d = this.blocks[l], {
|
|
126
|
+
output: g,
|
|
127
|
+
attention: b,
|
|
128
|
+
cache: p
|
|
129
|
+
} = d.call(n, s, o, i ? i[l] : void 0);
|
|
130
|
+
n = g, o && b && h.push(b), i && p ? (i[l]?.k.dispose(), i[l]?.v.dispose(), i[l] = p) : p && (p.k.dispose(), p.v.dispose());
|
|
115
131
|
}
|
|
116
|
-
let
|
|
117
|
-
|
|
118
|
-
const
|
|
119
|
-
let
|
|
120
|
-
return e && (
|
|
132
|
+
let c;
|
|
133
|
+
o && h.length > 0 && (c = this.computeAttentionRollout(h)), n = this.lnF.apply(n);
|
|
134
|
+
const r = this.wte.project(n);
|
|
135
|
+
let f;
|
|
136
|
+
return e && (f = this.calculateLoss(r, e)), { logits: r, loss: f, attention: o ? c : void 0 };
|
|
121
137
|
});
|
|
122
138
|
}
|
|
123
|
-
generate(t, e) {
|
|
124
|
-
const
|
|
139
|
+
generate(t, e, s) {
|
|
140
|
+
const o = s?.temperature ?? 1, i = s?.topK, a = s?.usePadding ?? !1, n = s?.includeAttention ?? !1;
|
|
125
141
|
return this.tf.tidy(() => {
|
|
126
|
-
const
|
|
127
|
-
[0,
|
|
128
|
-
[
|
|
129
|
-
),
|
|
142
|
+
const h = t, c = h.shape[1], r = c <= this.config.blockSize ? h : h.slice(
|
|
143
|
+
[0, c - this.config.blockSize],
|
|
144
|
+
[h.shape[0], this.config.blockSize]
|
|
145
|
+
), f = a ? this.config.blockSize - r.shape[1] : 0, l = f > 0 ? this.tf.pad(r, [
|
|
130
146
|
[0, 0],
|
|
131
|
-
[0,
|
|
132
|
-
]) : r, { logits:
|
|
133
|
-
let
|
|
147
|
+
[0, f]
|
|
148
|
+
]) : r, { logits: d, attention: g } = this.forward(l, void 0, !1, n, e), b = d.shape[1] - 1 - f, p = d.slice([0, b, 0], [d.shape[0], 1, d.shape[2]]), w = g ? g.slice([0, b, 0], [g.shape[0], 1, g.shape[2]]) : void 0, u = p.div(o);
|
|
149
|
+
let m;
|
|
134
150
|
if (i) {
|
|
135
|
-
const { values:
|
|
136
|
-
|
|
151
|
+
const { values: E, indices: y } = this.tf.topk(u, i), z = this.tf.multinomial(E.squeeze([1]), 1);
|
|
152
|
+
m = this.tf.gather(y.squeeze([1]), z, 1);
|
|
137
153
|
} else
|
|
138
|
-
|
|
139
|
-
let
|
|
140
|
-
return
|
|
154
|
+
m = this.tf.multinomial(u.squeeze([1]), 1);
|
|
155
|
+
let k;
|
|
156
|
+
return s?.includeProbabilities && (k = this.tf.softmax(u.squeeze([1]))), m = m.reshape([1, 1]), { output: m, attention: w?.squeeze([1]), probabilities: k };
|
|
141
157
|
});
|
|
142
158
|
}
|
|
143
159
|
getNumParams() {
|
|
144
160
|
const t = this.config.vocabSize * this.config.nEmbed + this.config.blockSize * this.config.nEmbed, e = this.config.nLayer * (4 * this.config.nEmbed * this.config.nEmbed + // qkv + proj
|
|
145
161
|
2 * this.config.nEmbed), s = this.config.nLayer * (4 * this.config.nEmbed * this.config.nEmbed + // fc
|
|
146
|
-
this.config.nEmbed * 4 * this.config.nEmbed),
|
|
147
|
-
return t + e + s +
|
|
162
|
+
this.config.nEmbed * 4 * this.config.nEmbed), o = this.config.nEmbed + this.config.vocabSize * this.config.nEmbed;
|
|
163
|
+
return t + e + s + o;
|
|
148
164
|
}
|
|
149
165
|
dispose() {
|
|
150
|
-
this.wte.dispose(), this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
|
|
166
|
+
this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
|
|
151
167
|
}
|
|
152
168
|
}
|
|
153
169
|
export {
|
|
154
|
-
|
|
170
|
+
F as default
|
|
155
171
|
};
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -47,25 +47,25 @@ class a extends c {
|
|
|
47
47
|
}
|
|
48
48
|
static loadModel(t, r) {
|
|
49
49
|
const e = new a(t);
|
|
50
|
-
return l(t, r).then(({ model:
|
|
51
|
-
e._model =
|
|
50
|
+
return l(t, r).then(({ model: s, tokeniser: o }) => {
|
|
51
|
+
e._model = s, e._tokeniser = o, e._config = s.config, e.setStatus("warmup"), h(s).then(() => {
|
|
52
52
|
e.setStatus("ready");
|
|
53
|
-
}).catch((
|
|
54
|
-
e.setStatus("error"), e.emit("error",
|
|
53
|
+
}).catch((i) => {
|
|
54
|
+
e.setStatus("error"), e.emit("error", i);
|
|
55
55
|
});
|
|
56
|
-
}).catch((
|
|
57
|
-
e.setStatus("error"), e.emit("error",
|
|
56
|
+
}).catch((s) => {
|
|
57
|
+
e.setStatus("error"), e.emit("error", s);
|
|
58
58
|
}), e;
|
|
59
59
|
}
|
|
60
60
|
static create(t, r = {}) {
|
|
61
|
-
const e = { ...u, ...r },
|
|
62
|
-
return
|
|
63
|
-
|
|
64
|
-
n === "trained" &&
|
|
65
|
-
});
|
|
61
|
+
const e = { ...u, ...r }, s = new g(e.vocabSize), o = new d(t, e), i = new a(t, s, o);
|
|
62
|
+
return i.setStatus("warmup"), h(o).then(() => {
|
|
63
|
+
i.tokeniser.trained ? i.setStatus("ready") : (i.setStatus("awaitingTokens"), i.tokeniser.once("trainStatus", (n) => {
|
|
64
|
+
n === "trained" && i.setStatus("ready");
|
|
65
|
+
}));
|
|
66
66
|
}).catch((n) => {
|
|
67
|
-
|
|
68
|
-
}),
|
|
67
|
+
i.setStatus("error"), i.emit("error", n);
|
|
68
|
+
}), i;
|
|
69
69
|
}
|
|
70
70
|
getNumParams() {
|
|
71
71
|
if (!this._model)
|
|
@@ -78,8 +78,8 @@ class a extends c {
|
|
|
78
78
|
const t = new _(this._model, this._tokeniser);
|
|
79
79
|
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (r) => {
|
|
80
80
|
const e = this.listeners("trainStep");
|
|
81
|
-
for (const
|
|
82
|
-
await
|
|
81
|
+
for (const s of e)
|
|
82
|
+
await s(r);
|
|
83
83
|
}), t;
|
|
84
84
|
}
|
|
85
85
|
train(t, r) {
|
package/dist/Trainer.d.ts
CHANGED
|
@@ -12,7 +12,9 @@ export interface ITrainerOptions {
|
|
|
12
12
|
}
|
|
13
13
|
export default class Trainer extends EE<'start' | 'stop' | 'log'> {
|
|
14
14
|
private trainer;
|
|
15
|
+
private hasTrained;
|
|
15
16
|
constructor(model: NanoGPT, tokeniser: ITokeniser);
|
|
16
17
|
stop(): void;
|
|
18
|
+
reset(): void;
|
|
17
19
|
train(text: string[], options?: ITrainerOptions): Promise<void>;
|
|
18
20
|
}
|
package/dist/Trainer.js
CHANGED
|
@@ -1,11 +1,16 @@
|
|
|
1
1
|
import { E as l } from "./index-SOhdqzHq.js";
|
|
2
|
-
import
|
|
3
|
-
class
|
|
2
|
+
import h from "./training/FullTrainer.js";
|
|
3
|
+
class m extends l {
|
|
4
4
|
trainer;
|
|
5
|
+
hasTrained = !1;
|
|
5
6
|
constructor(a, t) {
|
|
6
|
-
super(), this.trainer = new
|
|
7
|
+
super(), this.trainer = new h(a.tf, a, t, 1e-3);
|
|
7
8
|
}
|
|
8
9
|
stop() {
|
|
10
|
+
this.trainer.stop();
|
|
11
|
+
}
|
|
12
|
+
reset() {
|
|
13
|
+
this.hasTrained = !1, this.trainer.reset();
|
|
9
14
|
}
|
|
10
15
|
async train(a, t) {
|
|
11
16
|
const { trainDataset: e, validationDataset: r } = await this.trainer.createTrainValidationSplit(
|
|
@@ -13,7 +18,7 @@ class d extends l {
|
|
|
13
18
|
t?.batchSize || 32,
|
|
14
19
|
t?.validationSplit || 0.1
|
|
15
20
|
);
|
|
16
|
-
this.trainer.setLearningRate(t?.learningRate || 1e-3), this.emit("start"), await this.trainer.trainOnDataset(
|
|
21
|
+
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), await this.trainer.trainOnDataset(
|
|
17
22
|
e,
|
|
18
23
|
{
|
|
19
24
|
prompt: t?.prompt,
|
|
@@ -31,5 +36,5 @@ class d extends l {
|
|
|
31
36
|
}
|
|
32
37
|
}
|
|
33
38
|
export {
|
|
34
|
-
|
|
39
|
+
m as default
|
|
35
40
|
};
|
package/dist/config.d.ts
CHANGED
package/dist/config.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const
|
|
1
|
+
const e = {
|
|
2
2
|
vocabSize: 50304,
|
|
3
3
|
// GPT-2 vocab size
|
|
4
4
|
blockSize: 1024,
|
|
@@ -13,8 +13,10 @@ const a = {
|
|
|
13
13
|
// Dropout probability
|
|
14
14
|
biasInLinear: !1,
|
|
15
15
|
biasInLayerNorm: !1,
|
|
16
|
-
mlpFactor: 4
|
|
16
|
+
mlpFactor: 4,
|
|
17
|
+
useRope: !1
|
|
18
|
+
// Use Rotary Position Embeddings
|
|
17
19
|
};
|
|
18
20
|
export {
|
|
19
|
-
|
|
21
|
+
e as defaultConfig
|
|
20
22
|
};
|
|
@@ -1,6 +1,14 @@
|
|
|
1
1
|
import { default as TF } from '@tensorflow/tfjs';
|
|
2
2
|
import { GPTConfig } from '../config';
|
|
3
|
+
import { default as RoPECache } from './RoPECache';
|
|
4
|
+
export type KVCache = {
|
|
5
|
+
k: TF.Tensor;
|
|
6
|
+
v: TF.Tensor;
|
|
7
|
+
length: number;
|
|
8
|
+
cumulativeLength: number;
|
|
9
|
+
};
|
|
3
10
|
export default class CausalSelfAttention {
|
|
11
|
+
private readonly ropeCache?;
|
|
4
12
|
private config;
|
|
5
13
|
private cAttn;
|
|
6
14
|
private cProj;
|
|
@@ -12,18 +20,20 @@ export default class CausalSelfAttention {
|
|
|
12
20
|
private divisor;
|
|
13
21
|
private index;
|
|
14
22
|
private _trainable;
|
|
15
|
-
constructor(tf: typeof TF, index: number, config: GPTConfig);
|
|
23
|
+
constructor(tf: typeof TF, index: number, config: GPTConfig, ropeCache?: RoPECache | undefined);
|
|
16
24
|
get variables(): TF.Variable[];
|
|
17
25
|
get trainable(): boolean;
|
|
18
26
|
set trainable(value: boolean);
|
|
19
27
|
saveWeights(map: Map<string, TF.Tensor[]>): void;
|
|
20
28
|
loadWeights(weights: Map<string, TF.Tensor[]>): void;
|
|
21
29
|
private getAttentionScores;
|
|
30
|
+
private getAttentionScoresWithPast;
|
|
22
31
|
private getQKV;
|
|
23
32
|
private getOutputProjection;
|
|
24
|
-
call(x: TF.Tensor, training?: boolean, includeAttention?: boolean): {
|
|
33
|
+
call(x: TF.Tensor, training?: boolean, includeAttention?: boolean, pastKV?: KVCache): {
|
|
25
34
|
output: TF.Tensor;
|
|
26
35
|
attention?: TF.Tensor;
|
|
36
|
+
presentKV?: KVCache;
|
|
27
37
|
};
|
|
28
38
|
dispose(): void;
|
|
29
39
|
}
|
|
@@ -1,20 +1,9 @@
|
|
|
1
|
-
class
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
cProj;
|
|
5
|
-
attnDropout;
|
|
6
|
-
residDropout;
|
|
7
|
-
bias;
|
|
8
|
-
maskInf;
|
|
9
|
-
tf;
|
|
10
|
-
divisor;
|
|
11
|
-
index;
|
|
12
|
-
_trainable = !0;
|
|
13
|
-
constructor(t, e, s) {
|
|
14
|
-
this.config = s, this.tf = t, this.index = e, this.cAttn = this.tf.layers.dense({
|
|
1
|
+
class S {
|
|
2
|
+
constructor(t, i, s, e) {
|
|
3
|
+
this.ropeCache = e, this.config = s, this.tf = t, this.index = i, this.cAttn = this.tf.layers.dense({
|
|
15
4
|
units: 3 * s.nEmbed,
|
|
16
5
|
useBias: s.biasInLinear,
|
|
17
|
-
name: `block_${
|
|
6
|
+
name: `block_${i}_attn_cAttn`,
|
|
18
7
|
kernelInitializer: this.tf.initializers.randomNormal({
|
|
19
8
|
mean: 0,
|
|
20
9
|
stddev: 0.02
|
|
@@ -23,14 +12,27 @@ class m {
|
|
|
23
12
|
}), this.cProj = this.tf.layers.dense({
|
|
24
13
|
units: s.nEmbed,
|
|
25
14
|
useBias: s.biasInLinear,
|
|
26
|
-
name: `block_${
|
|
15
|
+
name: `block_${i}_attn_cProj`,
|
|
27
16
|
kernelInitializer: this.tf.initializers.randomNormal({
|
|
28
17
|
mean: 0,
|
|
29
18
|
stddev: 0.02 / Math.sqrt(2 * s.nLayer)
|
|
30
19
|
}),
|
|
31
20
|
biasInitializer: "zeros"
|
|
32
|
-
}), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = this.tf.scalar(1 / Math.sqrt(s.nEmbed / s.nHead))
|
|
21
|
+
}), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = this.tf.scalar(1 / Math.sqrt(s.nEmbed / s.nHead));
|
|
22
|
+
const a = this.tf.zeros([s.blockSize, s.blockSize]), h = this.tf.fill([s.blockSize, s.blockSize], Number.NEGATIVE_INFINITY);
|
|
23
|
+
this.maskInf = this.tf.where(this.bias, a, h);
|
|
33
24
|
}
|
|
25
|
+
config;
|
|
26
|
+
cAttn;
|
|
27
|
+
cProj;
|
|
28
|
+
attnDropout;
|
|
29
|
+
residDropout;
|
|
30
|
+
bias;
|
|
31
|
+
maskInf;
|
|
32
|
+
tf;
|
|
33
|
+
divisor;
|
|
34
|
+
index;
|
|
35
|
+
_trainable = !0;
|
|
34
36
|
get variables() {
|
|
35
37
|
return [
|
|
36
38
|
...this.cAttn.trainableWeights.map((t) => t.read()),
|
|
@@ -49,34 +51,65 @@ class m {
|
|
|
49
51
|
loadWeights(t) {
|
|
50
52
|
this.cAttn.setWeights(t.get(`block_${this.index}_cAttn`) || []), this.cProj.setWeights(t.get(`block_${this.index}_cProj`) || []);
|
|
51
53
|
}
|
|
52
|
-
getAttentionScores(t,
|
|
53
|
-
const
|
|
54
|
-
return this.attnDropout.apply(
|
|
54
|
+
getAttentionScores(t, i, s) {
|
|
55
|
+
const e = t.shape[2], h = this.tf.matMul(t, i, !1, !0).mul(this.divisor), n = this.maskInf.slice([0, 0], [e, e]).expandDims(0).expandDims(0), r = h.add(n), o = this.tf.softmax(r, -1);
|
|
56
|
+
return this.attnDropout.apply(o, { training: s });
|
|
57
|
+
}
|
|
58
|
+
// Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
|
|
59
|
+
getAttentionScoresWithPast(t, i, s, e) {
|
|
60
|
+
const a = t.shape[2];
|
|
61
|
+
let n = this.tf.matMul(t, i, !1, !0).mul(this.divisor);
|
|
62
|
+
if (a > 1 && e > 0)
|
|
63
|
+
throw new Error("Cannot use past with T_cur > 1");
|
|
64
|
+
if (a > 1) {
|
|
65
|
+
const o = this.maskInf.slice([0, 0], [a, a]).expandDims(0).expandDims(0);
|
|
66
|
+
n = n.add(o);
|
|
67
|
+
}
|
|
68
|
+
const r = this.tf.softmax(n, -1);
|
|
69
|
+
return this.attnDropout.apply(r, { training: s });
|
|
55
70
|
}
|
|
56
71
|
getQKV(t) {
|
|
57
|
-
const [
|
|
58
|
-
|
|
59
|
-
const
|
|
60
|
-
|
|
61
|
-
const
|
|
62
|
-
|
|
63
|
-
const d = this.tf.reshape(
|
|
64
|
-
i.dispose();
|
|
65
|
-
const u = d.transpose([0, 2, 1, 3]);
|
|
66
|
-
d.dispose();
|
|
67
|
-
const p = this.tf.reshape(n, [e, s, this.config.nHead, h]);
|
|
72
|
+
const [i, s, e] = t.shape, a = this.cAttn.apply(t), [h, n, r] = this.tf.split(a, 3, -1);
|
|
73
|
+
a.dispose();
|
|
74
|
+
const o = e / this.config.nHead, u = this.tf.reshape(h, [i, s, this.config.nHead, o]);
|
|
75
|
+
h.dispose();
|
|
76
|
+
const f = u.transpose([0, 2, 1, 3]);
|
|
77
|
+
u.dispose();
|
|
78
|
+
const d = this.tf.reshape(n, [i, s, this.config.nHead, o]);
|
|
68
79
|
n.dispose();
|
|
69
|
-
const
|
|
70
|
-
|
|
80
|
+
const c = d.transpose([0, 2, 1, 3]);
|
|
81
|
+
d.dispose();
|
|
82
|
+
const l = this.tf.reshape(r, [i, s, this.config.nHead, o]);
|
|
83
|
+
r.dispose();
|
|
84
|
+
const p = l.transpose([0, 2, 1, 3]);
|
|
85
|
+
return l.dispose(), [f, c, p];
|
|
71
86
|
}
|
|
72
|
-
getOutputProjection(t,
|
|
73
|
-
const s = t.shape[0],
|
|
74
|
-
return this.residDropout.apply(
|
|
87
|
+
getOutputProjection(t, i) {
|
|
88
|
+
const s = t.shape[0], e = t.shape[2], a = this.config.nEmbed, h = t.transpose([0, 2, 1, 3]), n = this.tf.reshape(h, [s, e, a]), r = this.cProj.apply(n);
|
|
89
|
+
return this.residDropout.apply(r, { training: i });
|
|
75
90
|
}
|
|
76
|
-
|
|
91
|
+
// Added optional KV cache support (pastKV). Returns presentKV for chaining.
|
|
92
|
+
call(t, i = !1, s = !1, e) {
|
|
93
|
+
if (e && !this.config.useRope)
|
|
94
|
+
throw new Error("Cannot use pastKV without RoPE enabled");
|
|
77
95
|
return this.tf.tidy(() => {
|
|
78
|
-
const [a,
|
|
79
|
-
|
|
96
|
+
const [a, h, n] = this.getQKV(t), r = a.shape[2], o = this.config.blockSize, u = e ? e.cumulativeLength : 0, [f, d] = this.ropeCache ? this.ropeCache.applyRoPE(a, h, u) : [a, h];
|
|
97
|
+
let c = d, l = n, p = 0;
|
|
98
|
+
e && (p = e.length, c = this.tf.concat([e.k, d], 2), l = this.tf.concat([e.v, n], 2));
|
|
99
|
+
const b = c.shape[2];
|
|
100
|
+
if (b > o) {
|
|
101
|
+
const k = b - o, g = c.shape[0], v = c.shape[1], I = c.shape[3];
|
|
102
|
+
c = c.slice([0, 0, k, 0], [g, v, o, I]), l = l.slice([0, 0, k, 0], [g, v, o, I]), p = o - r;
|
|
103
|
+
}
|
|
104
|
+
let m;
|
|
105
|
+
p > 0 ? m = this.getAttentionScoresWithPast(f, c, i, p) : m = this.getAttentionScores(f, c, i);
|
|
106
|
+
const _ = this.tf.matMul(m, l), A = this.getOutputProjection(_, i), P = {
|
|
107
|
+
k: this.tf.keep(c),
|
|
108
|
+
v: this.tf.keep(l),
|
|
109
|
+
length: p + r,
|
|
110
|
+
cumulativeLength: e ? e.cumulativeLength + r : r
|
|
111
|
+
};
|
|
112
|
+
return { output: A, attention: s ? m.mean(1) : void 0, presentKV: P };
|
|
80
113
|
});
|
|
81
114
|
}
|
|
82
115
|
dispose() {
|
|
@@ -84,5 +117,5 @@ class m {
|
|
|
84
117
|
}
|
|
85
118
|
}
|
|
86
119
|
export {
|
|
87
|
-
|
|
120
|
+
S as default
|
|
88
121
|
};
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { default as TF } from '@tensorflow/tfjs';
|
|
2
|
+
export default class RMSNorm {
|
|
3
|
+
private gamma;
|
|
4
|
+
private epsilon;
|
|
5
|
+
private tf;
|
|
6
|
+
constructor(tf: typeof TF, shape: number[], epsilon?: number, name?: string);
|
|
7
|
+
get trainableWeights(): TF.Variable[];
|
|
8
|
+
set trainable(value: boolean);
|
|
9
|
+
getWeights(): TF.Tensor[];
|
|
10
|
+
setWeights(weights: TF.Tensor[]): void;
|
|
11
|
+
apply(x: TF.Tensor): TF.Tensor;
|
|
12
|
+
dispose(): void;
|
|
13
|
+
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
class m {
|
|
2
|
+
gamma;
|
|
3
|
+
epsilon;
|
|
4
|
+
tf;
|
|
5
|
+
constructor(a, s, t = 1e-8, e = "") {
|
|
6
|
+
this.tf = a, this.epsilon = t, this.gamma = a.variable(a.ones(s), !0, `${e}_gamma`, "float32");
|
|
7
|
+
}
|
|
8
|
+
get trainableWeights() {
|
|
9
|
+
return [this.gamma];
|
|
10
|
+
}
|
|
11
|
+
set trainable(a) {
|
|
12
|
+
this.gamma.trainable = a;
|
|
13
|
+
}
|
|
14
|
+
getWeights() {
|
|
15
|
+
return [this.gamma];
|
|
16
|
+
}
|
|
17
|
+
setWeights(a) {
|
|
18
|
+
this.gamma.assign(a[0]);
|
|
19
|
+
}
|
|
20
|
+
apply(a) {
|
|
21
|
+
return this.tf.tidy(() => {
|
|
22
|
+
const t = a.square().mean(-1, !0).add(this.epsilon).rsqrt();
|
|
23
|
+
return a.mul(t).mul(this.gamma);
|
|
24
|
+
});
|
|
25
|
+
}
|
|
26
|
+
dispose() {
|
|
27
|
+
this.gamma.dispose();
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
export {
|
|
31
|
+
m as default
|
|
32
|
+
};
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import { default as TF } from '@tensorflow/tfjs';
|
|
2
|
+
import { GPTConfig } from '../config';
|
|
3
|
+
export default class RoPECache {
|
|
4
|
+
private readonly tf;
|
|
5
|
+
private readonly config;
|
|
6
|
+
private rotaryDim;
|
|
7
|
+
private ropeBase;
|
|
8
|
+
private ropeInvFreq;
|
|
9
|
+
private ropeCos;
|
|
10
|
+
private ropeSin;
|
|
11
|
+
private ropeCacheLen;
|
|
12
|
+
constructor(tf: typeof TF, config: GPTConfig);
|
|
13
|
+
private ensureRopeCache;
|
|
14
|
+
applyRoPE(q: TF.Tensor, k: TF.Tensor, pastLen: number): [TF.Tensor, TF.Tensor];
|
|
15
|
+
dispose(): void;
|
|
16
|
+
}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
class E {
|
|
2
|
+
constructor(t, c) {
|
|
3
|
+
this.tf = t, this.config = c;
|
|
4
|
+
const e = this.config.nEmbed / this.config.nHead;
|
|
5
|
+
if (this.rotaryDim = e, this.rotaryDim % 2 !== 0)
|
|
6
|
+
throw new Error("rotaryDim must be even");
|
|
7
|
+
this.ropeBase = 1e4;
|
|
8
|
+
const o = this.tf.range(0, this.rotaryDim, 2, "float32").div(this.tf.scalar(this.rotaryDim, "float32")), s = this.tf.pow(this.tf.scalar(this.ropeBase, "float32"), o);
|
|
9
|
+
this.ropeInvFreq = this.tf.reciprocal(s), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : this.ensureRopeCache(this.config.blockSize * 4);
|
|
10
|
+
}
|
|
11
|
+
rotaryDim;
|
|
12
|
+
ropeBase;
|
|
13
|
+
ropeInvFreq;
|
|
14
|
+
ropeCos = null;
|
|
15
|
+
// [cacheLen, rotaryDim/2]
|
|
16
|
+
ropeSin = null;
|
|
17
|
+
// [cacheLen, rotaryDim/2]
|
|
18
|
+
ropeCacheLen = 0;
|
|
19
|
+
ensureRopeCache(t) {
|
|
20
|
+
if (t <= this.ropeCacheLen) return;
|
|
21
|
+
this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose();
|
|
22
|
+
const e = this.tf.range(0, t, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
|
|
23
|
+
this.ropeCos = this.tf.keep(this.tf.cos(e).expandDims(-1)), this.ropeSin = this.tf.keep(this.tf.sin(e).expandDims(-1)), this.ropeCacheLen = t;
|
|
24
|
+
}
|
|
25
|
+
applyRoPE(t, c, e) {
|
|
26
|
+
const h = t.shape[3], o = this.rotaryDim;
|
|
27
|
+
if (o > h) return [t, c];
|
|
28
|
+
const s = t.shape[2], S = e + s;
|
|
29
|
+
this.ensureRopeCache(S);
|
|
30
|
+
const n = o / 2, g = this.ropeCos.slice([e, 0, 0], [s, n, 1]), v = this.ropeSin.slice([e, 0, 0], [s, n, 1]), l = g.reshape([1, 1, s, n, 1]), f = v.reshape([1, 1, s, n, 1]), p = this.tf.concat([t, c], 0), r = p.shape[0], i = p.shape[1], y = p.slice([0, 0, 0, 0], [r, i, s, o]), u = o < h ? p.slice([0, 0, 0, o], [r, i, s, h - o]) : null, d = y.reshape([r, i, s, n, 2]), m = d.slice([0, 0, 0, 0, 0], [r, i, s, n, 1]), C = d.slice([0, 0, 0, 0, 1], [r, i, s, n, 1]), B = m.mul(l).sub(C.mul(f)), b = C.mul(l).add(m.mul(f)), D = this.tf.concat([B, b], -1).reshape([r, i, s, o]), R = u ? this.tf.concat([D, u], 3) : D, a = r / 2, x = R.slice([0, 0, 0, 0], [a, i, s, h]), P = R.slice([a, 0, 0, 0], [a, i, s, h]);
|
|
31
|
+
return [x, P];
|
|
32
|
+
}
|
|
33
|
+
dispose() {
|
|
34
|
+
this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose(), this.ropeInvFreq.dispose();
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
export {
|
|
38
|
+
E as default
|
|
39
|
+
};
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import { default as TF } from '@tensorflow/tfjs';
|
|
2
2
|
import { GPTConfig } from '../config';
|
|
3
|
+
import { KVCache } from './CausalSelfAttention';
|
|
4
|
+
import { default as RoPECache } from './RoPECache';
|
|
3
5
|
export default class Block {
|
|
4
6
|
private ln1;
|
|
5
7
|
private attn;
|
|
@@ -9,16 +11,17 @@ export default class Block {
|
|
|
9
11
|
private index;
|
|
10
12
|
private _trainable;
|
|
11
13
|
skipped: boolean;
|
|
12
|
-
constructor(tf: typeof TF, index: number, config: GPTConfig);
|
|
14
|
+
constructor(tf: typeof TF, index: number, config: GPTConfig, ropeCache?: RoPECache);
|
|
13
15
|
get variables(): TF.Variable[];
|
|
14
16
|
get trainable(): boolean;
|
|
15
17
|
set trainable(value: boolean);
|
|
16
18
|
saveWeights(map: Map<string, TF.Tensor[]>): void;
|
|
17
19
|
loadWeights(weights: Map<string, TF.Tensor[]>): void;
|
|
18
20
|
private getMLPOutput;
|
|
19
|
-
call(x: TF.Tensor, training?: boolean, includeAttention?: boolean): {
|
|
21
|
+
call(x: TF.Tensor, training?: boolean, includeAttention?: boolean, cache?: KVCache): {
|
|
20
22
|
output: TF.Tensor;
|
|
21
23
|
attention?: TF.Tensor;
|
|
24
|
+
cache?: KVCache;
|
|
22
25
|
};
|
|
23
26
|
dispose(): void;
|
|
24
27
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
import
|
|
1
|
+
import r from "./CausalSelfAttention.js";
|
|
2
|
+
import o from "./MLP.js";
|
|
3
|
+
import a from "./RMSNorm.js";
|
|
4
4
|
class u {
|
|
5
5
|
ln1;
|
|
6
6
|
attn;
|
|
@@ -10,8 +10,8 @@ class u {
|
|
|
10
10
|
index;
|
|
11
11
|
_trainable = !0;
|
|
12
12
|
skipped = !1;
|
|
13
|
-
constructor(t, i, s) {
|
|
14
|
-
this.tf = t, this.index = i, this.ln1 = new
|
|
13
|
+
constructor(t, i, s, e) {
|
|
14
|
+
this.tf = t, this.index = i, this.ln1 = new a(t, [s.nEmbed], 1e-8, `block_${this.index}_rms1`), this.attn = new r(this.tf, this.index, s, e), this.ln2 = new a(t, [s.nEmbed], 1e-8, `block_${this.index}_rms2`), this.mlp = new o(this.tf, this.index, s);
|
|
15
15
|
}
|
|
16
16
|
get variables() {
|
|
17
17
|
return [
|
|
@@ -28,21 +28,25 @@ class u {
|
|
|
28
28
|
this._trainable = t, this.ln1.trainable = t, this.ln2.trainable = t, this.attn.trainable = t, this.mlp.trainable = t;
|
|
29
29
|
}
|
|
30
30
|
saveWeights(t) {
|
|
31
|
-
this.attn.saveWeights(t), this.mlp.saveWeights(t), t.set(`block_${this.index}
|
|
31
|
+
this.attn.saveWeights(t), this.mlp.saveWeights(t), t.set(`block_${this.index}_rms1`, this.ln1.getWeights()), t.set(`block_${this.index}_rms2`, this.ln2.getWeights());
|
|
32
32
|
}
|
|
33
33
|
loadWeights(t) {
|
|
34
|
-
this.attn.loadWeights(t), this.mlp.loadWeights(t), this.ln1.setWeights(t.get(`block_${this.index}
|
|
34
|
+
this.attn.loadWeights(t), this.mlp.loadWeights(t), this.ln1.setWeights(t.get(`block_${this.index}_rms1`) || []), this.ln2.setWeights(t.get(`block_${this.index}_rms2`) || []);
|
|
35
35
|
}
|
|
36
36
|
getMLPOutput(t, i) {
|
|
37
37
|
const s = this.ln2.apply(t), e = this.mlp.call(s, i);
|
|
38
38
|
return t.add(e);
|
|
39
39
|
}
|
|
40
|
-
call(t, i = !1, s = !1) {
|
|
40
|
+
call(t, i = !1, s = !1, e) {
|
|
41
41
|
return this.tf.tidy(() => {
|
|
42
42
|
if (this.skipped)
|
|
43
43
|
return { output: t };
|
|
44
|
-
const
|
|
45
|
-
return {
|
|
44
|
+
const l = this.ln1.apply(t), n = this.attn.call(l, i, s, e), h = t.add(n.output);
|
|
45
|
+
return {
|
|
46
|
+
output: this.getMLPOutput(h, i),
|
|
47
|
+
attention: n.attention,
|
|
48
|
+
cache: n.presentKV
|
|
49
|
+
};
|
|
46
50
|
});
|
|
47
51
|
}
|
|
48
52
|
dispose() {
|
|
@@ -1,70 +1,68 @@
|
|
|
1
1
|
import { generateText as L } from "../utilities/generate.js";
|
|
2
2
|
import w from "./Trainer.js";
|
|
3
|
-
import
|
|
4
|
-
const
|
|
3
|
+
import x from "./Evaluator.js";
|
|
4
|
+
const g = {
|
|
5
5
|
desiredLoss: 0.01,
|
|
6
6
|
logInterval: 1,
|
|
7
7
|
maxSteps: 1e3
|
|
8
8
|
};
|
|
9
|
-
class
|
|
9
|
+
class P extends w {
|
|
10
10
|
constructor(r, i, o, n = 3e-4) {
|
|
11
11
|
super(r, i, o, n);
|
|
12
12
|
}
|
|
13
13
|
// Train for multiple epochs using Dataset API - FIXED memory leaks
|
|
14
14
|
async trainOnDataset(r, i, o) {
|
|
15
|
-
const { desiredLoss: n, logInterval:
|
|
16
|
-
...
|
|
15
|
+
const { desiredLoss: n, logInterval: m, onStep: l, prompt: c, maxSteps: d } = {
|
|
16
|
+
...g,
|
|
17
17
|
...i
|
|
18
|
-
},
|
|
19
|
-
pass: 0,
|
|
20
|
-
depth: 1,
|
|
18
|
+
}, t = {
|
|
21
19
|
step: 0,
|
|
22
|
-
stepSinceDepthChange: 0,
|
|
23
20
|
lastLoss: 1e6,
|
|
24
21
|
totalSteps: 0,
|
|
25
22
|
losses: [],
|
|
26
|
-
validationLosses: []
|
|
23
|
+
validationLosses: [],
|
|
24
|
+
...this.lastState || {}
|
|
27
25
|
};
|
|
28
|
-
this.dummyPass(), this.model.trainable = !0;
|
|
26
|
+
this.lastState = t, this.dummyPass(), this.model.trainable = !0;
|
|
29
27
|
const u = Date.now();
|
|
30
28
|
this.running = !0;
|
|
31
|
-
const
|
|
29
|
+
const h = o ? new x(this.model, o) : void 0, f = await r.iterator();
|
|
32
30
|
try {
|
|
33
|
-
for (; this.running && !(
|
|
31
|
+
for (; this.running && !(t.lastLoss < n); ) {
|
|
34
32
|
const e = await f.next();
|
|
35
33
|
if (e.done) break;
|
|
36
|
-
const
|
|
37
|
-
loss:
|
|
38
|
-
step:
|
|
34
|
+
const p = e.value, v = this.trainBatch(t, p), a = {
|
|
35
|
+
loss: t.lastLoss,
|
|
36
|
+
step: t.step,
|
|
39
37
|
time: Date.now() - u,
|
|
40
|
-
batchSize:
|
|
38
|
+
batchSize: p.xs.shape[0]
|
|
41
39
|
};
|
|
42
|
-
if (this.model.log.push(a),
|
|
43
|
-
if (await v,
|
|
40
|
+
if (this.model.log.push(a), t.step % m === 0) {
|
|
41
|
+
if (await v, h)
|
|
44
42
|
try {
|
|
45
|
-
const
|
|
46
|
-
|
|
47
|
-
} catch (
|
|
48
|
-
console.error("Validation error:",
|
|
43
|
+
const s = await h.evaluate(5);
|
|
44
|
+
t.validationLosses.push(s), a.valLoss = s;
|
|
45
|
+
} catch (s) {
|
|
46
|
+
console.error("Validation error:", s);
|
|
49
47
|
}
|
|
50
48
|
if (l) {
|
|
51
|
-
if (
|
|
52
|
-
const
|
|
49
|
+
if (c) {
|
|
50
|
+
const s = await L(this.tokenizer, this.model, c, 100, {
|
|
53
51
|
temperature: 0.8
|
|
54
52
|
});
|
|
55
|
-
a.example =
|
|
53
|
+
a.example = s;
|
|
56
54
|
}
|
|
57
55
|
await l(a);
|
|
58
56
|
}
|
|
59
57
|
}
|
|
60
|
-
|
|
58
|
+
t.step >= d && this.stop();
|
|
61
59
|
}
|
|
62
60
|
} catch (e) {
|
|
63
61
|
throw console.error("Training error:", e), this.tf.dispose(), e;
|
|
64
62
|
}
|
|
65
|
-
return this.tf.dispose(), this.running = !1, { losses:
|
|
63
|
+
return this.tf.dispose(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
|
|
66
64
|
}
|
|
67
65
|
}
|
|
68
66
|
export {
|
|
69
|
-
|
|
67
|
+
P as default
|
|
70
68
|
};
|
|
@@ -31,8 +31,10 @@ export default abstract class GPTTrainer {
|
|
|
31
31
|
protected tf: typeof TF;
|
|
32
32
|
protected learningRate: number;
|
|
33
33
|
protected running: boolean;
|
|
34
|
+
protected lastState?: TrainingState;
|
|
34
35
|
constructor(tf: typeof TF, model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
|
|
35
36
|
setLearningRate(learningRate: number): void;
|
|
37
|
+
reset(): void;
|
|
36
38
|
stop(): void;
|
|
37
39
|
getOptimizer(): AdamExt;
|
|
38
40
|
resetOptimizer(config?: AdamConfig): void;
|
package/dist/training/Trainer.js
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import { DatasetBuilder as d } from "./DatasetBuilder.js";
|
|
2
|
-
import
|
|
2
|
+
import h from "./AdamExt.js";
|
|
3
3
|
class u {
|
|
4
|
-
constructor(t,
|
|
5
|
-
this.tokenizer =
|
|
4
|
+
constructor(t, e, s, i = 1e-3) {
|
|
5
|
+
this.tokenizer = s, this.tf = t, this.model = e, this.learningRate = i, this.resetOptimizer(), this.datasetBuilder = new d(this.tf, s, e.config.blockSize);
|
|
6
6
|
}
|
|
7
7
|
model;
|
|
8
8
|
optimizer;
|
|
@@ -10,9 +10,13 @@ class u {
|
|
|
10
10
|
tf;
|
|
11
11
|
learningRate;
|
|
12
12
|
running = !1;
|
|
13
|
+
lastState;
|
|
13
14
|
setLearningRate(t) {
|
|
14
15
|
this.learningRate = t, this.resetOptimizer({ learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 });
|
|
15
16
|
}
|
|
17
|
+
reset() {
|
|
18
|
+
this.lastState = void 0, this.running = !1;
|
|
19
|
+
}
|
|
16
20
|
stop() {
|
|
17
21
|
this.running = !1;
|
|
18
22
|
}
|
|
@@ -21,7 +25,7 @@ class u {
|
|
|
21
25
|
}
|
|
22
26
|
resetOptimizer(t = { learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 }) {
|
|
23
27
|
this.optimizer && this.optimizer.dispose();
|
|
24
|
-
const
|
|
28
|
+
const e = new h(
|
|
25
29
|
t.learningRateFactor * this.learningRate,
|
|
26
30
|
t.beta1,
|
|
27
31
|
t.beta2,
|
|
@@ -33,53 +37,53 @@ class u {
|
|
|
33
37
|
weightDecay: 0
|
|
34
38
|
}
|
|
35
39
|
);
|
|
36
|
-
this.optimizer =
|
|
40
|
+
this.optimizer = e;
|
|
37
41
|
}
|
|
38
42
|
printGradients(t) {
|
|
39
|
-
Object.keys(t).forEach((
|
|
40
|
-
const
|
|
41
|
-
console.log(`${
|
|
43
|
+
Object.keys(t).forEach((e) => {
|
|
44
|
+
const s = t[e];
|
|
45
|
+
console.log(`${e}:`), console.log(` Shape: ${s.shape}`), console.log(` Mean: ${this.tf.mean(s).dataSync()[0]}`), console.log(` Std: ${this.tf.moments(s).variance.sqrt().dataSync()[0]}`), console.log(` Min: ${this.tf.min(s).dataSync()[0]}`), console.log(` Max: ${this.tf.max(s).dataSync()[0]}`), console.log(` Norm: ${this.tf.norm(s).dataSync()[0]}`);
|
|
42
46
|
});
|
|
43
47
|
}
|
|
44
|
-
trainStep(t,
|
|
48
|
+
trainStep(t, e = !1, s = !1) {
|
|
45
49
|
return this.tf.tidy(() => {
|
|
46
50
|
const { xs: i, ys: a } = t, o = () => {
|
|
47
51
|
const { loss: l, logits: c } = this.model.forward(i, a, !0);
|
|
48
52
|
return c.dispose(), l;
|
|
49
53
|
}, { value: n, grads: r } = this.tf.variableGrads(o);
|
|
50
|
-
return
|
|
54
|
+
return e || (s && (console.log("-------"), this.printGradients(r), console.log("-------")), this.optimizer.applyGradients(r), this.tf.dispose(r)), n;
|
|
51
55
|
});
|
|
52
56
|
}
|
|
53
57
|
dummyPass() {
|
|
54
|
-
const t = this.tf.zeros([1, this.model.config.blockSize], "int32"),
|
|
58
|
+
const t = this.tf.zeros([1, this.model.config.blockSize], "int32"), e = this.tf.zeros([1, this.model.config.blockSize, this.model.config.vocabSize]);
|
|
55
59
|
try {
|
|
56
|
-
const
|
|
57
|
-
|
|
58
|
-
} catch (
|
|
59
|
-
console.error("Error during dummy pass:",
|
|
60
|
+
const s = this.trainStep({ xs: t, ys: e }, !0);
|
|
61
|
+
s.dataSync(), s.dispose();
|
|
62
|
+
} catch (s) {
|
|
63
|
+
console.error("Error during dummy pass:", s);
|
|
60
64
|
} finally {
|
|
61
|
-
t.dispose(),
|
|
65
|
+
t.dispose(), e.dispose();
|
|
62
66
|
}
|
|
63
67
|
}
|
|
64
|
-
async trainBatch(t,
|
|
68
|
+
async trainBatch(t, e) {
|
|
65
69
|
try {
|
|
66
|
-
const
|
|
67
|
-
return
|
|
68
|
-
} catch (
|
|
69
|
-
throw console.error(`Error processing batch at step ${t.step}:`,
|
|
70
|
+
const s = this.trainStep(e, !1, !1);
|
|
71
|
+
return e.xs.dispose(), e.ys.dispose(), t.step++, t.totalSteps++, s.array().then((i) => (t.lastLoss = i, t.losses.push(t.lastLoss), s.dispose(), t.lastLoss));
|
|
72
|
+
} catch (s) {
|
|
73
|
+
throw console.error(`Error processing batch at step ${t.step}:`, s), this.tf.dispose(), s;
|
|
70
74
|
}
|
|
71
75
|
}
|
|
72
|
-
async createTrainValidationSplit(t,
|
|
73
|
-
const i = await this.datasetBuilder.createTextDataset(t,
|
|
76
|
+
async createTrainValidationSplit(t, e = 32, s = 0.1) {
|
|
77
|
+
const i = await this.datasetBuilder.createTextDataset(t, e, 0, 1 - s), a = await this.datasetBuilder.createTextDataset(
|
|
74
78
|
t,
|
|
75
|
-
|
|
76
|
-
1 -
|
|
79
|
+
e,
|
|
80
|
+
1 - s,
|
|
77
81
|
1
|
|
78
82
|
);
|
|
79
83
|
return { trainDataset: i, validationDataset: a };
|
|
80
84
|
}
|
|
81
|
-
async createDataset(t,
|
|
82
|
-
return await this.datasetBuilder.createTextDataset(t,
|
|
85
|
+
async createDataset(t, e = 32) {
|
|
86
|
+
return await this.datasetBuilder.createTextDataset(t, e);
|
|
83
87
|
}
|
|
84
88
|
dispose() {
|
|
85
89
|
this.optimizer && this.optimizer.dispose();
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
async function
|
|
2
|
-
if (
|
|
1
|
+
async function h(r, t, a, c, g) {
|
|
2
|
+
if (c <= 0)
|
|
3
3
|
throw new Error("Length must be a positive integer");
|
|
4
|
-
if (
|
|
4
|
+
if (a.length === 0)
|
|
5
5
|
throw new Error("Prompt cannot be an empty string");
|
|
6
|
-
const
|
|
7
|
-
let e = t.tf.tensor2d(
|
|
8
|
-
for (let
|
|
9
|
-
const { output:
|
|
10
|
-
e = t.tf.concat([e,
|
|
6
|
+
const p = await r.tokenise([a], !0), s = t.config.useRope ? new Array(t.config.nLayer).fill(void 0) : void 0, u = t.tf.tidy(() => {
|
|
7
|
+
let e = t.tf.tensor2d(p, [1, p[0].length], "int32"), n = e;
|
|
8
|
+
for (let f = 0; f < c; f++) {
|
|
9
|
+
const { output: o } = t.generate(e, s, g), w = e, y = n;
|
|
10
|
+
n = t.tf.concat([n, o], 1), e = s ? o : t.tf.concat([e, o], 1), w.dispose(), y.dispose(), s || o.dispose();
|
|
11
11
|
}
|
|
12
|
-
return
|
|
13
|
-
}),
|
|
14
|
-
|
|
15
|
-
const
|
|
16
|
-
return
|
|
12
|
+
return n;
|
|
13
|
+
}), T = await u.array();
|
|
14
|
+
u.dispose();
|
|
15
|
+
const i = T[0], d = i.indexOf(r.eosToken);
|
|
16
|
+
return d !== -1 && i.splice(d), await r.decode(i);
|
|
17
17
|
}
|
|
18
18
|
export {
|
|
19
|
-
|
|
19
|
+
h as generateText
|
|
20
20
|
};
|