@genai-fi/nanogpt 0.2.12 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +30 -25
- package/dist/NanoGPTModel.d.ts +13 -14
- package/dist/NanoGPTModel.js +142 -70
- package/dist/TeachableLLM.d.ts +16 -7
- package/dist/TeachableLLM.js +81 -44
- package/dist/Trainer.js +8 -8
- package/dist/concat-BIZS_td9.js +33 -0
- package/dist/data/parquet.js +1 -1
- package/dist/exports_layers-tbTBcwMM.js +25 -0
- package/dist/{sum-D7fu15XL.js → gather-BPGW8RsB.js} +6 -8
- package/dist/index-C4L8Cm77.js +349 -0
- package/dist/{index-YPKosni4.js → index-pWA4_lUh.js} +1020 -782
- package/dist/layers/CausalSelfAttention.d.ts +11 -11
- package/dist/layers/CausalSelfAttention.js +71 -63
- package/dist/layers/MLP.d.ts +6 -7
- package/dist/layers/MLP.js +18 -16
- package/dist/layers/RMSNorm.d.ts +6 -7
- package/dist/layers/RMSNorm.js +15 -13
- package/dist/layers/RoPECache.d.ts +4 -5
- package/dist/layers/RoPECache.js +36 -12
- package/dist/layers/TiedEmbedding.d.ts +7 -8
- package/dist/layers/TiedEmbedding.js +16 -418
- package/dist/layers/TransformerBlock.d.ts +8 -9
- package/dist/layers/TransformerBlock.js +12 -12
- package/dist/main.d.ts +2 -0
- package/dist/main.js +35 -21
- package/dist/{mat_mul-Bu7bhLms.js → mat_mul-D7_a4KJn.js} +5 -5
- package/dist/moments-DfcpfwKi.js +132 -0
- package/dist/ones-Cog-G2ag.js +29 -0
- package/dist/ops/appendCache.d.ts +2 -0
- package/dist/ops/appendCache.js +9 -0
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +7 -85
- package/dist/ops/cpu/appendCache.d.ts +2 -0
- package/dist/ops/cpu/appendCache.js +28 -0
- package/dist/ops/cpu/attentionMask.js +18 -0
- package/dist/ops/cpu/gatherSub.d.ts +1 -0
- package/dist/ops/cpu/gatherSub.js +34 -0
- package/dist/ops/cpu/qkv.d.ts +5 -0
- package/dist/ops/cpu/qkv.js +38 -0
- package/dist/ops/cpu/rope.d.ts +6 -0
- package/dist/ops/cpu/rope.js +38 -0
- package/dist/ops/cpu/scatterSub.d.ts +1 -0
- package/dist/ops/cpu/scatterSub.js +70 -0
- package/dist/ops/gatherSub.d.ts +1 -1
- package/dist/ops/gatherSub.js +6 -63
- package/dist/ops/grads/attentionMask.d.ts +1 -0
- package/dist/ops/grads/attentionMask.js +21 -0
- package/dist/ops/grads/qkv.d.ts +1 -0
- package/dist/ops/grads/qkv.js +20 -0
- package/dist/ops/grads/rope.d.ts +1 -0
- package/dist/ops/grads/rope.js +14 -0
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/qkv.d.ts +1 -6
- package/dist/ops/qkv.js +7 -124
- package/dist/ops/rope.d.ts +0 -5
- package/dist/ops/rope.js +7 -151
- package/dist/ops/scatterSub.d.ts +1 -1
- package/dist/ops/scatterSub.js +6 -147
- package/dist/ops/webgl/appendCache.d.ts +1 -0
- package/dist/ops/webgl/appendCache.js +43 -0
- package/dist/ops/webgl/attentionMask.d.ts +1 -0
- package/dist/ops/webgl/attentionMask.js +43 -0
- package/dist/ops/webgl/gatherSub.d.ts +1 -0
- package/dist/ops/webgl/gatherSub.js +27 -0
- package/dist/ops/webgl/qkv.d.ts +1 -0
- package/dist/ops/webgl/qkv.js +46 -0
- package/dist/ops/webgl/rope.d.ts +1 -0
- package/dist/ops/webgl/rope.js +56 -0
- package/dist/ops/webgl/scatterSub.d.ts +1 -0
- package/dist/ops/webgl/scatterSub.js +27 -0
- package/dist/{parquet-BRl5lE_I.js → parquet-C0Tlmv9c.js} +3045 -3048
- package/dist/random_width-oeUIlUZj.js +15487 -0
- package/dist/range-CcDl05lo.js +26 -0
- package/dist/{reshape-DmnmKT6r.js → reshape-C8CR_Bad.js} +3 -3
- package/dist/sin-BJIrfnj7.js +47 -0
- package/dist/softmax-Be_lsqUc.js +105 -0
- package/dist/{complex-CJ-qCcLB.js → split-DZbvruEP.js} +6 -8
- package/dist/stack-BMm-efee.js +27 -0
- package/dist/sum-C7Mgy9Bw.js +104 -0
- package/dist/tensor-DJVbYhh1.js +24 -0
- package/dist/tensor2d-ZuQSh2D-.js +30 -0
- package/dist/tokeniser/bpe.d.ts +17 -6
- package/dist/tokeniser/bpe.js +89 -61
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.d.ts +6 -6
- package/dist/training/DatasetBuilder.js +1262 -17
- package/dist/training/Evaluator.d.ts +3 -2
- package/dist/training/FullTrainer.d.ts +9 -8
- package/dist/training/FullTrainer.js +26 -25
- package/dist/training/LayerTrainer.d.ts +9 -8
- package/dist/training/LayerTrainer.js +34 -33
- package/dist/training/Trainer.d.ts +22 -21
- package/dist/training/Trainer.js +21 -18
- package/dist/training/sparseCrossEntropy.js +22 -166
- package/dist/utilities/dummy.js +10 -8
- package/dist/utilities/generate.js +14 -11
- package/dist/utilities/load.d.ts +1 -2
- package/dist/utilities/load.js +37 -35
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/save.js +14 -9
- package/dist/utilities/tokenParse.d.ts +1 -1
- package/dist/utilities/tokenParse.js +7 -61
- package/dist/utilities/weights.d.ts +3 -3
- package/dist/utilities/weights.js +21 -19
- package/dist/variable-Dl_ub3pk.js +23 -0
- package/dist/{stack-BtKpB0Ry.js → zeros-CCy9C3uU.js} +18 -16
- package/package.json +2 -1
- package/dist/assets/worker-BYeSPNkq.js +0 -1
- package/dist/tokeniser/NodeTokeniser.d.ts +0 -20
- package/dist/tokeniser/NodeTokeniser.js +0 -46
- package/dist/tokeniser/WebTokeniser.d.ts +0 -18
- package/dist/tokeniser/WebTokeniser.js +0 -96
- package/dist/tokeniser/worker.js +0 -53
- /package/dist/{tokeniser/worker.d.ts → ops/cpu/attentionMask.d.ts} +0 -0
package/dist/Generator.js
CHANGED
|
@@ -1,49 +1,52 @@
|
|
|
1
1
|
import { E as u } from "./index-Dwqa6Zy2.js";
|
|
2
|
-
|
|
2
|
+
import "./index-pWA4_lUh.js";
|
|
3
|
+
import { t as d } from "./tensor2d-ZuQSh2D-.js";
|
|
4
|
+
import { c as k } from "./concat-BIZS_td9.js";
|
|
5
|
+
class w extends u {
|
|
3
6
|
constructor(s, e) {
|
|
4
7
|
super(), this.model = s, this.tokeniser = e;
|
|
5
8
|
}
|
|
6
9
|
active = !1;
|
|
7
10
|
async tokenisePrompt(s) {
|
|
8
11
|
const e = s ? await this.tokeniser.tokenise([s], !0) : [[this.tokeniser.eosToken]];
|
|
9
|
-
return
|
|
12
|
+
return d(e, [1, e[0].length], "int32");
|
|
10
13
|
}
|
|
11
14
|
async generateNoCache(s, e) {
|
|
12
|
-
let t = await this.tokenisePrompt(s),
|
|
13
|
-
const
|
|
14
|
-
for (let a = 0; a <
|
|
15
|
+
let t = await this.tokenisePrompt(s), o = s || "";
|
|
16
|
+
const n = e?.maxLength ?? 1e3;
|
|
17
|
+
for (let a = 0; a < n && this.active; a++) {
|
|
15
18
|
const {
|
|
16
|
-
output:
|
|
19
|
+
output: i,
|
|
17
20
|
attention: c,
|
|
18
21
|
probabilities: l
|
|
19
22
|
} = this.model.generate(t, void 0, e), h = t;
|
|
20
|
-
t =
|
|
21
|
-
const r = await this.processResponse(
|
|
22
|
-
if (
|
|
23
|
+
t = k([t, i], 1), h.dispose();
|
|
24
|
+
const r = await this.processResponse(i, c, l);
|
|
25
|
+
if (i.dispose(), r === null)
|
|
23
26
|
break;
|
|
24
|
-
|
|
27
|
+
o += r;
|
|
25
28
|
}
|
|
26
|
-
return t.dispose(),
|
|
29
|
+
return t.dispose(), o;
|
|
27
30
|
}
|
|
28
31
|
async processResponse(s, e, t) {
|
|
29
|
-
const
|
|
30
|
-
if (
|
|
32
|
+
const o = (await s.array())[0][0];
|
|
33
|
+
if (o === this.tokeniser.eosToken)
|
|
31
34
|
return null;
|
|
32
|
-
const
|
|
35
|
+
const n = await this.tokeniser.decode([o]);
|
|
33
36
|
let a;
|
|
34
37
|
e && (a = await e.array(), e.dispose());
|
|
35
|
-
let
|
|
36
|
-
return t && (
|
|
38
|
+
let i;
|
|
39
|
+
return t && (i = await t.array(), t.dispose()), this.emit("tokens", [o], n, a, i), n;
|
|
37
40
|
}
|
|
38
41
|
async generateCache(s, e) {
|
|
39
|
-
let t = await this.tokenisePrompt(s),
|
|
40
|
-
const
|
|
41
|
-
for (let
|
|
42
|
+
let t = await this.tokenisePrompt(s), o = s || "";
|
|
43
|
+
const n = new Array(this.model.config.nLayer).fill(void 0), a = e?.maxLength ?? 1e3;
|
|
44
|
+
for (let i = 0; i < a && this.active; i++) {
|
|
42
45
|
const {
|
|
43
46
|
output: c,
|
|
44
47
|
attention: l,
|
|
45
48
|
probabilities: h
|
|
46
|
-
} = this.model.generate(t,
|
|
49
|
+
} = this.model.generate(t, n, {
|
|
47
50
|
...e,
|
|
48
51
|
usePadding: !1
|
|
49
52
|
});
|
|
@@ -51,20 +54,22 @@ class f extends u {
|
|
|
51
54
|
const r = await this.processResponse(c, l, h);
|
|
52
55
|
if (r === null)
|
|
53
56
|
break;
|
|
54
|
-
|
|
57
|
+
o += r;
|
|
55
58
|
}
|
|
56
|
-
return
|
|
59
|
+
return n.forEach((i) => {
|
|
60
|
+
i && (i.k.dispose(), i.v.dispose());
|
|
61
|
+
}), t.dispose(), o;
|
|
57
62
|
}
|
|
58
63
|
async generate(s, e) {
|
|
59
64
|
const t = s && s.length > this.model.config.blockSize ? s.slice(-this.model.config.blockSize) : s;
|
|
60
65
|
this.active = !0, this.emit("start");
|
|
61
|
-
const
|
|
62
|
-
return this.active = !1, this.emit("stop"),
|
|
66
|
+
const n = await (this.model.config.useRope && !e?.noCache ? this.generateCache(t, e) : this.generateNoCache(t, e));
|
|
67
|
+
return this.active = !1, this.emit("stop"), n;
|
|
63
68
|
}
|
|
64
69
|
stop() {
|
|
65
70
|
this.active = !1;
|
|
66
71
|
}
|
|
67
72
|
}
|
|
68
73
|
export {
|
|
69
|
-
|
|
74
|
+
w as default
|
|
70
75
|
};
|
package/dist/NanoGPTModel.d.ts
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
import { default as TF } from '@tensorflow/tfjs';
|
|
2
1
|
import { GPTConfig } from './config';
|
|
3
2
|
import { KVCache } from './layers/CausalSelfAttention';
|
|
4
3
|
import { default as MemoryProfiler } from './utilities/profile';
|
|
5
4
|
import { default as BaseLayer } from './layers/BaseLayer';
|
|
5
|
+
import { Tensor, Variable } from '@tensorflow/tfjs-core';
|
|
6
6
|
export interface TrainingLogEntry {
|
|
7
7
|
loss: number;
|
|
8
8
|
valLoss?: number;
|
|
@@ -26,12 +26,11 @@ export default class NanoGPT extends BaseLayer {
|
|
|
26
26
|
private blocks;
|
|
27
27
|
private lnF;
|
|
28
28
|
private ropeCache?;
|
|
29
|
-
readonly tf: typeof TF;
|
|
30
29
|
log: TrainingLogEntry[];
|
|
31
|
-
constructor(
|
|
32
|
-
get variables():
|
|
33
|
-
saveWeights(): Map<string,
|
|
34
|
-
loadWeights(weights: Map<string,
|
|
30
|
+
constructor(config?: Partial<GPTConfig>);
|
|
31
|
+
get variables(): Variable[];
|
|
32
|
+
saveWeights(): Map<string, Tensor[]>;
|
|
33
|
+
loadWeights(weights: Map<string, Tensor[]>): void;
|
|
35
34
|
private inputPhase;
|
|
36
35
|
setSkipMask(mask: boolean[]): void;
|
|
37
36
|
setTrainableMask(mask: boolean[]): void;
|
|
@@ -40,15 +39,15 @@ export default class NanoGPT extends BaseLayer {
|
|
|
40
39
|
private validateInput;
|
|
41
40
|
private calculateLoss;
|
|
42
41
|
private computeAttentionRollout;
|
|
43
|
-
forward(idx:
|
|
44
|
-
logits:
|
|
45
|
-
loss?:
|
|
46
|
-
attention?:
|
|
42
|
+
forward(idx: Tensor, targets?: Tensor, training?: boolean, includeAttention?: boolean, cache?: (KVCache | undefined)[]): {
|
|
43
|
+
logits: Tensor;
|
|
44
|
+
loss?: Tensor;
|
|
45
|
+
attention?: Tensor;
|
|
47
46
|
};
|
|
48
|
-
generate(idx:
|
|
49
|
-
output:
|
|
50
|
-
attention?:
|
|
51
|
-
probabilities?:
|
|
47
|
+
generate(idx: Tensor, cache?: (KVCache | undefined)[], options?: GenerateOptions): {
|
|
48
|
+
output: Tensor;
|
|
49
|
+
attention?: Tensor;
|
|
50
|
+
probabilities?: Tensor;
|
|
52
51
|
};
|
|
53
52
|
getNumParams(): number;
|
|
54
53
|
dispose(): void;
|
package/dist/NanoGPTModel.js
CHANGED
|
@@ -1,12 +1,98 @@
|
|
|
1
|
-
import { defaultConfig as
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
import { estimateParameterCount as
|
|
7
|
-
import { createSoftmaxCrossEntropyWithGrad as
|
|
8
|
-
import
|
|
9
|
-
|
|
1
|
+
import { defaultConfig as F } from "./config.js";
|
|
2
|
+
import L from "./layers/TransformerBlock.js";
|
|
3
|
+
import P from "./layers/TiedEmbedding.js";
|
|
4
|
+
import C from "./layers/RoPECache.js";
|
|
5
|
+
import q from "./layers/RMSNorm.js";
|
|
6
|
+
import { estimateParameterCount as K } from "./utilities/parameters.js";
|
|
7
|
+
import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
|
|
8
|
+
import T from "./layers/BaseLayer.js";
|
|
9
|
+
import { r as R, e as D, p as A } from "./random_width-oeUIlUZj.js";
|
|
10
|
+
import { o as y, h as E, p as B, E as z, W as G, X as O, Y as Q, t as w, Z as X, f as _ } from "./index-pWA4_lUh.js";
|
|
11
|
+
import { e as j, a as U } from "./exports_layers-tbTBcwMM.js";
|
|
12
|
+
import { r as S } from "./reshape-C8CR_Bad.js";
|
|
13
|
+
import { r as V } from "./range-CcDl05lo.js";
|
|
14
|
+
import { g as Y } from "./gather-BPGW8RsB.js";
|
|
15
|
+
import { s as Z } from "./softmax-Be_lsqUc.js";
|
|
16
|
+
/**
|
|
17
|
+
* @license
|
|
18
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
19
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
20
|
+
* you may not use this file except in compliance with the License.
|
|
21
|
+
* You may obtain a copy of the License at
|
|
22
|
+
*
|
|
23
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
24
|
+
*
|
|
25
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
26
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
27
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
28
|
+
* See the License for the specific language governing permissions and
|
|
29
|
+
* limitations under the License.
|
|
30
|
+
* =============================================================================
|
|
31
|
+
*/
|
|
32
|
+
function H(m, t) {
|
|
33
|
+
let e = E(m, "a", "mod"), o = E(t, "b", "mod");
|
|
34
|
+
[e, o] = B(e, o);
|
|
35
|
+
const i = { a: e, b: o };
|
|
36
|
+
return z.runKernel(G, i);
|
|
37
|
+
}
|
|
38
|
+
const J = /* @__PURE__ */ y({ mod_: H });
|
|
39
|
+
/**
|
|
40
|
+
* @license
|
|
41
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
42
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
43
|
+
* you may not use this file except in compliance with the License.
|
|
44
|
+
* You may obtain a copy of the License at
|
|
45
|
+
*
|
|
46
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
47
|
+
*
|
|
48
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
49
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
50
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
51
|
+
* See the License for the specific language governing permissions and
|
|
52
|
+
* limitations under the License.
|
|
53
|
+
* =============================================================================
|
|
54
|
+
*/
|
|
55
|
+
function tt(m, t, e, o = !1) {
|
|
56
|
+
const i = E(m, "logits", "multinomial"), s = i.size, r = i.rank;
|
|
57
|
+
if (s < 2)
|
|
58
|
+
throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
|
|
59
|
+
if (r > 2)
|
|
60
|
+
throw new Error(`Rank of probabilities must be 1 or 2, but is ${r}`);
|
|
61
|
+
e = e || Math.random();
|
|
62
|
+
const n = { logits: r === 1 ? S(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = z.runKernel(O, n, h);
|
|
63
|
+
return r === 1 ? S(a, [a.size]) : a;
|
|
64
|
+
}
|
|
65
|
+
const I = /* @__PURE__ */ y({ multinomial_: tt });
|
|
66
|
+
/**
|
|
67
|
+
* @license
|
|
68
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
69
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
70
|
+
* you may not use this file except in compliance with the License.
|
|
71
|
+
* You may obtain a copy of the License at
|
|
72
|
+
*
|
|
73
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
74
|
+
*
|
|
75
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
76
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
77
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
78
|
+
* See the License for the specific language governing permissions and
|
|
79
|
+
* limitations under the License.
|
|
80
|
+
* =============================================================================
|
|
81
|
+
*/
|
|
82
|
+
function et(m, t = 1, e = !0) {
|
|
83
|
+
const o = E(m, "x", "topk");
|
|
84
|
+
if (o.rank === 0)
|
|
85
|
+
throw new Error("topk() expects the input to be of rank 1 or higher");
|
|
86
|
+
const i = o.shape[o.shape.length - 1];
|
|
87
|
+
if (t < 0)
|
|
88
|
+
throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
|
|
89
|
+
if (t > i)
|
|
90
|
+
throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
|
|
91
|
+
const s = { x: o }, r = { k: t, sorted: e }, [l, n] = z.runKernel(Q, s, r);
|
|
92
|
+
return { values: l, indices: n };
|
|
93
|
+
}
|
|
94
|
+
const ot = /* @__PURE__ */ y({ topk_: et });
|
|
95
|
+
class kt extends T {
|
|
10
96
|
config;
|
|
11
97
|
wte;
|
|
12
98
|
// Token embeddings
|
|
@@ -18,23 +104,22 @@ class A extends P {
|
|
|
18
104
|
lnF;
|
|
19
105
|
// Final layer norm
|
|
20
106
|
ropeCache;
|
|
21
|
-
tf;
|
|
22
107
|
log = [];
|
|
23
108
|
// Training log
|
|
24
|
-
constructor(t
|
|
25
|
-
super(), this.
|
|
109
|
+
constructor(t = {}) {
|
|
110
|
+
super(), this.config = { ...F, ...t }, this.wte = new P({
|
|
26
111
|
vocabSize: this.config.vocabSize,
|
|
27
112
|
embedDim: this.config.nEmbed,
|
|
28
113
|
name: "token_embedding"
|
|
29
|
-
}), this.config.useRope === !1 ? this.wpe =
|
|
114
|
+
}), this.config.useRope === !1 ? this.wpe = j({
|
|
30
115
|
inputDim: this.config.blockSize,
|
|
31
116
|
outputDim: this.config.nEmbed,
|
|
32
117
|
name: "positional_embedding",
|
|
33
|
-
embeddingsInitializer:
|
|
34
|
-
}) : this.ropeCache = new
|
|
35
|
-
for (let
|
|
36
|
-
this.blocks.push(new
|
|
37
|
-
this.lnF = new
|
|
118
|
+
embeddingsInitializer: R({ mean: 0, stddev: 0.02 })
|
|
119
|
+
}) : this.ropeCache = new C(this.config), this.drop = U({ rate: this.config.dropout }), this.blocks = [];
|
|
120
|
+
for (let e = 0; e < this.config.nLayer; e++)
|
|
121
|
+
this.blocks.push(new L(e, this.config, this.ropeCache));
|
|
122
|
+
this.lnF = new q([this.config.nEmbed], 1e-8, "final_rms_norm");
|
|
38
123
|
}
|
|
39
124
|
get variables() {
|
|
40
125
|
return [
|
|
@@ -58,14 +143,11 @@ class A extends P {
|
|
|
58
143
|
this.lnF.setWeights(t.get("final_rms_norm") || []);
|
|
59
144
|
}
|
|
60
145
|
inputPhase(t, e, o = !1) {
|
|
61
|
-
return
|
|
146
|
+
return w(() => {
|
|
62
147
|
const i = this.wte.embed(t);
|
|
63
148
|
if (this.config.useRope === !1) {
|
|
64
|
-
const [, s] = t.shape,
|
|
65
|
-
|
|
66
|
-
this.tf.scalar(l, "int32")
|
|
67
|
-
), h = this.wpe.apply(n), c = i.add(h);
|
|
68
|
-
return this.drop.apply(c, { training: o });
|
|
149
|
+
const [, s] = t.shape, r = this.config.blockSize, l = V(0, s, 1, "int32"), n = J(X(l, _(e, "int32")), _(r, "int32")), h = this.wpe.apply(n), a = i.add(h);
|
|
150
|
+
return this.drop.apply(a, { training: o });
|
|
69
151
|
} else
|
|
70
152
|
return this.drop.apply(i, { training: o });
|
|
71
153
|
});
|
|
@@ -103,7 +185,7 @@ class A extends P {
|
|
|
103
185
|
}
|
|
104
186
|
calculateLoss(t, e) {
|
|
105
187
|
try {
|
|
106
|
-
return
|
|
188
|
+
return N()(t, e).mean();
|
|
107
189
|
} catch (o) {
|
|
108
190
|
throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
|
|
109
191
|
}
|
|
@@ -111,89 +193,79 @@ class A extends P {
|
|
|
111
193
|
// Attention rollout per Abnar & Zuidema (2020)
|
|
112
194
|
// Expects list of (B, T, T) attention matrices already averaged over heads.
|
|
113
195
|
computeAttentionRollout(t) {
|
|
114
|
-
return
|
|
196
|
+
return w(() => {
|
|
115
197
|
if (t.length === 0)
|
|
116
198
|
throw new Error("No attentions for rollout");
|
|
117
199
|
const [e, o, i] = t[0].shape;
|
|
118
200
|
for (const s of t) {
|
|
119
|
-
const [
|
|
120
|
-
if (
|
|
201
|
+
const [r, l, n] = s.shape;
|
|
202
|
+
if (r !== e || l !== o || n !== i)
|
|
121
203
|
throw new Error(
|
|
122
|
-
`Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${
|
|
204
|
+
`Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${r},${l},${n}]`
|
|
123
205
|
);
|
|
124
206
|
}
|
|
125
207
|
if (o === i) {
|
|
126
|
-
const s =
|
|
127
|
-
let
|
|
128
|
-
for (const
|
|
129
|
-
const n =
|
|
130
|
-
|
|
131
|
-
}
|
|
132
|
-
return l;
|
|
133
|
-
}
|
|
134
|
-
if (o === 1) {
|
|
135
|
-
let s = null;
|
|
136
|
-
const l = this.tf.tensor1d([i - 1], "int32"), r = this.tf.oneHot(l, i).reshape([1, 1, i]).tile([e, 1, 1]);
|
|
137
|
-
l.dispose();
|
|
138
|
-
for (const n of t) {
|
|
139
|
-
let h = n.add(r);
|
|
140
|
-
h = h.div(h.sum(-1, !0)), s == null ? s = h : (s = s.mul(h), s = s.div(s.sum(-1, !0)));
|
|
208
|
+
const s = D(i, i).expandDims(0);
|
|
209
|
+
let r = s.tile([e, 1, 1]);
|
|
210
|
+
for (const l of t) {
|
|
211
|
+
const n = l.add(s);
|
|
212
|
+
r = n.div(n.sum(-1, !0)).matMul(r);
|
|
141
213
|
}
|
|
142
|
-
return
|
|
214
|
+
return r;
|
|
143
215
|
}
|
|
144
216
|
throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${o}, K=${i}]`);
|
|
145
217
|
});
|
|
146
218
|
}
|
|
147
219
|
forward(t, e, o = !1, i = !1, s) {
|
|
148
|
-
return this.validateInput(t),
|
|
220
|
+
return this.validateInput(t), w(() => {
|
|
149
221
|
this.startMemory();
|
|
150
|
-
const
|
|
151
|
-
let
|
|
222
|
+
const r = s?.[0]?.length ?? 0;
|
|
223
|
+
let l = this.inputPhase(t, r, o);
|
|
152
224
|
const n = [];
|
|
153
225
|
if (s && s.length !== this.blocks.length)
|
|
154
226
|
throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
|
|
155
|
-
for (let
|
|
156
|
-
const
|
|
157
|
-
output:
|
|
158
|
-
attention:
|
|
227
|
+
for (let c = 0; c < this.blocks.length; c++) {
|
|
228
|
+
const u = l, d = this.blocks[c], {
|
|
229
|
+
output: b,
|
|
230
|
+
attention: k,
|
|
159
231
|
cache: f
|
|
160
|
-
} =
|
|
161
|
-
|
|
232
|
+
} = d.call(l, o, i, s ? s[c] : void 0);
|
|
233
|
+
l = b, u.dispose(), i && k && n.push(k), s && f ? (s[c]?.k.dispose(), s[c]?.v.dispose(), s[c] = f) : f && (f.k.dispose(), f.v.dispose());
|
|
162
234
|
}
|
|
163
235
|
let h;
|
|
164
|
-
i && n.length > 0 && (h = this.computeAttentionRollout(n)),
|
|
165
|
-
const
|
|
236
|
+
i && n.length > 0 && (h = this.computeAttentionRollout(n)), l = this.lnF.apply(l);
|
|
237
|
+
const a = this.wte.project(l);
|
|
166
238
|
let p;
|
|
167
|
-
return e && (p = this.calculateLoss(
|
|
239
|
+
return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
|
|
168
240
|
});
|
|
169
241
|
}
|
|
170
242
|
generate(t, e, o) {
|
|
171
|
-
const i = o?.temperature ?? 1, s = o?.topK,
|
|
172
|
-
return
|
|
173
|
-
const n = t, h = n.shape[1],
|
|
243
|
+
const i = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, l = o?.includeAttention ?? !1;
|
|
244
|
+
return w(() => {
|
|
245
|
+
const n = t, h = n.shape[1], a = h <= this.config.blockSize ? n : n.slice(
|
|
174
246
|
[0, h - this.config.blockSize],
|
|
175
247
|
[n.shape[0], this.config.blockSize]
|
|
176
|
-
), p =
|
|
248
|
+
), p = r ? this.config.blockSize - a.shape[1] : 0, c = p > 0 ? A(a, [
|
|
177
249
|
[0, 0],
|
|
178
250
|
[0, p]
|
|
179
|
-
]) :
|
|
180
|
-
let
|
|
251
|
+
]) : a, { logits: u, attention: d } = this.forward(c, void 0, !1, l, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = d ? d.slice([0, b, 0], [d.shape[0], 1, d.shape[2]]) : void 0, $ = k.div(i);
|
|
252
|
+
let g;
|
|
181
253
|
if (s) {
|
|
182
|
-
const { values:
|
|
183
|
-
|
|
254
|
+
const { values: M, indices: x } = ot($, s), W = I(M.squeeze([1]), 1);
|
|
255
|
+
g = Y(x.squeeze([1]), W, 1);
|
|
184
256
|
} else
|
|
185
|
-
|
|
186
|
-
let
|
|
187
|
-
return o?.includeProbabilities && (
|
|
257
|
+
g = I($.squeeze([1]), 1);
|
|
258
|
+
let v;
|
|
259
|
+
return o?.includeProbabilities && (v = Z($.squeeze([1]))), g = g.reshape([1, 1]), { output: g, attention: f?.squeeze([1]), probabilities: v };
|
|
188
260
|
});
|
|
189
261
|
}
|
|
190
262
|
getNumParams() {
|
|
191
|
-
return
|
|
263
|
+
return K(this.config);
|
|
192
264
|
}
|
|
193
265
|
dispose() {
|
|
194
266
|
this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
|
|
195
267
|
}
|
|
196
268
|
}
|
|
197
269
|
export {
|
|
198
|
-
|
|
270
|
+
kt as default
|
|
199
271
|
};
|
package/dist/TeachableLLM.d.ts
CHANGED
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
import { default as TF } from '@tensorflow/tfjs';
|
|
2
1
|
import { GPTConfig } from './config';
|
|
3
2
|
import { ITokeniser } from './tokeniser/type';
|
|
4
3
|
import { default as NanoGPT } from './NanoGPTModel';
|
|
5
4
|
import { SaveOptions } from './utilities/save';
|
|
6
5
|
import { default as Generator, IGenerateOptions } from './Generator';
|
|
7
6
|
import { default as Trainer, ITrainerOptions } from './Trainer';
|
|
8
|
-
import { default as EE } from 'eventemitter3';
|
|
9
7
|
import { default as MemoryProfiler } from './utilities/profile';
|
|
10
8
|
type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
|
|
11
|
-
export default class TeachableLLM
|
|
9
|
+
export default class TeachableLLM {
|
|
10
|
+
private ee;
|
|
12
11
|
private _config?;
|
|
13
12
|
private _model?;
|
|
14
|
-
readonly tf: typeof TF;
|
|
15
13
|
private _tokeniser?;
|
|
16
14
|
private _status;
|
|
17
|
-
constructor(
|
|
15
|
+
constructor(tokeniser?: ITokeniser, model?: NanoGPT);
|
|
16
|
+
get vocab(): string[];
|
|
17
|
+
get loaded(): boolean;
|
|
18
18
|
get config(): GPTConfig;
|
|
19
19
|
get model(): NanoGPT;
|
|
20
20
|
get tokeniser(): ITokeniser;
|
|
@@ -22,16 +22,25 @@ export default class TeachableLLM extends EE<'status' | 'error' | 'trainStep'> {
|
|
|
22
22
|
get ready(): boolean;
|
|
23
23
|
private setStatus;
|
|
24
24
|
saveModel(options?: SaveOptions): Promise<Blob>;
|
|
25
|
-
static loadModel(
|
|
26
|
-
static create(
|
|
25
|
+
static loadModel(data: Blob | Buffer | string): TeachableLLM;
|
|
26
|
+
static create(tokeniserType: 'char' | 'bpe', config?: Partial<GPTConfig>): TeachableLLM;
|
|
27
27
|
getProfiler(): MemoryProfiler | undefined;
|
|
28
28
|
get enableProfiler(): boolean;
|
|
29
29
|
set enableProfiler(value: boolean);
|
|
30
30
|
getNumParams(): number;
|
|
31
31
|
trainer(): Trainer;
|
|
32
32
|
train(text: string[], options?: ITrainerOptions): Promise<void>;
|
|
33
|
+
trainTokeniser(text: string[]): Promise<number>;
|
|
33
34
|
generator(): Generator;
|
|
34
35
|
generateText(prompt?: string, options?: IGenerateOptions): Promise<string>;
|
|
35
36
|
dispose(): void;
|
|
37
|
+
on(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
|
|
38
|
+
on(event: 'error', listener: (error: Error) => void): void;
|
|
39
|
+
on(event: 'trainStep', listener: (step: number) => void): void;
|
|
40
|
+
on(event: 'loaded', listener: () => void): void;
|
|
41
|
+
off(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
|
|
42
|
+
off(event: 'error', listener: (error: Error) => void): void;
|
|
43
|
+
off(event: 'trainStep', listener: (step: number) => void): void;
|
|
44
|
+
off(event: 'loaded', listener: () => void): void;
|
|
36
45
|
}
|
|
37
46
|
export {};
|