@genai-fi/nanogpt 0.7.3 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +25 -2
- package/dist/Generator.js +152 -49
- package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-D_q39E3A.js} +13 -13
- package/dist/{Reshape-DvudQDvJ.js → Reshape-41YpQqEo.js} +1 -1
- package/dist/{Reshape-DH5srBP0.js → Reshape-Bh_jzKzV.js} +5 -5
- package/dist/TeachableLLM.d.ts +6 -6
- package/dist/TeachableLLM.js +33 -31
- package/dist/Trainer.d.ts +13 -2
- package/dist/Trainer.js +21 -12
- package/dist/{axis_util-BzbKo31C.js → axis_util-Did9235A.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-TE7aTPhZ.js → backend_util-yC3YH1jo.js} +58 -58
- package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-CUvOdOT5.js} +2 -2
- package/dist/checks/appendCache.d.ts +1 -0
- package/dist/checks/appendCache.js +22 -0
- package/dist/checks/attentionMask.d.ts +1 -0
- package/dist/checks/attentionMask.js +37 -0
- package/dist/checks/check.d.ts +9 -0
- package/dist/checks/check.js +20 -0
- package/dist/checks/gelu.d.ts +1 -0
- package/dist/checks/gelu.js +18 -0
- package/dist/checks/index.d.ts +19 -0
- package/dist/checks/index.js +21 -0
- package/dist/checks/normRMS.d.ts +1 -0
- package/dist/checks/normRMS.js +16 -0
- package/dist/checks/normRMSGrad.d.ts +1 -0
- package/dist/checks/normRMSGrad.js +12 -0
- package/dist/checks/qkv.d.ts +1 -0
- package/dist/checks/qkv.js +25 -0
- package/dist/checks/rope.d.ts +1 -0
- package/dist/checks/rope.js +21 -0
- package/dist/{concat-CsxrgovM.js → concat-pHiVqR3L.js} +1 -1
- package/dist/{dataset-CtdBYwjo.js → dataset-DPPl-iLT.js} +9 -9
- package/dist/{dropout-DYs5QFGQ.js → dropout-CcKSfOYE.js} +18 -18
- package/dist/exports_initializers-DKk7-bsx.js +16 -0
- package/dist/{gather-CMMy2KEG.js → gather-CPg6ZlQA.js} +1 -1
- package/dist/{gelu-C-dPj6Ku.js → gelu-BkcmEEyD.js} +1 -1
- package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-D_ODOLix.js} +26 -26
- package/dist/{index-BoWRt-10.js → index-DdmHGZjq.js} +659 -650
- package/dist/{index-CLthM0TO.js → index-evZ57wr4.js} +185 -185
- package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-CDfFpUab.js} +21 -21
- package/dist/layers/BaseLayer.d.ts +8 -13
- package/dist/layers/BaseLayer.js +25 -13
- package/dist/layers/CausalSelfAttention.d.ts +3 -2
- package/dist/layers/CausalSelfAttention.js +28 -28
- package/dist/layers/MLP.d.ts +3 -2
- package/dist/layers/MLP.js +16 -20
- package/dist/layers/PositionEmbedding.d.ts +9 -0
- package/dist/layers/PositionEmbedding.js +45 -0
- package/dist/layers/RMSNorm.d.ts +3 -2
- package/dist/layers/RMSNorm.js +6 -6
- package/dist/layers/RoPECache.d.ts +1 -1
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.d.ts +3 -2
- package/dist/layers/TiedEmbedding.js +29 -7
- package/dist/layers/TransformerBlock.d.ts +3 -2
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/load.d.ts +2 -2
- package/dist/loader/loadHF.d.ts +2 -2
- package/dist/loader/loadTransformers.d.ts +4 -2
- package/dist/loader/loadTransformers.js +10 -9
- package/dist/loader/newZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.js +44 -51
- package/dist/loader/save.d.ts +8 -0
- package/dist/loader/save.js +62 -0
- package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C8yFJfZz.js} +45 -24
- package/dist/main.d.ts +6 -4
- package/dist/main.js +24 -18
- package/dist/{mat_mul-8m8pfdcx.js → mat_mul-Dpy2mMRu.js} +1 -1
- package/dist/mod-CbibJi3D.js +27 -0
- package/dist/models/NanoGPTV1.d.ts +15 -0
- package/dist/models/NanoGPTV1.js +71 -0
- package/dist/{config.d.ts → models/config.d.ts} +1 -0
- package/dist/{config.js → models/config.js} +1 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +14 -0
- package/dist/models/model.d.ts +26 -0
- package/dist/models/model.js +70 -0
- package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-q_Gmwyld.js} +1 -1
- package/dist/{ones-Dj0SDhHf.js → ones-BAqVh-eA.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +7 -7
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +10 -10
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +3 -3
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +5 -5
- package/dist/ops/webgpu/qkv.js +3 -3
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/ops-542ai2vG.js +1525 -0
- package/dist/{random_width-sZORGo5k.js → random_width-DKGeiFuR.js} +1471 -1538
- package/dist/{range-CRuAh-gd.js → range-BcUvLuf5.js} +1 -1
- package/dist/{reciprocal-BvGAyKyu.js → reciprocal-DhDWSKiD.js} +1 -1
- package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-Do9VvZmo.js} +2488 -2534
- package/dist/{max-Ddnnb5xe.js → relu-B1AXs7p5.js} +6 -6
- package/dist/{reshape-CdBq1WJ6.js → reshape-WeJkT3ja.js} +1 -1
- package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-B7yDhiQr.js} +1 -1
- package/dist/{selu_util-BJEXVvjX.js → selu_util-BgUO9gHY.js} +125 -146
- package/dist/{shared-wS99K7_n.js → shared-CZiWmQCI.js} +1 -1
- package/dist/{shared-B8ztnyEk.js → shared-V6D_md-c.js} +72 -72
- package/dist/{sin-BeA3tsEd.js → sin-CPxad7Am.js} +1 -1
- package/dist/{slice-BiOsknYS.js → slice-B7jXtPnp.js} +1 -1
- package/dist/{softmax-Bv_6lyMX.js → softmax-BfsyI4As.js} +1 -1
- package/dist/{split-B-dikLRw.js → split-BPxr8_8m.js} +1 -1
- package/dist/{stack-B17UN2nn.js → stack-BNwLzE43.js} +1 -1
- package/dist/{sum-66ew2byf.js → sum-ByFINZgi.js} +3 -3
- package/dist/{tensor-JwS7ZYY6.js → tensor-DbqgIV9B.js} +1 -1
- package/dist/tensor1d-CtJq5BOv.js +27 -0
- package/dist/{tensor2d-wxPAnDQy.js → tensor2d-CObBWBkW.js} +1 -1
- package/dist/tensor3d-BOukqWwr.js +30 -0
- package/dist/tensor4d-DLtk7Nxh.js +30 -0
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/Evaluator.d.ts +2 -2
- package/dist/training/FullTrainer.d.ts +3 -3
- package/dist/training/FullTrainer.js +61 -69
- package/dist/training/Trainer.d.ts +15 -3
- package/dist/training/Trainer.js +39 -47
- package/dist/training/sparseCrossEntropy.js +12 -13
- package/dist/utilities/arrayClose.d.ts +1 -1
- package/dist/utilities/arrayClose.js +16 -7
- package/dist/utilities/dummy.d.ts +4 -4
- package/dist/utilities/dummy.js +13 -13
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/parameters.d.ts +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BuddVFLa.js → variable-DPFOJyRG.js} +1 -1
- package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-Dhk9R5aG.js} +1 -1
- package/dist/{webgpu_util-D____QpY.js → webgpu_util-BqGnZg8t.js} +27 -27
- package/dist/{zeros--BdLQ3oG.js → zeros-Dnwix0p4.js} +1 -1
- package/package.json +2 -3
- package/dist/NanoGPTModel.d.ts +0 -52
- package/dist/NanoGPTModel.js +0 -203
- package/dist/TiedEmbedding-BxOerUmB.js +0 -43
- package/dist/ops-BFGCx8Ri.js +0 -1202
- package/dist/utilities/generate.d.ts +0 -3
- package/dist/utilities/generate.js +0 -22
- package/dist/utilities/save.d.ts +0 -9
- package/dist/utilities/save.js +0 -61
|
@@ -1,16 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
import S from "
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
const T = {
|
|
1
|
+
import y from "./Trainer.js";
|
|
2
|
+
import v from "./Evaluator.js";
|
|
3
|
+
import { d as S } from "../index-DdmHGZjq.js";
|
|
4
|
+
import w from "../utilities/profile.js";
|
|
5
|
+
const f = {
|
|
7
6
|
desiredLoss: 0.01,
|
|
8
7
|
logInterval: 1,
|
|
9
8
|
maxSteps: 1e3
|
|
10
9
|
};
|
|
11
|
-
class
|
|
12
|
-
constructor(
|
|
13
|
-
super(
|
|
10
|
+
class b extends y {
|
|
11
|
+
constructor(s, t, a = 3e-4) {
|
|
12
|
+
super(s, t, a);
|
|
14
13
|
}
|
|
15
14
|
createEmptyState() {
|
|
16
15
|
return {
|
|
@@ -24,104 +23,97 @@ class z extends x {
|
|
|
24
23
|
...this.lastState || {}
|
|
25
24
|
};
|
|
26
25
|
}
|
|
27
|
-
createLogEntry(
|
|
26
|
+
createLogEntry(s, t, a, n) {
|
|
28
27
|
return {
|
|
29
|
-
loss:
|
|
30
|
-
step:
|
|
28
|
+
loss: s.lastLoss,
|
|
29
|
+
step: s.step,
|
|
31
30
|
time: Date.now() - t,
|
|
32
|
-
batchSize:
|
|
33
|
-
learningRate:
|
|
31
|
+
batchSize: a,
|
|
32
|
+
learningRate: n ? this.optimizer.lr : void 0
|
|
34
33
|
};
|
|
35
34
|
}
|
|
36
|
-
createProgress(
|
|
35
|
+
createProgress(s, t, a) {
|
|
37
36
|
return {
|
|
38
|
-
duration:
|
|
39
|
-
totalSamples:
|
|
40
|
-
samplesPerSecond:
|
|
41
|
-
memory:
|
|
37
|
+
duration: s.trainingDuration,
|
|
38
|
+
totalSamples: s.totalSteps * t.batchSize,
|
|
39
|
+
samplesPerSecond: s.totalSteps * t.batchSize / (s.trainingDuration / 1e3),
|
|
40
|
+
memory: a ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
42
41
|
};
|
|
43
42
|
}
|
|
44
|
-
async stepDataset(
|
|
45
|
-
const { logInterval:
|
|
46
|
-
...
|
|
43
|
+
async stepDataset(s, t, a) {
|
|
44
|
+
const { logInterval: n } = {
|
|
45
|
+
...f,
|
|
47
46
|
...t
|
|
48
|
-
},
|
|
49
|
-
this.lastState =
|
|
50
|
-
const
|
|
47
|
+
}, l = Date.now(), r = this.createEmptyState();
|
|
48
|
+
this.lastState = r, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new w())), this.running = !0, r.logStartTime = l;
|
|
49
|
+
const m = a ? new v(this.model, a) : void 0, e = await s.iterator();
|
|
51
50
|
try {
|
|
52
51
|
for (; this.running; ) {
|
|
53
52
|
const i = await e.next();
|
|
54
53
|
if (i.done) break;
|
|
55
|
-
const
|
|
56
|
-
if (this.model.
|
|
54
|
+
const g = i.value, o = this.trainBatch(r, g), c = this.createLogEntry(r, l, g.xs.shape[0], t?.advancedMetrics);
|
|
55
|
+
if (this.model.trainingState = {
|
|
56
|
+
steps: r.totalSteps,
|
|
57
|
+
learningRate: this.optimizer.lr,
|
|
58
|
+
batchSize: g.xs.shape[0],
|
|
59
|
+
loss: r.lastLoss
|
|
60
|
+
}, r.step % n === 0) {
|
|
57
61
|
await o.data();
|
|
58
|
-
const
|
|
59
|
-
if (
|
|
62
|
+
const u = Date.now();
|
|
63
|
+
if (r.trainingDuration += u - r.logStartTime, m)
|
|
60
64
|
try {
|
|
61
|
-
const
|
|
62
|
-
|
|
63
|
-
} catch (
|
|
64
|
-
console.error("Validation error:",
|
|
65
|
+
const h = await m.evaluate(5);
|
|
66
|
+
r.validationLosses.push(h), c.valLoss = h;
|
|
67
|
+
} catch (h) {
|
|
68
|
+
console.error("Validation error:", h);
|
|
65
69
|
}
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
temperature: 0.8
|
|
69
|
-
});
|
|
70
|
-
n.example = l;
|
|
71
|
-
}
|
|
72
|
-
const c = this.createProgress(a, n, t?.advancedMetrics);
|
|
73
|
-
return o.dispose(), this.stop(), { log: n, progress: c };
|
|
70
|
+
const p = this.createProgress(r, c, t?.advancedMetrics);
|
|
71
|
+
return o.dispose(), this.stop(), { log: c, progress: p };
|
|
74
72
|
}
|
|
75
73
|
o.dispose();
|
|
76
74
|
}
|
|
77
75
|
} catch (i) {
|
|
78
|
-
throw console.error("Training error:", i),
|
|
76
|
+
throw console.error("Training error:", i), S(), i;
|
|
79
77
|
}
|
|
80
|
-
throw
|
|
78
|
+
throw S(), this.running = !1, new Error("No log returned before training stopped.");
|
|
81
79
|
}
|
|
82
80
|
// Train for multiple epochs using Dataset API - FIXED memory leaks
|
|
83
|
-
async trainOnDataset(
|
|
84
|
-
const { logInterval:
|
|
85
|
-
...
|
|
81
|
+
async trainOnDataset(s, t, a) {
|
|
82
|
+
const { logInterval: n, onStep: l, maxSteps: r } = {
|
|
83
|
+
...f,
|
|
86
84
|
...t
|
|
87
|
-
},
|
|
88
|
-
this.lastState = e, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() ||
|
|
89
|
-
const i =
|
|
85
|
+
}, m = Date.now(), e = this.createEmptyState();
|
|
86
|
+
this.lastState = e, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new w())), this.running = !0, e.logStartTime = m;
|
|
87
|
+
const i = a ? new v(this.model, a) : void 0, g = await s.iterator();
|
|
90
88
|
try {
|
|
91
89
|
for (; this.running; ) {
|
|
92
|
-
const o = await
|
|
90
|
+
const o = await g.next();
|
|
93
91
|
if (o.done) break;
|
|
94
|
-
const
|
|
95
|
-
if (
|
|
96
|
-
await
|
|
97
|
-
const
|
|
98
|
-
if (e.trainingDuration +=
|
|
92
|
+
const c = o.value, u = this.trainBatch(e, c), p = this.createLogEntry(e, m, c.xs.shape[0], t?.advancedMetrics);
|
|
93
|
+
if (e.step % n === 0) {
|
|
94
|
+
await u.data();
|
|
95
|
+
const h = Date.now();
|
|
96
|
+
if (e.trainingDuration += h - e.logStartTime, i)
|
|
99
97
|
try {
|
|
100
98
|
const d = await i.evaluate(5);
|
|
101
|
-
e.validationLosses.push(d),
|
|
99
|
+
e.validationLosses.push(d), p.valLoss = d;
|
|
102
100
|
} catch (d) {
|
|
103
101
|
console.error("Validation error:", d);
|
|
104
102
|
}
|
|
105
|
-
if (
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
temperature: 0.8
|
|
109
|
-
});
|
|
110
|
-
c.example = L;
|
|
111
|
-
}
|
|
112
|
-
const d = this.createProgress(e, c, t?.advancedMetrics);
|
|
113
|
-
await m(c, d);
|
|
103
|
+
if (l) {
|
|
104
|
+
const d = this.createProgress(e, p, t?.advancedMetrics);
|
|
105
|
+
await l(p, d);
|
|
114
106
|
}
|
|
115
107
|
e.logStartTime = Date.now();
|
|
116
108
|
}
|
|
117
|
-
|
|
109
|
+
u.dispose(), e.step >= r && this.stop();
|
|
118
110
|
}
|
|
119
111
|
} catch (o) {
|
|
120
|
-
throw console.error("Training error:", o),
|
|
112
|
+
throw console.error("Training error:", o), S(), o;
|
|
121
113
|
}
|
|
122
|
-
return
|
|
114
|
+
return S(), this.running = !1, { losses: e.losses, validationLosses: e.validationLosses };
|
|
123
115
|
}
|
|
124
116
|
}
|
|
125
117
|
export {
|
|
126
|
-
|
|
118
|
+
b as default
|
|
127
119
|
};
|
|
@@ -1,10 +1,20 @@
|
|
|
1
1
|
import { ITokeniser } from '../tokeniser/type';
|
|
2
2
|
import { DatasetBuilder } from './DatasetBuilder';
|
|
3
|
-
import { default as NanoGPT, TrainingLogEntry } from '../NanoGPTModel';
|
|
4
3
|
import { default as AdamExt } from './AdamExt';
|
|
5
4
|
import { TensorContainer } from '@tensorflow/tfjs-core/dist/tensor_types';
|
|
6
5
|
import { Scalar, Tensor } from '@tensorflow/tfjs-core';
|
|
7
6
|
import { Dataset } from '@tensorflow/tfjs-data';
|
|
7
|
+
import { default as Model, ModelForwardAttributes } from '../models/model';
|
|
8
|
+
export interface TrainingLogEntry {
|
|
9
|
+
loss: number;
|
|
10
|
+
valLoss?: number;
|
|
11
|
+
step: number;
|
|
12
|
+
time: number;
|
|
13
|
+
example?: string;
|
|
14
|
+
batchSize: number;
|
|
15
|
+
gradientNorm?: number;
|
|
16
|
+
learningRate?: number;
|
|
17
|
+
}
|
|
8
18
|
export interface TrainingState {
|
|
9
19
|
step: number;
|
|
10
20
|
lastLoss: number;
|
|
@@ -35,13 +45,15 @@ export interface TrainingOptions {
|
|
|
35
45
|
}
|
|
36
46
|
export default abstract class GPTTrainer {
|
|
37
47
|
protected tokenizer: ITokeniser;
|
|
38
|
-
protected model:
|
|
48
|
+
protected model: Model<ModelForwardAttributes>;
|
|
39
49
|
protected optimizer: AdamExt;
|
|
40
50
|
protected datasetBuilder: DatasetBuilder;
|
|
41
51
|
protected learningRate: number;
|
|
42
52
|
protected running: boolean;
|
|
43
53
|
protected lastState?: TrainingState;
|
|
44
|
-
|
|
54
|
+
protected _gradientCheckpointing: boolean;
|
|
55
|
+
constructor(model: Model<ModelForwardAttributes>, tokenizer: ITokeniser, learningRate?: number);
|
|
56
|
+
setGradientCheckpointing(enabled: boolean): void;
|
|
45
57
|
setLearningRate(learningRate: number): void;
|
|
46
58
|
reset(): void;
|
|
47
59
|
stop(): void;
|
package/dist/training/Trainer.js
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import { DatasetBuilder as
|
|
1
|
+
import { DatasetBuilder as m, flattenTokens as c, PAGE_FACTOR as g } from "./DatasetBuilder.js";
|
|
2
2
|
import u from "./AdamExt.js";
|
|
3
|
-
import { t as f, v as y, d as
|
|
4
|
-
import { z as
|
|
3
|
+
import { t as f, v as y, d as p } from "../index-DdmHGZjq.js";
|
|
4
|
+
import { z as h } from "../zeros-Dnwix0p4.js";
|
|
5
5
|
class x {
|
|
6
|
-
constructor(t, e,
|
|
7
|
-
this.tokenizer = e, this.model = t, this.learningRate =
|
|
6
|
+
constructor(t, e, i = 1e-3) {
|
|
7
|
+
this.tokenizer = e, this.model = t, this.learningRate = i, this.resetOptimizer(), this.datasetBuilder = new m(e, t.config.blockSize);
|
|
8
8
|
}
|
|
9
9
|
model;
|
|
10
10
|
optimizer;
|
|
@@ -12,6 +12,10 @@ class x {
|
|
|
12
12
|
learningRate;
|
|
13
13
|
running = !1;
|
|
14
14
|
lastState;
|
|
15
|
+
_gradientCheckpointing = !1;
|
|
16
|
+
setGradientCheckpointing(t) {
|
|
17
|
+
this._gradientCheckpointing = t;
|
|
18
|
+
}
|
|
15
19
|
setLearningRate(t) {
|
|
16
20
|
this.learningRate = t, this.resetOptimizer({ learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 });
|
|
17
21
|
}
|
|
@@ -40,71 +44,59 @@ class x {
|
|
|
40
44
|
);
|
|
41
45
|
this.optimizer = e;
|
|
42
46
|
}
|
|
43
|
-
|
|
44
|
-
let maxNorm = 0;
|
|
45
|
-
// Print all gradients
|
|
46
|
-
await Promise.all(
|
|
47
|
-
Object.keys(grads).map(async (varName) => {
|
|
48
|
-
const grad = grads[varName];
|
|
49
|
-
const temp = norm(grad);
|
|
50
|
-
const gradNorm = (await temp.data())[0];
|
|
51
|
-
temp.dispose();
|
|
52
|
-
if (gradNorm > maxNorm) {
|
|
53
|
-
maxNorm = gradNorm;
|
|
54
|
-
}
|
|
55
|
-
})
|
|
56
|
-
);
|
|
57
|
-
return maxNorm;
|
|
58
|
-
}*/
|
|
59
|
-
trainStep(t, e, a = !1) {
|
|
47
|
+
trainStep(t, e, i = !1) {
|
|
60
48
|
return f(() => {
|
|
61
49
|
this.model.getProfiler()?.startMemory();
|
|
62
|
-
const { xs:
|
|
63
|
-
const [l, d] = this.model.forward(
|
|
50
|
+
const { xs: a, ys: s } = e, n = () => {
|
|
51
|
+
const [l, d] = this.model.forward(
|
|
52
|
+
{ training: !0, checkpointing: this._gradientCheckpointing },
|
|
53
|
+
a,
|
|
54
|
+
s
|
|
55
|
+
);
|
|
64
56
|
return l.dispose(), d;
|
|
65
|
-
}, { value:
|
|
66
|
-
return
|
|
57
|
+
}, { value: o, grads: r } = y(n);
|
|
58
|
+
return i ? this.model.getProfiler()?.endMemory("Training") : (this.optimizer.applyGradients(r), this.model.getProfiler()?.endMemory("Training"), p(r)), o;
|
|
67
59
|
});
|
|
68
60
|
}
|
|
69
61
|
async dummyPass() {
|
|
70
|
-
const t =
|
|
62
|
+
const t = h([1, this.model.config.blockSize], "int32"), e = h([1, this.model.config.blockSize], "int32");
|
|
71
63
|
try {
|
|
72
|
-
const
|
|
73
|
-
await
|
|
74
|
-
} catch (
|
|
75
|
-
console.error("Error during dummy pass:",
|
|
64
|
+
const i = this.trainStep({}, { xs: t, ys: e }, !0);
|
|
65
|
+
await i.data(), i.dispose();
|
|
66
|
+
} catch (i) {
|
|
67
|
+
console.error("Error during dummy pass:", i);
|
|
76
68
|
} finally {
|
|
77
69
|
t.dispose(), e.dispose();
|
|
78
70
|
}
|
|
79
71
|
}
|
|
80
72
|
trainBatch(t, e) {
|
|
81
73
|
try {
|
|
82
|
-
const
|
|
83
|
-
return e.xs.dispose(), e.ys.dispose(), t.step++, t.totalSteps++,
|
|
84
|
-
} catch (
|
|
85
|
-
throw console.error(`Error processing batch at step ${t.step}:`,
|
|
74
|
+
const i = this.trainStep(t, e, !1);
|
|
75
|
+
return e.xs.dispose(), e.ys.dispose(), t.step++, t.totalSteps++, i;
|
|
76
|
+
} catch (i) {
|
|
77
|
+
throw console.error(`Error processing batch at step ${t.step}:`, i), p(), i;
|
|
86
78
|
}
|
|
87
79
|
}
|
|
88
|
-
async createTrainValidationSplit(t, e = 32,
|
|
89
|
-
const
|
|
90
|
-
if (
|
|
91
|
-
const r = Math.floor(
|
|
92
|
-
for (;
|
|
80
|
+
async createTrainValidationSplit(t, e = 32, i = 0.1) {
|
|
81
|
+
const a = await c(t, this.tokenizer), s = /* @__PURE__ */ new Set();
|
|
82
|
+
if (i > 0) {
|
|
83
|
+
const r = Math.floor(a.length / (this.datasetBuilder.blockSize * g)), l = Math.max(1, Math.floor(r * i));
|
|
84
|
+
for (; s.size < l; ) {
|
|
93
85
|
const d = Math.floor(Math.random() * r);
|
|
94
|
-
|
|
86
|
+
s.add(d);
|
|
95
87
|
}
|
|
96
88
|
}
|
|
97
|
-
const
|
|
98
|
-
|
|
89
|
+
const n = await this.datasetBuilder.createTextDataset(a, e, s, !1), o = await this.datasetBuilder.createTextDataset(
|
|
90
|
+
a,
|
|
99
91
|
e,
|
|
100
|
-
|
|
92
|
+
s,
|
|
101
93
|
!0
|
|
102
94
|
);
|
|
103
|
-
return { trainDataset:
|
|
95
|
+
return { trainDataset: n, validationDataset: o };
|
|
104
96
|
}
|
|
105
97
|
async createDataset(t, e = 32) {
|
|
106
|
-
const
|
|
107
|
-
return await this.datasetBuilder.createTextDataset(
|
|
98
|
+
const i = await c(t, this.tokenizer);
|
|
99
|
+
return await this.datasetBuilder.createTextDataset(i, e);
|
|
108
100
|
}
|
|
109
101
|
dispose() {
|
|
110
102
|
this.optimizer && this.optimizer.dispose();
|
|
@@ -1,28 +1,27 @@
|
|
|
1
1
|
import { gatherSub as x } from "../ops/gatherSub.js";
|
|
2
2
|
import { scatterSub as L } from "../ops/scatterSub.js";
|
|
3
|
-
import {
|
|
4
|
-
import { s as
|
|
5
|
-
import { m as z } from "../
|
|
6
|
-
import { l as v } from "../log_sum_exp-DbjkV734.js";
|
|
3
|
+
import { J as C, t as u, L as E, c as G } from "../index-DdmHGZjq.js";
|
|
4
|
+
import { s as y } from "../softmax-BfsyI4As.js";
|
|
5
|
+
import { m as z, l as v } from "../log_sum_exp-C8yFJfZz.js";
|
|
7
6
|
function k(t, s) {
|
|
8
7
|
return u(() => {
|
|
9
|
-
const n = t.shape[t.shape.length - 1], c = t.shape.slice(0, -1).reduce((o, e) => o * e, 1), h = t.shape.length > 2 ? t.reshape([c, n]) : t, p = s.shape.length > 1 ? s.reshape([c]).cast("int32") : s.cast("int32"), r = z(h, -1, !0), a =
|
|
10
|
-
return x(
|
|
8
|
+
const n = t.shape[t.shape.length - 1], c = t.shape.slice(0, -1).reduce((o, e) => o * e, 1), h = t.shape.length > 2 ? t.reshape([c, n]) : t, p = s.shape.length > 1 ? s.reshape([c]).cast("int32") : s.cast("int32"), r = z(h, -1, !0), a = G(h, r), d = v(a, -1);
|
|
9
|
+
return x(d, p, a);
|
|
11
10
|
});
|
|
12
11
|
}
|
|
13
|
-
function
|
|
14
|
-
return
|
|
12
|
+
function q() {
|
|
13
|
+
return C(
|
|
15
14
|
// @ts-expect-error Invalid params
|
|
16
|
-
(s, n,
|
|
17
|
-
const c = s.shape[s.shape.length - 1], p = s.shape.slice(0, -1).reduce((o, e) => o * e, 1), r = s.reshape([p, c]), a = n.reshape([p]).cast("int32"),
|
|
18
|
-
return
|
|
19
|
-
const S = e[0], f = e[1], b =
|
|
15
|
+
(s, n, m) => {
|
|
16
|
+
const c = s.shape[s.shape.length - 1], p = s.shape.slice(0, -1).reduce((o, e) => o * e, 1), r = s.reshape([p, c]), a = n.reshape([p]).cast("int32"), d = k(r, a);
|
|
17
|
+
return m([r, a]), r.dispose(), a.dispose(), { value: d, gradFunc: (o, e) => u(() => {
|
|
18
|
+
const S = e[0], f = e[1], b = y(S), l = L(b, f, o), g = E(n);
|
|
20
19
|
return [l.reshape(s.shape), g];
|
|
21
20
|
}) };
|
|
22
21
|
}
|
|
23
22
|
);
|
|
24
23
|
}
|
|
25
24
|
export {
|
|
26
|
-
|
|
25
|
+
q as createSoftmaxCrossEntropyWithGrad,
|
|
27
26
|
k as sparseSoftmaxCrossEntropy
|
|
28
27
|
};
|
|
@@ -1 +1 @@
|
|
|
1
|
-
export declare function arraysClose(a: unknown, b: unknown
|
|
1
|
+
export declare function arraysClose(a: unknown, b: unknown): number;
|
|
@@ -1,11 +1,20 @@
|
|
|
1
|
-
function
|
|
1
|
+
function n(r, e) {
|
|
2
|
+
let t = 0;
|
|
2
3
|
if (Array.isArray(r) && Array.isArray(e)) {
|
|
3
|
-
if (r.length !== e.length) return
|
|
4
|
-
for (let
|
|
5
|
-
|
|
6
|
-
return
|
|
7
|
-
} else
|
|
4
|
+
if (r.length !== e.length) return Number.POSITIVE_INFINITY;
|
|
5
|
+
for (let i = 0; i < r.length; ++i)
|
|
6
|
+
t = Math.max(t, n(r[i], e[i]));
|
|
7
|
+
return t;
|
|
8
|
+
} else if (typeof r == "number" && typeof e == "number") {
|
|
9
|
+
if (isNaN(r) && isNaN(e))
|
|
10
|
+
return 0;
|
|
11
|
+
if (!isFinite(r) || !isFinite(e))
|
|
12
|
+
return r === e ? 0 : Number.POSITIVE_INFINITY;
|
|
13
|
+
const i = Math.abs(r - e);
|
|
14
|
+
return t = Math.max(t, i), t;
|
|
15
|
+
} else
|
|
16
|
+
return Number.POSITIVE_INFINITY;
|
|
8
17
|
}
|
|
9
18
|
export {
|
|
10
|
-
|
|
19
|
+
n as arraysClose
|
|
11
20
|
};
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import { default as
|
|
2
|
-
export declare function dummyPassAsync(model:
|
|
1
|
+
import { default as Model, ModelForwardAttributes } from '../models/model';
|
|
2
|
+
export declare function dummyPassAsync(model: Model<ModelForwardAttributes>): Promise<void>;
|
|
3
3
|
export interface MemoryRequirements {
|
|
4
4
|
perBatch: number;
|
|
5
5
|
tapeSize: number;
|
|
6
6
|
gradients: number;
|
|
7
7
|
}
|
|
8
|
-
export declare function dummyPassTrainAsync(model:
|
|
9
|
-
export declare function dummyPass(model:
|
|
8
|
+
export declare function dummyPassTrainAsync(model: Model<ModelForwardAttributes>): Promise<MemoryRequirements>;
|
|
9
|
+
export declare function dummyPass(model: Model<ModelForwardAttributes>): void;
|
package/dist/utilities/dummy.js
CHANGED
|
@@ -1,31 +1,31 @@
|
|
|
1
|
-
import { m as y, v as P, e as S } from "../index-
|
|
2
|
-
import { z as i } from "../zeros
|
|
1
|
+
import { m as y, v as P, e as S } from "../index-DdmHGZjq.js";
|
|
2
|
+
import { z as i } from "../zeros-Dnwix0p4.js";
|
|
3
3
|
async function w(s) {
|
|
4
|
-
const t = i([1, s.config.
|
|
4
|
+
const t = i([1, s.config.blockSize], "int32"), [e, n] = s.forward({ training: !1 }, t);
|
|
5
5
|
await e.data(), e.dispose(), n && n.dispose(), t.dispose();
|
|
6
6
|
}
|
|
7
7
|
async function k(s) {
|
|
8
8
|
const t = y(), e = t.numBytesInGPUAllocated ?? t.numBytesAllocatedInGPU ?? t.numBytes;
|
|
9
9
|
await w(s);
|
|
10
|
-
const n = i([1, s.config.
|
|
10
|
+
const n = i([1, s.config.blockSize], "int32"), r = i([1, s.config.blockSize], "int32"), o = {
|
|
11
11
|
perBatch: 0,
|
|
12
12
|
tapeSize: 0,
|
|
13
13
|
gradients: s.getNumParams() * 4
|
|
14
14
|
}, f = () => {
|
|
15
|
-
const [c,
|
|
16
|
-
let
|
|
17
|
-
if (
|
|
18
|
-
for (const z of
|
|
19
|
-
|
|
20
|
-
return o.tapeSize =
|
|
21
|
-
}, { value: m, grads: d } = P(f), a = y(),
|
|
22
|
-
o.perBatch =
|
|
15
|
+
const [c, g] = s.forward({ training: !0 }, n, r), u = S().state.activeTape;
|
|
16
|
+
let p = 0;
|
|
17
|
+
if (u)
|
|
18
|
+
for (const z of u)
|
|
19
|
+
p += z.saved?.reduce((B, I) => B + I.size * 4, 0) || 0;
|
|
20
|
+
return o.tapeSize = p, c.dispose(), g;
|
|
21
|
+
}, { value: m, grads: d } = P(f), a = y(), l = a.numBytesInGPUAllocated ?? a.numBytesAllocatedInGPU ?? a.numBytes;
|
|
22
|
+
o.perBatch = l - e - o.gradients, console.log("Dummy training memory requirements:", o), await m.data(), m.dispose();
|
|
23
23
|
for (const c in d)
|
|
24
24
|
d[c].dispose();
|
|
25
25
|
return n.dispose(), r.dispose(), o;
|
|
26
26
|
}
|
|
27
27
|
function v(s) {
|
|
28
|
-
const t = i([1, s.config.
|
|
28
|
+
const t = i([1, s.config.blockSize], "int32"), [e, n] = s.forward({ training: !1 }, t);
|
|
29
29
|
e.dispose(), n && n.dispose(), t.dispose();
|
|
30
30
|
}
|
|
31
31
|
export {
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { GPTConfig } from '../config';
|
|
1
|
+
import { GPTConfig } from '../models/config';
|
|
2
2
|
export declare function estimateParameterCount(config: GPTConfig): number;
|
|
3
3
|
export declare function estimateMemoryUsage(config: GPTConfig): number;
|
|
4
4
|
export declare function estimateTrainingMemoryUsage(config: GPTConfig, batchSize: number): number;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import "../index-
|
|
2
|
-
import { t as p } from "../tensor-
|
|
1
|
+
import "../index-DdmHGZjq.js";
|
|
2
|
+
import { t as p } from "../tensor-DbqgIV9B.js";
|
|
3
3
|
function h(n) {
|
|
4
4
|
const e = n.reduce((s, o) => s + o.length, 0), a = new Float32Array(e);
|
|
5
5
|
let t = 0;
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { n as u } from "./index-DdmHGZjq.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
@@ -15,57 +15,57 @@ import { l as u } from "./index-BoWRt-10.js";
|
|
|
15
15
|
* limitations under the License.
|
|
16
16
|
* =============================================================================
|
|
17
17
|
*/
|
|
18
|
-
const e = (
|
|
18
|
+
const e = (n) => {
|
|
19
19
|
let t = 1;
|
|
20
|
-
for (let
|
|
21
|
-
t *= r
|
|
20
|
+
for (let r = 0; r < n.length; r++)
|
|
21
|
+
t *= n[r];
|
|
22
22
|
return t;
|
|
23
23
|
};
|
|
24
|
-
function m(
|
|
24
|
+
function m(n, t, r = [1, 1, 1], a = [1, 1, 1]) {
|
|
25
25
|
const [o, i, f] = [
|
|
26
|
-
Math.ceil(e(
|
|
27
|
-
|
|
28
|
-
|
|
26
|
+
Math.ceil(e(n.x.map((c) => t[c])) / (r[0] * a[0])),
|
|
27
|
+
n.y ? Math.ceil(e(n.y.map((c) => t[c])) / (r[1] * a[1])) : 1,
|
|
28
|
+
n.z ? Math.ceil(e(n.z.map((c) => t[c])) / (r[2] * a[2])) : 1
|
|
29
29
|
];
|
|
30
30
|
return [o, i, f];
|
|
31
31
|
}
|
|
32
|
-
function d(
|
|
32
|
+
function d(n, t, r, a = !1) {
|
|
33
33
|
const o = [8, 8, 1], i = [4, 4, 1];
|
|
34
|
-
return a || (
|
|
34
|
+
return a || (n <= 8 && (i[1] = 1), t <= 16 && r <= 16 && (o[0] = 4)), { workgroupSize: o, elementsPerThread: i };
|
|
35
35
|
}
|
|
36
|
-
function p(
|
|
37
|
-
if (
|
|
36
|
+
function p(n, t, r = !1) {
|
|
37
|
+
if (r)
|
|
38
38
|
return [8, 8, 1];
|
|
39
|
-
const a = e(
|
|
39
|
+
const a = e(n.x.map((i) => t[i])), o = e(n.y.map((i) => t[i]));
|
|
40
40
|
return a <= 4 ? [4, 16, 1] : o <= 4 ? [16, 4, 1] : [16, 16, 1];
|
|
41
41
|
}
|
|
42
|
-
function M(
|
|
43
|
-
if (
|
|
42
|
+
function M(n, t, r = !1) {
|
|
43
|
+
if (r)
|
|
44
44
|
return [4, 4, 1];
|
|
45
|
-
const a = e(
|
|
45
|
+
const a = e(n.x.map((i) => t[i])), o = e(n.y.map((i) => t[i]));
|
|
46
46
|
return a <= 4 ? [1, 2, 1] : o <= 4 ? [2, 1, 1] : [2, 2, 1];
|
|
47
47
|
}
|
|
48
|
-
function h(
|
|
49
|
-
return { x:
|
|
48
|
+
function h(n) {
|
|
49
|
+
return { x: n.map((t, r) => r) };
|
|
50
50
|
}
|
|
51
|
-
function x(
|
|
52
|
-
if (
|
|
51
|
+
function x(n) {
|
|
52
|
+
if (n === "float32" || n === "int32" || n === "bool" || n === "string")
|
|
53
53
|
return 4;
|
|
54
|
-
if (
|
|
54
|
+
if (n === "complex64")
|
|
55
55
|
return 8;
|
|
56
|
-
throw new Error(`Unknown dtype ${
|
|
56
|
+
throw new Error(`Unknown dtype ${n}`);
|
|
57
57
|
}
|
|
58
58
|
function g() {
|
|
59
59
|
return !!(typeof globalThis < "u" && globalThis.navigator && globalThis.navigator.gpu);
|
|
60
60
|
}
|
|
61
|
-
function b(
|
|
62
|
-
Array.isArray(
|
|
63
|
-
|
|
61
|
+
function b(n, t) {
|
|
62
|
+
Array.isArray(n) || (n = [n]), n.forEach((r) => {
|
|
63
|
+
r != null && u(r.dtype !== "complex64", () => `${t} does not support complex64 tensors in the WebGPU backend.`);
|
|
64
64
|
});
|
|
65
65
|
}
|
|
66
66
|
var s;
|
|
67
|
-
(function(
|
|
68
|
-
|
|
67
|
+
(function(n) {
|
|
68
|
+
n[n.MatMulReduceProgram = 0] = "MatMulReduceProgram", n[n.MatMulSplitKProgram = 1] = "MatMulSplitKProgram", n[n.MatMulSmallOutputSizeProgram = 2] = "MatMulSmallOutputSizeProgram", n[n.MatMulPackedProgram = 3] = "MatMulPackedProgram", n[n.MatMulMax = 4] = "MatMulMax";
|
|
69
69
|
})(s || (s = {}));
|
|
70
70
|
export {
|
|
71
71
|
x as G,
|