@genai-fi/nanogpt 0.7.3 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +25 -2
- package/dist/Generator.js +152 -49
- package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-D_q39E3A.js} +13 -13
- package/dist/{Reshape-DvudQDvJ.js → Reshape-41YpQqEo.js} +1 -1
- package/dist/{Reshape-DH5srBP0.js → Reshape-Bh_jzKzV.js} +5 -5
- package/dist/TeachableLLM.d.ts +6 -6
- package/dist/TeachableLLM.js +33 -31
- package/dist/Trainer.d.ts +13 -2
- package/dist/Trainer.js +21 -12
- package/dist/{axis_util-BzbKo31C.js → axis_util-Did9235A.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-TE7aTPhZ.js → backend_util-yC3YH1jo.js} +58 -58
- package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-CUvOdOT5.js} +2 -2
- package/dist/checks/appendCache.d.ts +1 -0
- package/dist/checks/appendCache.js +22 -0
- package/dist/checks/attentionMask.d.ts +1 -0
- package/dist/checks/attentionMask.js +37 -0
- package/dist/checks/check.d.ts +9 -0
- package/dist/checks/check.js +20 -0
- package/dist/checks/gelu.d.ts +1 -0
- package/dist/checks/gelu.js +18 -0
- package/dist/checks/index.d.ts +19 -0
- package/dist/checks/index.js +21 -0
- package/dist/checks/normRMS.d.ts +1 -0
- package/dist/checks/normRMS.js +16 -0
- package/dist/checks/normRMSGrad.d.ts +1 -0
- package/dist/checks/normRMSGrad.js +12 -0
- package/dist/checks/qkv.d.ts +1 -0
- package/dist/checks/qkv.js +25 -0
- package/dist/checks/rope.d.ts +1 -0
- package/dist/checks/rope.js +21 -0
- package/dist/{concat-CsxrgovM.js → concat-pHiVqR3L.js} +1 -1
- package/dist/{dataset-CtdBYwjo.js → dataset-DPPl-iLT.js} +9 -9
- package/dist/{dropout-DYs5QFGQ.js → dropout-CcKSfOYE.js} +18 -18
- package/dist/exports_initializers-DKk7-bsx.js +16 -0
- package/dist/{gather-CMMy2KEG.js → gather-CPg6ZlQA.js} +1 -1
- package/dist/{gelu-C-dPj6Ku.js → gelu-BkcmEEyD.js} +1 -1
- package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-D_ODOLix.js} +26 -26
- package/dist/{index-BoWRt-10.js → index-DdmHGZjq.js} +659 -650
- package/dist/{index-CLthM0TO.js → index-evZ57wr4.js} +185 -185
- package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-CDfFpUab.js} +21 -21
- package/dist/layers/BaseLayer.d.ts +8 -13
- package/dist/layers/BaseLayer.js +25 -13
- package/dist/layers/CausalSelfAttention.d.ts +3 -2
- package/dist/layers/CausalSelfAttention.js +28 -28
- package/dist/layers/MLP.d.ts +3 -2
- package/dist/layers/MLP.js +16 -20
- package/dist/layers/PositionEmbedding.d.ts +9 -0
- package/dist/layers/PositionEmbedding.js +45 -0
- package/dist/layers/RMSNorm.d.ts +3 -2
- package/dist/layers/RMSNorm.js +6 -6
- package/dist/layers/RoPECache.d.ts +1 -1
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.d.ts +3 -2
- package/dist/layers/TiedEmbedding.js +29 -7
- package/dist/layers/TransformerBlock.d.ts +3 -2
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/load.d.ts +2 -2
- package/dist/loader/loadHF.d.ts +2 -2
- package/dist/loader/loadTransformers.d.ts +4 -2
- package/dist/loader/loadTransformers.js +10 -9
- package/dist/loader/newZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.js +44 -51
- package/dist/loader/save.d.ts +8 -0
- package/dist/loader/save.js +62 -0
- package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C8yFJfZz.js} +45 -24
- package/dist/main.d.ts +6 -4
- package/dist/main.js +24 -18
- package/dist/{mat_mul-8m8pfdcx.js → mat_mul-Dpy2mMRu.js} +1 -1
- package/dist/mod-CbibJi3D.js +27 -0
- package/dist/models/NanoGPTV1.d.ts +15 -0
- package/dist/models/NanoGPTV1.js +71 -0
- package/dist/{config.d.ts → models/config.d.ts} +1 -0
- package/dist/{config.js → models/config.js} +1 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +14 -0
- package/dist/models/model.d.ts +26 -0
- package/dist/models/model.js +70 -0
- package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-q_Gmwyld.js} +1 -1
- package/dist/{ones-Dj0SDhHf.js → ones-BAqVh-eA.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +7 -7
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +10 -10
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +3 -3
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +5 -5
- package/dist/ops/webgpu/qkv.js +3 -3
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/ops-542ai2vG.js +1525 -0
- package/dist/{random_width-sZORGo5k.js → random_width-DKGeiFuR.js} +1471 -1538
- package/dist/{range-CRuAh-gd.js → range-BcUvLuf5.js} +1 -1
- package/dist/{reciprocal-BvGAyKyu.js → reciprocal-DhDWSKiD.js} +1 -1
- package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-Do9VvZmo.js} +2488 -2534
- package/dist/{max-Ddnnb5xe.js → relu-B1AXs7p5.js} +6 -6
- package/dist/{reshape-CdBq1WJ6.js → reshape-WeJkT3ja.js} +1 -1
- package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-B7yDhiQr.js} +1 -1
- package/dist/{selu_util-BJEXVvjX.js → selu_util-BgUO9gHY.js} +125 -146
- package/dist/{shared-wS99K7_n.js → shared-CZiWmQCI.js} +1 -1
- package/dist/{shared-B8ztnyEk.js → shared-V6D_md-c.js} +72 -72
- package/dist/{sin-BeA3tsEd.js → sin-CPxad7Am.js} +1 -1
- package/dist/{slice-BiOsknYS.js → slice-B7jXtPnp.js} +1 -1
- package/dist/{softmax-Bv_6lyMX.js → softmax-BfsyI4As.js} +1 -1
- package/dist/{split-B-dikLRw.js → split-BPxr8_8m.js} +1 -1
- package/dist/{stack-B17UN2nn.js → stack-BNwLzE43.js} +1 -1
- package/dist/{sum-66ew2byf.js → sum-ByFINZgi.js} +3 -3
- package/dist/{tensor-JwS7ZYY6.js → tensor-DbqgIV9B.js} +1 -1
- package/dist/tensor1d-CtJq5BOv.js +27 -0
- package/dist/{tensor2d-wxPAnDQy.js → tensor2d-CObBWBkW.js} +1 -1
- package/dist/tensor3d-BOukqWwr.js +30 -0
- package/dist/tensor4d-DLtk7Nxh.js +30 -0
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/Evaluator.d.ts +2 -2
- package/dist/training/FullTrainer.d.ts +3 -3
- package/dist/training/FullTrainer.js +61 -69
- package/dist/training/Trainer.d.ts +15 -3
- package/dist/training/Trainer.js +39 -47
- package/dist/training/sparseCrossEntropy.js +12 -13
- package/dist/utilities/arrayClose.d.ts +1 -1
- package/dist/utilities/arrayClose.js +16 -7
- package/dist/utilities/dummy.d.ts +4 -4
- package/dist/utilities/dummy.js +13 -13
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/parameters.d.ts +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BuddVFLa.js → variable-DPFOJyRG.js} +1 -1
- package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-Dhk9R5aG.js} +1 -1
- package/dist/{webgpu_util-D____QpY.js → webgpu_util-BqGnZg8t.js} +27 -27
- package/dist/{zeros--BdLQ3oG.js → zeros-Dnwix0p4.js} +1 -1
- package/package.json +2 -3
- package/dist/NanoGPTModel.d.ts +0 -52
- package/dist/NanoGPTModel.js +0 -203
- package/dist/TiedEmbedding-BxOerUmB.js +0 -43
- package/dist/ops-BFGCx8Ri.js +0 -1202
- package/dist/utilities/generate.d.ts +0 -3
- package/dist/utilities/generate.js +0 -22
- package/dist/utilities/save.d.ts +0 -9
- package/dist/utilities/save.js +0 -61
package/dist/Generator.d.ts
CHANGED
|
@@ -1,10 +1,23 @@
|
|
|
1
|
-
import { default as NanoGPT, GenerateOptions } from './NanoGPTModel';
|
|
2
1
|
import { ITokeniser } from './tokeniser/type';
|
|
3
2
|
import { default as EE } from 'eventemitter3';
|
|
3
|
+
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
4
|
+
export interface GenerateOptions {
|
|
5
|
+
temperature?: number;
|
|
6
|
+
topK?: number;
|
|
7
|
+
topP?: number;
|
|
8
|
+
usePadding?: boolean;
|
|
9
|
+
attentionScores?: boolean;
|
|
10
|
+
includeProbabilities?: boolean;
|
|
11
|
+
embeddings?: boolean;
|
|
12
|
+
}
|
|
4
13
|
export interface IGenerateOptions extends GenerateOptions {
|
|
5
14
|
maxLength?: number;
|
|
6
15
|
noCache?: boolean;
|
|
7
16
|
}
|
|
17
|
+
/**
|
|
18
|
+
* Text generator using a NanoGPT model and a tokeniser.
|
|
19
|
+
* This uses the forward method of the model to generate text token by token, including options for temperature, top-k, and top-p sampling.
|
|
20
|
+
*/
|
|
8
21
|
export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
|
|
9
22
|
private readonly model;
|
|
10
23
|
private readonly tokeniser;
|
|
@@ -14,9 +27,16 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
|
|
|
14
27
|
private outputText;
|
|
15
28
|
private actualTokeniser;
|
|
16
29
|
private lastToken;
|
|
17
|
-
|
|
30
|
+
private attentionData;
|
|
31
|
+
private probabilitiesData;
|
|
32
|
+
private embeddingsData;
|
|
33
|
+
private tokens;
|
|
34
|
+
constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
|
|
18
35
|
private tokenisePrompt;
|
|
19
36
|
private processResponse;
|
|
37
|
+
/** Generate logits and select a token. */
|
|
38
|
+
private _generateToken;
|
|
39
|
+
/** Generate multiple tokens in a loop and produce text */
|
|
20
40
|
private _generate;
|
|
21
41
|
reset(): void;
|
|
22
42
|
dispose(): void;
|
|
@@ -25,4 +45,7 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
|
|
|
25
45
|
generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
|
|
26
46
|
stop(): void;
|
|
27
47
|
getText(): string;
|
|
48
|
+
getAttentionData(): number[][][][];
|
|
49
|
+
getProbabilitiesData(): number[][][];
|
|
50
|
+
getTokens(): number[];
|
|
28
51
|
}
|
package/dist/Generator.js
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
import { E as
|
|
2
|
-
import "./index-
|
|
1
|
+
import { E as z } from "./index-Dwqa6Zy2.js";
|
|
2
|
+
import { C as A, D as L, E as C, a6 as I, t as O, k as R } from "./index-DdmHGZjq.js";
|
|
3
3
|
import "./ops/cpu/attentionMask.js";
|
|
4
4
|
import "./ops/webgl/attentionMask.js";
|
|
5
5
|
import "./ops/grads/attentionMask.js";
|
|
6
6
|
import "./ops/cpu/qkv.js";
|
|
7
7
|
import "./ops/webgl/qkv.js";
|
|
8
8
|
import "./ops/grads/qkv.js";
|
|
9
|
-
import "./random_width-
|
|
10
|
-
import "./register_all_kernels-
|
|
9
|
+
import { p as _ } from "./random_width-DKGeiFuR.js";
|
|
10
|
+
import { t as K } from "./register_all_kernels-Do9VvZmo.js";
|
|
11
11
|
import "./index-Tf7vU29b.js";
|
|
12
|
-
import "./dataset-
|
|
12
|
+
import "./dataset-DPPl-iLT.js";
|
|
13
13
|
import "./ops/cpu/rope.js";
|
|
14
14
|
import "./ops/webgl/rope.js";
|
|
15
15
|
import "./ops/grads/rope.js";
|
|
@@ -29,7 +29,7 @@ import "./ops/webgl/gatherSub.js";
|
|
|
29
29
|
import "./ops/cpu/scatterSub.js";
|
|
30
30
|
import "./ops/webgl/scatterSub.js";
|
|
31
31
|
import "./jszip.min-CjP2V1VV.js";
|
|
32
|
-
import
|
|
32
|
+
import M from "./tokeniser/CharTokeniser.js";
|
|
33
33
|
import "./ops/cpu/adamAdjust.js";
|
|
34
34
|
import "./ops/webgl/adamAdjust.js";
|
|
35
35
|
import "./ops/cpu/adamMoments.js";
|
|
@@ -37,12 +37,44 @@ import "./ops/webgl/adamMoments.js";
|
|
|
37
37
|
import "./papaparse.min-C8l2Kvo1.js";
|
|
38
38
|
import "./ops/cpu/gelu.js";
|
|
39
39
|
import "./ops/webgl/gelu.js";
|
|
40
|
-
import "./gelu-
|
|
40
|
+
import "./gelu-BkcmEEyD.js";
|
|
41
41
|
import "./ops/webgl/log.js";
|
|
42
|
-
import
|
|
43
|
-
import
|
|
44
|
-
|
|
45
|
-
|
|
42
|
+
import "./checks/normRMS.js";
|
|
43
|
+
import "./checks/normRMSGrad.js";
|
|
44
|
+
import $ from "./utilities/multinomialCPU.js";
|
|
45
|
+
import { r as x } from "./reshape-WeJkT3ja.js";
|
|
46
|
+
import { t as P } from "./tensor2d-CObBWBkW.js";
|
|
47
|
+
import { s as v } from "./softmax-BfsyI4As.js";
|
|
48
|
+
import { g as q } from "./gather-CPg6ZlQA.js";
|
|
49
|
+
import { c as G } from "./concat-pHiVqR3L.js";
|
|
50
|
+
/**
|
|
51
|
+
* @license
|
|
52
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
53
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
54
|
+
* you may not use this file except in compliance with the License.
|
|
55
|
+
* You may obtain a copy of the License at
|
|
56
|
+
*
|
|
57
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
58
|
+
*
|
|
59
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
60
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
61
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
62
|
+
* See the License for the specific language governing permissions and
|
|
63
|
+
* limitations under the License.
|
|
64
|
+
* =============================================================================
|
|
65
|
+
*/
|
|
66
|
+
function N(m, t, e, i = !1) {
|
|
67
|
+
const o = L(m, "logits", "multinomial"), s = o.size, n = o.rank;
|
|
68
|
+
if (s < 2)
|
|
69
|
+
throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
|
|
70
|
+
if (n > 2)
|
|
71
|
+
throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
|
|
72
|
+
e = e || Math.random();
|
|
73
|
+
const a = { logits: n === 1 ? x(o, [1, -1]) : o }, p = { numSamples: t, seed: e, normalized: i }, l = C.runKernel(I, a, p);
|
|
74
|
+
return n === 1 ? x(l, [l.size]) : l;
|
|
75
|
+
}
|
|
76
|
+
const S = /* @__PURE__ */ A({ multinomial_: N }), H = [
|
|
77
|
+
...Array.from({ length: 95 }, (m, t) => String.fromCharCode(t + 32)),
|
|
46
78
|
// ASCII
|
|
47
79
|
// Spanish accented letters and punctuation
|
|
48
80
|
..."áéíóúüñ¿¡",
|
|
@@ -53,12 +85,12 @@ const k = [
|
|
|
53
85
|
// Cyrillic letters
|
|
54
86
|
..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
|
|
55
87
|
];
|
|
56
|
-
function
|
|
57
|
-
return
|
|
88
|
+
function U(m, t) {
|
|
89
|
+
return m.length === t ? m : m.length > t ? m.slice(0, t) : m.concat(Array(t - m.length).fill(""));
|
|
58
90
|
}
|
|
59
|
-
class
|
|
60
|
-
constructor(t,
|
|
61
|
-
super(), this.model = t, this.tokeniser =
|
|
91
|
+
class qt extends z {
|
|
92
|
+
constructor(t, e) {
|
|
93
|
+
super(), this.model = t, this.tokeniser = e, this.actualTokeniser = e;
|
|
62
94
|
}
|
|
63
95
|
active = !1;
|
|
64
96
|
cache = null;
|
|
@@ -66,71 +98,133 @@ class nt extends l {
|
|
|
66
98
|
outputText = "";
|
|
67
99
|
actualTokeniser;
|
|
68
100
|
lastToken = -1;
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
101
|
+
attentionData = [];
|
|
102
|
+
probabilitiesData = [];
|
|
103
|
+
embeddingsData = [];
|
|
104
|
+
tokens = [];
|
|
105
|
+
async tokenisePrompt(t, e) {
|
|
106
|
+
const i = e ? await t.tokenise([e], !0) : [[t.eosToken]];
|
|
107
|
+
return P(i, [1, i[0].length], "int32");
|
|
72
108
|
}
|
|
73
|
-
async processResponse(t,
|
|
74
|
-
const s = (await
|
|
109
|
+
async processResponse(t, e, i, o) {
|
|
110
|
+
const s = (await e.array())[0][0];
|
|
75
111
|
if (this.lastToken = s, s === this.tokeniser.eosToken)
|
|
76
112
|
return null;
|
|
77
113
|
const n = await t.decode([s]);
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
114
|
+
if (i) {
|
|
115
|
+
const d = await Promise.all(i.map((a) => a.array().then((p) => p)));
|
|
116
|
+
i.forEach((a) => a.dispose()), this.attentionData.push(d);
|
|
117
|
+
}
|
|
118
|
+
if (o) {
|
|
119
|
+
const d = await o.array();
|
|
120
|
+
o.dispose(), this.probabilitiesData.push(d);
|
|
121
|
+
}
|
|
122
|
+
return this.tokens.push(s), this.emit("tokens", [s], n), n;
|
|
123
|
+
}
|
|
124
|
+
/** Generate logits and select a token. */
|
|
125
|
+
async _generateToken(t, e, i) {
|
|
126
|
+
const o = i?.temperature ?? 1, s = i?.topK, n = i?.topP, d = i?.usePadding ?? !1, a = {
|
|
127
|
+
training: !1,
|
|
128
|
+
attentionScores: i?.attentionScores ? {
|
|
129
|
+
attentionOut: []
|
|
130
|
+
} : void 0,
|
|
131
|
+
cache: e,
|
|
132
|
+
outputEmbeddings: i?.embeddings ?? !1
|
|
133
|
+
}, p = O(() => {
|
|
134
|
+
const r = t, h = r.shape[1], u = h <= this.model.config.blockSize ? r : r.slice(
|
|
135
|
+
[0, h - this.model.config.blockSize],
|
|
136
|
+
[r.shape[0], this.model.config.blockSize]
|
|
137
|
+
), k = d ? this.model.config.blockSize - u.shape[1] : 0, b = k > 0 ? _(u, [
|
|
138
|
+
[0, 0],
|
|
139
|
+
[0, k]
|
|
140
|
+
]) : u, [f] = this.model.forward(a, b), y = f.shape[1] - 1 - k, c = f.slice([0, y, 0], [f.shape[0], 1, f.shape[2]]);
|
|
141
|
+
return a.attentionScores?.attentionOut && a.attentionScores.attentionOut.forEach((T, E) => {
|
|
142
|
+
T.shape[1] !== 1 && (a.attentionScores.attentionOut[E] = R(
|
|
143
|
+
T.slice([0, y, 0], [T.shape[0], 1, T.shape[2]])
|
|
144
|
+
), T.dispose());
|
|
145
|
+
}), f.dispose(), c.div(o).squeeze([1]);
|
|
146
|
+
});
|
|
147
|
+
let l;
|
|
148
|
+
if (n) {
|
|
149
|
+
const r = v(p), h = await r.array();
|
|
150
|
+
r.dispose();
|
|
151
|
+
const u = h[0].map((c, g) => ({ prob: c, index: g })).sort((c, g) => g.prob - c.prob);
|
|
152
|
+
let k = 0;
|
|
153
|
+
const b = new Array(u.length).fill(0);
|
|
154
|
+
for (const c of u)
|
|
155
|
+
if (k += c.prob, b[c.index] = c.prob, k >= n)
|
|
156
|
+
break;
|
|
157
|
+
const f = b.reduce((c, g) => c + g, 0), y = b.map((c) => c / f);
|
|
158
|
+
l = $(y);
|
|
159
|
+
} else if (s) {
|
|
160
|
+
const { values: r, indices: h } = K(p, s), u = S(r, 1);
|
|
161
|
+
l = q(h, u, 1), r.dispose(), h.dispose(), u.dispose();
|
|
162
|
+
} else
|
|
163
|
+
l = S(p, 1);
|
|
164
|
+
let w;
|
|
165
|
+
i?.includeProbabilities && (w = v(p)), a.embeddings && this.embeddingsData.push(
|
|
166
|
+
await Promise.all(
|
|
167
|
+
a.embeddings.map(async (r) => {
|
|
168
|
+
const h = await r.array();
|
|
169
|
+
return r.dispose(), h;
|
|
170
|
+
})
|
|
171
|
+
)
|
|
172
|
+
);
|
|
173
|
+
const D = l.reshape([1, 1]);
|
|
174
|
+
return l.dispose(), l = D, p.dispose(), { output: l, probabilities: w, attention: a.attentionScores?.attentionOut };
|
|
82
175
|
}
|
|
176
|
+
/** Generate multiple tokens in a loop and produce text */
|
|
83
177
|
async _generate(t) {
|
|
84
|
-
let
|
|
85
|
-
const
|
|
86
|
-
for (let o = 0; o <
|
|
178
|
+
let e = this.lastToken >= 0 && this.cache ? P([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
|
|
179
|
+
const i = t?.maxLength ?? 1e3;
|
|
180
|
+
for (let o = 0; o < i && this.active; o++) {
|
|
87
181
|
const {
|
|
88
182
|
output: s,
|
|
89
183
|
probabilities: n,
|
|
90
|
-
attention:
|
|
91
|
-
} = await this.
|
|
184
|
+
attention: d
|
|
185
|
+
} = await this._generateToken(e, this.cache ? this.cache : void 0, {
|
|
92
186
|
...t,
|
|
93
187
|
usePadding: !this.cache
|
|
94
188
|
});
|
|
95
189
|
if (this.cache)
|
|
96
|
-
|
|
190
|
+
e.dispose(), e = s;
|
|
97
191
|
else {
|
|
98
|
-
const
|
|
99
|
-
|
|
192
|
+
const p = e;
|
|
193
|
+
e = G([e, s], 1), p.dispose();
|
|
100
194
|
}
|
|
101
|
-
const a = await this.processResponse(this.actualTokeniser, s,
|
|
195
|
+
const a = await this.processResponse(this.actualTokeniser, s, d, n);
|
|
102
196
|
if (this.cache || s.dispose(), a === null)
|
|
103
197
|
break;
|
|
104
198
|
this.outputText += a;
|
|
105
199
|
}
|
|
106
|
-
return
|
|
200
|
+
return e.dispose(), this.outputText;
|
|
107
201
|
}
|
|
108
202
|
reset() {
|
|
109
203
|
this.cache && (this.cache.forEach((t) => {
|
|
110
204
|
t && (t.k && t.k.dispose(), t.v && t.v.dispose());
|
|
111
|
-
}), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1;
|
|
205
|
+
}), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1, this.attentionData = [], this.probabilitiesData = [], this.tokens = [];
|
|
112
206
|
}
|
|
113
207
|
dispose() {
|
|
114
208
|
this.reset();
|
|
115
209
|
}
|
|
116
|
-
initialise(t,
|
|
117
|
-
const
|
|
118
|
-
if (this.cache &&
|
|
119
|
-
const s = new Array(this.model.config.
|
|
120
|
-
for (let n = 0; n < this.model.config.
|
|
210
|
+
initialise(t, e) {
|
|
211
|
+
const i = t && t.length > this.model.config.blockSize ? t.slice(-this.model.config.blockSize) : t ?? null;
|
|
212
|
+
if (this.cache && e?.noCache && this.reset(), this.initialPrompt = i || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !e?.noCache && this.model.config.useRope) {
|
|
213
|
+
const s = new Array(this.model.config.nLayer);
|
|
214
|
+
for (let n = 0; n < this.model.config.nLayer; n++)
|
|
121
215
|
s[n] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
|
|
122
216
|
this.cache = s, this.lastToken = -1;
|
|
123
217
|
}
|
|
124
|
-
const o = this.tokeniser.trained ? this.tokeniser : new
|
|
218
|
+
const o = this.tokeniser.trained ? this.tokeniser : new M(U(H, this.tokeniser.vocabSize));
|
|
125
219
|
this.actualTokeniser = o;
|
|
126
220
|
}
|
|
127
|
-
async step(t,
|
|
128
|
-
const
|
|
129
|
-
return this.generate(t,
|
|
221
|
+
async step(t, e) {
|
|
222
|
+
const i = { ...e, maxLength: 1 };
|
|
223
|
+
return this.generate(t, i);
|
|
130
224
|
}
|
|
131
|
-
async generate(t,
|
|
132
|
-
this.initialise(t,
|
|
133
|
-
const o = await this._generate(
|
|
225
|
+
async generate(t, e) {
|
|
226
|
+
this.initialise(t, e), this.active = !0, this.emit("start");
|
|
227
|
+
const o = await this._generate(e);
|
|
134
228
|
return this.active = !1, this.emit("stop"), o;
|
|
135
229
|
}
|
|
136
230
|
stop() {
|
|
@@ -139,7 +233,16 @@ class nt extends l {
|
|
|
139
233
|
getText() {
|
|
140
234
|
return this.outputText;
|
|
141
235
|
}
|
|
236
|
+
getAttentionData() {
|
|
237
|
+
return this.attentionData;
|
|
238
|
+
}
|
|
239
|
+
getProbabilitiesData() {
|
|
240
|
+
return this.probabilitiesData;
|
|
241
|
+
}
|
|
242
|
+
getTokens() {
|
|
243
|
+
return this.tokens;
|
|
244
|
+
}
|
|
142
245
|
}
|
|
143
246
|
export {
|
|
144
|
-
|
|
247
|
+
qt as default
|
|
145
248
|
};
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import { aq as T,
|
|
2
|
-
import { r as $ } from "./Reshape-
|
|
3
|
-
import { g as A, a as
|
|
4
|
-
import { t as
|
|
5
|
-
import { c as _ } from "./backend_util-
|
|
6
|
-
import { f as y } from "./gpgpu_math-
|
|
7
|
-
import { g as G, b as L } from "./kernel_funcs_utils-
|
|
1
|
+
import { aq as T, ag as E, p as O, j as V, aB as B, a1 as F, ah as j, aC as K } from "./index-DdmHGZjq.js";
|
|
2
|
+
import { r as $ } from "./Reshape-Bh_jzKzV.js";
|
|
3
|
+
import { g as A, a as C, b as k, c as N, e as R } from "./axis_util-Did9235A.js";
|
|
4
|
+
import { t as U, m as W } from "./shared-CZiWmQCI.js";
|
|
5
|
+
import { c as _ } from "./backend_util-yC3YH1jo.js";
|
|
6
|
+
import { f as y } from "./gpgpu_math-D_ODOLix.js";
|
|
7
|
+
import { g as G, b as L } from "./kernel_funcs_utils-CDfFpUab.js";
|
|
8
8
|
/**
|
|
9
9
|
* @license
|
|
10
10
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -381,7 +381,7 @@ function Q(a, s, e, t) {
|
|
|
381
381
|
let i = r;
|
|
382
382
|
const c = A(i, l), o = c != null;
|
|
383
383
|
let u = a;
|
|
384
|
-
o && (u = D(a, c, t), i =
|
|
384
|
+
o && (u = D(a, c, t), i = C(i.length, l)), k("sum", i, l);
|
|
385
385
|
const [p, h] = N(u.shape, i);
|
|
386
386
|
let d = p;
|
|
387
387
|
e && (d = R(p, r));
|
|
@@ -459,15 +459,15 @@ function te(a) {
|
|
|
459
459
|
const I = e.texData.get(d.dataId).values, m = new Array(i);
|
|
460
460
|
for (let v = 0; v < m.length; v++)
|
|
461
461
|
m[v] = n.shape[u[v]];
|
|
462
|
-
const z =
|
|
462
|
+
const z = U(I, n.shape, n.dtype, u, m);
|
|
463
463
|
d = e.makeTensorInfo(m, n.dtype);
|
|
464
464
|
const M = e.texData.get(d.dataId);
|
|
465
465
|
M.values = z;
|
|
466
466
|
} else
|
|
467
467
|
d = D(n, u, e);
|
|
468
|
-
o =
|
|
468
|
+
o = C(o.length, i);
|
|
469
469
|
}
|
|
470
|
-
|
|
470
|
+
k("max", o, i);
|
|
471
471
|
const [f, S] = N(d.shape, o);
|
|
472
472
|
let g = f;
|
|
473
473
|
r && (g = R(f, c));
|
|
@@ -482,7 +482,7 @@ function te(a) {
|
|
|
482
482
|
return p && e.disposeIntermediateTensorInfo(d), x;
|
|
483
483
|
}
|
|
484
484
|
const he = {
|
|
485
|
-
kernelName:
|
|
485
|
+
kernelName: j,
|
|
486
486
|
backendName: "webgl",
|
|
487
487
|
kernelFunc: te
|
|
488
488
|
};
|
|
@@ -525,7 +525,7 @@ return a / b;`, se = `
|
|
|
525
525
|
|
|
526
526
|
return result;
|
|
527
527
|
`, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
|
|
528
|
-
kernelName:
|
|
528
|
+
kernelName: K,
|
|
529
529
|
backendName: "webgl",
|
|
530
530
|
kernelFunc: ne
|
|
531
531
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { j as c,
|
|
2
|
-
import { u as g, g as I, a as x, b as F, c as $, d as u, e as
|
|
1
|
+
import { j as c, a5 as C, n as f, V as R } from "./index-DdmHGZjq.js";
|
|
2
|
+
import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-D_ODOLix.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
5
5
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -82,14 +82,14 @@ function v(s, t) {
|
|
|
82
82
|
function b(s, t, i) {
|
|
83
83
|
const a = [
|
|
84
84
|
u(s.shape),
|
|
85
|
-
...
|
|
85
|
+
...m(s.shape)
|
|
86
86
|
], e = {
|
|
87
87
|
dtype: s.dtype,
|
|
88
88
|
shape: a,
|
|
89
89
|
dataId: s.dataId
|
|
90
90
|
}, o = [
|
|
91
91
|
u(t),
|
|
92
|
-
...
|
|
92
|
+
...m(t)
|
|
93
93
|
], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
|
|
94
94
|
return { dataId: h.dataId, shape: t, dtype: h.dtype };
|
|
95
95
|
}
|
|
@@ -113,7 +113,7 @@ function y(s) {
|
|
|
113
113
|
const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = C(o, p), h = c(n);
|
|
114
114
|
f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
|
|
115
115
|
const d = r.texData.get(e.dataId);
|
|
116
|
-
return d.isPacked && !
|
|
116
|
+
return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
|
|
117
117
|
}
|
|
118
118
|
const U = {
|
|
119
119
|
kernelName: R,
|
package/dist/TeachableLLM.d.ts
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
import { GPTConfig } from './config';
|
|
1
|
+
import { GPTConfig } from './models/config';
|
|
2
2
|
import { ITokeniser } from './tokeniser/type';
|
|
3
|
-
import {
|
|
4
|
-
import { SaveOptions } from './utilities/save';
|
|
3
|
+
import { SaveOptions } from './loader/save';
|
|
5
4
|
import { default as Generator, IGenerateOptions } from './Generator';
|
|
6
5
|
import { default as Trainer, ITrainerOptions } from './Trainer';
|
|
7
6
|
import { default as MemoryProfiler } from './utilities/profile';
|
|
8
|
-
import { TrainingProgress } from './training/Trainer';
|
|
7
|
+
import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
|
|
8
|
+
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
9
9
|
type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
|
|
10
10
|
interface TeachableLLMMeta {
|
|
11
11
|
name?: string;
|
|
@@ -20,12 +20,12 @@ export default class TeachableLLM {
|
|
|
20
20
|
private _status;
|
|
21
21
|
private _memoryRequirements?;
|
|
22
22
|
meta: TeachableLLMMeta;
|
|
23
|
-
constructor(tokeniser?: ITokeniser, model?:
|
|
23
|
+
constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes>);
|
|
24
24
|
get vocab(): string[];
|
|
25
25
|
/** Model is fully loaded */
|
|
26
26
|
get loaded(): boolean;
|
|
27
27
|
get config(): GPTConfig;
|
|
28
|
-
get model():
|
|
28
|
+
get model(): Model<ModelForwardAttributes>;
|
|
29
29
|
get tokeniser(): ITokeniser;
|
|
30
30
|
get status(): TeachableLLMStatus;
|
|
31
31
|
/** Model is both ready and not busy */
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -1,30 +1,21 @@
|
|
|
1
|
-
import { defaultConfig as
|
|
2
|
-
import
|
|
3
|
-
import {
|
|
4
|
-
import { loadModel as l } from "./loader/load.js";
|
|
1
|
+
import { defaultConfig as d } from "./models/config.js";
|
|
2
|
+
import { saveModel as l } from "./loader/save.js";
|
|
3
|
+
import { loadModel as _ } from "./loader/load.js";
|
|
5
4
|
import u from "./Generator.js";
|
|
6
|
-
import
|
|
7
|
-
import { E as
|
|
5
|
+
import f from "./Trainer.js";
|
|
6
|
+
import { E as p } from "./index-Dwqa6Zy2.js";
|
|
8
7
|
import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
|
|
9
|
-
import
|
|
10
|
-
import k from "./tokeniser/bpe.js";
|
|
11
|
-
import "./papaparse.min-C8l2Kvo1.js";
|
|
12
|
-
import "./index-Tf7vU29b.js";
|
|
13
|
-
import "./jszip.min-CjP2V1VV.js";
|
|
14
|
-
import "./index-BoWRt-10.js";
|
|
15
|
-
import "./ops/cpu/scatterSub.js";
|
|
16
|
-
import "./ops/webgl/scatterSub.js";
|
|
17
|
-
import "./ops/cpu/gatherSub.js";
|
|
18
|
-
import "./ops/webgl/gatherSub.js";
|
|
8
|
+
import "./index-DdmHGZjq.js";
|
|
19
9
|
import "./ops/cpu/attentionMask.js";
|
|
20
10
|
import "./ops/webgl/attentionMask.js";
|
|
21
11
|
import "./ops/grads/attentionMask.js";
|
|
22
12
|
import "./ops/cpu/qkv.js";
|
|
23
13
|
import "./ops/webgl/qkv.js";
|
|
24
14
|
import "./ops/grads/qkv.js";
|
|
25
|
-
import "./random_width-
|
|
26
|
-
import "./register_all_kernels-
|
|
27
|
-
import "./
|
|
15
|
+
import "./random_width-DKGeiFuR.js";
|
|
16
|
+
import "./register_all_kernels-Do9VvZmo.js";
|
|
17
|
+
import "./index-Tf7vU29b.js";
|
|
18
|
+
import "./dataset-DPPl-iLT.js";
|
|
28
19
|
import "./ops/cpu/rope.js";
|
|
29
20
|
import "./ops/webgl/rope.js";
|
|
30
21
|
import "./ops/grads/rope.js";
|
|
@@ -36,20 +27,31 @@ import "./ops/grads/fusedSoftmax.js";
|
|
|
36
27
|
import "./ops/cpu/matMulGelu.js";
|
|
37
28
|
import "./ops/webgl/matMulGelu.js";
|
|
38
29
|
import "./ops/grads/matMulGelu.js";
|
|
39
|
-
import "./ops/cpu/gelu.js";
|
|
40
|
-
import "./ops/webgl/gelu.js";
|
|
41
|
-
import "./gelu-C-dPj6Ku.js";
|
|
42
30
|
import "./ops/cpu/normRMS.js";
|
|
43
31
|
import "./ops/webgl/normRMS.js";
|
|
44
32
|
import "./ops/grads/normRMS.js";
|
|
33
|
+
import "./ops/cpu/gatherSub.js";
|
|
34
|
+
import "./ops/webgl/gatherSub.js";
|
|
35
|
+
import "./ops/cpu/scatterSub.js";
|
|
36
|
+
import "./ops/webgl/scatterSub.js";
|
|
37
|
+
import c from "./tokeniser/CharTokeniser.js";
|
|
38
|
+
import g from "./tokeniser/bpe.js";
|
|
39
|
+
import "./papaparse.min-C8l2Kvo1.js";
|
|
40
|
+
import "./jszip.min-CjP2V1VV.js";
|
|
41
|
+
import "./ops/cpu/gelu.js";
|
|
42
|
+
import "./ops/webgl/gelu.js";
|
|
43
|
+
import "./gelu-BkcmEEyD.js";
|
|
45
44
|
import "./ops/webgl/log.js";
|
|
46
45
|
import "./ops/cpu/adamMoments.js";
|
|
47
46
|
import "./ops/webgl/adamMoments.js";
|
|
48
47
|
import "./ops/cpu/adamAdjust.js";
|
|
49
48
|
import "./ops/webgl/adamAdjust.js";
|
|
50
|
-
import
|
|
49
|
+
import "./checks/normRMS.js";
|
|
50
|
+
import "./checks/normRMSGrad.js";
|
|
51
|
+
import k from "./utilities/profile.js";
|
|
52
|
+
import w from "./models/factory.js";
|
|
51
53
|
class a {
|
|
52
|
-
ee = new
|
|
54
|
+
ee = new p();
|
|
53
55
|
_config;
|
|
54
56
|
_model;
|
|
55
57
|
_tokeniser;
|
|
@@ -69,7 +71,7 @@ class a {
|
|
|
69
71
|
get config() {
|
|
70
72
|
if (!this._config)
|
|
71
73
|
throw new Error("configuration_not_initialized.");
|
|
72
|
-
return this._config
|
|
74
|
+
return this._config;
|
|
73
75
|
}
|
|
74
76
|
get model() {
|
|
75
77
|
if (!this._model)
|
|
@@ -101,14 +103,14 @@ class a {
|
|
|
101
103
|
saveModel(t) {
|
|
102
104
|
if (!this._model || !this._tokeniser)
|
|
103
105
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
104
|
-
return
|
|
106
|
+
return l(this._model, this._tokeniser, {
|
|
105
107
|
...t,
|
|
106
108
|
name: t?.name || this.meta.name
|
|
107
109
|
});
|
|
108
110
|
}
|
|
109
111
|
static loadModel(t) {
|
|
110
112
|
const e = new a();
|
|
111
|
-
return
|
|
113
|
+
return _(t).then(({ model: r, tokeniser: o, name: s }) => {
|
|
112
114
|
e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
|
|
113
115
|
e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
|
|
114
116
|
}).catch((i) => {
|
|
@@ -119,7 +121,7 @@ class a {
|
|
|
119
121
|
}), e;
|
|
120
122
|
}
|
|
121
123
|
static create(t, e = {}) {
|
|
122
|
-
const r = { ...
|
|
124
|
+
const r = { ...d, ...e }, o = t === "char" ? new c(r.vocabSize) : new g(r.vocabSize), s = w(r), i = new a(o, s);
|
|
123
125
|
return i.setStatus("warmup"), m(s).then((n) => {
|
|
124
126
|
i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
|
|
125
127
|
h === "trained" && i.setStatus("ready");
|
|
@@ -138,9 +140,9 @@ class a {
|
|
|
138
140
|
if (t) {
|
|
139
141
|
if (!this._config)
|
|
140
142
|
return;
|
|
141
|
-
this.
|
|
143
|
+
this.model.getProfiler() || this.model.setProfiler(new k());
|
|
142
144
|
} else
|
|
143
|
-
this.
|
|
145
|
+
this.model.getProfiler() && this.model.setProfiler(null);
|
|
144
146
|
}
|
|
145
147
|
getNumParams() {
|
|
146
148
|
return this._model ? this._model.getNumParams() : 0;
|
|
@@ -148,7 +150,7 @@ class a {
|
|
|
148
150
|
trainer() {
|
|
149
151
|
if (!this._model || !this._tokeniser)
|
|
150
152
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
151
|
-
const t = new
|
|
153
|
+
const t = new f(this._model, this._tokeniser);
|
|
152
154
|
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
|
|
153
155
|
const o = this.ee.listeners("trainStep");
|
|
154
156
|
for (const s of o)
|
package/dist/Trainer.d.ts
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
import { default as NanoGPT } from './NanoGPTModel';
|
|
2
1
|
import { ITokeniser } from './tokeniser/type';
|
|
3
2
|
import { default as EE } from 'eventemitter3';
|
|
3
|
+
import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
|
|
4
|
+
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
4
5
|
export interface ITrainerOptions {
|
|
5
6
|
batchSize?: number;
|
|
6
7
|
learningRate?: number;
|
|
@@ -10,6 +11,11 @@ export interface ITrainerOptions {
|
|
|
10
11
|
prompt?: string;
|
|
11
12
|
validationSplit?: number;
|
|
12
13
|
advancedMetrics?: boolean;
|
|
14
|
+
gradientCheckpointing?: boolean;
|
|
15
|
+
}
|
|
16
|
+
interface ExtendedTrainingProgress extends TrainingProgress {
|
|
17
|
+
progress: number;
|
|
18
|
+
remaining: number;
|
|
13
19
|
}
|
|
14
20
|
export default class Trainer extends EE<'start' | 'stop' | 'log'> {
|
|
15
21
|
private trainer;
|
|
@@ -17,10 +23,15 @@ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
|
|
|
17
23
|
private trainDataset?;
|
|
18
24
|
private validationDataset?;
|
|
19
25
|
private totalSamples;
|
|
20
|
-
|
|
26
|
+
private log;
|
|
27
|
+
private progress;
|
|
28
|
+
constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
|
|
21
29
|
stop(): void;
|
|
22
30
|
reset(): void;
|
|
23
31
|
prepare(text: string[], options?: ITrainerOptions): Promise<void>;
|
|
24
32
|
train(options?: ITrainerOptions): Promise<void>;
|
|
25
33
|
step(options?: ITrainerOptions): Promise<void>;
|
|
34
|
+
getLog(): TrainingLogEntry[];
|
|
35
|
+
getProgress(): ExtendedTrainingProgress | null;
|
|
26
36
|
}
|
|
37
|
+
export {};
|