@genai-fi/nanogpt 0.7.2 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +36 -4
- package/dist/Generator.js +183 -69
- package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-N8TpOMYv.js} +14 -14
- package/dist/{Reshape-DvudQDvJ.js → Reshape-B-lWQRnF.js} +1 -1
- package/dist/{Reshape-DH5srBP0.js → Reshape-Bo8HzP8V.js} +5 -5
- package/dist/TeachableLLM.d.ts +6 -6
- package/dist/TeachableLLM.js +51 -50
- package/dist/Trainer.d.ts +19 -3
- package/dist/Trainer.js +71 -28
- package/dist/{axis_util-BzbKo31C.js → axis_util-DubwyOhW.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-TE7aTPhZ.js → backend_util-BJ-_jSeK.js} +46 -46
- package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-BYfCp5iL.js} +2 -2
- package/dist/{concat-CsxrgovM.js → concat-BmDqqFsa.js} +1 -1
- package/dist/{dataset-CtdBYwjo.js → dataset-CJmEGu6D.js} +5 -5
- package/dist/{dropout-DYs5QFGQ.js → dropout-sx0sjVAT.js} +8 -8
- package/dist/exports_initializers-DAKM8UO9.js +16 -0
- package/dist/{gather-CMMy2KEG.js → gather-C1siEkdp.js} +1 -1
- package/dist/{gelu-C-dPj6Ku.js → gelu-Bd3UBBxg.js} +1 -1
- package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-TFLxaLkw.js} +26 -26
- package/dist/{index-CLthM0TO.js → index-BaPo_0H8.js} +185 -185
- package/dist/{index-BoWRt-10.js → index-CUQrfsw_.js} +266 -265
- package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-P9aFa232.js} +9 -9
- package/dist/layers/BaseLayer.d.ts +8 -13
- package/dist/layers/BaseLayer.js +25 -13
- package/dist/layers/CausalSelfAttention.d.ts +3 -2
- package/dist/layers/CausalSelfAttention.js +28 -28
- package/dist/layers/MLP.d.ts +3 -2
- package/dist/layers/MLP.js +16 -20
- package/dist/layers/PositionEmbedding.d.ts +9 -0
- package/dist/layers/PositionEmbedding.js +45 -0
- package/dist/layers/RMSNorm.d.ts +3 -2
- package/dist/layers/RMSNorm.js +6 -6
- package/dist/layers/RoPECache.d.ts +1 -1
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.d.ts +3 -2
- package/dist/layers/TiedEmbedding.js +29 -7
- package/dist/layers/TransformerBlock.d.ts +3 -2
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/load.d.ts +2 -2
- package/dist/loader/loadHF.d.ts +2 -2
- package/dist/loader/loadTransformers.d.ts +4 -2
- package/dist/loader/loadTransformers.js +10 -9
- package/dist/loader/newZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.js +42 -51
- package/dist/loader/save.d.ts +8 -0
- package/dist/loader/save.js +62 -0
- package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C142qZqY.js} +14 -14
- package/dist/main.d.ts +5 -4
- package/dist/main.js +22 -18
- package/dist/{mat_mul-8m8pfdcx.js → mat_mul-DMkduNJu.js} +1 -1
- package/dist/{max-Ddnnb5xe.js → max-B3JOcNGb.js} +1 -1
- package/dist/mod-uUuj4gSb.js +27 -0
- package/dist/models/NanoGPTV1.d.ts +15 -0
- package/dist/models/NanoGPTV1.js +71 -0
- package/dist/{config.d.ts → models/config.d.ts} +1 -0
- package/dist/{config.js → models/config.js} +1 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +14 -0
- package/dist/models/model.d.ts +26 -0
- package/dist/models/model.js +68 -0
- package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-Cm2gw-c8.js} +1 -1
- package/dist/{ones-Dj0SDhHf.js → ones-ZdgQGBCP.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +3 -3
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +11 -11
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +10 -10
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +3 -3
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +5 -5
- package/dist/ops/webgpu/qkv.js +3 -3
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/{ops-BFGCx8Ri.js → ops-C_1K_-35.js} +103 -103
- package/dist/{random_width-sZORGo5k.js → random_width-D8Pwy_na.js} +136 -136
- package/dist/{range-CRuAh-gd.js → range-LVHrSLdi.js} +1 -1
- package/dist/{reciprocal-BvGAyKyu.js → reciprocal-CaR9e67G.js} +1 -1
- package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-DUshvVWP.js} +2026 -2049
- package/dist/{reshape-CdBq1WJ6.js → reshape-DEfQGSin.js} +1 -1
- package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-CUPPNLaA.js} +1 -1
- package/dist/{selu_util-BJEXVvjX.js → selu_util-8vv5JxQV.js} +3 -3
- package/dist/{shared-B8ztnyEk.js → shared-CkNorDcU.js} +83 -83
- package/dist/{shared-wS99K7_n.js → shared-D1elLckx.js} +1 -1
- package/dist/{sin-BeA3tsEd.js → sin-D2CKKmyR.js} +1 -1
- package/dist/{slice-BiOsknYS.js → slice-BnyE-M_7.js} +1 -1
- package/dist/{softmax-Bv_6lyMX.js → softmax-DLoZWYBx.js} +1 -1
- package/dist/{split-B-dikLRw.js → split-By_n4TKP.js} +1 -1
- package/dist/{stack-B17UN2nn.js → stack-DkdFLq37.js} +1 -1
- package/dist/{sum-66ew2byf.js → sum-l_0SqM4h.js} +3 -3
- package/dist/{tensor-JwS7ZYY6.js → tensor-BAQdLqoU.js} +1 -1
- package/dist/{tensor2d-wxPAnDQy.js → tensor2d-BHy261cI.js} +1 -1
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/Evaluator.d.ts +2 -2
- package/dist/training/FullTrainer.d.ts +16 -3
- package/dist/training/FullTrainer.js +91 -53
- package/dist/training/Trainer.d.ts +25 -3
- package/dist/training/Trainer.js +39 -47
- package/dist/training/sparseCrossEntropy.js +9 -9
- package/dist/utilities/dummy.d.ts +4 -4
- package/dist/utilities/dummy.js +13 -13
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/parameters.d.ts +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BuddVFLa.js → variable-C9hihzDB.js} +1 -1
- package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-dFEVbDPL.js} +1 -1
- package/dist/{webgpu_util-D____QpY.js → webgpu_util-DLImlSc6.js} +27 -27
- package/dist/{zeros--BdLQ3oG.js → zeros-VZ72lWXM.js} +1 -1
- package/package.json +2 -3
- package/dist/NanoGPTModel.d.ts +0 -52
- package/dist/NanoGPTModel.js +0 -203
- package/dist/TiedEmbedding-BxOerUmB.js +0 -43
- package/dist/utilities/generate.d.ts +0 -3
- package/dist/utilities/generate.js +0 -22
- package/dist/utilities/save.d.ts +0 -9
- package/dist/utilities/save.js +0 -61
package/dist/TeachableLLM.js
CHANGED
|
@@ -1,30 +1,21 @@
|
|
|
1
|
-
import { defaultConfig as
|
|
2
|
-
import
|
|
3
|
-
import {
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
import p from "./
|
|
7
|
-
import { E as g } from "./index-Dwqa6Zy2.js";
|
|
1
|
+
import { defaultConfig as d } from "./models/config.js";
|
|
2
|
+
import { saveModel as l } from "./loader/save.js";
|
|
3
|
+
import { loadModel as _ } from "./loader/load.js";
|
|
4
|
+
import u from "./Generator.js";
|
|
5
|
+
import f from "./Trainer.js";
|
|
6
|
+
import { E as p } from "./index-Dwqa6Zy2.js";
|
|
8
7
|
import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
|
|
9
|
-
import
|
|
10
|
-
import k from "./tokeniser/bpe.js";
|
|
11
|
-
import "./papaparse.min-C8l2Kvo1.js";
|
|
12
|
-
import "./index-Tf7vU29b.js";
|
|
13
|
-
import "./jszip.min-CjP2V1VV.js";
|
|
14
|
-
import "./index-BoWRt-10.js";
|
|
15
|
-
import "./ops/cpu/scatterSub.js";
|
|
16
|
-
import "./ops/webgl/scatterSub.js";
|
|
17
|
-
import "./ops/cpu/gatherSub.js";
|
|
18
|
-
import "./ops/webgl/gatherSub.js";
|
|
8
|
+
import "./index-CUQrfsw_.js";
|
|
19
9
|
import "./ops/cpu/attentionMask.js";
|
|
20
10
|
import "./ops/webgl/attentionMask.js";
|
|
21
11
|
import "./ops/grads/attentionMask.js";
|
|
22
12
|
import "./ops/cpu/qkv.js";
|
|
23
13
|
import "./ops/webgl/qkv.js";
|
|
24
14
|
import "./ops/grads/qkv.js";
|
|
25
|
-
import "./random_width-
|
|
26
|
-
import "./register_all_kernels-
|
|
27
|
-
import "./
|
|
15
|
+
import "./random_width-D8Pwy_na.js";
|
|
16
|
+
import "./register_all_kernels-DUshvVWP.js";
|
|
17
|
+
import "./index-Tf7vU29b.js";
|
|
18
|
+
import "./dataset-CJmEGu6D.js";
|
|
28
19
|
import "./ops/cpu/rope.js";
|
|
29
20
|
import "./ops/webgl/rope.js";
|
|
30
21
|
import "./ops/grads/rope.js";
|
|
@@ -36,20 +27,29 @@ import "./ops/grads/fusedSoftmax.js";
|
|
|
36
27
|
import "./ops/cpu/matMulGelu.js";
|
|
37
28
|
import "./ops/webgl/matMulGelu.js";
|
|
38
29
|
import "./ops/grads/matMulGelu.js";
|
|
39
|
-
import "./ops/cpu/gelu.js";
|
|
40
|
-
import "./ops/webgl/gelu.js";
|
|
41
|
-
import "./gelu-C-dPj6Ku.js";
|
|
42
30
|
import "./ops/cpu/normRMS.js";
|
|
43
31
|
import "./ops/webgl/normRMS.js";
|
|
44
32
|
import "./ops/grads/normRMS.js";
|
|
33
|
+
import "./ops/cpu/gatherSub.js";
|
|
34
|
+
import "./ops/webgl/gatherSub.js";
|
|
35
|
+
import "./ops/cpu/scatterSub.js";
|
|
36
|
+
import "./ops/webgl/scatterSub.js";
|
|
37
|
+
import c from "./tokeniser/CharTokeniser.js";
|
|
38
|
+
import g from "./tokeniser/bpe.js";
|
|
39
|
+
import "./papaparse.min-C8l2Kvo1.js";
|
|
40
|
+
import "./jszip.min-CjP2V1VV.js";
|
|
41
|
+
import "./ops/cpu/gelu.js";
|
|
42
|
+
import "./ops/webgl/gelu.js";
|
|
43
|
+
import "./gelu-Bd3UBBxg.js";
|
|
45
44
|
import "./ops/webgl/log.js";
|
|
46
45
|
import "./ops/cpu/adamMoments.js";
|
|
47
46
|
import "./ops/webgl/adamMoments.js";
|
|
48
47
|
import "./ops/cpu/adamAdjust.js";
|
|
49
48
|
import "./ops/webgl/adamAdjust.js";
|
|
50
|
-
import
|
|
49
|
+
import k from "./utilities/profile.js";
|
|
50
|
+
import w from "./models/factory.js";
|
|
51
51
|
class a {
|
|
52
|
-
ee = new
|
|
52
|
+
ee = new p();
|
|
53
53
|
_config;
|
|
54
54
|
_model;
|
|
55
55
|
_tokeniser;
|
|
@@ -69,7 +69,7 @@ class a {
|
|
|
69
69
|
get config() {
|
|
70
70
|
if (!this._config)
|
|
71
71
|
throw new Error("configuration_not_initialized.");
|
|
72
|
-
return this._config
|
|
72
|
+
return this._config;
|
|
73
73
|
}
|
|
74
74
|
get model() {
|
|
75
75
|
if (!this._model)
|
|
@@ -92,8 +92,8 @@ class a {
|
|
|
92
92
|
return this._status === "busy" || this._status === "training";
|
|
93
93
|
}
|
|
94
94
|
estimateTrainingMemoryUsage(t) {
|
|
95
|
-
const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 },
|
|
96
|
-
return
|
|
95
|
+
const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 }, r = e.perBatch * t, o = e.gradients;
|
|
96
|
+
return r * 0.66 + o * 4;
|
|
97
97
|
}
|
|
98
98
|
setStatus(t) {
|
|
99
99
|
this._status !== t && (this._status = t, this.ee.emit("status", t));
|
|
@@ -101,32 +101,32 @@ class a {
|
|
|
101
101
|
saveModel(t) {
|
|
102
102
|
if (!this._model || !this._tokeniser)
|
|
103
103
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
104
|
-
return
|
|
104
|
+
return l(this._model, this._tokeniser, {
|
|
105
105
|
...t,
|
|
106
106
|
name: t?.name || this.meta.name
|
|
107
107
|
});
|
|
108
108
|
}
|
|
109
109
|
static loadModel(t) {
|
|
110
110
|
const e = new a();
|
|
111
|
-
return
|
|
112
|
-
e._model =
|
|
113
|
-
e._memoryRequirements =
|
|
114
|
-
}).catch((
|
|
115
|
-
e.setStatus("error"), e.ee.emit("error",
|
|
111
|
+
return _(t).then(({ model: r, tokeniser: o, name: s }) => {
|
|
112
|
+
e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
|
|
113
|
+
e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
|
|
114
|
+
}).catch((i) => {
|
|
115
|
+
e.setStatus("error"), e.ee.emit("error", i);
|
|
116
116
|
});
|
|
117
|
-
}).catch((
|
|
118
|
-
e.setStatus("error"), e.ee.emit("error",
|
|
117
|
+
}).catch((r) => {
|
|
118
|
+
e.setStatus("error"), e.ee.emit("error", r);
|
|
119
119
|
}), e;
|
|
120
120
|
}
|
|
121
121
|
static create(t, e = {}) {
|
|
122
|
-
const
|
|
123
|
-
return
|
|
124
|
-
|
|
125
|
-
h === "trained" &&
|
|
122
|
+
const r = { ...d, ...e }, o = t === "char" ? new c(r.vocabSize) : new g(r.vocabSize), s = w(r), i = new a(o, s);
|
|
123
|
+
return i.setStatus("warmup"), m(s).then((n) => {
|
|
124
|
+
i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
|
|
125
|
+
h === "trained" && i.setStatus("ready");
|
|
126
126
|
}));
|
|
127
127
|
}).catch((n) => {
|
|
128
|
-
|
|
129
|
-
}),
|
|
128
|
+
i.setStatus("error"), i.ee.emit("error", n);
|
|
129
|
+
}), i;
|
|
130
130
|
}
|
|
131
131
|
getProfiler() {
|
|
132
132
|
return this._model?.getProfiler();
|
|
@@ -138,9 +138,9 @@ class a {
|
|
|
138
138
|
if (t) {
|
|
139
139
|
if (!this._config)
|
|
140
140
|
return;
|
|
141
|
-
this.
|
|
141
|
+
this.model.getProfiler() || this.model.setProfiler(new k());
|
|
142
142
|
} else
|
|
143
|
-
this.
|
|
143
|
+
this.model.getProfiler() && this.model.setProfiler(null);
|
|
144
144
|
}
|
|
145
145
|
getNumParams() {
|
|
146
146
|
return this._model ? this._model.getNumParams() : 0;
|
|
@@ -148,15 +148,16 @@ class a {
|
|
|
148
148
|
trainer() {
|
|
149
149
|
if (!this._model || !this._tokeniser)
|
|
150
150
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
151
|
-
const t = new
|
|
152
|
-
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e,
|
|
151
|
+
const t = new f(this._model, this._tokeniser);
|
|
152
|
+
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
|
|
153
153
|
const o = this.ee.listeners("trainStep");
|
|
154
154
|
for (const s of o)
|
|
155
|
-
await s(e,
|
|
155
|
+
await s(e, r);
|
|
156
156
|
}), t;
|
|
157
157
|
}
|
|
158
|
-
train(t, e) {
|
|
159
|
-
|
|
158
|
+
async train(t, e) {
|
|
159
|
+
const r = this.trainer();
|
|
160
|
+
await r.prepare(t, e), await r.train(e);
|
|
160
161
|
}
|
|
161
162
|
async trainTokeniser(t) {
|
|
162
163
|
if (!this._tokeniser)
|
|
@@ -167,7 +168,7 @@ class a {
|
|
|
167
168
|
generator() {
|
|
168
169
|
if (!this._model || !this._tokeniser)
|
|
169
170
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
170
|
-
const t = new
|
|
171
|
+
const t = new u(this._model, this._tokeniser);
|
|
171
172
|
return t.on("start", () => {
|
|
172
173
|
this.status === "ready" && this.setStatus("busy");
|
|
173
174
|
}), t.on("stop", () => {
|
package/dist/Trainer.d.ts
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
import { default as NanoGPT } from './NanoGPTModel';
|
|
2
1
|
import { ITokeniser } from './tokeniser/type';
|
|
3
2
|
import { default as EE } from 'eventemitter3';
|
|
3
|
+
import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
|
|
4
|
+
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
4
5
|
export interface ITrainerOptions {
|
|
5
6
|
batchSize?: number;
|
|
6
7
|
learningRate?: number;
|
|
@@ -10,12 +11,27 @@ export interface ITrainerOptions {
|
|
|
10
11
|
prompt?: string;
|
|
11
12
|
validationSplit?: number;
|
|
12
13
|
advancedMetrics?: boolean;
|
|
14
|
+
gradientCheckpointing?: boolean;
|
|
15
|
+
}
|
|
16
|
+
interface ExtendedTrainingProgress extends TrainingProgress {
|
|
17
|
+
progress: number;
|
|
18
|
+
remaining: number;
|
|
13
19
|
}
|
|
14
20
|
export default class Trainer extends EE<'start' | 'stop' | 'log'> {
|
|
15
21
|
private trainer;
|
|
16
22
|
private hasTrained;
|
|
17
|
-
|
|
23
|
+
private trainDataset?;
|
|
24
|
+
private validationDataset?;
|
|
25
|
+
private totalSamples;
|
|
26
|
+
private log;
|
|
27
|
+
private progress;
|
|
28
|
+
constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
|
|
18
29
|
stop(): void;
|
|
19
30
|
reset(): void;
|
|
20
|
-
|
|
31
|
+
prepare(text: string[], options?: ITrainerOptions): Promise<void>;
|
|
32
|
+
train(options?: ITrainerOptions): Promise<void>;
|
|
33
|
+
step(options?: ITrainerOptions): Promise<void>;
|
|
34
|
+
getLog(): TrainingLogEntry[];
|
|
35
|
+
getProgress(): ExtendedTrainingProgress | null;
|
|
21
36
|
}
|
|
37
|
+
export {};
|
package/dist/Trainer.js
CHANGED
|
@@ -1,48 +1,91 @@
|
|
|
1
|
-
import { E as
|
|
2
|
-
import
|
|
3
|
-
class
|
|
1
|
+
import { E as l } from "./index-Dwqa6Zy2.js";
|
|
2
|
+
import h from "./training/FullTrainer.js";
|
|
3
|
+
class m extends l {
|
|
4
4
|
trainer;
|
|
5
5
|
hasTrained = !1;
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
trainDataset;
|
|
7
|
+
validationDataset;
|
|
8
|
+
totalSamples = 0;
|
|
9
|
+
log = [];
|
|
10
|
+
progress = null;
|
|
11
|
+
constructor(t, e) {
|
|
12
|
+
super(), this.trainer = new h(t, e, 1e-3);
|
|
8
13
|
}
|
|
9
14
|
stop() {
|
|
10
15
|
this.trainer.stop();
|
|
11
16
|
}
|
|
12
17
|
reset() {
|
|
13
|
-
this.hasTrained = !1, this.trainer.reset();
|
|
14
|
-
}
|
|
15
|
-
async
|
|
16
|
-
const { trainDataset:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
),
|
|
21
|
-
this.
|
|
22
|
-
|
|
18
|
+
this.hasTrained = !1, this.log = [], this.trainer.reset();
|
|
19
|
+
}
|
|
20
|
+
async prepare(t, e) {
|
|
21
|
+
const { trainDataset: a, validationDataset: s } = await this.trainer.createTrainValidationSplit(
|
|
22
|
+
t,
|
|
23
|
+
e?.batchSize || 32,
|
|
24
|
+
e?.validationSplit || 0.1
|
|
25
|
+
), i = t.reduce((r, n) => r + n.length, 0) * (1 - (e?.validationSplit || 0));
|
|
26
|
+
this.trainDataset = a, this.validationDataset = s, this.totalSamples = i;
|
|
27
|
+
}
|
|
28
|
+
async train(t) {
|
|
29
|
+
if (!this.trainDataset || !this.validationDataset)
|
|
30
|
+
throw new Error("Datasets not prepared");
|
|
31
|
+
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), await this.trainer.trainOnDataset(
|
|
32
|
+
this.trainDataset,
|
|
23
33
|
{
|
|
24
34
|
prompt: t?.prompt,
|
|
25
35
|
logInterval: t?.logInterval || 10,
|
|
26
36
|
desiredLoss: t?.desiredLoss || 0.01,
|
|
27
37
|
maxSteps: t?.maxSteps || 1e3,
|
|
28
38
|
advancedMetrics: t?.advancedMetrics || !1,
|
|
29
|
-
onStep: async (
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
39
|
+
onStep: async (e, a) => {
|
|
40
|
+
this.log.push(e), this.progress = {
|
|
41
|
+
...a,
|
|
42
|
+
progress: a.totalSamples / this.totalSamples,
|
|
43
|
+
remaining: Math.max(
|
|
44
|
+
0,
|
|
45
|
+
(this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
|
|
46
|
+
)
|
|
47
|
+
};
|
|
48
|
+
const s = this.listeners("log");
|
|
49
|
+
for (const i of s)
|
|
50
|
+
await i(e, this.progress);
|
|
40
51
|
}
|
|
41
52
|
},
|
|
42
|
-
|
|
53
|
+
this.validationDataset
|
|
43
54
|
), this.emit("stop");
|
|
44
55
|
}
|
|
56
|
+
async step(t) {
|
|
57
|
+
if (!this.trainDataset || !this.validationDataset)
|
|
58
|
+
throw new Error("Datasets not prepared");
|
|
59
|
+
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
|
|
60
|
+
const { log: e, progress: a } = await this.trainer.stepDataset(
|
|
61
|
+
this.trainDataset,
|
|
62
|
+
{
|
|
63
|
+
prompt: t?.prompt,
|
|
64
|
+
logInterval: t?.logInterval || 10,
|
|
65
|
+
desiredLoss: t?.desiredLoss || 0.01,
|
|
66
|
+
maxSteps: t?.maxSteps || 1e3,
|
|
67
|
+
advancedMetrics: t?.advancedMetrics || !1
|
|
68
|
+
},
|
|
69
|
+
this.validationDataset
|
|
70
|
+
), s = this.listeners("log");
|
|
71
|
+
for (const i of s)
|
|
72
|
+
await i(e, {
|
|
73
|
+
...a,
|
|
74
|
+
progress: a.totalSamples / this.totalSamples,
|
|
75
|
+
remaining: Math.max(
|
|
76
|
+
0,
|
|
77
|
+
(this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
|
|
78
|
+
)
|
|
79
|
+
});
|
|
80
|
+
this.emit("stop");
|
|
81
|
+
}
|
|
82
|
+
getLog() {
|
|
83
|
+
return this.log;
|
|
84
|
+
}
|
|
85
|
+
getProgress() {
|
|
86
|
+
return this.progress;
|
|
87
|
+
}
|
|
45
88
|
}
|
|
46
89
|
export {
|
|
47
|
-
|
|
90
|
+
m as default
|
|
48
91
|
};
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { n as c } from "./index-CUQrfsw_.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
@@ -28,7 +28,7 @@ function a(e, n, t) {
|
|
|
28
28
|
t.indexOf(u) === -1 ? s.push(e[o++]) : s.push(n[f++]);
|
|
29
29
|
return s;
|
|
30
30
|
}
|
|
31
|
-
function
|
|
31
|
+
function l(e, n) {
|
|
32
32
|
const t = [], r = e.length;
|
|
33
33
|
for (let o = 0; o < r; o++)
|
|
34
34
|
n.indexOf(o) === -1 && t.push(e[o]);
|
|
@@ -62,7 +62,7 @@ function x(e, n) {
|
|
|
62
62
|
export {
|
|
63
63
|
x as a,
|
|
64
64
|
m as b,
|
|
65
|
-
|
|
65
|
+
l as c,
|
|
66
66
|
i as d,
|
|
67
67
|
h as e,
|
|
68
68
|
a as f,
|
package/dist/backend.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { g as a, s as i, r as o } from "./index-
|
|
1
|
+
import { g as a, s as i, r as o } from "./index-CUQrfsw_.js";
|
|
2
2
|
async function e(t) {
|
|
3
|
-
a() !== t && (t === "webgpu" && (await import("./index-
|
|
3
|
+
a() !== t && (t === "webgpu" && (await import("./index-BaPo_0H8.js"), await import("./ops/webgpu/index.js")), await i(t), await o(), console.log(`Backend set to ${t}`));
|
|
4
4
|
}
|
|
5
5
|
export {
|
|
6
6
|
e as selectBackend
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { j as m,
|
|
2
|
-
import { b as L, d as W, f as v, c as N, e as x, g as P, a as C, h as z } from "./axis_util-
|
|
3
|
-
import { S as U, a as B, b as V, c as j, d as
|
|
4
|
-
import { c as re, v as oe, a as ae } from "./scatter_nd_util-
|
|
1
|
+
import { j as m, a2 as O, n as g, aM as $, aN as R, aO as M, l as _, ad as y, ay as D, aP as T, u as b, aQ as F } from "./index-CUQrfsw_.js";
|
|
2
|
+
import { b as L, d as W, f as v, c as N, e as x, g as P, a as C, h as z } from "./axis_util-DubwyOhW.js";
|
|
3
|
+
import { S as U, a as B, b as V, c as j, d as G, e as H, f as k, g as q, h as Z, i as X, j as J, k as K, l as Q, m as Y, s as ee, n as te, o as ne, t as se } from "./selu_util-8vv5JxQV.js";
|
|
4
|
+
import { c as re, v as oe, a as ae } from "./scatter_nd_util-CUPPNLaA.js";
|
|
5
5
|
function ie(e, n) {
|
|
6
6
|
const r = e.shape.length, t = n.shape.length;
|
|
7
7
|
if (r < 1)
|
|
@@ -233,7 +233,7 @@ function Ie(e, n) {
|
|
|
233
233
|
r.push(e[t][0]);
|
|
234
234
|
return r;
|
|
235
235
|
}
|
|
236
|
-
function
|
|
236
|
+
function Se(e, n, r) {
|
|
237
237
|
const t = e.slice(0, 1);
|
|
238
238
|
for (let s = 0; s < r; ++s)
|
|
239
239
|
t.push(e[s + 1] - n[s][0] - n[s][1]);
|
|
@@ -255,7 +255,7 @@ function we(e, n, r) {
|
|
|
255
255
|
* limitations under the License.
|
|
256
256
|
* =============================================================================
|
|
257
257
|
*/
|
|
258
|
-
const
|
|
258
|
+
const we = 0.3275911, Ae = 0.254829592, Oe = -0.284496736, Re = 1.421413741, Me = -1.453152027, _e = 1.061405429;
|
|
259
259
|
/**
|
|
260
260
|
* @license
|
|
261
261
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -333,7 +333,7 @@ function ve(e, n, r) {
|
|
|
333
333
|
* limitations under the License.
|
|
334
334
|
* =============================================================================
|
|
335
335
|
*/
|
|
336
|
-
const E = "->", Ne = /->/g,
|
|
336
|
+
const E = "->", Ne = /->/g, S = ",", w = "...";
|
|
337
337
|
function xe(e, n) {
|
|
338
338
|
e = e.replace(/\s/g, "");
|
|
339
339
|
const r = (e.length - e.replace(Ne, "").length) / E.length;
|
|
@@ -342,8 +342,8 @@ function xe(e, n) {
|
|
|
342
342
|
if (r > 1)
|
|
343
343
|
throw new Error(`Equation must contain exactly one arrow ("${E}").`);
|
|
344
344
|
const [t, s] = e.split(E);
|
|
345
|
-
g(t.indexOf(
|
|
346
|
-
const o = t.split(
|
|
345
|
+
g(t.indexOf(w) === -1, () => `The ellipsis notation ("${w}") is not supported yet.`);
|
|
346
|
+
const o = t.split(S), a = o.length;
|
|
347
347
|
if (n !== a)
|
|
348
348
|
throw new Error(`Expected ${a} input tensors, received ${n}`);
|
|
349
349
|
if (a > 2)
|
|
@@ -357,7 +357,7 @@ function xe(e, n) {
|
|
|
357
357
|
}
|
|
358
358
|
for (let l = 0; l < t.length; ++l) {
|
|
359
359
|
const f = t[l];
|
|
360
|
-
u.indexOf(f) === -1 && f !==
|
|
360
|
+
u.indexOf(f) === -1 && f !== S && u.push(f);
|
|
361
361
|
}
|
|
362
362
|
const c = new Array(o.length);
|
|
363
363
|
for (let l = 0; l < a; ++l) {
|
|
@@ -449,10 +449,10 @@ function je(e) {
|
|
|
449
449
|
return `Received SparseTensor with denseShape[0] = 0 but
|
|
450
450
|
indices.shape[0] = ${e}`;
|
|
451
451
|
}
|
|
452
|
-
function
|
|
452
|
+
function Ge(e, n) {
|
|
453
453
|
return `indices(${e}, 0) is invalid: ${n} < 0`;
|
|
454
454
|
}
|
|
455
|
-
function
|
|
455
|
+
function He(e, n, r) {
|
|
456
456
|
return `indices(${e}, 0) is invalid: ${n} >= ${r}`;
|
|
457
457
|
}
|
|
458
458
|
/**
|
|
@@ -471,7 +471,7 @@ function Ge(e, n, r) {
|
|
|
471
471
|
* limitations under the License.
|
|
472
472
|
* =============================================================================
|
|
473
473
|
*/
|
|
474
|
-
function
|
|
474
|
+
function ke(e, n) {
|
|
475
475
|
return `only one output dimension may be -1, not both ${e} and ${n}`;
|
|
476
476
|
}
|
|
477
477
|
function qe(e, n) {
|
|
@@ -480,12 +480,12 @@ function qe(e, n) {
|
|
|
480
480
|
function Ze() {
|
|
481
481
|
return "reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero";
|
|
482
482
|
}
|
|
483
|
-
function
|
|
483
|
+
function Xe(e, n) {
|
|
484
484
|
const r = m(e), t = m(n);
|
|
485
485
|
return `Input to reshape is a SparseTensor with ${r}
|
|
486
486
|
dense values, but the requested shape requires a multiple of ${t}. inputShape=${e} outputShape= ${n}`;
|
|
487
487
|
}
|
|
488
|
-
function
|
|
488
|
+
function Je(e, n) {
|
|
489
489
|
const r = m(e), t = m(n);
|
|
490
490
|
return `Input to reshape is a tensor with ${r} dense values, but the requested shape has ${t}. inputShape=${e} outputShape=${n}`;
|
|
491
491
|
}
|
|
@@ -505,13 +505,13 @@ function Xe(e, n) {
|
|
|
505
505
|
* limitations under the License.
|
|
506
506
|
* =============================================================================
|
|
507
507
|
*/
|
|
508
|
-
function
|
|
508
|
+
function Ke() {
|
|
509
509
|
return "segment ids must be >= 0";
|
|
510
510
|
}
|
|
511
|
-
function
|
|
511
|
+
function Qe() {
|
|
512
512
|
return "segment ids are not increasing";
|
|
513
513
|
}
|
|
514
|
-
function
|
|
514
|
+
function Ye(e, n) {
|
|
515
515
|
return `Segment id ${e} out of range [0, ${n}), possibly because segmentIds input is not sorted.`;
|
|
516
516
|
}
|
|
517
517
|
function et(e, n, r) {
|
|
@@ -608,7 +608,7 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
608
608
|
ERF_A3: Re,
|
|
609
609
|
ERF_A4: Me,
|
|
610
610
|
ERF_A5: _e,
|
|
611
|
-
ERF_P:
|
|
611
|
+
ERF_P: we,
|
|
612
612
|
PARALLELIZE_THRESHOLD: I,
|
|
613
613
|
get RowPartitionType() {
|
|
614
614
|
return p;
|
|
@@ -628,18 +628,18 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
628
628
|
combineRaggedTensorToTensorShapes: ce,
|
|
629
629
|
complexWithEvenIndex: Te,
|
|
630
630
|
complexWithOddIndex: be,
|
|
631
|
-
computeConv2DInfo:
|
|
632
|
-
computeConv3DInfo:
|
|
633
|
-
computeDefaultPad:
|
|
631
|
+
computeConv2DInfo: G,
|
|
632
|
+
computeConv3DInfo: H,
|
|
633
|
+
computeDefaultPad: k,
|
|
634
634
|
computeDilation2DInfo: q,
|
|
635
635
|
computeOptimalWindowSize: ge,
|
|
636
636
|
computeOutAndReduceShapes: N,
|
|
637
637
|
computeOutShape: le,
|
|
638
638
|
computePool2DInfo: Z,
|
|
639
|
-
computePool3DInfo:
|
|
640
|
-
convertConv2DDataFormat:
|
|
639
|
+
computePool3DInfo: X,
|
|
640
|
+
convertConv2DDataFormat: J,
|
|
641
641
|
decodeEinsumEquation: xe,
|
|
642
|
-
eitherStridesOrDilationsAreOne:
|
|
642
|
+
eitherStridesOrDilationsAreOne: K,
|
|
643
643
|
expandShapeToKeepDim: x,
|
|
644
644
|
exponent: ve,
|
|
645
645
|
exponents: We,
|
|
@@ -650,8 +650,8 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
650
650
|
getComplexWithIndex: Fe,
|
|
651
651
|
getEinsumComputePath: ze,
|
|
652
652
|
getEinsumPermutation: Pe,
|
|
653
|
-
getFusedBiasGradient:
|
|
654
|
-
getFusedDyActivation:
|
|
653
|
+
getFusedBiasGradient: Q,
|
|
654
|
+
getFusedDyActivation: Y,
|
|
655
655
|
getImageCenter: de,
|
|
656
656
|
getInnerMostAxes: C,
|
|
657
657
|
getPermuted: Ee,
|
|
@@ -661,19 +661,19 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
661
661
|
getReshapedPermuted: $e,
|
|
662
662
|
getRowPartitionTypesHelper: he,
|
|
663
663
|
getSliceBeginCoords: Ie,
|
|
664
|
-
getSliceSize:
|
|
664
|
+
getSliceSize: Se,
|
|
665
665
|
getSparseFillEmptyRowsIndicesDenseShapeMismatch: je,
|
|
666
|
-
getSparseFillEmptyRowsNegativeIndexErrorMessage:
|
|
667
|
-
getSparseFillEmptyRowsOutOfRangeIndexErrorMessage:
|
|
666
|
+
getSparseFillEmptyRowsNegativeIndexErrorMessage: Ge,
|
|
667
|
+
getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: He,
|
|
668
668
|
getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: Ze,
|
|
669
|
-
getSparseReshapeInputOutputMismatchErrorMessage:
|
|
670
|
-
getSparseReshapeInputOutputMultipleErrorMessage:
|
|
671
|
-
getSparseReshapeMultipleNegativeOneOutputDimErrorMessage:
|
|
669
|
+
getSparseReshapeInputOutputMismatchErrorMessage: Je,
|
|
670
|
+
getSparseReshapeInputOutputMultipleErrorMessage: Xe,
|
|
671
|
+
getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: ke,
|
|
672
672
|
getSparseReshapeNegativeOutputDimErrorMessage: qe,
|
|
673
673
|
getSparseSegmentReductionIndicesOutOfRangeErrorMessage: et,
|
|
674
|
-
getSparseSegmentReductionNegativeSegmentIdsErrorMessage:
|
|
675
|
-
getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage:
|
|
676
|
-
getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage:
|
|
674
|
+
getSparseSegmentReductionNegativeSegmentIdsErrorMessage: Ke,
|
|
675
|
+
getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: Qe,
|
|
676
|
+
getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: Ye,
|
|
677
677
|
getUndoAxesPermutation: z,
|
|
678
678
|
isIdentityPermutation: Ue,
|
|
679
679
|
log: T,
|
|
@@ -697,8 +697,8 @@ export {
|
|
|
697
697
|
Ee as B,
|
|
698
698
|
$e as C,
|
|
699
699
|
Ie as D,
|
|
700
|
-
|
|
701
|
-
|
|
700
|
+
we as E,
|
|
701
|
+
Se as F,
|
|
702
702
|
le as G,
|
|
703
703
|
ue as H,
|
|
704
704
|
xe as I,
|
|
@@ -728,17 +728,17 @@ export {
|
|
|
728
728
|
ot as f,
|
|
729
729
|
he as g,
|
|
730
730
|
je as h,
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
731
|
+
Ge as i,
|
|
732
|
+
He as j,
|
|
733
|
+
ke as k,
|
|
734
734
|
qe as l,
|
|
735
735
|
ye as m,
|
|
736
736
|
Ze as n,
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
737
|
+
Xe as o,
|
|
738
|
+
Je as p,
|
|
739
|
+
Ke as q,
|
|
740
|
+
Qe as r,
|
|
741
|
+
Ye as s,
|
|
742
742
|
et as t,
|
|
743
743
|
Ae as u,
|
|
744
744
|
pe as v,
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { B as h, C as f,
|
|
2
|
-
import { r as T } from "./reshape-
|
|
1
|
+
import { B as h, C as f, L as p, F as g, E as u, W as b } from "./index-CUQrfsw_.js";
|
|
2
|
+
import { r as T } from "./reshape-DEfQGSin.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
5
5
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { ai as S, T as h, af as k, d as v, aj as o, ak as p, al as g, n as N, t as y } from "./index-CUQrfsw_.js";
|
|
2
2
|
import { s as R } from "./index-C4L8Cm77.js";
|
|
3
|
-
import { s as $ } from "./stack-
|
|
4
|
-
import { t as B } from "./tensor-
|
|
3
|
+
import { s as $ } from "./stack-DkdFLq37.js";
|
|
4
|
+
import { t as B } from "./tensor-BAQdLqoU.js";
|
|
5
5
|
/**
|
|
6
6
|
* @license
|
|
7
7
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -75,7 +75,7 @@ function I(s) {
|
|
|
75
75
|
}
|
|
76
76
|
function c(s) {
|
|
77
77
|
let t = !1;
|
|
78
|
-
if (
|
|
78
|
+
if (k().get("IS_BROWSER"))
|
|
79
79
|
t = s instanceof TextDecoder;
|
|
80
80
|
else {
|
|
81
81
|
const { StringDecoder: e } = require("string_decoder");
|
|
@@ -930,7 +930,7 @@ class T {
|
|
|
930
930
|
*/
|
|
931
931
|
batch(t, e = !0) {
|
|
932
932
|
const r = this;
|
|
933
|
-
|
|
933
|
+
N(t > 0, () => `batchSize needs to be positive, but it is
|
|
934
934
|
${t}`);
|
|
935
935
|
let n;
|
|
936
936
|
return this.size === 1 / 0 || this.size == null ? n = this.size : e ? n = Math.ceil(this.size / t) : n = Math.floor(this.size / t), u(async () => (await r.iterator()).columnMajorBatch(t, e, st), n);
|