@genai-fi/nanogpt 0.17.4 → 0.18.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +2 -15
- package/dist/Generator.js +45 -34
- package/dist/{RealDiv-CGwv0liw.js → RealDiv-ioj6Z-ox.js} +9 -9
- package/dist/{Reshape-BW__R4mZ.js → Reshape-BZC-ebeR.js} +7 -7
- package/dist/{Reshape-CPBkTIH2.js → Reshape-pwprEaej.js} +1 -1
- package/dist/TeachableLLM.d.ts +3 -8
- package/dist/TeachableLLM.js +61 -44
- package/dist/Trainer.d.ts +6 -4
- package/dist/Trainer.js +107 -92
- package/dist/{axis_util-GTVlo58H.js → axis_util-QWWgLjut.js} +1 -1
- package/dist/backend.js +2 -2
- package/dist/{backend_util-GaFarB78.js → backend_util-qwSFfxYx.js} +21 -21
- package/dist/{backend_webgpu-BqASlsbV.js → backend_webgpu-DI2wXEC2.js} +8 -8
- package/dist/{broadcast_to-eS93CCN_.js → broadcast_to-C_EJTVTZ.js} +2 -2
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +5 -5
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/matMulGelu.js +2 -2
- package/dist/checks/normRMS.js +6 -6
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.js +6 -6
- package/dist/checks/qkv.js +2 -2
- package/dist/checks/rope.js +2 -2
- package/dist/{clip_by_value-DDA7rrcT.js → clip_by_value-CLAD4h_I.js} +1 -1
- package/dist/complex-3DpPEG9B.js +11 -0
- package/dist/{concat-CAQpCret.js → concat-Dqk7Xk7h.js} +5 -5
- package/dist/{concat_util-D18dJ4fD.js → concat_util-C1Mxe27t.js} +1 -1
- package/dist/{dataset-CGGp1z9P.js → dataset-DlqAN81i.js} +3 -3
- package/dist/{dropout_util--NxWuYg2.js → dropout_util-N0z8Os-K.js} +1 -1
- package/dist/{expand_dims-Bkd1YD5x.js → expand_dims-D0rBtgT1.js} +4 -4
- package/dist/{exports_initializers-CYzKLjN7.js → exports_initializers-DIOZQt_L.js} +1 -1
- package/dist/{floor-BQtb-Azg.js → floor-CymuCmTO.js} +1 -1
- package/dist/{gather-qIqEqaGn.js → gather-DEyjXNb1.js} +1 -1
- package/dist/{gelu-B220X1Go.js → gelu-DpTCC3eB.js} +1 -1
- package/dist/{gpgpu_math-BwvV12df.js → gpgpu_math-3bCb5ooU.js} +25 -25
- package/dist/{index-CjOWnMXP.js → index-BQvB7LCC.js} +15 -15
- package/dist/{index-CUXkjxiT.js → index-DSGwv2Yx.js} +33 -33
- package/dist/inference/types.d.ts +16 -0
- package/dist/inference/types.js +1 -0
- package/dist/{kernel_funcs_utils-pq0CK9co.js → kernel_funcs_utils-DGqzNlHT.js} +6 -6
- package/dist/layers/BaseLayer.js +4 -4
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/LoRA.js +4 -4
- package/dist/layers/MLP.js +4 -4
- package/dist/layers/PositionEmbedding.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +6 -6
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/layers/WeightStore.js +2 -2
- package/dist/loader/load.d.ts +2 -8
- package/dist/loader/loadTransformers.d.ts +2 -8
- package/dist/loader/loadTransformers.js +13 -11
- package/dist/loader/newZipLoad.d.ts +2 -8
- package/dist/loader/newZipLoad.js +25 -10
- package/dist/loader/oldZipLoad.js +13 -13
- package/dist/loader/save.d.ts +9 -2
- package/dist/loader/save.js +64 -55
- package/dist/loader/types.d.ts +29 -1
- package/dist/main.d.ts +2 -0
- package/dist/main.js +45 -43
- package/dist/{matMul16-BcVC_E62.js → matMul16-BIT70Vya.js} +3 -3
- package/dist/{matMulGelu-JNLZqKQp.js → matMulGelu-CsZnh18H.js} +18 -18
- package/dist/mat_mul-DP86qZtZ.js +11 -0
- package/dist/mod-BXjLYwvM.js +11 -0
- package/dist/models/NanoGPTV1.js +2 -2
- package/dist/models/NanoGPTV2.js +2 -2
- package/dist/models/model.d.ts +3 -2
- package/dist/models/model.js +13 -13
- package/dist/{not_equal-hurPF26l.js → not_equal-CkQKkKZy.js} +15 -15
- package/dist/{ones-BytntneX.js → ones-DbVB5N58.js} +3 -3
- package/dist/ops/adamAdjust.js +3 -3
- package/dist/ops/adamMoments.js +3 -3
- package/dist/ops/add16.js +1 -1
- package/dist/ops/appendCache.js +6 -6
- package/dist/ops/attentionMask.js +3 -3
- package/dist/ops/concat16.js +3 -3
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +5 -5
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +6 -6
- package/dist/ops/cpu/fusedSoftmax.js +4 -4
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +4 -4
- package/dist/ops/cpu/matMul16.js +2 -2
- package/dist/ops/cpu/matMulGelu.js +7 -7
- package/dist/ops/cpu/matMulMul.js +2 -2
- package/dist/ops/cpu/mulDropout.js +5 -5
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +5 -5
- package/dist/ops/dot16.js +2 -2
- package/dist/ops/dropout.js +6 -6
- package/dist/ops/dropout16.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/globalNorm.js +7 -7
- package/dist/ops/grads/add16.js +1 -1
- package/dist/ops/grads/attentionMask.js +2 -2
- package/dist/ops/grads/dropout16.js +1 -1
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMul16.js +3 -3
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/mul16.js +1 -1
- package/dist/ops/grads/normRMS.js +7 -7
- package/dist/ops/grads/pack16.js +3 -3
- package/dist/ops/grads/qkv.js +11 -11
- package/dist/ops/grads/rope.js +2 -2
- package/dist/ops/grads/softmax16.js +1 -1
- package/dist/ops/grads/unpack16.js +2 -2
- package/dist/ops/matMul16.js +3 -3
- package/dist/ops/matMulGelu.js +6 -6
- package/dist/ops/matMulMul.js +3 -3
- package/dist/ops/mul16.js +1 -1
- package/dist/ops/mulDrop.js +3 -3
- package/dist/ops/normRMS.js +4 -4
- package/dist/ops/pack16.js +2 -2
- package/dist/ops/qkv.js +3 -3
- package/dist/ops/reshape16.js +6 -6
- package/dist/ops/rope.js +2 -2
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.js +2 -2
- package/dist/ops/softmax16.js +1 -1
- package/dist/ops/sub16.js +1 -1
- package/dist/ops/sum16.js +6 -6
- package/dist/ops/transpose16.js +3 -3
- package/dist/ops/unpack16.js +2 -2
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/dropout16.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +7 -7
- package/dist/ops/webgl/gatherSub.js +3 -3
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMul16.js +13 -13
- package/dist/ops/webgl/matMulGelu.js +4 -4
- package/dist/ops/webgl/matMulMul.js +2 -2
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +2 -2
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/add16.js +6 -6
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +2 -2
- package/dist/ops/webgpu/attentionMask32_program.js +2 -2
- package/dist/ops/webgpu/clipScale.js +7 -7
- package/dist/ops/webgpu/concat16.js +5 -5
- package/dist/ops/webgpu/dropout16.js +6 -6
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +8 -8
- package/dist/ops/webgpu/matMul16.js +16 -16
- package/dist/ops/webgpu/matMul16_program.js +2 -2
- package/dist/ops/webgpu/mul16.js +5 -5
- package/dist/ops/webgpu/norm2.js +1 -1
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +4 -4
- package/dist/ops/webgpu/pack16.js +4 -4
- package/dist/ops/webgpu/pack16_program.js +2 -2
- package/dist/ops/webgpu/qkv.js +2 -2
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/slice16.js +4 -4
- package/dist/ops/webgpu/softmax16.js +4 -4
- package/dist/ops/webgpu/softmax16_program.js +2 -2
- package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
- package/dist/ops/webgpu/softmax16grad.js +4 -4
- package/dist/ops/webgpu/sub16.js +6 -6
- package/dist/ops/webgpu/sum16.js +3 -3
- package/dist/ops/webgpu/transpose16.js +8 -8
- package/dist/ops/webgpu/transpose16_program.js +2 -2
- package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
- package/dist/ops/webgpu/unpack16.js +3 -3
- package/dist/ops/webgpu/utils/binary_op.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +5 -5
- package/dist/{ops-CsXeTq1P.js → ops-CURIZSVt.js} +100 -100
- package/dist/{pack16-bqltoUlR.js → pack16-WlOSOuZA.js} +2 -2
- package/dist/patches/webgpu_backend.js +6 -6
- package/dist/patches/webgpu_base.js +1 -1
- package/dist/patches/webgpu_program.js +2 -2
- package/dist/{random_normal-IBRrha8a.js → random_normal-CIm8lk2-.js} +1 -1
- package/dist/{random_width-DN5ZtQkM.js → random_width-B_fVXhGx.js} +131 -131
- package/dist/{range-C-CjF-LI.js → range-BDxO73mk.js} +1 -1
- package/dist/{readers-iz5u3HBo.js → readers-17HLdxVM.js} +2 -2
- package/dist/relu-DTvZKBsZ.js +9 -0
- package/dist/{reshape-BDOuCSNW.js → reshape-BIN71H3p.js} +1 -1
- package/dist/{resize_nearest_neighbor-BojqlfRe.js → resize_nearest_neighbor-C6_0dAnK.js} +41 -41
- package/dist/{rope-0j_f1TPm.js → rope-CC5RjmKU.js} +4 -4
- package/dist/{scatter_nd_util-ByNJaL6I.js → scatter_nd_util-C-x73Cj6.js} +1 -1
- package/dist/{segment_util-Dasb2Zaf.js → segment_util-4zuHV5IG.js} +2 -2
- package/dist/{selu_util-BLhIqRkw.js → selu_util-BXdhy_W6.js} +5 -5
- package/dist/{shared-CagdqkLh.js → shared-DRWDyk9w.js} +6 -6
- package/dist/{shared-3agzAqQ_.js → shared-zTaJ5siv.js} +1 -1
- package/dist/slice-BvItlgXu.js +12 -0
- package/dist/{slice_util-CC35pLmT.js → slice_util-DPY56GzQ.js} +5 -5
- package/dist/{softmax-D4q1LJN7.js → softmax-BLGJqdwx.js} +1 -1
- package/dist/split-BN9LkEgS.js +9 -0
- package/dist/{squeeze-ho4wLUek.js → squeeze-O_YWJpw_.js} +2 -2
- package/dist/{stack-DudVrtmG.js → stack-z6QE7kmP.js} +1 -1
- package/dist/{step-BTxPtq1r.js → step-DQY6_ABw.js} +4 -4
- package/dist/{sum-BpiwSWvg.js → sum-D39FeU5h.js} +3 -3
- package/dist/{tensor-BWFldCso.js → tensor-D8e0Gd7c.js} +1 -1
- package/dist/{tensor1d-LMGMIUlr.js → tensor1d-BMl0eZYV.js} +1 -1
- package/dist/{tensor2d-BnXMKScO.js → tensor2d-DTtQ1QcT.js} +1 -1
- package/dist/{tensor4d-C6UCG_u8.js → tensor4d-Dj4rDssL.js} +1 -1
- package/dist/{tfjs_backend-BGnG-ppu.js → tfjs_backend-Bk3PmK91.js} +65 -65
- package/dist/{tile-CFy-xTO6.js → tile-CsWlVKKz.js} +1 -1
- package/dist/tokeniser/BaseTokeniser.d.ts +4 -1
- package/dist/tokeniser/BaseTokeniser.js +21 -5
- package/dist/tokeniser/CharTokeniser.d.ts +1 -1
- package/dist/tokeniser/CharTokeniser.js +62 -50
- package/dist/tokeniser/bpe.d.ts +1 -1
- package/dist/tokeniser/bpe.js +41 -35
- package/dist/tokeniser/type.d.ts +3 -1
- package/dist/training/AdamW.d.ts +3 -0
- package/dist/training/AdamW.js +59 -30
- package/dist/training/BasicTrainer.d.ts +1 -0
- package/dist/training/BasicTrainer.js +112 -92
- package/dist/training/DatasetBuilder.js +3 -3
- package/dist/training/Evaluator.js +2 -2
- package/dist/training/LRScheduler.d.ts +1 -0
- package/dist/training/LRScheduler.js +18 -12
- package/dist/training/PreTrainer.js +3 -3
- package/dist/training/SFTDatasetBuilder.js +3 -3
- package/dist/training/SFTTrainer.js +1 -1
- package/dist/training/orthoGrad.js +1 -1
- package/dist/training/sparseCrossEntropy.js +30 -30
- package/dist/training/types.d.ts +5 -3
- package/dist/training/validation.js +13 -13
- package/dist/{transpose-9kRxIXWR.js → transpose-Qxz-4os3.js} +7 -7
- package/dist/{unsorted_segment_sum-DJvk5xnh.js → unsorted_segment_sum-BfFVV9Zm.js} +20 -20
- package/dist/utilities/datasetID.d.ts +2 -0
- package/dist/utilities/datasetID.js +21 -0
- package/dist/utilities/dummy.js +6 -6
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.js +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.js +5 -5
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-Ck482e3n.js → variable-SSATClyt.js} +1 -1
- package/dist/{webgpu_program-B4HmApL1.js → webgpu_program-CbjdYLYk.js} +1 -1
- package/dist/{webgpu_util-DYlGSwOJ.js → webgpu_util-DuofJBMo.js} +7 -7
- package/dist/{zeros-DvZpK8s6.js → zeros-Bw0puq_w.js} +2 -2
- package/dist/{zeros_like-CWjDdwr-.js → zeros_like-rOHr54NY.js} +69 -69
- package/package.json +3 -3
- package/dist/complex-DI35Q-gW.js +0 -11
- package/dist/mat_mul-DhG0Newp.js +0 -11
- package/dist/mod-CSdCpRjf.js +0 -11
- package/dist/relu-J_X6MUzx.js +0 -9
- package/dist/slice-BzS11Qh0.js +0 -12
- package/dist/split-C2Sj255c.js +0 -9
package/dist/tokeniser/bpe.js
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
import { yieldIfNeeded as
|
|
1
|
+
import { yieldIfNeeded as p } from "../utilities/yielder.js";
|
|
2
2
|
import m from "../utilities/tokenParse.js";
|
|
3
|
-
import
|
|
4
|
-
function
|
|
3
|
+
import T, { SPECIALS as S } from "./BaseTokeniser.js";
|
|
4
|
+
function g(o, e) {
|
|
5
5
|
return `${o}-::-${e}`;
|
|
6
6
|
}
|
|
7
|
-
function
|
|
7
|
+
function y(o) {
|
|
8
8
|
const e = /* @__PURE__ */ new Map();
|
|
9
9
|
for (let s = 0; s < o.length; s++) {
|
|
10
10
|
const t = o[s];
|
|
11
11
|
for (let n = 0; n < t.length - 1; n++) {
|
|
12
|
-
const r =
|
|
12
|
+
const r = g(t[n], t[n + 1]), a = e.get(r) || {
|
|
13
13
|
a: t[n],
|
|
14
14
|
b: t[n + 1],
|
|
15
15
|
count: 0,
|
|
@@ -20,21 +20,21 @@ function w(o) {
|
|
|
20
20
|
}
|
|
21
21
|
return { pairs: e, tokens: o };
|
|
22
22
|
}
|
|
23
|
-
function
|
|
24
|
-
const r =
|
|
23
|
+
function f(o, e, s, t, n) {
|
|
24
|
+
const r = g(e, s);
|
|
25
25
|
if (o.pairs.has(r)) {
|
|
26
26
|
const a = o.pairs.get(r);
|
|
27
27
|
a.count += n, n > 0 ? a.instances.add(t) : a.count <= 0 ? o.pairs.delete(r) : a.instances.delete(t);
|
|
28
28
|
} else
|
|
29
29
|
o.pairs.set(r, { a: e, b: s, count: n, instances: /* @__PURE__ */ new Set([t]) });
|
|
30
30
|
}
|
|
31
|
-
function
|
|
31
|
+
function I(o) {
|
|
32
32
|
let e = null, s = 0;
|
|
33
33
|
for (const t of o.pairs.values())
|
|
34
34
|
t.count > s && (s = t.count, e = t);
|
|
35
35
|
return e;
|
|
36
36
|
}
|
|
37
|
-
function
|
|
37
|
+
function x(o, e) {
|
|
38
38
|
return o.map((s) => {
|
|
39
39
|
const t = [];
|
|
40
40
|
for (let n = 0; n < s.length; n++)
|
|
@@ -42,19 +42,19 @@ function y(o, e) {
|
|
|
42
42
|
return t;
|
|
43
43
|
});
|
|
44
44
|
}
|
|
45
|
-
function
|
|
45
|
+
function A(o, e) {
|
|
46
46
|
e.instances.forEach((s) => {
|
|
47
47
|
const t = o.tokens[s], n = [];
|
|
48
48
|
for (let r = 0; r < t.length; r++)
|
|
49
49
|
if (r < t.length - 1 && t[r] === e.a && t[r + 1] === e.b) {
|
|
50
50
|
const a = e.a + e.b;
|
|
51
|
-
n.push(a), r > 0 && (
|
|
51
|
+
n.push(a), r > 0 && (f(o, t[r - 1], e.a, s, -1), f(o, t[r - 1], a, s, 1)), r++, r < t.length - 1 && (f(o, e.b, t[r + 1], s, -1), f(o, a, t[r + 1], s, 1));
|
|
52
52
|
} else
|
|
53
53
|
n.push(t[r]);
|
|
54
54
|
o.tokens[s] = n;
|
|
55
|
-
}), o.pairs.delete(
|
|
55
|
+
}), o.pairs.delete(g(e.a, e.b));
|
|
56
56
|
}
|
|
57
|
-
class
|
|
57
|
+
class P extends T {
|
|
58
58
|
targetSize;
|
|
59
59
|
vocab = /* @__PURE__ */ new Set();
|
|
60
60
|
vocabIndex = /* @__PURE__ */ new Map();
|
|
@@ -63,7 +63,7 @@ class E extends z {
|
|
|
63
63
|
constructor(e, s) {
|
|
64
64
|
super(), Array.isArray(e) ? (e.forEach((t, n) => {
|
|
65
65
|
this.vocab.add(t), this.vocabIndex.set(t, n);
|
|
66
|
-
}), s && (this.merges = s), this.targetSize = e.length,
|
|
66
|
+
}), s && (this.merges = s), this.targetSize = e.length, S.forEach((t) => {
|
|
67
67
|
const n = e.indexOf(t);
|
|
68
68
|
n !== -1 && this.addSpecialToken(t, n);
|
|
69
69
|
})) : (this.addSpecialTokens(), this.targetSize = e);
|
|
@@ -81,7 +81,7 @@ class E extends z {
|
|
|
81
81
|
this.vocab.clear(), this.vocabIndex.clear(), this.merges = [], this.pretokenMap.clear();
|
|
82
82
|
}
|
|
83
83
|
get trained() {
|
|
84
|
-
return this.vocab.size >
|
|
84
|
+
return this.vocab.size > S.length && this.vocab.size <= this.targetSize;
|
|
85
85
|
}
|
|
86
86
|
get vocabSize() {
|
|
87
87
|
return this.vocab.size;
|
|
@@ -95,42 +95,48 @@ class E extends z {
|
|
|
95
95
|
get unkToken() {
|
|
96
96
|
return this.vocabIndex.get("") ?? 1;
|
|
97
97
|
}
|
|
98
|
-
async train(e = [], s) {
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
98
|
+
async train(e = [], s, t) {
|
|
99
|
+
this.datasetID = t;
|
|
100
|
+
let n = performance.now();
|
|
101
|
+
const r = new Array(e.length);
|
|
102
|
+
for (let i = 0; i < e.length; i++) {
|
|
103
|
+
const h = e[i], l = new Array(h.length);
|
|
104
|
+
for (let d = 0; d < h.length; d++)
|
|
105
|
+
l[d] = m(h[d].content);
|
|
106
|
+
n = await p(n, s, this.vocab.size), r[i] = l;
|
|
107
|
+
}
|
|
108
|
+
const a = r.flat(2), z = new Set(a);
|
|
103
109
|
this.vocab = /* @__PURE__ */ new Set(), this.pretokenMap.clear(), this.merges = [], this.addSpecialTokens();
|
|
104
|
-
const
|
|
105
|
-
if (
|
|
110
|
+
const b = Array.from(z), v = b.map((i) => Array.from(i).map((l) => (this.vocab.add(l), l))), k = y(v);
|
|
111
|
+
if (n = await p(n, s, this.vocab.size), this.vocab.size >= this.targetSize) {
|
|
106
112
|
console.warn("Initial vocab size is greater than or equal to target size. No merges will be performed.");
|
|
107
113
|
const i = /* @__PURE__ */ new Map();
|
|
108
|
-
|
|
114
|
+
a.forEach((c) => {
|
|
109
115
|
Array.from(c).forEach((u) => {
|
|
110
116
|
i.set(u, (i.get(u) || 0) + 1);
|
|
111
117
|
});
|
|
112
118
|
});
|
|
113
119
|
const h = Array.from(i.entries()).sort((c, u) => u[1] - c[1]);
|
|
114
120
|
this.vocab = /* @__PURE__ */ new Set(), this.addSpecialTokens(), h.slice(0, this.targetSize - this.vocab.size).map(([c]) => c).forEach((c) => this.vocab.add(c)), this.vocabIndex.clear();
|
|
115
|
-
let
|
|
121
|
+
let d = 0;
|
|
116
122
|
for (const c of this.vocab.keys())
|
|
117
|
-
this.vocabIndex.set(c,
|
|
118
|
-
return this.emit("trainStatus", "trained"), this.vocab.size;
|
|
123
|
+
this.vocabIndex.set(c, d++);
|
|
124
|
+
return this.generateID(), this.emit("trainStatus", "trained"), this.vocab.size;
|
|
119
125
|
}
|
|
120
126
|
for (; this.vocab.size < this.targetSize && this.merges.length < this.targetSize; ) {
|
|
121
|
-
const i =
|
|
127
|
+
const i = I(k);
|
|
122
128
|
if (!i)
|
|
123
129
|
break;
|
|
124
|
-
this.merges.push([i.a, i.b]), this.vocab.add(i.a + i.b),
|
|
130
|
+
this.merges.push([i.a, i.b]), this.vocab.add(i.a + i.b), A(k, i), n = await p(n, s, this.vocab.size);
|
|
125
131
|
}
|
|
126
|
-
|
|
127
|
-
const l =
|
|
132
|
+
b.forEach((i, h) => {
|
|
133
|
+
const l = v[h];
|
|
128
134
|
this.pretokenMap.set(i, l);
|
|
129
135
|
}), this.vocabIndex.clear();
|
|
130
|
-
let
|
|
136
|
+
let w = 0;
|
|
131
137
|
for (const i of this.vocab.keys())
|
|
132
|
-
this.vocabIndex.set(i,
|
|
133
|
-
return this.emit("trainStatus", "trained"), this.vocab.size;
|
|
138
|
+
this.vocabIndex.set(i, w++);
|
|
139
|
+
return this.generateID(), this.emit("trainStatus", "trained"), this.vocab.size;
|
|
134
140
|
}
|
|
135
141
|
getVocab() {
|
|
136
142
|
return Array.from(this.vocab);
|
|
@@ -141,7 +147,7 @@ class E extends z {
|
|
|
141
147
|
tokeniseWord(e) {
|
|
142
148
|
let s = Array.from(e);
|
|
143
149
|
return this.merges.forEach((t) => {
|
|
144
|
-
s =
|
|
150
|
+
s = x([s], t)[0];
|
|
145
151
|
}), this.pretokenMap.set(e, s), s;
|
|
146
152
|
}
|
|
147
153
|
tokeniseStrings(e) {
|
|
@@ -163,5 +169,5 @@ class E extends z {
|
|
|
163
169
|
}
|
|
164
170
|
}
|
|
165
171
|
export {
|
|
166
|
-
|
|
172
|
+
P as default
|
|
167
173
|
};
|
package/dist/tokeniser/type.d.ts
CHANGED
|
@@ -5,7 +5,9 @@ export interface Conversation {
|
|
|
5
5
|
content: string;
|
|
6
6
|
}
|
|
7
7
|
export interface ITokeniser extends EE<'trainStatus'> {
|
|
8
|
-
|
|
8
|
+
id: string;
|
|
9
|
+
datasetID?: string;
|
|
10
|
+
train(text: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
|
|
9
11
|
getVocab(): string[];
|
|
10
12
|
getMerges(): [string, string][];
|
|
11
13
|
destroy(): void;
|
package/dist/training/AdamW.d.ts
CHANGED
|
@@ -21,6 +21,9 @@ export declare class AdamWOptimizer extends Optimizer {
|
|
|
21
21
|
protected orthGrad: boolean;
|
|
22
22
|
constructor(config: AdamWOptimizerConfig);
|
|
23
23
|
get lr(): number;
|
|
24
|
+
saveMoments(): Promise<ArrayBuffer>;
|
|
25
|
+
loadMoments(momentData: ArrayBuffer): Promise<void>;
|
|
26
|
+
serializeConfig(): AdamWOptimizerConfig;
|
|
24
27
|
private orthogonalizeGradient;
|
|
25
28
|
updateConfig(newConfig: Partial<AdamWOptimizerConfig>): void;
|
|
26
29
|
applyGradients(variableGradients: NamedVariableMap | NamedTensor[]): Tensor;
|
package/dist/training/AdamW.js
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
import { adamAdjust as
|
|
2
|
-
import { adamMoments as
|
|
3
|
-
import { O as
|
|
4
|
-
import
|
|
5
|
-
import { clipScale as
|
|
6
|
-
import {
|
|
7
|
-
|
|
1
|
+
import { adamAdjust as B } from "../ops/adamAdjust.js";
|
|
2
|
+
import { adamMoments as N } from "../ops/adamMoments.js";
|
|
3
|
+
import { O as S, e as b, t as c, b as M, l as w } from "../index-DSGwv2Yx.js";
|
|
4
|
+
import R from "./LRScheduler.js";
|
|
5
|
+
import { clipScale as f } from "../ops/globalNorm.js";
|
|
6
|
+
import { save_safetensors as v, load_safetensors as A } from "../utilities/safetensors.js";
|
|
7
|
+
import { z as O } from "../zeros-Bw0puq_w.js";
|
|
8
|
+
class _ extends S {
|
|
8
9
|
constructor(t) {
|
|
9
|
-
super(), this.config = t, this.accBeta1 = t.beta1, this.accBeta2 = t.beta2, this.learningRate = t.learningRate, this.beta1 = t.beta1, this.beta2 = t.beta2, this.weightDecay = t.weightDecay, this.lossScaling = t.lossScaling, this.clipNorm = t.clipNorm, this.orthGrad = t.orthoGrad ?? !1, t.epsilon === null || t.epsilon === void 0 ? this.epsilon = b().backend.epsilon() : this.epsilon = t.epsilon, this.lrScheduler = new
|
|
10
|
+
super(), this.config = t, this.accBeta1 = t.accBeta1 ?? t.beta1, this.accBeta2 = t.accBeta2 ?? t.beta2, this.learningRate = t.learningRate, this.beta1 = t.beta1, this.beta2 = t.beta2, this.weightDecay = t.weightDecay, this.lossScaling = t.lossScaling, this.clipNorm = t.clipNorm, this.orthGrad = t.orthoGrad ?? !1, t.epsilon === null || t.epsilon === void 0 ? this.epsilon = b().backend.epsilon() : this.epsilon = t.epsilon, this.lrScheduler = new R(t.learningRate, t);
|
|
10
11
|
}
|
|
11
12
|
className = "AdamW";
|
|
12
13
|
accBeta1 = 0;
|
|
@@ -25,10 +26,38 @@ class G extends R {
|
|
|
25
26
|
get lr() {
|
|
26
27
|
return this.learningRate;
|
|
27
28
|
}
|
|
29
|
+
saveMoments() {
|
|
30
|
+
const t = {};
|
|
31
|
+
return this.accumulatedMoments.forEach((e) => {
|
|
32
|
+
t[e.originalName] = e.variable;
|
|
33
|
+
}), v(t);
|
|
34
|
+
}
|
|
35
|
+
async loadMoments(t) {
|
|
36
|
+
const e = await A(t);
|
|
37
|
+
Object.entries(e).forEach(([a, s]) => {
|
|
38
|
+
const n = s.variable(!1);
|
|
39
|
+
this.accumulatedMoments.push({ originalName: a, variable: n });
|
|
40
|
+
});
|
|
41
|
+
}
|
|
42
|
+
serializeConfig() {
|
|
43
|
+
return {
|
|
44
|
+
learningRate: this.learningRate,
|
|
45
|
+
beta1: this.beta1,
|
|
46
|
+
beta2: this.beta2,
|
|
47
|
+
accBeta1: this.accBeta1,
|
|
48
|
+
accBeta2: this.accBeta2,
|
|
49
|
+
epsilon: this.epsilon ?? void 0,
|
|
50
|
+
weightDecay: this.weightDecay,
|
|
51
|
+
lossScaling: this.lossScaling,
|
|
52
|
+
clipNorm: this.clipNorm,
|
|
53
|
+
orthoGrad: this.orthGrad,
|
|
54
|
+
...this.lrScheduler.serializeConfig()
|
|
55
|
+
};
|
|
56
|
+
}
|
|
28
57
|
orthogonalizeGradient(t, e) {
|
|
29
|
-
return
|
|
30
|
-
const a = t.reshape([-1]), s = e.reshape([-1]),
|
|
31
|
-
return
|
|
58
|
+
return c(() => {
|
|
59
|
+
const a = t.reshape([-1]), s = e.reshape([-1]), n = a.mul(a).sum().add(this.orthGradEpsilon), h = a.mul(s).sum().div(n), o = s.sub(a.mul(h)), l = s.norm(), i = o.norm().add(this.orthGradEpsilon);
|
|
60
|
+
return o.mul(l.div(i)).reshape(e.shape);
|
|
32
61
|
});
|
|
33
62
|
}
|
|
34
63
|
updateConfig(t) {
|
|
@@ -38,42 +67,42 @@ class G extends R {
|
|
|
38
67
|
applyGradients(t) {
|
|
39
68
|
const e = this.lrScheduler.getNextLR();
|
|
40
69
|
this.learningRate = e;
|
|
41
|
-
const a = Array.isArray(t) ? t.map((
|
|
42
|
-
const
|
|
43
|
-
let
|
|
70
|
+
const a = Array.isArray(t) ? t.map((n) => n.name) : Object.keys(t), s = c(() => {
|
|
71
|
+
const n = 1 - this.accBeta1, h = 1 - this.accBeta2;
|
|
72
|
+
let o;
|
|
44
73
|
if (this.clipNorm !== void 0) {
|
|
45
|
-
const
|
|
46
|
-
|
|
74
|
+
const l = a.map((i, r) => Array.isArray(t) ? t[r].tensor : t[i]);
|
|
75
|
+
o = f(l, 1 / this.lossScaling, this.clipNorm);
|
|
47
76
|
} else
|
|
48
|
-
|
|
49
|
-
return a.forEach((
|
|
50
|
-
const r = b().registeredVariables[
|
|
77
|
+
o = M(1 / this.lossScaling);
|
|
78
|
+
return a.forEach((l, i) => {
|
|
79
|
+
const r = b().registeredVariables[l], p = !1;
|
|
51
80
|
this.accumulatedMoments[i] == null && (this.accumulatedMoments[i] = {
|
|
52
|
-
originalName: `${
|
|
53
|
-
variable:
|
|
81
|
+
originalName: `${l}/m`,
|
|
82
|
+
variable: c(() => O([...r.shape, 2]).variable(p))
|
|
54
83
|
});
|
|
55
|
-
const m = Array.isArray(t) ? t[i].tensor : t[
|
|
84
|
+
const m = Array.isArray(t) ? t[i].tensor : t[l];
|
|
56
85
|
if (m == null)
|
|
57
86
|
return;
|
|
58
|
-
const u = this.orthGrad ? this.orthogonalizeGradient(r, m) : m, d = this.accumulatedMoments[i].variable, g =
|
|
87
|
+
const u = this.orthGrad ? this.orthogonalizeGradient(r, m) : m, d = this.accumulatedMoments[i].variable, g = N(d, u, this.beta1, this.beta2, o);
|
|
59
88
|
d.assign(g), this.orthGrad && u.dispose();
|
|
60
|
-
const y =
|
|
89
|
+
const y = B(
|
|
61
90
|
g,
|
|
62
91
|
r,
|
|
63
|
-
|
|
64
|
-
|
|
92
|
+
n,
|
|
93
|
+
h,
|
|
65
94
|
this.epsilon ?? 1e-8,
|
|
66
95
|
this.learningRate,
|
|
67
96
|
// Only apply weight decay if the variable is multi-dimensional (e.g. weights, not biases)
|
|
68
97
|
r.shape.length > 1 ? this.weightDecay : 0
|
|
69
98
|
);
|
|
70
99
|
r.assign(y);
|
|
71
|
-
}), this.accBeta1 = this.accBeta1 * this.beta1, this.accBeta2 = this.accBeta2 * this.beta2,
|
|
100
|
+
}), this.accBeta1 = this.accBeta1 * this.beta1, this.accBeta2 = this.accBeta2 * this.beta2, o;
|
|
72
101
|
});
|
|
73
102
|
return this.incrementIterations(), s;
|
|
74
103
|
}
|
|
75
104
|
dispose() {
|
|
76
|
-
this.accumulatedMoments != null &&
|
|
105
|
+
this.accumulatedMoments != null && w(this.accumulatedMoments.map((t) => t.variable));
|
|
77
106
|
}
|
|
78
107
|
async getWeights() {
|
|
79
108
|
const t = [...this.accumulatedMoments];
|
|
@@ -82,7 +111,7 @@ class G extends R {
|
|
|
82
111
|
);
|
|
83
112
|
}
|
|
84
113
|
async setWeights(t) {
|
|
85
|
-
t = await this.extractIterations(t),
|
|
114
|
+
t = await this.extractIterations(t), c(() => {
|
|
86
115
|
this.accBeta1 = Math.pow(this.beta1, this.iterations_ + 1), this.accBeta2 = Math.pow(this.beta2, this.iterations_ + 1);
|
|
87
116
|
});
|
|
88
117
|
const e = t.length / 2, a = !1;
|
|
@@ -105,5 +134,5 @@ class G extends R {
|
|
|
105
134
|
}
|
|
106
135
|
}
|
|
107
136
|
export {
|
|
108
|
-
|
|
137
|
+
_ as AdamWOptimizer
|
|
109
138
|
};
|
|
@@ -31,6 +31,7 @@ export default class BasicTrainer {
|
|
|
31
31
|
get isRunning(): boolean;
|
|
32
32
|
getOptimizer(): AdamWOptimizer;
|
|
33
33
|
updateOptimizer(config?: Partial<AdamWOptimizerConfig>): void;
|
|
34
|
+
resumeFromLog(log: TrainingLogEntry): void;
|
|
34
35
|
protected trainStep(state: Partial<TrainingState>, batch: {
|
|
35
36
|
xs: Tensor;
|
|
36
37
|
ys: Tensor;
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
import
|
|
2
|
-
import { t as
|
|
3
|
-
import
|
|
4
|
-
import { createTensorStatistics as
|
|
5
|
-
import { calculateLoss as x, calculateAccuracy as
|
|
6
|
-
import { AdamWOptimizer as
|
|
7
|
-
import { z as
|
|
8
|
-
const
|
|
1
|
+
import y from "./Evaluator.js";
|
|
2
|
+
import { t as L, Z as k, k as u, l as p, b as S } from "../index-DSGwv2Yx.js";
|
|
3
|
+
import w from "../utilities/profile.js";
|
|
4
|
+
import { createTensorStatistics as b } from "../checks/weights.js";
|
|
5
|
+
import { calculateLoss as x, calculateAccuracy as P } from "./loss.js";
|
|
6
|
+
import { AdamWOptimizer as T } from "./AdamW.js";
|
|
7
|
+
import { z as v } from "../zeros-Bw0puq_w.js";
|
|
8
|
+
const z = {
|
|
9
9
|
logInterval: 1,
|
|
10
10
|
maxEpochs: 100,
|
|
11
11
|
sftMode: "full",
|
|
12
12
|
batchSize: 32
|
|
13
|
-
},
|
|
13
|
+
}, D = {
|
|
14
14
|
learningRate: 3e-4,
|
|
15
15
|
beta1: 0.9,
|
|
16
16
|
beta2: 0.99,
|
|
@@ -23,14 +23,14 @@ const v = {
|
|
|
23
23
|
lossScaling: 1
|
|
24
24
|
};
|
|
25
25
|
class G {
|
|
26
|
-
constructor(s,
|
|
27
|
-
this.tokenizer =
|
|
28
|
-
...
|
|
29
|
-
...
|
|
26
|
+
constructor(s, e, n, l) {
|
|
27
|
+
this.tokenizer = e, this.model = s, this.optimizerConfig = {
|
|
28
|
+
...D,
|
|
29
|
+
...n,
|
|
30
30
|
lossScaling: s.lossScaling
|
|
31
31
|
};
|
|
32
|
-
const
|
|
33
|
-
|
|
32
|
+
const m = l || new T(this.optimizerConfig);
|
|
33
|
+
l && l.updateConfig(this.optimizerConfig), this.optimizer = m;
|
|
34
34
|
}
|
|
35
35
|
model;
|
|
36
36
|
optimizer;
|
|
@@ -80,11 +80,22 @@ class G {
|
|
|
80
80
|
updateOptimizer(s) {
|
|
81
81
|
s && (this.optimizerConfig = { ...this.optimizerConfig, ...s }), this.optimizer.updateConfig(this.optimizerConfig);
|
|
82
82
|
}
|
|
83
|
+
resumeFromLog(s) {
|
|
84
|
+
(!this.lastState || this.lastState.step === 0) && (this.lastState = {
|
|
85
|
+
losses: [],
|
|
86
|
+
validationLosses: [],
|
|
87
|
+
logStartTime: 0,
|
|
88
|
+
step: s.step,
|
|
89
|
+
lastLoss: s.trainingMetrics.loss,
|
|
90
|
+
totalSteps: s.step,
|
|
91
|
+
trainingDuration: s.duration
|
|
92
|
+
});
|
|
93
|
+
}
|
|
83
94
|
// A single forward pass, backward pass, and optimizer step
|
|
84
|
-
trainStep(s,
|
|
85
|
-
return
|
|
95
|
+
trainStep(s, e, n = !1, l = !1) {
|
|
96
|
+
return L(() => {
|
|
86
97
|
this.model.getProfiler()?.startMemory();
|
|
87
|
-
const { xs:
|
|
98
|
+
const { xs: m, ys: i } = e, d = () => {
|
|
88
99
|
const r = this.model.forward(
|
|
89
100
|
{
|
|
90
101
|
training: !0,
|
|
@@ -93,32 +104,32 @@ class G {
|
|
|
93
104
|
dropout: this._dropout,
|
|
94
105
|
layerDrop: this._layerDrop
|
|
95
106
|
},
|
|
96
|
-
|
|
97
|
-
),
|
|
98
|
-
this.metrics.has("accuracy") && (s.accuracy =
|
|
99
|
-
const
|
|
100
|
-
return
|
|
101
|
-
}, { value: t, grads:
|
|
102
|
-
if (
|
|
107
|
+
m
|
|
108
|
+
), o = x(r, i, this.maskedLoss, !1, this._labelSmoothing);
|
|
109
|
+
this.metrics.has("accuracy") && (s.accuracy = P(r, i), u(s.accuracy)), r.dispose();
|
|
110
|
+
const a = o.mul(S(this.optimizerConfig.lossScaling));
|
|
111
|
+
return o.dispose(), a;
|
|
112
|
+
}, { value: t, grads: c } = k(d);
|
|
113
|
+
if (n)
|
|
103
114
|
this.model.getProfiler()?.endMemory("Training");
|
|
104
115
|
else {
|
|
105
|
-
const r = this.optimizer.applyGradients(
|
|
106
|
-
this.metrics.has("gradientNorm") ? (s.gradientNorm = r,
|
|
107
|
-
const
|
|
108
|
-
this.model.weightStore.touchVariables(
|
|
116
|
+
const r = this.optimizer.applyGradients(c);
|
|
117
|
+
this.metrics.has("gradientNorm") ? (s.gradientNorm = r, u(r)) : (s.gradientNorm = void 0, r.dispose());
|
|
118
|
+
const o = Object.keys(c);
|
|
119
|
+
this.model.weightStore.touchVariables(o), this.model.getProfiler()?.endMemory("Training"), l ? (s.gradients = c, Object.values(c).forEach((a) => u(a))) : p(c);
|
|
109
120
|
}
|
|
110
|
-
return t.mul(
|
|
121
|
+
return t.mul(S(1 / this.optimizerConfig.lossScaling));
|
|
111
122
|
});
|
|
112
123
|
}
|
|
113
124
|
async dummyPass() {
|
|
114
|
-
const s =
|
|
125
|
+
const s = v([1, this.model.config.blockSize], "int32"), e = v([1, this.model.config.blockSize], "int32");
|
|
115
126
|
try {
|
|
116
|
-
const
|
|
117
|
-
await
|
|
118
|
-
} catch (
|
|
119
|
-
console.error("Error during dummy pass:",
|
|
127
|
+
const n = this.trainStep({}, { xs: s, ys: e }, !0);
|
|
128
|
+
await n.data(), n.dispose();
|
|
129
|
+
} catch (n) {
|
|
130
|
+
console.error("Error during dummy pass:", n);
|
|
120
131
|
} finally {
|
|
121
|
-
s.dispose(),
|
|
132
|
+
s.dispose(), e.dispose();
|
|
122
133
|
}
|
|
123
134
|
}
|
|
124
135
|
dispose() {
|
|
@@ -136,33 +147,40 @@ class G {
|
|
|
136
147
|
...this.lastState || {}
|
|
137
148
|
};
|
|
138
149
|
}
|
|
139
|
-
async stepDataset(s,
|
|
140
|
-
const { logInterval:
|
|
141
|
-
...
|
|
142
|
-
...
|
|
150
|
+
async stepDataset(s, e, n) {
|
|
151
|
+
const { logInterval: l = 10 } = {
|
|
152
|
+
...z,
|
|
153
|
+
...e
|
|
143
154
|
};
|
|
144
|
-
|
|
145
|
-
const
|
|
146
|
-
this.lastState =
|
|
147
|
-
const d =
|
|
155
|
+
e.metrics && this.setMetrics(e.metrics);
|
|
156
|
+
const m = Date.now(), i = this.createEmptyState();
|
|
157
|
+
this.lastState = i, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new w())), this.running = !0, i.logStartTime = m;
|
|
158
|
+
const d = n ? new y(this.model, n, void 0, this.maskedLoss) : void 0, t = await s.iterator();
|
|
148
159
|
try {
|
|
149
160
|
for (; this.running; ) {
|
|
150
|
-
const
|
|
151
|
-
if (
|
|
152
|
-
const r =
|
|
153
|
-
r.xs.dispose(), r.ys.dispose(),
|
|
161
|
+
const c = await t.next();
|
|
162
|
+
if (c.done) break;
|
|
163
|
+
const r = c.value, o = this.trainStep(i, r, !1);
|
|
164
|
+
r.xs.dispose(), r.ys.dispose(), i.step++, i.totalSteps++, i.step % l === 0 ? await this.performLogging(o, r.xs.shape[0], e, d) : (i.gradientNorm && (i.gradientNorm.dispose(), i.gradientNorm = void 0), i.accuracy && (i.accuracy.dispose(), i.accuracy = void 0)), o.dispose();
|
|
154
165
|
}
|
|
155
|
-
} catch (
|
|
156
|
-
throw console.error("Training error:",
|
|
166
|
+
} catch (c) {
|
|
167
|
+
throw console.error("Training error:", c), c;
|
|
157
168
|
}
|
|
158
|
-
throw
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
169
|
+
throw this.model.trainingState = {
|
|
170
|
+
steps: i.totalSteps,
|
|
171
|
+
learningRate: this.optimizer.lr,
|
|
172
|
+
batchSize: e.batchSize || 32,
|
|
173
|
+
loss: i.lastLoss,
|
|
174
|
+
tokensProcessed: i.totalSteps * (e.batchSize || 32) * this.model.config.blockSize,
|
|
175
|
+
duration: i.trainingDuration
|
|
176
|
+
}, p(), this.running = !1, new Error("No log returned before training stopped.");
|
|
177
|
+
}
|
|
178
|
+
async performLogging(s, e, n, l) {
|
|
179
|
+
const m = n?.onStep, i = this.metrics.has("gradientStatistics"), d = (await s.data())[0], t = this.lastState;
|
|
162
180
|
t.lastLoss = d;
|
|
163
|
-
const
|
|
164
|
-
t.trainingDuration +=
|
|
165
|
-
const r = {
|
|
181
|
+
const c = Date.now();
|
|
182
|
+
t.trainingDuration += c - t.logStartTime;
|
|
183
|
+
const r = t.totalSteps * e * this.model.config.blockSize, o = {
|
|
166
184
|
trainingMetrics: {
|
|
167
185
|
loss: t.lastLoss,
|
|
168
186
|
perplexity: this.metrics.has("perplexity") ? Math.exp(t.lastLoss) : void 0,
|
|
@@ -171,55 +189,57 @@ class G {
|
|
|
171
189
|
step: t.step,
|
|
172
190
|
time: Date.now() - t.logStartTime,
|
|
173
191
|
gradientNorm: t.gradientNorm ? (await t.gradientNorm.data())[1] : void 0,
|
|
174
|
-
batchSize:
|
|
192
|
+
batchSize: e,
|
|
175
193
|
learningRate: this.metrics.has("learningRate") ? this.optimizer.lr : void 0,
|
|
176
194
|
duration: t.trainingDuration,
|
|
177
|
-
|
|
178
|
-
|
|
195
|
+
totalTokens: r,
|
|
196
|
+
tokensPerSecond: r / (t.trainingDuration / 1e3),
|
|
179
197
|
memoryUsage: this.metrics.has("memoryUsage") ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
180
198
|
};
|
|
181
|
-
if (
|
|
199
|
+
if (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0), this.model.trainingState = {
|
|
182
200
|
steps: t.totalSteps,
|
|
183
201
|
learningRate: this.optimizer.lr,
|
|
184
|
-
batchSize:
|
|
185
|
-
loss: t.lastLoss
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
202
|
+
batchSize: e,
|
|
203
|
+
loss: t.lastLoss,
|
|
204
|
+
tokensProcessed: r,
|
|
205
|
+
duration: t.trainingDuration
|
|
206
|
+
}, i && t.gradients) {
|
|
207
|
+
const a = /* @__PURE__ */ new Map();
|
|
208
|
+
for (const [h, g] of Object.entries(t.gradients))
|
|
209
|
+
a.set(h, await b(g)), g.dispose();
|
|
210
|
+
o.gradientMetrics = a;
|
|
191
211
|
}
|
|
192
|
-
if (
|
|
212
|
+
if (l)
|
|
193
213
|
try {
|
|
194
|
-
const
|
|
195
|
-
Array.isArray(
|
|
196
|
-
accuracy:
|
|
197
|
-
loss:
|
|
198
|
-
perplexity: this.metrics.has("perplexity") ? Math.exp(
|
|
214
|
+
const a = await l.evaluate(5);
|
|
215
|
+
Array.isArray(a) ? o.validationMetrics = { loss: a[0].loss, accuracy: a[0].accuracy } : (t.validationLosses.push(a.loss), o.validationMetrics = {
|
|
216
|
+
accuracy: a.accuracy,
|
|
217
|
+
loss: a.loss,
|
|
218
|
+
perplexity: this.metrics.has("perplexity") ? Math.exp(a.loss) : void 0
|
|
199
219
|
});
|
|
200
|
-
} catch (
|
|
201
|
-
console.error("Validation error:",
|
|
220
|
+
} catch (a) {
|
|
221
|
+
console.error("Validation error:", a);
|
|
202
222
|
}
|
|
203
|
-
|
|
204
|
-
}
|
|
205
|
-
async trainOnDataset(s,
|
|
206
|
-
const { logInterval:
|
|
207
|
-
...
|
|
208
|
-
...
|
|
209
|
-
},
|
|
210
|
-
|
|
223
|
+
m && await m(o), t.logStartTime = Date.now();
|
|
224
|
+
}
|
|
225
|
+
async trainOnDataset(s, e, n) {
|
|
226
|
+
const { logInterval: l = 10, maxEpochs: m = 1 / 0 } = {
|
|
227
|
+
...z,
|
|
228
|
+
...e
|
|
229
|
+
}, i = m * (e?.epochSteps || 1e3);
|
|
230
|
+
e.metrics && this.setMetrics(e.metrics);
|
|
211
231
|
const d = Date.now(), t = this.createEmptyState();
|
|
212
|
-
this.lastState = t, await this.dummyPass(),
|
|
213
|
-
const
|
|
232
|
+
this.lastState = t, await this.dummyPass(), e?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new w())), this.running = !0, t.logStartTime = d;
|
|
233
|
+
const c = n ? new y(this.model, n, void 0, this.maskedLoss) : void 0, r = await s.iterator();
|
|
214
234
|
try {
|
|
215
235
|
for (; this.running; ) {
|
|
216
|
-
const
|
|
217
|
-
if (
|
|
218
|
-
const
|
|
219
|
-
|
|
236
|
+
const o = await r.next();
|
|
237
|
+
if (o.done) break;
|
|
238
|
+
const a = o.value, h = t.step % l === 0, g = (e?.metrics?.includes("gradientStatistics") || !1) && h, f = this.trainStep(t, a, !1, g);
|
|
239
|
+
a.xs.dispose(), a.ys.dispose(), t.step++, t.totalSteps++, h ? await this.performLogging(f, a.xs.shape[0], e, c) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= i && this.stop();
|
|
220
240
|
}
|
|
221
|
-
} catch (
|
|
222
|
-
throw console.error("Training error:",
|
|
241
|
+
} catch (o) {
|
|
242
|
+
throw console.error("Training error:", o), p(), o;
|
|
223
243
|
}
|
|
224
244
|
return p(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
|
|
225
245
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { t as f } from "../index-
|
|
2
|
-
import "../dataset-
|
|
3
|
-
import { g as a } from "../readers-
|
|
1
|
+
import { t as f } from "../index-DSGwv2Yx.js";
|
|
2
|
+
import "../dataset-DlqAN81i.js";
|
|
3
|
+
import { g as a } from "../readers-17HLdxVM.js";
|
|
4
4
|
import "../index-Cp39cXWe.js";
|
|
5
5
|
const g = 8;
|
|
6
6
|
async function p(n, e) {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { t as p } from "../index-
|
|
1
|
+
import { t as p } from "../index-DSGwv2Yx.js";
|
|
2
2
|
import { calculateLoss as d, calculateAccuracy as m } from "./loss.js";
|
|
3
3
|
import { buildSFTExample as x } from "./SFTDatasetBuilder.js";
|
|
4
|
-
import { t as h } from "../tensor-
|
|
4
|
+
import { t as h } from "../tensor-D8e0Gd7c.js";
|
|
5
5
|
class k {
|
|
6
6
|
constructor(i, t, o, c) {
|
|
7
7
|
if (this.model = i, this.masked = !!c, Array.isArray(t)) {
|