@genai-fi/nanogpt 0.7.3 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +25 -2
- package/dist/Generator.js +150 -49
- package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-N8TpOMYv.js} +14 -14
- package/dist/{Reshape-DvudQDvJ.js → Reshape-B-lWQRnF.js} +1 -1
- package/dist/{Reshape-DH5srBP0.js → Reshape-Bo8HzP8V.js} +5 -5
- package/dist/TeachableLLM.d.ts +6 -6
- package/dist/TeachableLLM.js +31 -31
- package/dist/Trainer.d.ts +13 -2
- package/dist/Trainer.js +21 -12
- package/dist/{axis_util-BzbKo31C.js → axis_util-DubwyOhW.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-TE7aTPhZ.js → backend_util-BJ-_jSeK.js} +46 -46
- package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-BYfCp5iL.js} +2 -2
- package/dist/{concat-CsxrgovM.js → concat-BmDqqFsa.js} +1 -1
- package/dist/{dataset-CtdBYwjo.js → dataset-CJmEGu6D.js} +5 -5
- package/dist/{dropout-DYs5QFGQ.js → dropout-sx0sjVAT.js} +8 -8
- package/dist/exports_initializers-DAKM8UO9.js +16 -0
- package/dist/{gather-CMMy2KEG.js → gather-C1siEkdp.js} +1 -1
- package/dist/{gelu-C-dPj6Ku.js → gelu-Bd3UBBxg.js} +1 -1
- package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-TFLxaLkw.js} +26 -26
- package/dist/{index-CLthM0TO.js → index-BaPo_0H8.js} +185 -185
- package/dist/{index-BoWRt-10.js → index-CUQrfsw_.js} +266 -265
- package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-P9aFa232.js} +9 -9
- package/dist/layers/BaseLayer.d.ts +8 -13
- package/dist/layers/BaseLayer.js +25 -13
- package/dist/layers/CausalSelfAttention.d.ts +3 -2
- package/dist/layers/CausalSelfAttention.js +28 -28
- package/dist/layers/MLP.d.ts +3 -2
- package/dist/layers/MLP.js +16 -20
- package/dist/layers/PositionEmbedding.d.ts +9 -0
- package/dist/layers/PositionEmbedding.js +45 -0
- package/dist/layers/RMSNorm.d.ts +3 -2
- package/dist/layers/RMSNorm.js +6 -6
- package/dist/layers/RoPECache.d.ts +1 -1
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.d.ts +3 -2
- package/dist/layers/TiedEmbedding.js +29 -7
- package/dist/layers/TransformerBlock.d.ts +3 -2
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/load.d.ts +2 -2
- package/dist/loader/loadHF.d.ts +2 -2
- package/dist/loader/loadTransformers.d.ts +4 -2
- package/dist/loader/loadTransformers.js +10 -9
- package/dist/loader/newZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.js +42 -51
- package/dist/loader/save.d.ts +8 -0
- package/dist/loader/save.js +62 -0
- package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C142qZqY.js} +14 -14
- package/dist/main.d.ts +5 -4
- package/dist/main.js +22 -18
- package/dist/{mat_mul-8m8pfdcx.js → mat_mul-DMkduNJu.js} +1 -1
- package/dist/{max-Ddnnb5xe.js → max-B3JOcNGb.js} +1 -1
- package/dist/mod-uUuj4gSb.js +27 -0
- package/dist/models/NanoGPTV1.d.ts +15 -0
- package/dist/models/NanoGPTV1.js +71 -0
- package/dist/{config.d.ts → models/config.d.ts} +1 -0
- package/dist/{config.js → models/config.js} +1 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +14 -0
- package/dist/models/model.d.ts +26 -0
- package/dist/models/model.js +68 -0
- package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-Cm2gw-c8.js} +1 -1
- package/dist/{ones-Dj0SDhHf.js → ones-ZdgQGBCP.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +3 -3
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +11 -11
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +10 -10
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +3 -3
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +5 -5
- package/dist/ops/webgpu/qkv.js +3 -3
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/{ops-BFGCx8Ri.js → ops-C_1K_-35.js} +103 -103
- package/dist/{random_width-sZORGo5k.js → random_width-D8Pwy_na.js} +136 -136
- package/dist/{range-CRuAh-gd.js → range-LVHrSLdi.js} +1 -1
- package/dist/{reciprocal-BvGAyKyu.js → reciprocal-CaR9e67G.js} +1 -1
- package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-DUshvVWP.js} +2026 -2049
- package/dist/{reshape-CdBq1WJ6.js → reshape-DEfQGSin.js} +1 -1
- package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-CUPPNLaA.js} +1 -1
- package/dist/{selu_util-BJEXVvjX.js → selu_util-8vv5JxQV.js} +3 -3
- package/dist/{shared-B8ztnyEk.js → shared-CkNorDcU.js} +83 -83
- package/dist/{shared-wS99K7_n.js → shared-D1elLckx.js} +1 -1
- package/dist/{sin-BeA3tsEd.js → sin-D2CKKmyR.js} +1 -1
- package/dist/{slice-BiOsknYS.js → slice-BnyE-M_7.js} +1 -1
- package/dist/{softmax-Bv_6lyMX.js → softmax-DLoZWYBx.js} +1 -1
- package/dist/{split-B-dikLRw.js → split-By_n4TKP.js} +1 -1
- package/dist/{stack-B17UN2nn.js → stack-DkdFLq37.js} +1 -1
- package/dist/{sum-66ew2byf.js → sum-l_0SqM4h.js} +3 -3
- package/dist/{tensor-JwS7ZYY6.js → tensor-BAQdLqoU.js} +1 -1
- package/dist/{tensor2d-wxPAnDQy.js → tensor2d-BHy261cI.js} +1 -1
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/Evaluator.d.ts +2 -2
- package/dist/training/FullTrainer.d.ts +3 -3
- package/dist/training/FullTrainer.js +61 -69
- package/dist/training/Trainer.d.ts +15 -3
- package/dist/training/Trainer.js +39 -47
- package/dist/training/sparseCrossEntropy.js +9 -9
- package/dist/utilities/dummy.d.ts +4 -4
- package/dist/utilities/dummy.js +13 -13
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/parameters.d.ts +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BuddVFLa.js → variable-C9hihzDB.js} +1 -1
- package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-dFEVbDPL.js} +1 -1
- package/dist/{webgpu_util-D____QpY.js → webgpu_util-DLImlSc6.js} +27 -27
- package/dist/{zeros--BdLQ3oG.js → zeros-VZ72lWXM.js} +1 -1
- package/package.json +2 -3
- package/dist/NanoGPTModel.d.ts +0 -52
- package/dist/NanoGPTModel.js +0 -203
- package/dist/TiedEmbedding-BxOerUmB.js +0 -43
- package/dist/utilities/generate.d.ts +0 -3
- package/dist/utilities/generate.js +0 -22
- package/dist/utilities/save.d.ts +0 -9
- package/dist/utilities/save.js +0 -61
package/dist/Generator.d.ts
CHANGED
|
@@ -1,10 +1,23 @@
|
|
|
1
|
-
import { default as NanoGPT, GenerateOptions } from './NanoGPTModel';
|
|
2
1
|
import { ITokeniser } from './tokeniser/type';
|
|
3
2
|
import { default as EE } from 'eventemitter3';
|
|
3
|
+
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
4
|
+
export interface GenerateOptions {
|
|
5
|
+
temperature?: number;
|
|
6
|
+
topK?: number;
|
|
7
|
+
topP?: number;
|
|
8
|
+
usePadding?: boolean;
|
|
9
|
+
attentionScores?: boolean;
|
|
10
|
+
includeProbabilities?: boolean;
|
|
11
|
+
embeddings?: boolean;
|
|
12
|
+
}
|
|
4
13
|
export interface IGenerateOptions extends GenerateOptions {
|
|
5
14
|
maxLength?: number;
|
|
6
15
|
noCache?: boolean;
|
|
7
16
|
}
|
|
17
|
+
/**
|
|
18
|
+
* Text generator using a NanoGPT model and a tokeniser.
|
|
19
|
+
* This uses the forward method of the model to generate text token by token, including options for temperature, top-k, and top-p sampling.
|
|
20
|
+
*/
|
|
8
21
|
export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
|
|
9
22
|
private readonly model;
|
|
10
23
|
private readonly tokeniser;
|
|
@@ -14,9 +27,16 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
|
|
|
14
27
|
private outputText;
|
|
15
28
|
private actualTokeniser;
|
|
16
29
|
private lastToken;
|
|
17
|
-
|
|
30
|
+
private attentionData;
|
|
31
|
+
private probabilitiesData;
|
|
32
|
+
private embeddingsData;
|
|
33
|
+
private tokens;
|
|
34
|
+
constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
|
|
18
35
|
private tokenisePrompt;
|
|
19
36
|
private processResponse;
|
|
37
|
+
/** Generate logits and select a token. */
|
|
38
|
+
private _generateToken;
|
|
39
|
+
/** Generate multiple tokens in a loop and produce text */
|
|
20
40
|
private _generate;
|
|
21
41
|
reset(): void;
|
|
22
42
|
dispose(): void;
|
|
@@ -25,4 +45,7 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
|
|
|
25
45
|
generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
|
|
26
46
|
stop(): void;
|
|
27
47
|
getText(): string;
|
|
48
|
+
getAttentionData(): number[][][][];
|
|
49
|
+
getProbabilitiesData(): number[][][];
|
|
50
|
+
getTokens(): number[];
|
|
28
51
|
}
|
package/dist/Generator.js
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
import { E as
|
|
2
|
-
import "./index-
|
|
1
|
+
import { E as z } from "./index-Dwqa6Zy2.js";
|
|
2
|
+
import { B as A, C as L, E as C, a5 as I, t as O, k as R } from "./index-CUQrfsw_.js";
|
|
3
3
|
import "./ops/cpu/attentionMask.js";
|
|
4
4
|
import "./ops/webgl/attentionMask.js";
|
|
5
5
|
import "./ops/grads/attentionMask.js";
|
|
6
6
|
import "./ops/cpu/qkv.js";
|
|
7
7
|
import "./ops/webgl/qkv.js";
|
|
8
8
|
import "./ops/grads/qkv.js";
|
|
9
|
-
import "./random_width-
|
|
10
|
-
import "./register_all_kernels-
|
|
9
|
+
import { p as _ } from "./random_width-D8Pwy_na.js";
|
|
10
|
+
import { t as K } from "./register_all_kernels-DUshvVWP.js";
|
|
11
11
|
import "./index-Tf7vU29b.js";
|
|
12
|
-
import "./dataset-
|
|
12
|
+
import "./dataset-CJmEGu6D.js";
|
|
13
13
|
import "./ops/cpu/rope.js";
|
|
14
14
|
import "./ops/webgl/rope.js";
|
|
15
15
|
import "./ops/grads/rope.js";
|
|
@@ -29,7 +29,7 @@ import "./ops/webgl/gatherSub.js";
|
|
|
29
29
|
import "./ops/cpu/scatterSub.js";
|
|
30
30
|
import "./ops/webgl/scatterSub.js";
|
|
31
31
|
import "./jszip.min-CjP2V1VV.js";
|
|
32
|
-
import
|
|
32
|
+
import M from "./tokeniser/CharTokeniser.js";
|
|
33
33
|
import "./ops/cpu/adamAdjust.js";
|
|
34
34
|
import "./ops/webgl/adamAdjust.js";
|
|
35
35
|
import "./ops/cpu/adamMoments.js";
|
|
@@ -37,12 +37,42 @@ import "./ops/webgl/adamMoments.js";
|
|
|
37
37
|
import "./papaparse.min-C8l2Kvo1.js";
|
|
38
38
|
import "./ops/cpu/gelu.js";
|
|
39
39
|
import "./ops/webgl/gelu.js";
|
|
40
|
-
import "./gelu-
|
|
40
|
+
import "./gelu-Bd3UBBxg.js";
|
|
41
41
|
import "./ops/webgl/log.js";
|
|
42
|
-
import
|
|
43
|
-
import {
|
|
44
|
-
|
|
45
|
-
|
|
42
|
+
import $ from "./utilities/multinomialCPU.js";
|
|
43
|
+
import { r as x } from "./reshape-DEfQGSin.js";
|
|
44
|
+
import { t as P } from "./tensor2d-BHy261cI.js";
|
|
45
|
+
import { s as v } from "./softmax-DLoZWYBx.js";
|
|
46
|
+
import { g as q } from "./gather-C1siEkdp.js";
|
|
47
|
+
import { c as G } from "./concat-BmDqqFsa.js";
|
|
48
|
+
/**
|
|
49
|
+
* @license
|
|
50
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
51
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
52
|
+
* you may not use this file except in compliance with the License.
|
|
53
|
+
* You may obtain a copy of the License at
|
|
54
|
+
*
|
|
55
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
56
|
+
*
|
|
57
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
58
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
59
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
60
|
+
* See the License for the specific language governing permissions and
|
|
61
|
+
* limitations under the License.
|
|
62
|
+
* =============================================================================
|
|
63
|
+
*/
|
|
64
|
+
function N(h, t, e, i = !1) {
|
|
65
|
+
const o = L(h, "logits", "multinomial"), s = o.size, n = o.rank;
|
|
66
|
+
if (s < 2)
|
|
67
|
+
throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
|
|
68
|
+
if (n > 2)
|
|
69
|
+
throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
|
|
70
|
+
e = e || Math.random();
|
|
71
|
+
const a = { logits: n === 1 ? x(o, [1, -1]) : o }, p = { numSamples: t, seed: e, normalized: i }, l = C.runKernel(I, a, p);
|
|
72
|
+
return n === 1 ? x(l, [l.size]) : l;
|
|
73
|
+
}
|
|
74
|
+
const S = /* @__PURE__ */ A({ multinomial_: N }), B = [
|
|
75
|
+
...Array.from({ length: 95 }, (h, t) => String.fromCharCode(t + 32)),
|
|
46
76
|
// ASCII
|
|
47
77
|
// Spanish accented letters and punctuation
|
|
48
78
|
..."áéíóúüñ¿¡",
|
|
@@ -53,12 +83,12 @@ const k = [
|
|
|
53
83
|
// Cyrillic letters
|
|
54
84
|
..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
|
|
55
85
|
];
|
|
56
|
-
function
|
|
57
|
-
return
|
|
86
|
+
function H(h, t) {
|
|
87
|
+
return h.length === t ? h : h.length > t ? h.slice(0, t) : h.concat(Array(t - h.length).fill(""));
|
|
58
88
|
}
|
|
59
|
-
class
|
|
60
|
-
constructor(t,
|
|
61
|
-
super(), this.model = t, this.tokeniser =
|
|
89
|
+
class Mt extends z {
|
|
90
|
+
constructor(t, e) {
|
|
91
|
+
super(), this.model = t, this.tokeniser = e, this.actualTokeniser = e;
|
|
62
92
|
}
|
|
63
93
|
active = !1;
|
|
64
94
|
cache = null;
|
|
@@ -66,71 +96,133 @@ class nt extends l {
|
|
|
66
96
|
outputText = "";
|
|
67
97
|
actualTokeniser;
|
|
68
98
|
lastToken = -1;
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
99
|
+
attentionData = [];
|
|
100
|
+
probabilitiesData = [];
|
|
101
|
+
embeddingsData = [];
|
|
102
|
+
tokens = [];
|
|
103
|
+
async tokenisePrompt(t, e) {
|
|
104
|
+
const i = e ? await t.tokenise([e], !0) : [[t.eosToken]];
|
|
105
|
+
return P(i, [1, i[0].length], "int32");
|
|
72
106
|
}
|
|
73
|
-
async processResponse(t,
|
|
74
|
-
const s = (await
|
|
107
|
+
async processResponse(t, e, i, o) {
|
|
108
|
+
const s = (await e.array())[0][0];
|
|
75
109
|
if (this.lastToken = s, s === this.tokeniser.eosToken)
|
|
76
110
|
return null;
|
|
77
111
|
const n = await t.decode([s]);
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
112
|
+
if (i) {
|
|
113
|
+
const d = await Promise.all(i.map((a) => a.array().then((p) => p)));
|
|
114
|
+
i.forEach((a) => a.dispose()), this.attentionData.push(d);
|
|
115
|
+
}
|
|
116
|
+
if (o) {
|
|
117
|
+
const d = await o.array();
|
|
118
|
+
o.dispose(), this.probabilitiesData.push(d);
|
|
119
|
+
}
|
|
120
|
+
return this.tokens.push(s), this.emit("tokens", [s], n), n;
|
|
121
|
+
}
|
|
122
|
+
/** Generate logits and select a token. */
|
|
123
|
+
async _generateToken(t, e, i) {
|
|
124
|
+
const o = i?.temperature ?? 1, s = i?.topK, n = i?.topP, d = i?.usePadding ?? !1, a = {
|
|
125
|
+
training: !1,
|
|
126
|
+
attentionScores: i?.attentionScores ? {
|
|
127
|
+
attentionOut: []
|
|
128
|
+
} : void 0,
|
|
129
|
+
cache: e,
|
|
130
|
+
outputEmbeddings: i?.embeddings ?? !1
|
|
131
|
+
}, p = O(() => {
|
|
132
|
+
const r = t, m = r.shape[1], u = m <= this.model.config.blockSize ? r : r.slice(
|
|
133
|
+
[0, m - this.model.config.blockSize],
|
|
134
|
+
[r.shape[0], this.model.config.blockSize]
|
|
135
|
+
), k = d ? this.model.config.blockSize - u.shape[1] : 0, b = k > 0 ? _(u, [
|
|
136
|
+
[0, 0],
|
|
137
|
+
[0, k]
|
|
138
|
+
]) : u, [f] = this.model.forward(a, b), y = f.shape[1] - 1 - k, c = f.slice([0, y, 0], [f.shape[0], 1, f.shape[2]]);
|
|
139
|
+
return a.attentionScores?.attentionOut && a.attentionScores.attentionOut.forEach((T, E) => {
|
|
140
|
+
T.shape[1] !== 1 && (a.attentionScores.attentionOut[E] = R(
|
|
141
|
+
T.slice([0, y, 0], [T.shape[0], 1, T.shape[2]])
|
|
142
|
+
), T.dispose());
|
|
143
|
+
}), f.dispose(), c.div(o).squeeze([1]);
|
|
144
|
+
});
|
|
145
|
+
let l;
|
|
146
|
+
if (n) {
|
|
147
|
+
const r = v(p), m = await r.array();
|
|
148
|
+
r.dispose();
|
|
149
|
+
const u = m[0].map((c, g) => ({ prob: c, index: g })).sort((c, g) => g.prob - c.prob);
|
|
150
|
+
let k = 0;
|
|
151
|
+
const b = new Array(u.length).fill(0);
|
|
152
|
+
for (const c of u)
|
|
153
|
+
if (k += c.prob, b[c.index] = c.prob, k >= n)
|
|
154
|
+
break;
|
|
155
|
+
const f = b.reduce((c, g) => c + g, 0), y = b.map((c) => c / f);
|
|
156
|
+
l = $(y);
|
|
157
|
+
} else if (s) {
|
|
158
|
+
const { values: r, indices: m } = K(p, s), u = S(r, 1);
|
|
159
|
+
l = q(m, u, 1), r.dispose(), m.dispose(), u.dispose();
|
|
160
|
+
} else
|
|
161
|
+
l = S(p, 1);
|
|
162
|
+
let w;
|
|
163
|
+
i?.includeProbabilities && (w = v(p)), a.embeddings && this.embeddingsData.push(
|
|
164
|
+
await Promise.all(
|
|
165
|
+
a.embeddings.map(async (r) => {
|
|
166
|
+
const m = await r.array();
|
|
167
|
+
return r.dispose(), m;
|
|
168
|
+
})
|
|
169
|
+
)
|
|
170
|
+
);
|
|
171
|
+
const D = l.reshape([1, 1]);
|
|
172
|
+
return l.dispose(), l = D, p.dispose(), { output: l, probabilities: w, attention: a.attentionScores?.attentionOut };
|
|
82
173
|
}
|
|
174
|
+
/** Generate multiple tokens in a loop and produce text */
|
|
83
175
|
async _generate(t) {
|
|
84
|
-
let
|
|
85
|
-
const
|
|
86
|
-
for (let o = 0; o <
|
|
176
|
+
let e = this.lastToken >= 0 && this.cache ? P([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
|
|
177
|
+
const i = t?.maxLength ?? 1e3;
|
|
178
|
+
for (let o = 0; o < i && this.active; o++) {
|
|
87
179
|
const {
|
|
88
180
|
output: s,
|
|
89
181
|
probabilities: n,
|
|
90
|
-
attention:
|
|
91
|
-
} = await this.
|
|
182
|
+
attention: d
|
|
183
|
+
} = await this._generateToken(e, this.cache ? this.cache : void 0, {
|
|
92
184
|
...t,
|
|
93
185
|
usePadding: !this.cache
|
|
94
186
|
});
|
|
95
187
|
if (this.cache)
|
|
96
|
-
|
|
188
|
+
e.dispose(), e = s;
|
|
97
189
|
else {
|
|
98
|
-
const
|
|
99
|
-
|
|
190
|
+
const p = e;
|
|
191
|
+
e = G([e, s], 1), p.dispose();
|
|
100
192
|
}
|
|
101
|
-
const a = await this.processResponse(this.actualTokeniser, s,
|
|
193
|
+
const a = await this.processResponse(this.actualTokeniser, s, d, n);
|
|
102
194
|
if (this.cache || s.dispose(), a === null)
|
|
103
195
|
break;
|
|
104
196
|
this.outputText += a;
|
|
105
197
|
}
|
|
106
|
-
return
|
|
198
|
+
return e.dispose(), this.outputText;
|
|
107
199
|
}
|
|
108
200
|
reset() {
|
|
109
201
|
this.cache && (this.cache.forEach((t) => {
|
|
110
202
|
t && (t.k && t.k.dispose(), t.v && t.v.dispose());
|
|
111
|
-
}), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1;
|
|
203
|
+
}), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1, this.attentionData = [], this.probabilitiesData = [], this.tokens = [];
|
|
112
204
|
}
|
|
113
205
|
dispose() {
|
|
114
206
|
this.reset();
|
|
115
207
|
}
|
|
116
|
-
initialise(t,
|
|
117
|
-
const
|
|
118
|
-
if (this.cache &&
|
|
119
|
-
const s = new Array(this.model.config.
|
|
120
|
-
for (let n = 0; n < this.model.config.
|
|
208
|
+
initialise(t, e) {
|
|
209
|
+
const i = t && t.length > this.model.config.blockSize ? t.slice(-this.model.config.blockSize) : t ?? null;
|
|
210
|
+
if (this.cache && e?.noCache && this.reset(), this.initialPrompt = i || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !e?.noCache && this.model.config.useRope) {
|
|
211
|
+
const s = new Array(this.model.config.nLayer);
|
|
212
|
+
for (let n = 0; n < this.model.config.nLayer; n++)
|
|
121
213
|
s[n] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
|
|
122
214
|
this.cache = s, this.lastToken = -1;
|
|
123
215
|
}
|
|
124
|
-
const o = this.tokeniser.trained ? this.tokeniser : new
|
|
216
|
+
const o = this.tokeniser.trained ? this.tokeniser : new M(H(B, this.tokeniser.vocabSize));
|
|
125
217
|
this.actualTokeniser = o;
|
|
126
218
|
}
|
|
127
|
-
async step(t,
|
|
128
|
-
const
|
|
129
|
-
return this.generate(t,
|
|
219
|
+
async step(t, e) {
|
|
220
|
+
const i = { ...e, maxLength: 1 };
|
|
221
|
+
return this.generate(t, i);
|
|
130
222
|
}
|
|
131
|
-
async generate(t,
|
|
132
|
-
this.initialise(t,
|
|
133
|
-
const o = await this._generate(
|
|
223
|
+
async generate(t, e) {
|
|
224
|
+
this.initialise(t, e), this.active = !0, this.emit("start");
|
|
225
|
+
const o = await this._generate(e);
|
|
134
226
|
return this.active = !1, this.emit("stop"), o;
|
|
135
227
|
}
|
|
136
228
|
stop() {
|
|
@@ -139,7 +231,16 @@ class nt extends l {
|
|
|
139
231
|
getText() {
|
|
140
232
|
return this.outputText;
|
|
141
233
|
}
|
|
234
|
+
getAttentionData() {
|
|
235
|
+
return this.attentionData;
|
|
236
|
+
}
|
|
237
|
+
getProbabilitiesData() {
|
|
238
|
+
return this.probabilitiesData;
|
|
239
|
+
}
|
|
240
|
+
getTokens() {
|
|
241
|
+
return this.tokens;
|
|
242
|
+
}
|
|
142
243
|
}
|
|
143
244
|
export {
|
|
144
|
-
|
|
245
|
+
Mt as default
|
|
145
246
|
};
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { r as $ } from "./Reshape-
|
|
3
|
-
import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-
|
|
4
|
-
import { t as
|
|
5
|
-
import { c as _ } from "./backend_util-
|
|
6
|
-
import { f as y } from "./gpgpu_math-
|
|
7
|
-
import { g as G, b as L } from "./kernel_funcs_utils-
|
|
1
|
+
import { as as T, af as E, p as O, j as V, aA as B, a0 as F, X as j, aB as K } from "./index-CUQrfsw_.js";
|
|
2
|
+
import { r as $ } from "./Reshape-Bo8HzP8V.js";
|
|
3
|
+
import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-DubwyOhW.js";
|
|
4
|
+
import { t as U, m as W } from "./shared-D1elLckx.js";
|
|
5
|
+
import { c as _ } from "./backend_util-BJ-_jSeK.js";
|
|
6
|
+
import { f as y } from "./gpgpu_math-TFLxaLkw.js";
|
|
7
|
+
import { g as G, b as L } from "./kernel_funcs_utils-P9aFa232.js";
|
|
8
8
|
/**
|
|
9
9
|
* @license
|
|
10
10
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -105,7 +105,7 @@ class w {
|
|
|
105
105
|
* limitations under the License.
|
|
106
106
|
* =============================================================================
|
|
107
107
|
*/
|
|
108
|
-
class
|
|
108
|
+
class X {
|
|
109
109
|
constructor(s, e) {
|
|
110
110
|
this.variableNames = ["x"];
|
|
111
111
|
const { windowSize: t, batchSize: n, inSize: l, outSize: r } = s;
|
|
@@ -229,7 +229,7 @@ class q {
|
|
|
229
229
|
* limitations under the License.
|
|
230
230
|
* =============================================================================
|
|
231
231
|
*/
|
|
232
|
-
function
|
|
232
|
+
function q(a) {
|
|
233
233
|
const s = [];
|
|
234
234
|
for (; s.length === 0 || s[s.length - 1].outSize !== 1; ) {
|
|
235
235
|
const e = s.length ? s[s.length - 1].outSize : a[1], t = _(e);
|
|
@@ -242,12 +242,12 @@ function X(a) {
|
|
|
242
242
|
return s;
|
|
243
243
|
}
|
|
244
244
|
function P(a, s, e, t) {
|
|
245
|
-
const n =
|
|
245
|
+
const n = q(a.shape);
|
|
246
246
|
let l = a;
|
|
247
247
|
for (let r = 0; r < n.length; r++) {
|
|
248
248
|
const { inSize: i, windowSize: c, outSize: o } = n[r];
|
|
249
249
|
let u, p;
|
|
250
|
-
e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new
|
|
250
|
+
e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new X({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, e), p = l, l = t.runWebGLProgram(u, [l], s), p.dataId !== a.dataId && t.disposeIntermediateTensorInfo(p);
|
|
251
251
|
}
|
|
252
252
|
return l;
|
|
253
253
|
}
|
|
@@ -459,7 +459,7 @@ function te(a) {
|
|
|
459
459
|
const I = e.texData.get(d.dataId).values, m = new Array(i);
|
|
460
460
|
for (let v = 0; v < m.length; v++)
|
|
461
461
|
m[v] = n.shape[u[v]];
|
|
462
|
-
const z =
|
|
462
|
+
const z = U(I, n.shape, n.dtype, u, m);
|
|
463
463
|
d = e.makeTensorInfo(m, n.dtype);
|
|
464
464
|
const M = e.texData.get(d.dataId);
|
|
465
465
|
M.values = z;
|
|
@@ -482,7 +482,7 @@ function te(a) {
|
|
|
482
482
|
return p && e.disposeIntermediateTensorInfo(d), x;
|
|
483
483
|
}
|
|
484
484
|
const he = {
|
|
485
|
-
kernelName:
|
|
485
|
+
kernelName: j,
|
|
486
486
|
backendName: "webgl",
|
|
487
487
|
kernelFunc: te
|
|
488
488
|
};
|
|
@@ -525,7 +525,7 @@ return a / b;`, se = `
|
|
|
525
525
|
|
|
526
526
|
return result;
|
|
527
527
|
`, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
|
|
528
|
-
kernelName:
|
|
528
|
+
kernelName: K,
|
|
529
529
|
backendName: "webgl",
|
|
530
530
|
kernelFunc: ne
|
|
531
531
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { j as c,
|
|
2
|
-
import { u as g, g as I, a as x, b as F, c as $, d as u, e as
|
|
1
|
+
import { j as c, a4 as C, n as f, U as R } from "./index-CUQrfsw_.js";
|
|
2
|
+
import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-TFLxaLkw.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
5
5
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -82,14 +82,14 @@ function v(s, t) {
|
|
|
82
82
|
function b(s, t, i) {
|
|
83
83
|
const a = [
|
|
84
84
|
u(s.shape),
|
|
85
|
-
...
|
|
85
|
+
...m(s.shape)
|
|
86
86
|
], e = {
|
|
87
87
|
dtype: s.dtype,
|
|
88
88
|
shape: a,
|
|
89
89
|
dataId: s.dataId
|
|
90
90
|
}, o = [
|
|
91
91
|
u(t),
|
|
92
|
-
...
|
|
92
|
+
...m(t)
|
|
93
93
|
], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
|
|
94
94
|
return { dataId: h.dataId, shape: t, dtype: h.dtype };
|
|
95
95
|
}
|
|
@@ -113,7 +113,7 @@ function y(s) {
|
|
|
113
113
|
const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = C(o, p), h = c(n);
|
|
114
114
|
f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
|
|
115
115
|
const d = r.texData.get(e.dataId);
|
|
116
|
-
return d.isPacked && !
|
|
116
|
+
return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
|
|
117
117
|
}
|
|
118
118
|
const U = {
|
|
119
119
|
kernelName: R,
|
package/dist/TeachableLLM.d.ts
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
import { GPTConfig } from './config';
|
|
1
|
+
import { GPTConfig } from './models/config';
|
|
2
2
|
import { ITokeniser } from './tokeniser/type';
|
|
3
|
-
import {
|
|
4
|
-
import { SaveOptions } from './utilities/save';
|
|
3
|
+
import { SaveOptions } from './loader/save';
|
|
5
4
|
import { default as Generator, IGenerateOptions } from './Generator';
|
|
6
5
|
import { default as Trainer, ITrainerOptions } from './Trainer';
|
|
7
6
|
import { default as MemoryProfiler } from './utilities/profile';
|
|
8
|
-
import { TrainingProgress } from './training/Trainer';
|
|
7
|
+
import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
|
|
8
|
+
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
9
9
|
type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
|
|
10
10
|
interface TeachableLLMMeta {
|
|
11
11
|
name?: string;
|
|
@@ -20,12 +20,12 @@ export default class TeachableLLM {
|
|
|
20
20
|
private _status;
|
|
21
21
|
private _memoryRequirements?;
|
|
22
22
|
meta: TeachableLLMMeta;
|
|
23
|
-
constructor(tokeniser?: ITokeniser, model?:
|
|
23
|
+
constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes>);
|
|
24
24
|
get vocab(): string[];
|
|
25
25
|
/** Model is fully loaded */
|
|
26
26
|
get loaded(): boolean;
|
|
27
27
|
get config(): GPTConfig;
|
|
28
|
-
get model():
|
|
28
|
+
get model(): Model<ModelForwardAttributes>;
|
|
29
29
|
get tokeniser(): ITokeniser;
|
|
30
30
|
get status(): TeachableLLMStatus;
|
|
31
31
|
/** Model is both ready and not busy */
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -1,30 +1,21 @@
|
|
|
1
|
-
import { defaultConfig as
|
|
2
|
-
import
|
|
3
|
-
import {
|
|
4
|
-
import { loadModel as l } from "./loader/load.js";
|
|
1
|
+
import { defaultConfig as d } from "./models/config.js";
|
|
2
|
+
import { saveModel as l } from "./loader/save.js";
|
|
3
|
+
import { loadModel as _ } from "./loader/load.js";
|
|
5
4
|
import u from "./Generator.js";
|
|
6
|
-
import
|
|
7
|
-
import { E as
|
|
5
|
+
import f from "./Trainer.js";
|
|
6
|
+
import { E as p } from "./index-Dwqa6Zy2.js";
|
|
8
7
|
import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
|
|
9
|
-
import
|
|
10
|
-
import k from "./tokeniser/bpe.js";
|
|
11
|
-
import "./papaparse.min-C8l2Kvo1.js";
|
|
12
|
-
import "./index-Tf7vU29b.js";
|
|
13
|
-
import "./jszip.min-CjP2V1VV.js";
|
|
14
|
-
import "./index-BoWRt-10.js";
|
|
15
|
-
import "./ops/cpu/scatterSub.js";
|
|
16
|
-
import "./ops/webgl/scatterSub.js";
|
|
17
|
-
import "./ops/cpu/gatherSub.js";
|
|
18
|
-
import "./ops/webgl/gatherSub.js";
|
|
8
|
+
import "./index-CUQrfsw_.js";
|
|
19
9
|
import "./ops/cpu/attentionMask.js";
|
|
20
10
|
import "./ops/webgl/attentionMask.js";
|
|
21
11
|
import "./ops/grads/attentionMask.js";
|
|
22
12
|
import "./ops/cpu/qkv.js";
|
|
23
13
|
import "./ops/webgl/qkv.js";
|
|
24
14
|
import "./ops/grads/qkv.js";
|
|
25
|
-
import "./random_width-
|
|
26
|
-
import "./register_all_kernels-
|
|
27
|
-
import "./
|
|
15
|
+
import "./random_width-D8Pwy_na.js";
|
|
16
|
+
import "./register_all_kernels-DUshvVWP.js";
|
|
17
|
+
import "./index-Tf7vU29b.js";
|
|
18
|
+
import "./dataset-CJmEGu6D.js";
|
|
28
19
|
import "./ops/cpu/rope.js";
|
|
29
20
|
import "./ops/webgl/rope.js";
|
|
30
21
|
import "./ops/grads/rope.js";
|
|
@@ -36,20 +27,29 @@ import "./ops/grads/fusedSoftmax.js";
|
|
|
36
27
|
import "./ops/cpu/matMulGelu.js";
|
|
37
28
|
import "./ops/webgl/matMulGelu.js";
|
|
38
29
|
import "./ops/grads/matMulGelu.js";
|
|
39
|
-
import "./ops/cpu/gelu.js";
|
|
40
|
-
import "./ops/webgl/gelu.js";
|
|
41
|
-
import "./gelu-C-dPj6Ku.js";
|
|
42
30
|
import "./ops/cpu/normRMS.js";
|
|
43
31
|
import "./ops/webgl/normRMS.js";
|
|
44
32
|
import "./ops/grads/normRMS.js";
|
|
33
|
+
import "./ops/cpu/gatherSub.js";
|
|
34
|
+
import "./ops/webgl/gatherSub.js";
|
|
35
|
+
import "./ops/cpu/scatterSub.js";
|
|
36
|
+
import "./ops/webgl/scatterSub.js";
|
|
37
|
+
import c from "./tokeniser/CharTokeniser.js";
|
|
38
|
+
import g from "./tokeniser/bpe.js";
|
|
39
|
+
import "./papaparse.min-C8l2Kvo1.js";
|
|
40
|
+
import "./jszip.min-CjP2V1VV.js";
|
|
41
|
+
import "./ops/cpu/gelu.js";
|
|
42
|
+
import "./ops/webgl/gelu.js";
|
|
43
|
+
import "./gelu-Bd3UBBxg.js";
|
|
45
44
|
import "./ops/webgl/log.js";
|
|
46
45
|
import "./ops/cpu/adamMoments.js";
|
|
47
46
|
import "./ops/webgl/adamMoments.js";
|
|
48
47
|
import "./ops/cpu/adamAdjust.js";
|
|
49
48
|
import "./ops/webgl/adamAdjust.js";
|
|
50
|
-
import
|
|
49
|
+
import k from "./utilities/profile.js";
|
|
50
|
+
import w from "./models/factory.js";
|
|
51
51
|
class a {
|
|
52
|
-
ee = new
|
|
52
|
+
ee = new p();
|
|
53
53
|
_config;
|
|
54
54
|
_model;
|
|
55
55
|
_tokeniser;
|
|
@@ -69,7 +69,7 @@ class a {
|
|
|
69
69
|
get config() {
|
|
70
70
|
if (!this._config)
|
|
71
71
|
throw new Error("configuration_not_initialized.");
|
|
72
|
-
return this._config
|
|
72
|
+
return this._config;
|
|
73
73
|
}
|
|
74
74
|
get model() {
|
|
75
75
|
if (!this._model)
|
|
@@ -101,14 +101,14 @@ class a {
|
|
|
101
101
|
saveModel(t) {
|
|
102
102
|
if (!this._model || !this._tokeniser)
|
|
103
103
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
104
|
-
return
|
|
104
|
+
return l(this._model, this._tokeniser, {
|
|
105
105
|
...t,
|
|
106
106
|
name: t?.name || this.meta.name
|
|
107
107
|
});
|
|
108
108
|
}
|
|
109
109
|
static loadModel(t) {
|
|
110
110
|
const e = new a();
|
|
111
|
-
return
|
|
111
|
+
return _(t).then(({ model: r, tokeniser: o, name: s }) => {
|
|
112
112
|
e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
|
|
113
113
|
e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
|
|
114
114
|
}).catch((i) => {
|
|
@@ -119,7 +119,7 @@ class a {
|
|
|
119
119
|
}), e;
|
|
120
120
|
}
|
|
121
121
|
static create(t, e = {}) {
|
|
122
|
-
const r = { ...
|
|
122
|
+
const r = { ...d, ...e }, o = t === "char" ? new c(r.vocabSize) : new g(r.vocabSize), s = w(r), i = new a(o, s);
|
|
123
123
|
return i.setStatus("warmup"), m(s).then((n) => {
|
|
124
124
|
i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
|
|
125
125
|
h === "trained" && i.setStatus("ready");
|
|
@@ -138,9 +138,9 @@ class a {
|
|
|
138
138
|
if (t) {
|
|
139
139
|
if (!this._config)
|
|
140
140
|
return;
|
|
141
|
-
this.
|
|
141
|
+
this.model.getProfiler() || this.model.setProfiler(new k());
|
|
142
142
|
} else
|
|
143
|
-
this.
|
|
143
|
+
this.model.getProfiler() && this.model.setProfiler(null);
|
|
144
144
|
}
|
|
145
145
|
getNumParams() {
|
|
146
146
|
return this._model ? this._model.getNumParams() : 0;
|
|
@@ -148,7 +148,7 @@ class a {
|
|
|
148
148
|
trainer() {
|
|
149
149
|
if (!this._model || !this._tokeniser)
|
|
150
150
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
151
|
-
const t = new
|
|
151
|
+
const t = new f(this._model, this._tokeniser);
|
|
152
152
|
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
|
|
153
153
|
const o = this.ee.listeners("trainStep");
|
|
154
154
|
for (const s of o)
|
package/dist/Trainer.d.ts
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
import { default as NanoGPT } from './NanoGPTModel';
|
|
2
1
|
import { ITokeniser } from './tokeniser/type';
|
|
3
2
|
import { default as EE } from 'eventemitter3';
|
|
3
|
+
import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
|
|
4
|
+
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
4
5
|
export interface ITrainerOptions {
|
|
5
6
|
batchSize?: number;
|
|
6
7
|
learningRate?: number;
|
|
@@ -10,6 +11,11 @@ export interface ITrainerOptions {
|
|
|
10
11
|
prompt?: string;
|
|
11
12
|
validationSplit?: number;
|
|
12
13
|
advancedMetrics?: boolean;
|
|
14
|
+
gradientCheckpointing?: boolean;
|
|
15
|
+
}
|
|
16
|
+
interface ExtendedTrainingProgress extends TrainingProgress {
|
|
17
|
+
progress: number;
|
|
18
|
+
remaining: number;
|
|
13
19
|
}
|
|
14
20
|
export default class Trainer extends EE<'start' | 'stop' | 'log'> {
|
|
15
21
|
private trainer;
|
|
@@ -17,10 +23,15 @@ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
|
|
|
17
23
|
private trainDataset?;
|
|
18
24
|
private validationDataset?;
|
|
19
25
|
private totalSamples;
|
|
20
|
-
|
|
26
|
+
private log;
|
|
27
|
+
private progress;
|
|
28
|
+
constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
|
|
21
29
|
stop(): void;
|
|
22
30
|
reset(): void;
|
|
23
31
|
prepare(text: string[], options?: ITrainerOptions): Promise<void>;
|
|
24
32
|
train(options?: ITrainerOptions): Promise<void>;
|
|
25
33
|
step(options?: ITrainerOptions): Promise<void>;
|
|
34
|
+
getLog(): TrainingLogEntry[];
|
|
35
|
+
getProgress(): ExtendedTrainingProgress | null;
|
|
26
36
|
}
|
|
37
|
+
export {};
|