@genai-fi/nanogpt 0.7.1 → 0.7.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +11 -2
- package/dist/Generator.js +81 -68
- package/dist/NanoGPTModel.js +8 -8
- package/dist/{RealDiv-CVYNbZxu.js → RealDiv-Dy0p8Bvo.js} +7 -7
- package/dist/{Reshape-CEsEp0AI.js → Reshape-DH5srBP0.js} +2 -2
- package/dist/{Reshape-Do18N3gO.js → Reshape-DvudQDvJ.js} +1 -1
- package/dist/TeachableLLM.js +33 -32
- package/dist/{TiedEmbedding-ccLBFiZi.js → TiedEmbedding-BxOerUmB.js} +4 -4
- package/dist/Trainer.d.ts +6 -1
- package/dist/Trainer.js +53 -19
- package/dist/{axis_util-5DTW2tFV.js → axis_util-BzbKo31C.js} +1 -1
- package/dist/backend.js +2 -2
- package/dist/{backend_util-C9Ut8n0Q.js → backend_util-TE7aTPhZ.js} +4 -4
- package/dist/{broadcast_to-Ba9h_8DO.js → broadcast_to-CdbwV-Dj.js} +2 -2
- package/dist/{concat-CbXTetof.js → concat-CsxrgovM.js} +1 -1
- package/dist/{dataset-U3PrjwgU.js → dataset-CtdBYwjo.js} +3 -3
- package/dist/{dropout-DPfPgWWe.js → dropout-DYs5QFGQ.js} +1 -1
- package/dist/{gather-Bbh8DHhM.js → gather-CMMy2KEG.js} +1 -1
- package/dist/{gelu-BFwVnd1r.js → gelu-C-dPj6Ku.js} +1 -1
- package/dist/{gpgpu_math-DffelNS-.js → gpgpu_math-DGNLNL4I.js} +2 -2
- package/dist/{index-UdZhlibC.js → index-BoWRt-10.js} +4 -4
- package/dist/{index-DYD_yPa-.js → index-CLthM0TO.js} +10 -10
- package/dist/{kernel_funcs_utils-CXDy3EN7.js → kernel_funcs_utils-BYKWV8Aa.js} +3 -3
- package/dist/layers/BaseLayer.js +2 -2
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +5 -5
- package/dist/{log_sum_exp-BnmCkHWl.js → log_sum_exp-DbjkV734.js} +5 -5
- package/dist/main.js +5 -5
- package/dist/{mat_mul-dwmZz69e.js → mat_mul-8m8pfdcx.js} +1 -1
- package/dist/{max-ByjEGoFx.js → max-Ddnnb5xe.js} +1 -1
- package/dist/{mulmat_packed_gpu-IGPBp6h9.js → mulmat_packed_gpu-VSekgsNv.js} +1 -1
- package/dist/{ones-C8Mfln6-.js → ones-Dj0SDhHf.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +1 -1
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +3 -3
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +5 -5
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +7 -5
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +4 -4
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +15 -13
- package/dist/ops/webgpu/adamMoments.js +18 -11
- package/dist/ops/webgpu/appendCache.js +18 -15
- package/dist/ops/webgpu/attentionMask.js +24 -18
- package/dist/ops/webgpu/gatherSub.js +17 -30
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/normRMS.js +16 -8
- package/dist/ops/webgpu/normRMSGrad.js +25 -20
- package/dist/ops/webgpu/qkv.js +23 -19
- package/dist/ops/webgpu/rope.js +37 -24
- package/dist/ops/webgpu/scatterSub.js +16 -14
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/{ops-aRTXR2Sr.js → ops-BFGCx8Ri.js} +15 -15
- package/dist/{random_width-DbSpgl4o.js → random_width-sZORGo5k.js} +22 -22
- package/dist/{range-D9CZhVlR.js → range-CRuAh-gd.js} +1 -1
- package/dist/{reciprocal-CGB48wZB.js → reciprocal-BvGAyKyu.js} +1 -1
- package/dist/{register_all_kernels-DnbAyBXt.js → register_all_kernels-BwDSRN-f.js} +30 -30
- package/dist/{reshape-BR0eoLYN.js → reshape-CdBq1WJ6.js} +1 -1
- package/dist/{scatter_nd_util-OjyAxku2.js → scatter_nd_util-DUstGbU1.js} +1 -1
- package/dist/{selu_util-Ce6pu9IM.js → selu_util-BJEXVvjX.js} +3 -3
- package/dist/{shared-Czipaeb6.js → shared-B8ztnyEk.js} +6 -6
- package/dist/{shared-DS5waSIY.js → shared-wS99K7_n.js} +1 -1
- package/dist/{sin-CiBxrDqX.js → sin-BeA3tsEd.js} +1 -1
- package/dist/{slice-BHbDHObE.js → slice-BiOsknYS.js} +1 -1
- package/dist/{softmax-JMEIUo2J.js → softmax-Bv_6lyMX.js} +1 -1
- package/dist/{split-CRU0PjVV.js → split-B-dikLRw.js} +1 -1
- package/dist/{stack-ikk2Y8_P.js → stack-B17UN2nn.js} +1 -1
- package/dist/{sum-NLYbiDag.js → sum-66ew2byf.js} +1 -1
- package/dist/{tensor-Do9PKbIE.js → tensor-JwS7ZYY6.js} +1 -1
- package/dist/{tensor2d-CWHxHpLh.js → tensor2d-wxPAnDQy.js} +1 -1
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +35 -32
- package/dist/training/FullTrainer.d.ts +15 -2
- package/dist/training/FullTrainer.js +97 -51
- package/dist/training/Trainer.d.ts +10 -0
- package/dist/training/Trainer.js +2 -2
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BTBkayv_.js → variable-BuddVFLa.js} +1 -1
- package/dist/{webgpu_program-WaoMq-WD.js → webgpu_program-PFzf1hAQ.js} +1 -1
- package/dist/{webgpu_util-DhSeP4b6.js → webgpu_util-D____QpY.js} +1 -1
- package/dist/{zeros-DnPT2nD4.js → zeros--BdLQ3oG.js} +1 -1
- package/package.json +1 -1
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { t as
|
|
2
|
-
import { d as
|
|
1
|
+
import { t as g } from "../index-BoWRt-10.js";
|
|
2
|
+
import { d as u, i as d } from "../dataset-CtdBYwjo.js";
|
|
3
3
|
import "../index-Tf7vU29b.js";
|
|
4
4
|
/**
|
|
5
5
|
* @license
|
|
@@ -18,57 +18,60 @@ import "../index-Tf7vU29b.js";
|
|
|
18
18
|
*
|
|
19
19
|
* =============================================================================
|
|
20
20
|
*/
|
|
21
|
-
function
|
|
22
|
-
return
|
|
23
|
-
const t = await
|
|
24
|
-
return
|
|
21
|
+
function z(r) {
|
|
22
|
+
return u(async () => {
|
|
23
|
+
const t = await r();
|
|
24
|
+
return d(() => t.next());
|
|
25
25
|
});
|
|
26
26
|
}
|
|
27
|
-
const
|
|
28
|
-
async function y(
|
|
29
|
-
const s = await Promise.all(
|
|
30
|
-
|
|
27
|
+
const S = 8;
|
|
28
|
+
async function y(r, t) {
|
|
29
|
+
const s = await Promise.all(r.map((e) => t.encode(e))), o = t.eosToken >= 0, a = s.map((e) => o ? [...e, t.eosToken] : e).flat();
|
|
30
|
+
for (const e of a)
|
|
31
|
+
if (e < 0 || e >= t.vocabSize)
|
|
32
|
+
throw new Error(`Invalid token index ${e} found in tokenised data`);
|
|
33
|
+
return a;
|
|
31
34
|
}
|
|
32
35
|
class w {
|
|
33
36
|
tokenizer;
|
|
34
37
|
blockSize;
|
|
35
38
|
pageSize;
|
|
36
39
|
constructor(t, s = 128) {
|
|
37
|
-
this.tokenizer = t, this.blockSize = s, this.pageSize = s *
|
|
40
|
+
this.tokenizer = t, this.blockSize = s, this.pageSize = s * S;
|
|
38
41
|
}
|
|
39
42
|
// Create dataset from text files
|
|
40
|
-
async createTextDataset(t, s = 32,
|
|
43
|
+
async createTextDataset(t, s = 32, o, a) {
|
|
41
44
|
if (t.length < this.blockSize + 1)
|
|
42
45
|
throw new Error(`Not enough tokens (${t.length}) for block size ${this.blockSize}`);
|
|
43
|
-
if (
|
|
46
|
+
if (o && o.size > t.length / this.pageSize / 2)
|
|
44
47
|
throw new Error("Too many masked pages - would leave insufficient training data");
|
|
45
|
-
const
|
|
46
|
-
if (
|
|
47
|
-
const
|
|
48
|
+
const e = (function* () {
|
|
49
|
+
if (o && a) {
|
|
50
|
+
const i = Array.from(o);
|
|
48
51
|
for (; ; ) {
|
|
49
|
-
const
|
|
50
|
-
if (
|
|
52
|
+
const c = Math.floor(Math.random() * i.length), l = Math.floor(Math.random() * this.pageSize), n = i[c] * this.pageSize + l;
|
|
53
|
+
if (n + this.blockSize + 1 > t.length)
|
|
51
54
|
continue;
|
|
52
|
-
const h = t.slice(
|
|
53
|
-
yield { xs: h, ys:
|
|
55
|
+
const h = t.slice(n, n + this.blockSize), f = t.slice(n + 1, n + this.blockSize + 1);
|
|
56
|
+
yield { xs: h, ys: f };
|
|
54
57
|
}
|
|
55
58
|
} else
|
|
56
59
|
for (; ; ) {
|
|
57
|
-
const
|
|
58
|
-
if (
|
|
59
|
-
const
|
|
60
|
-
if (h && !
|
|
60
|
+
const i = Math.floor(Math.random() * (t.length - this.blockSize - 1));
|
|
61
|
+
if (o) {
|
|
62
|
+
const n = Math.floor(i / this.pageSize), h = o.has(n);
|
|
63
|
+
if (h && !a || !h && a)
|
|
61
64
|
continue;
|
|
62
65
|
}
|
|
63
|
-
const
|
|
64
|
-
yield { xs:
|
|
66
|
+
const c = t.slice(i, i + this.blockSize), l = t.slice(i + 1, i + this.blockSize + 1);
|
|
67
|
+
yield { xs: c, ys: l };
|
|
65
68
|
}
|
|
66
69
|
}).bind(this);
|
|
67
|
-
return
|
|
68
|
-
const
|
|
69
|
-
return
|
|
70
|
-
xs:
|
|
71
|
-
ys:
|
|
70
|
+
return z(e).batch(s).map((i) => {
|
|
71
|
+
const c = i;
|
|
72
|
+
return g(() => ({
|
|
73
|
+
xs: c.xs.cast("int32"),
|
|
74
|
+
ys: c.ys.cast("int32")
|
|
72
75
|
// this.tf.oneHot(batchData.ys.cast('int32'), this.tokenizer.vocabSize),
|
|
73
76
|
}));
|
|
74
77
|
}).prefetch(2);
|
|
@@ -76,6 +79,6 @@ class w {
|
|
|
76
79
|
}
|
|
77
80
|
export {
|
|
78
81
|
w as DatasetBuilder,
|
|
79
|
-
|
|
82
|
+
S as PAGE_FACTOR,
|
|
80
83
|
y as flattenTokens
|
|
81
84
|
};
|
|
@@ -1,10 +1,23 @@
|
|
|
1
1
|
import { ITokeniser } from '../tokeniser/type';
|
|
2
|
-
import { default as NanoGPT } from '../NanoGPTModel';
|
|
3
|
-
import { default as GPTTrainer, TrainingOptions } from './Trainer';
|
|
2
|
+
import { default as NanoGPT, TrainingLogEntry } from '../NanoGPTModel';
|
|
3
|
+
import { default as GPTTrainer, TrainingOptions, TrainingProgress } from './Trainer';
|
|
4
4
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
5
5
|
import { Dataset } from '@tensorflow/tfjs-data';
|
|
6
6
|
export default class FullTrainer extends GPTTrainer {
|
|
7
7
|
constructor(model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
|
|
8
|
+
private createEmptyState;
|
|
9
|
+
private createLogEntry;
|
|
10
|
+
private createProgress;
|
|
11
|
+
stepDataset(dataset: Dataset<{
|
|
12
|
+
xs: Tensor;
|
|
13
|
+
ys: Tensor;
|
|
14
|
+
}>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
|
|
15
|
+
xs: Tensor;
|
|
16
|
+
ys: Tensor;
|
|
17
|
+
}>): Promise<{
|
|
18
|
+
log: TrainingLogEntry;
|
|
19
|
+
progress: TrainingProgress;
|
|
20
|
+
}>;
|
|
8
21
|
trainOnDataset(dataset: Dataset<{
|
|
9
22
|
xs: Tensor;
|
|
10
23
|
ys: Tensor;
|
|
@@ -1,81 +1,127 @@
|
|
|
1
|
-
import { generateText as
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
import { d as
|
|
5
|
-
import
|
|
6
|
-
const
|
|
1
|
+
import { generateText as v } from "../utilities/generate.js";
|
|
2
|
+
import x from "./Trainer.js";
|
|
3
|
+
import S from "./Evaluator.js";
|
|
4
|
+
import { d as w } from "../index-BoWRt-10.js";
|
|
5
|
+
import y from "../utilities/profile.js";
|
|
6
|
+
const T = {
|
|
7
7
|
desiredLoss: 0.01,
|
|
8
8
|
logInterval: 1,
|
|
9
9
|
maxSteps: 1e3
|
|
10
10
|
};
|
|
11
|
-
class
|
|
12
|
-
constructor(
|
|
13
|
-
super(
|
|
11
|
+
class z extends x {
|
|
12
|
+
constructor(r, t, s = 3e-4) {
|
|
13
|
+
super(r, t, s);
|
|
14
14
|
}
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
const { logInterval: g, onStep: l, prompt: c, maxSteps: u } = {
|
|
18
|
-
...y,
|
|
19
|
-
...e
|
|
20
|
-
}, n = Date.now(), t = {
|
|
15
|
+
createEmptyState() {
|
|
16
|
+
return {
|
|
21
17
|
step: 0,
|
|
22
18
|
lastLoss: 1e6,
|
|
23
19
|
totalSteps: 0,
|
|
24
20
|
losses: [],
|
|
25
21
|
validationLosses: [],
|
|
26
|
-
logStartTime:
|
|
22
|
+
logStartTime: 0,
|
|
27
23
|
trainingDuration: 0,
|
|
28
24
|
...this.lastState || {}
|
|
29
25
|
};
|
|
30
|
-
|
|
31
|
-
|
|
26
|
+
}
|
|
27
|
+
createLogEntry(r, t, s, h) {
|
|
28
|
+
return {
|
|
29
|
+
loss: r.lastLoss,
|
|
30
|
+
step: r.step,
|
|
31
|
+
time: Date.now() - t,
|
|
32
|
+
batchSize: s,
|
|
33
|
+
learningRate: h ? this.optimizer.lr : void 0
|
|
34
|
+
};
|
|
35
|
+
}
|
|
36
|
+
createProgress(r, t, s) {
|
|
37
|
+
return {
|
|
38
|
+
duration: r.trainingDuration,
|
|
39
|
+
totalSamples: r.totalSteps * t.batchSize,
|
|
40
|
+
samplesPerSecond: r.totalSteps * t.batchSize / (r.trainingDuration / 1e3),
|
|
41
|
+
memory: s ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
42
|
+
};
|
|
43
|
+
}
|
|
44
|
+
async stepDataset(r, t, s) {
|
|
45
|
+
const { logInterval: h, prompt: m } = {
|
|
46
|
+
...T,
|
|
47
|
+
...t
|
|
48
|
+
}, g = Date.now(), a = this.createEmptyState();
|
|
49
|
+
this.lastState = a, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, a.logStartTime = g;
|
|
50
|
+
const p = s ? new S(this.model, s) : void 0, e = await r.iterator();
|
|
51
|
+
try {
|
|
52
|
+
for (; this.running; ) {
|
|
53
|
+
const i = await e.next();
|
|
54
|
+
if (i.done) break;
|
|
55
|
+
const u = i.value, o = this.trainBatch(a, u), n = this.createLogEntry(a, g, u.xs.shape[0], t?.advancedMetrics);
|
|
56
|
+
if (this.model.log.push(n), a.step % h === 0) {
|
|
57
|
+
await o.data();
|
|
58
|
+
const f = Date.now();
|
|
59
|
+
if (a.trainingDuration += f - a.logStartTime, p)
|
|
60
|
+
try {
|
|
61
|
+
const l = await p.evaluate(5);
|
|
62
|
+
a.validationLosses.push(l), n.valLoss = l;
|
|
63
|
+
} catch (l) {
|
|
64
|
+
console.error("Validation error:", l);
|
|
65
|
+
}
|
|
66
|
+
if (m) {
|
|
67
|
+
const l = await v(this.tokenizer, this.model, m, 100, {
|
|
68
|
+
temperature: 0.8
|
|
69
|
+
});
|
|
70
|
+
n.example = l;
|
|
71
|
+
}
|
|
72
|
+
const c = this.createProgress(a, n, t?.advancedMetrics);
|
|
73
|
+
return o.dispose(), this.stop(), { log: n, progress: c };
|
|
74
|
+
}
|
|
75
|
+
o.dispose();
|
|
76
|
+
}
|
|
77
|
+
} catch (i) {
|
|
78
|
+
throw console.error("Training error:", i), w(), i;
|
|
79
|
+
}
|
|
80
|
+
throw w(), this.running = !1, new Error("No log returned before training stopped.");
|
|
81
|
+
}
|
|
82
|
+
// Train for multiple epochs using Dataset API - FIXED memory leaks
|
|
83
|
+
async trainOnDataset(r, t, s) {
|
|
84
|
+
const { logInterval: h, onStep: m, prompt: g, maxSteps: a } = {
|
|
85
|
+
...T,
|
|
86
|
+
...t
|
|
87
|
+
}, p = Date.now(), e = this.createEmptyState();
|
|
88
|
+
this.lastState = e, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, e.logStartTime = p;
|
|
89
|
+
const i = s ? new S(this.model, s) : void 0, u = await r.iterator();
|
|
32
90
|
try {
|
|
33
91
|
for (; this.running; ) {
|
|
34
|
-
const o = await
|
|
92
|
+
const o = await u.next();
|
|
35
93
|
if (o.done) break;
|
|
36
|
-
const
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
learningRate: e?.advancedMetrics ? this.optimizer.lr : void 0
|
|
42
|
-
//gradientNorm: options?.advancedMetrics ? await state.gradientNorm : undefined,
|
|
43
|
-
};
|
|
44
|
-
if (this.model.log.push(s), t.step % g === 0) {
|
|
45
|
-
await p.data();
|
|
46
|
-
const S = Date.now();
|
|
47
|
-
if (t.trainingDuration += S - t.logStartTime, m)
|
|
94
|
+
const n = o.value, f = this.trainBatch(e, n), c = this.createLogEntry(e, p, n.xs.shape[0], t?.advancedMetrics);
|
|
95
|
+
if (this.model.log.push(c), e.step % h === 0) {
|
|
96
|
+
await f.data();
|
|
97
|
+
const l = Date.now();
|
|
98
|
+
if (e.trainingDuration += l - e.logStartTime, i)
|
|
48
99
|
try {
|
|
49
|
-
const
|
|
50
|
-
|
|
51
|
-
} catch (
|
|
52
|
-
console.error("Validation error:",
|
|
100
|
+
const d = await i.evaluate(5);
|
|
101
|
+
e.validationLosses.push(d), c.valLoss = d;
|
|
102
|
+
} catch (d) {
|
|
103
|
+
console.error("Validation error:", d);
|
|
53
104
|
}
|
|
54
|
-
if (
|
|
55
|
-
if (
|
|
56
|
-
const
|
|
105
|
+
if (m) {
|
|
106
|
+
if (g) {
|
|
107
|
+
const L = await v(this.tokenizer, this.model, g, 100, {
|
|
57
108
|
temperature: 0.8
|
|
58
109
|
});
|
|
59
|
-
|
|
110
|
+
c.example = L;
|
|
60
111
|
}
|
|
61
|
-
const
|
|
62
|
-
|
|
63
|
-
totalSamples: t.totalSteps * s.batchSize,
|
|
64
|
-
samplesPerSecond: t.totalSteps * s.batchSize / (t.trainingDuration / 1e3),
|
|
65
|
-
memory: e.advancedMetrics ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
66
|
-
};
|
|
67
|
-
await l(s, a);
|
|
112
|
+
const d = this.createProgress(e, c, t?.advancedMetrics);
|
|
113
|
+
await m(c, d);
|
|
68
114
|
}
|
|
69
|
-
|
|
115
|
+
e.logStartTime = Date.now();
|
|
70
116
|
}
|
|
71
|
-
|
|
117
|
+
f.dispose(), e.step >= a && this.stop();
|
|
72
118
|
}
|
|
73
119
|
} catch (o) {
|
|
74
|
-
throw console.error("Training error:", o),
|
|
120
|
+
throw console.error("Training error:", o), w(), o;
|
|
75
121
|
}
|
|
76
|
-
return
|
|
122
|
+
return w(), this.running = !1, { losses: e.losses, validationLosses: e.validationLosses };
|
|
77
123
|
}
|
|
78
124
|
}
|
|
79
125
|
export {
|
|
80
|
-
|
|
126
|
+
z as default
|
|
81
127
|
};
|
|
@@ -66,6 +66,16 @@ export default abstract class GPTTrainer {
|
|
|
66
66
|
losses: number[];
|
|
67
67
|
validationLosses: number[];
|
|
68
68
|
}>;
|
|
69
|
+
abstract stepDataset(dataset: Dataset<{
|
|
70
|
+
xs: Tensor;
|
|
71
|
+
ys: Tensor;
|
|
72
|
+
}>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
|
|
73
|
+
xs: Tensor;
|
|
74
|
+
ys: Tensor;
|
|
75
|
+
}>): Promise<{
|
|
76
|
+
log: TrainingLogEntry;
|
|
77
|
+
progress: TrainingProgress;
|
|
78
|
+
}>;
|
|
69
79
|
createTrainValidationSplit(textData: string[], batchSize?: number, validationSplit?: number): Promise<{
|
|
70
80
|
trainDataset: Dataset<{
|
|
71
81
|
xs: Tensor;
|
package/dist/training/Trainer.js
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { DatasetBuilder as h, flattenTokens as p, PAGE_FACTOR as g } from "./DatasetBuilder.js";
|
|
2
2
|
import u from "./AdamExt.js";
|
|
3
|
-
import { t as f, v as y, d as c } from "../index-
|
|
4
|
-
import { z as m } from "../zeros
|
|
3
|
+
import { t as f, v as y, d as c } from "../index-BoWRt-10.js";
|
|
4
|
+
import { z as m } from "../zeros--BdLQ3oG.js";
|
|
5
5
|
class x {
|
|
6
6
|
constructor(t, e, a = 1e-3) {
|
|
7
7
|
this.tokenizer = e, this.model = t, this.learningRate = a, this.resetOptimizer(), this.datasetBuilder = new h(e, t.config.gpt.blockSize);
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import { gatherSub as x } from "../ops/gatherSub.js";
|
|
2
2
|
import { scatterSub as L } from "../ops/scatterSub.js";
|
|
3
|
-
import { y, t as u, z as C, c as E } from "../index-
|
|
4
|
-
import { s as G } from "../softmax-
|
|
5
|
-
import { m as z } from "../max-
|
|
6
|
-
import { l as v } from "../log_sum_exp-
|
|
3
|
+
import { y, t as u, z as C, c as E } from "../index-BoWRt-10.js";
|
|
4
|
+
import { s as G } from "../softmax-Bv_6lyMX.js";
|
|
5
|
+
import { m as z } from "../max-Ddnnb5xe.js";
|
|
6
|
+
import { l as v } from "../log_sum_exp-DbjkV734.js";
|
|
7
7
|
function k(t, s) {
|
|
8
8
|
return u(() => {
|
|
9
9
|
const n = t.shape[t.shape.length - 1], c = t.shape.slice(0, -1).reduce((o, e) => o * e, 1), h = t.shape.length > 2 ? t.reshape([c, n]) : t, p = s.shape.length > 1 ? s.reshape([c]).cast("int32") : s.cast("int32"), r = z(h, -1, !0), a = E(h, r), m = v(a, -1);
|
package/dist/utilities/dummy.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { m as y, v as P, e as S } from "../index-
|
|
2
|
-
import { z as i } from "../zeros
|
|
1
|
+
import { m as y, v as P, e as S } from "../index-BoWRt-10.js";
|
|
2
|
+
import { z as i } from "../zeros--BdLQ3oG.js";
|
|
3
3
|
async function w(s) {
|
|
4
4
|
const t = i([1, s.config.gpt.blockSize], "int32"), [e, n] = s.forward({ training: !1 }, t);
|
|
5
5
|
await e.data(), e.dispose(), n && n.dispose(), t.dispose();
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import "../index-
|
|
2
|
-
import { t as m } from "../tensor2d-
|
|
3
|
-
import { c as u } from "../concat-
|
|
1
|
+
import "../index-BoWRt-10.js";
|
|
2
|
+
import { t as m } from "../tensor2d-wxPAnDQy.js";
|
|
3
|
+
import { c as u } from "../concat-CsxrgovM.js";
|
|
4
4
|
async function v(o, r, a, c, f) {
|
|
5
5
|
if (c <= 0)
|
|
6
6
|
throw new Error("Length must be a positive integer");
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import "../index-
|
|
2
|
-
import { t as p } from "../tensor-
|
|
1
|
+
import "../index-BoWRt-10.js";
|
|
2
|
+
import { t as p } from "../tensor-JwS7ZYY6.js";
|
|
3
3
|
function h(n) {
|
|
4
4
|
const e = n.reduce((s, o) => s + o.length, 0), a = new Float32Array(e);
|
|
5
5
|
let t = 0;
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { B as m, C as r,
|
|
1
|
+
import { B as m, C as r, a2 as l, E as c, a6 as i, F as p, a7 as u, j as f } from "./index-BoWRt-10.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2020 Google LLC. All Rights Reserved.
|