@genai-fi/nanogpt 0.5.4 → 0.5.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +5 -5
- package/dist/NanoGPTModel.d.ts +2 -0
- package/dist/NanoGPTModel.js +8 -8
- package/dist/{Reshape-Bt_t7RNz.js → Reshape-Biok_3X1.js} +6 -6
- package/dist/TeachableLLM.js +11 -11
- package/dist/{TiedEmbedding-DORsPlNL.js → TiedEmbedding-8S8xn8e6.js} +5 -5
- package/dist/Trainer.d.ts +1 -0
- package/dist/Trainer.js +8 -7
- package/dist/{axis_util-CVbf1vmL.js → axis_util-BczFISHz.js} +1 -1
- package/dist/{broadcast_to-BBoMQXbL.js → broadcast_to-B7NGsBSh.js} +2 -2
- package/dist/{concat-BRRtq4S2.js → concat-DdKPyAtw.js} +1 -1
- package/dist/{dataset-ZHEPJmED.js → dataset-iqT4Otvb.js} +7 -7
- package/dist/{dropout-lQm_YyX3.js → dropout-B09InSJS.js} +1 -1
- package/dist/{gather-BWyutxwi.js → gather-D6MsdXqc.js} +1 -1
- package/dist/{gpgpu_math-Df7gzJWH.js → gpgpu_math-BFbOyvk4.js} +1 -1
- package/dist/{index-CnHyhpKc.js → index-Du-bmOP8.js} +98 -98
- package/dist/{kernel_funcs_utils-Dqo82NH4.js → kernel_funcs_utils-DShm7-0k.js} +33 -33
- package/dist/layers/BaseLayer.js +2 -2
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +6 -6
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/{log_sum_exp-CRH7Np9v.js → log_sum_exp-CxfBtUaG.js} +5 -5
- package/dist/main.d.ts +1 -0
- package/dist/main.js +1 -1
- package/dist/{mat_mul-DeGU1U_C.js → mat_mul-CbiqIe2d.js} +1 -1
- package/dist/{max-CcnEArWK.js → max-0Xnlpv8k.js} +1 -1
- package/dist/{norm-BpWsOapl.js → norm-01kY9I2B.js} +5 -5
- package/dist/{ones-CDWGzVnm.js → ones-CrutWGas.js} +2 -2
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +3 -3
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +4 -4
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +96 -96
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.js +4 -4
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{ops-DzQTmLIl.js → ops-CJNniCAV.js} +13 -13
- package/dist/{random_width-DI2h9CMs.js → random_width-C-v-35bY.js} +1324 -1279
- package/dist/{range-CkOJ7090.js → range-Bvs1hidm.js} +1 -1
- package/dist/{reshape-CTIbqjwm.js → reshape-BH7eBpwq.js} +1 -1
- package/dist/{sin-HzioENy_.js → sin-CPAZXNjH.js} +1 -1
- package/dist/{slice_util-n4wHKmex.js → slice_util-DskXqRZa.js} +1 -1
- package/dist/{softmax-DX6qXAbm.js → softmax-DhWoBa7r.js} +1 -1
- package/dist/{split-CVwhL8Oe.js → split-BCUhuU7B.js} +1 -1
- package/dist/{stack-S2-D2JAQ.js → stack-BV1v7l3S.js} +1 -1
- package/dist/{sum-UdfvaNhB.js → sum-Cvq06317.js} +1 -1
- package/dist/{tensor-IZex6Bwp.js → tensor-DgTOPY6h.js} +1 -1
- package/dist/{tensor2d-CqtBzOKq.js → tensor2d-CRWjDyUe.js} +1 -1
- package/dist/{tfjs_backend-DX9yVvwk.js → tfjs_backend-D9Ytje0G.js} +39 -39
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/FullTrainer.js +36 -32
- package/dist/training/Trainer.d.ts +7 -4
- package/dist/training/Trainer.js +58 -50
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.d.ts +1 -0
- package/dist/utilities/profile.js +6 -3
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BGvK-VN3.js → variable-DZ3fF0R2.js} +1 -1
- package/dist/{zeros-CYMicyqz.js → zeros-BaHhQTWf.js} +1 -1
- package/package.json +1 -1
- package/dist/moments-DLTE6-1p.js +0 -53
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
import { o as h, j as f, E as $,
|
|
2
|
-
import { s as C, t as Ke, a as Ue, b as ve } from "./ops-
|
|
3
|
-
import { r as Re, d as Ve } from "./dropout-
|
|
4
|
-
import { r as u } from "./reshape-
|
|
5
|
-
import { g as qe } from "./gather-
|
|
6
|
-
import { s as Ge } from "./sum-
|
|
7
|
-
import { m as A } from "./mat_mul-
|
|
8
|
-
import { c as M } from "./concat-
|
|
1
|
+
import { o as h, j as f, E as $, ao as Te, l as _, g as Ee, ap as xe, aq as Ie, ar as Le, as as be, at as Ne, au as Ce, av as Pe, b as H, aw as Fe, a8 as U, u as ae, q as ie, Q as le, c as fe, ax as he, ai as pe, ay as je, t as S, D as $e, al as Me, a2 as Be } from "./index-Du-bmOP8.js";
|
|
2
|
+
import { s as C, t as Ke, a as Ue, b as ve } from "./ops-CJNniCAV.js";
|
|
3
|
+
import { r as Re, d as Ve } from "./dropout-B09InSJS.js";
|
|
4
|
+
import { r as u } from "./reshape-BH7eBpwq.js";
|
|
5
|
+
import { g as qe } from "./gather-D6MsdXqc.js";
|
|
6
|
+
import { s as Ge } from "./sum-Cvq06317.js";
|
|
7
|
+
import { m as A } from "./mat_mul-CbiqIe2d.js";
|
|
8
|
+
import { c as M } from "./concat-DdKPyAtw.js";
|
|
9
9
|
/**
|
|
10
10
|
* @license
|
|
11
11
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -213,11 +213,11 @@ const X = /* @__PURE__ */ h({ slice1d_: dn });
|
|
|
213
213
|
* limitations under the License.
|
|
214
214
|
* =============================================================================
|
|
215
215
|
*/
|
|
216
|
-
function
|
|
216
|
+
function gn(e, n, t) {
|
|
217
217
|
const r = f(e, "x", "slice2d");
|
|
218
218
|
return _(r.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${r.rank} tensor`), C(r, n, t);
|
|
219
219
|
}
|
|
220
|
-
const we = /* @__PURE__ */ h({ slice2d_:
|
|
220
|
+
const we = /* @__PURE__ */ h({ slice2d_: gn });
|
|
221
221
|
/**
|
|
222
222
|
* @license
|
|
223
223
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -234,11 +234,11 @@ const we = /* @__PURE__ */ h({ slice2d_: mn });
|
|
|
234
234
|
* limitations under the License.
|
|
235
235
|
* =============================================================================
|
|
236
236
|
*/
|
|
237
|
-
function
|
|
237
|
+
function mn(e, n, t) {
|
|
238
238
|
const r = f(e, "x", "slice3d");
|
|
239
239
|
return _(r.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${r.rank} tensor`), C(r, n, t);
|
|
240
240
|
}
|
|
241
|
-
const z = /* @__PURE__ */ h({ slice3d_:
|
|
241
|
+
const z = /* @__PURE__ */ h({ slice3d_: mn });
|
|
242
242
|
/**
|
|
243
243
|
* @license
|
|
244
244
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -350,9 +350,9 @@ function _n({ a: e, b: n, transposeA: t = !1, transposeB: r = !1, bias: s, activ
|
|
|
350
350
|
}
|
|
351
351
|
let c = f(e, "a", "fused matMul"), a = f(n, "b", "fused matMul");
|
|
352
352
|
[c, a] = ae(c, a);
|
|
353
|
-
const k = t ? c.shape[c.rank - 2] : c.shape[c.rank - 1],
|
|
354
|
-
_(k ===
|
|
355
|
-
const R = le(c.shape.slice(0, -2), a.shape.slice(0, -2)).concat([E, d]), V = t ? u(c, [te, k, E]) : u(c, [te, E, k]), q = r ? u(a, [re, d,
|
|
353
|
+
const k = t ? c.shape[c.rank - 2] : c.shape[c.rank - 1], m = r ? a.shape[a.rank - 1] : a.shape[a.rank - 2], E = t ? c.shape[c.rank - 1] : c.shape[c.rank - 2], d = r ? a.shape[a.rank - 2] : a.shape[a.rank - 1], ne = c.shape.slice(0, -2), x = a.shape.slice(0, -2), te = ie(ne), re = ie(x);
|
|
354
|
+
_(k === m, () => `Error in fused matMul: inner shapes (${k}) and (${m}) of Tensors with shapes ${c.shape} and ${a.shape} and transposeA=${t} and transposeB=${r} must match.`);
|
|
355
|
+
const R = le(c.shape.slice(0, -2), a.shape.slice(0, -2)).concat([E, d]), V = t ? u(c, [te, k, E]) : u(c, [te, E, k]), q = r ? u(a, [re, d, m]) : u(a, [re, m, d]);
|
|
356
356
|
let I;
|
|
357
357
|
s != null && (I = f(s, "bias", "fused matMul"), [I] = ae(I, c), le(R, I.shape));
|
|
358
358
|
let se;
|
|
@@ -450,7 +450,7 @@ function Jn(e, n) {
|
|
|
450
450
|
return t.fill(e), t;
|
|
451
451
|
}
|
|
452
452
|
}
|
|
453
|
-
function
|
|
453
|
+
function ge(e, n) {
|
|
454
454
|
if (!e)
|
|
455
455
|
throw new ee(n);
|
|
456
456
|
}
|
|
@@ -473,7 +473,7 @@ function Qn(e) {
|
|
|
473
473
|
function Hn(e) {
|
|
474
474
|
return e.length <= 1 || e.indexOf("_") === -1 ? e : e.replace(/[_]+(\w|$)/g, (n, t) => t.toUpperCase());
|
|
475
475
|
}
|
|
476
|
-
let
|
|
476
|
+
let g = {};
|
|
477
477
|
function Xn(e) {
|
|
478
478
|
if (e == null)
|
|
479
479
|
return null;
|
|
@@ -498,8 +498,8 @@ function zn(e, n = {}, t = {}, r = "object", s = !1) {
|
|
|
498
498
|
let i;
|
|
499
499
|
if (o in t)
|
|
500
500
|
i = t[o];
|
|
501
|
-
else if (o in
|
|
502
|
-
i =
|
|
501
|
+
else if (o in g)
|
|
502
|
+
i = g[o];
|
|
503
503
|
else if (i = n[o], i == null)
|
|
504
504
|
throw new l(`Unknown ${r}: ${e}. This may be due to one of the following reasons:
|
|
505
505
|
1. The ${r} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
|
|
@@ -512,30 +512,30 @@ function zn(e, n = {}, t = {}, r = "object", s = !1) {
|
|
|
512
512
|
'className' and 'config' must set.`);
|
|
513
513
|
const i = o.className;
|
|
514
514
|
let p, c;
|
|
515
|
-
if (i in t ? [p, c] = t[i] : i in
|
|
515
|
+
if (i in t ? [p, c] = t[i] : i in g ? [p, c] = g.className : i in n && ([p, c] = n[i]), p == null)
|
|
516
516
|
throw new l(`Unknown ${r}: ${i}. This may be due to one of the following reasons:
|
|
517
517
|
1. The ${r} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
|
|
518
518
|
2. The custom ${r} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`);
|
|
519
519
|
if (c != null) {
|
|
520
520
|
const a = {};
|
|
521
|
-
for (const d of Object.keys(
|
|
522
|
-
a[d] =
|
|
521
|
+
for (const d of Object.keys(g))
|
|
522
|
+
a[d] = g[d];
|
|
523
523
|
for (const d of Object.keys(t))
|
|
524
524
|
a[d] = t[d];
|
|
525
525
|
const k = o.config;
|
|
526
526
|
k.customObjects = a;
|
|
527
|
-
const
|
|
527
|
+
const m = Object.assign({}, g);
|
|
528
528
|
for (const d of Object.keys(t))
|
|
529
|
-
|
|
529
|
+
g[d] = t[d];
|
|
530
530
|
W(o.config);
|
|
531
531
|
const E = c(p, o.config, t, s);
|
|
532
|
-
return
|
|
532
|
+
return g = Object.assign({}, m), E;
|
|
533
533
|
} else {
|
|
534
|
-
const a = Object.assign({},
|
|
535
|
-
for (const
|
|
536
|
-
m
|
|
534
|
+
const a = Object.assign({}, g);
|
|
535
|
+
for (const m of Object.keys(t))
|
|
536
|
+
g[m] = t[m];
|
|
537
537
|
const k = new p(o.config);
|
|
538
|
-
return
|
|
538
|
+
return g = Object.assign({}, a), k;
|
|
539
539
|
}
|
|
540
540
|
}
|
|
541
541
|
}
|
|
@@ -566,7 +566,7 @@ function v(e, n, t) {
|
|
|
566
566
|
throw new l(`${t} is not a valid ${n}. Valid values are ${e} or null/undefined.`);
|
|
567
567
|
}
|
|
568
568
|
function rt(e, n, t = 0, r = 1 / 0) {
|
|
569
|
-
return
|
|
569
|
+
return ge(t >= 0), ge(r >= t), Array.isArray(e) && e.length >= t && e.length <= r && e.every((s) => typeof s === n);
|
|
570
570
|
}
|
|
571
571
|
function Ln(e, n) {
|
|
572
572
|
Array.isArray(e) ? (_(e.length > 0, () => `${n} is unexpectedly an empty array.`), e.forEach((t, r) => Ln(t, `element ${r + 1} of ${n}`))) : _(Number.isInteger(e) && e > 0, () => `Expected ${n} to be a positive integer, but got ${ye(e)}.`);
|
|
@@ -606,7 +606,7 @@ function ct(e) {
|
|
|
606
606
|
function at(e) {
|
|
607
607
|
v(xn, "PoolMode", e);
|
|
608
608
|
}
|
|
609
|
-
const F = [],
|
|
609
|
+
const F = [], me = "/";
|
|
610
610
|
function it(e, n) {
|
|
611
611
|
F.push(e);
|
|
612
612
|
try {
|
|
@@ -617,7 +617,7 @@ function it(e, n) {
|
|
|
617
617
|
}
|
|
618
618
|
}
|
|
619
619
|
function Nn() {
|
|
620
|
-
return F.length === 0 ? "" : F.join(
|
|
620
|
+
return F.length === 0 ? "" : F.join(me) + me;
|
|
621
621
|
}
|
|
622
622
|
function lt(e) {
|
|
623
623
|
if (!Oe(e))
|
|
@@ -678,7 +678,7 @@ function dt(e) {
|
|
|
678
678
|
}
|
|
679
679
|
return n;
|
|
680
680
|
}
|
|
681
|
-
function
|
|
681
|
+
function gt(e, n) {
|
|
682
682
|
if (n < e)
|
|
683
683
|
throw new l(`end (${n}) < begin (${e}) is forbidden.`);
|
|
684
684
|
const t = [];
|
|
@@ -696,7 +696,7 @@ function mt(e, n) {
|
|
|
696
696
|
* =============================================================================
|
|
697
697
|
*/
|
|
698
698
|
let G;
|
|
699
|
-
function
|
|
699
|
+
function mt() {
|
|
700
700
|
return G == null && (G = je().epsilon()), G;
|
|
701
701
|
}
|
|
702
702
|
function Y() {
|
|
@@ -876,7 +876,7 @@ function Dt(e, n, t, r) {
|
|
|
876
876
|
e = u(e, [-1, o]);
|
|
877
877
|
const i = n.shape.slice(), p = i.pop(), c = i.pop(), a = [...i, p], k = Array.from({ length: n.rank }, (ne, x) => x === 0 ? n.rank - 2 : x <= n.rank - 2 ? x - 1 : x);
|
|
878
878
|
n = u(ve(n, k), [c, -1]);
|
|
879
|
-
const
|
|
879
|
+
const m = [...s, ...a];
|
|
880
880
|
return u(de({
|
|
881
881
|
a: e,
|
|
882
882
|
b: n,
|
|
@@ -884,7 +884,7 @@ function Dt(e, n, t, r) {
|
|
|
884
884
|
transposeB: !1,
|
|
885
885
|
bias: r ? Q(e.rank, r, Y()) : null,
|
|
886
886
|
activation: t
|
|
887
|
-
}),
|
|
887
|
+
}), m);
|
|
888
888
|
}
|
|
889
889
|
}
|
|
890
890
|
function Tt(e, n, t) {
|
|
@@ -951,7 +951,7 @@ export {
|
|
|
951
951
|
J as H,
|
|
952
952
|
Pn as I,
|
|
953
953
|
Tt as J,
|
|
954
|
-
|
|
954
|
+
gt as K,
|
|
955
955
|
Zn as L,
|
|
956
956
|
It as M,
|
|
957
957
|
j as N,
|
|
@@ -1001,10 +1001,10 @@ export {
|
|
|
1001
1001
|
_t as r,
|
|
1002
1002
|
On as s,
|
|
1003
1003
|
Qn as t,
|
|
1004
|
-
|
|
1004
|
+
mt as u,
|
|
1005
1005
|
fn as v,
|
|
1006
1006
|
st as w,
|
|
1007
1007
|
wt as x,
|
|
1008
1008
|
Et as y,
|
|
1009
|
-
|
|
1009
|
+
ge as z
|
|
1010
1010
|
};
|
package/dist/training/AdamExt.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { A as r, b as c, f as h, s as g, e as o } from "../index-
|
|
1
|
+
import { A as r, b as c, f as h, s as g, e as o } from "../index-Du-bmOP8.js";
|
|
2
2
|
class u extends r {
|
|
3
3
|
constructor(t, e, s, a, i) {
|
|
4
4
|
super(t, e, s, a), this.config = i, this.startLearningRate = t;
|
|
@@ -1,21 +1,22 @@
|
|
|
1
|
-
import { generateText as
|
|
1
|
+
import { generateText as T } from "../utilities/generate.js";
|
|
2
2
|
import L from "./Trainer.js";
|
|
3
3
|
import x from "./Evaluator.js";
|
|
4
|
-
import { a as h } from "../index-
|
|
4
|
+
import { a as h } from "../index-Du-bmOP8.js";
|
|
5
|
+
import y from "../utilities/profile.js";
|
|
5
6
|
const D = {
|
|
6
7
|
desiredLoss: 0.01,
|
|
7
8
|
logInterval: 1,
|
|
8
9
|
maxSteps: 1e3
|
|
9
10
|
};
|
|
10
|
-
class
|
|
11
|
-
constructor(
|
|
12
|
-
super(
|
|
11
|
+
class I extends L {
|
|
12
|
+
constructor(i, e, o = 3e-4) {
|
|
13
|
+
super(i, e, o);
|
|
13
14
|
}
|
|
14
15
|
// Train for multiple epochs using Dataset API - FIXED memory leaks
|
|
15
|
-
async trainOnDataset(
|
|
16
|
-
const { desiredLoss:
|
|
16
|
+
async trainOnDataset(i, e, o) {
|
|
17
|
+
const { desiredLoss: p, logInterval: g, onStep: l, prompt: c, maxSteps: u } = {
|
|
17
18
|
...D,
|
|
18
|
-
...
|
|
19
|
+
...e
|
|
19
20
|
}, n = Date.now(), t = {
|
|
20
21
|
step: 0,
|
|
21
22
|
lastLoss: 1e6,
|
|
@@ -26,52 +27,55 @@ class E extends L {
|
|
|
26
27
|
trainingDuration: 0,
|
|
27
28
|
...this.lastState || {}
|
|
28
29
|
};
|
|
29
|
-
this.lastState = t, this.dummyPass(), this.model.trainable = !0, this.running = !0, t.logStartTime = n;
|
|
30
|
-
const m = o ? new x(this.model, o) : void 0,
|
|
30
|
+
this.lastState = t, this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, t.logStartTime = n;
|
|
31
|
+
const m = o ? new x(this.model, o) : void 0, f = await i.iterator();
|
|
31
32
|
try {
|
|
32
|
-
for (; this.running && !(t.lastLoss <
|
|
33
|
-
const
|
|
34
|
-
if (
|
|
35
|
-
const
|
|
33
|
+
for (; this.running && !(t.lastLoss < p); ) {
|
|
34
|
+
const r = await f.next();
|
|
35
|
+
if (r.done) break;
|
|
36
|
+
const d = r.value, v = this.trainBatch(t, d, e.advancedMetrics || !1), s = {
|
|
36
37
|
loss: t.lastLoss,
|
|
37
38
|
step: t.step,
|
|
38
39
|
time: Date.now() - n,
|
|
39
|
-
batchSize:
|
|
40
|
+
batchSize: d.xs.shape[0],
|
|
41
|
+
learningRate: e?.advancedMetrics ? this.optimizer.lr : void 0,
|
|
42
|
+
gradientNorm: e?.advancedMetrics ? t.gradientNorm : void 0
|
|
40
43
|
};
|
|
41
|
-
if (this.model.log.push(s), t.step %
|
|
42
|
-
await
|
|
43
|
-
const
|
|
44
|
-
if (t.trainingDuration +=
|
|
44
|
+
if (this.model.log.push(s), t.step % g === 0) {
|
|
45
|
+
await v;
|
|
46
|
+
const S = Date.now();
|
|
47
|
+
if (t.trainingDuration += S - t.logStartTime, m)
|
|
45
48
|
try {
|
|
46
|
-
const
|
|
47
|
-
t.validationLosses.push(
|
|
48
|
-
} catch (
|
|
49
|
-
console.error("Validation error:",
|
|
49
|
+
const a = await m.evaluate(5);
|
|
50
|
+
t.validationLosses.push(a), s.valLoss = a;
|
|
51
|
+
} catch (a) {
|
|
52
|
+
console.error("Validation error:", a);
|
|
50
53
|
}
|
|
51
54
|
if (l) {
|
|
52
55
|
if (c) {
|
|
53
|
-
const
|
|
56
|
+
const w = await T(this.tokenizer, this.model, c, 100, {
|
|
54
57
|
temperature: 0.8
|
|
55
58
|
});
|
|
56
|
-
s.example =
|
|
59
|
+
s.example = w;
|
|
57
60
|
}
|
|
58
|
-
const
|
|
61
|
+
const a = {
|
|
59
62
|
duration: t.trainingDuration,
|
|
60
63
|
totalSamples: t.totalSteps * s.batchSize,
|
|
61
|
-
samplesPerSecond: t.totalSteps * s.batchSize / (t.trainingDuration / 1e3)
|
|
64
|
+
samplesPerSecond: t.totalSteps * s.batchSize / (t.trainingDuration / 1e3),
|
|
65
|
+
memory: e.advancedMetrics ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
|
|
62
66
|
};
|
|
63
|
-
await l(s,
|
|
67
|
+
await l(s, a);
|
|
64
68
|
}
|
|
65
69
|
t.logStartTime = Date.now();
|
|
66
70
|
}
|
|
67
|
-
t.step >=
|
|
71
|
+
t.step >= u && this.stop();
|
|
68
72
|
}
|
|
69
|
-
} catch (
|
|
70
|
-
throw console.error("Training error:",
|
|
73
|
+
} catch (r) {
|
|
74
|
+
throw console.error("Training error:", r), h(), r;
|
|
71
75
|
}
|
|
72
76
|
return h(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
|
|
73
77
|
}
|
|
74
78
|
}
|
|
75
79
|
export {
|
|
76
|
-
|
|
80
|
+
I as default
|
|
77
81
|
};
|
|
@@ -11,11 +11,13 @@ export interface TrainingState {
|
|
|
11
11
|
totalSteps: number;
|
|
12
12
|
losses: number[];
|
|
13
13
|
validationLosses: number[];
|
|
14
|
+
gradientNorm?: number;
|
|
14
15
|
}
|
|
15
16
|
export interface TrainingProgress {
|
|
16
17
|
duration: number;
|
|
17
18
|
totalSamples: number;
|
|
18
19
|
samplesPerSecond: number;
|
|
20
|
+
memory?: number;
|
|
19
21
|
}
|
|
20
22
|
export interface AdamConfig {
|
|
21
23
|
learningRateFactor: number;
|
|
@@ -28,6 +30,7 @@ export interface TrainingOptions {
|
|
|
28
30
|
logInterval: number;
|
|
29
31
|
prompt?: string;
|
|
30
32
|
maxSteps: number;
|
|
33
|
+
advancedMetrics?: boolean;
|
|
31
34
|
onStep?: (log: TrainingLogEntry, progress: TrainingProgress) => Promise<void> | void;
|
|
32
35
|
}
|
|
33
36
|
export default abstract class GPTTrainer {
|
|
@@ -44,16 +47,16 @@ export default abstract class GPTTrainer {
|
|
|
44
47
|
stop(): void;
|
|
45
48
|
getOptimizer(): AdamExt;
|
|
46
49
|
resetOptimizer(config?: AdamConfig): void;
|
|
47
|
-
private
|
|
48
|
-
protected trainStep(batch: {
|
|
50
|
+
private maxGradNorm;
|
|
51
|
+
protected trainStep(state: Partial<TrainingState>, batch: {
|
|
49
52
|
xs: Tensor;
|
|
50
53
|
ys: Tensor;
|
|
51
|
-
}, dummy?: boolean,
|
|
54
|
+
}, dummy?: boolean, calcNorm?: boolean): Scalar;
|
|
52
55
|
protected dummyPass(): void;
|
|
53
56
|
protected trainBatch(state: TrainingState, batch: {
|
|
54
57
|
xs: Tensor;
|
|
55
58
|
ys: Tensor;
|
|
56
|
-
}): Promise<number>;
|
|
59
|
+
}, calcNorm?: boolean): Promise<number>;
|
|
57
60
|
abstract trainOnDataset(dataset: Dataset<{
|
|
58
61
|
xs: Tensor;
|
|
59
62
|
ys: Tensor;
|
package/dist/training/Trainer.js
CHANGED
|
@@ -1,13 +1,11 @@
|
|
|
1
|
-
import { DatasetBuilder as
|
|
2
|
-
import
|
|
3
|
-
import { t as
|
|
4
|
-
import {
|
|
5
|
-
import {
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
constructor(t, s, e = 1e-3) {
|
|
10
|
-
this.tokenizer = s, this.model = t, this.learningRate = e, this.resetOptimizer(), this.datasetBuilder = new h(s, t.config.gpt.blockSize);
|
|
1
|
+
import { DatasetBuilder as g, flattenTokens as m, PAGE_FACTOR as u } from "./DatasetBuilder.js";
|
|
2
|
+
import f from "./AdamExt.js";
|
|
3
|
+
import { t as y, v as z, a as c } from "../index-Du-bmOP8.js";
|
|
4
|
+
import { n as S } from "../norm-01kY9I2B.js";
|
|
5
|
+
import { z as p } from "../zeros-BaHhQTWf.js";
|
|
6
|
+
class R {
|
|
7
|
+
constructor(t, e, s = 1e-3) {
|
|
8
|
+
this.tokenizer = e, this.model = t, this.learningRate = s, this.resetOptimizer(), this.datasetBuilder = new g(e, t.config.gpt.blockSize);
|
|
11
9
|
}
|
|
12
10
|
model;
|
|
13
11
|
optimizer;
|
|
@@ -29,7 +27,7 @@ class G {
|
|
|
29
27
|
}
|
|
30
28
|
resetOptimizer(t = { learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 }) {
|
|
31
29
|
this.optimizer && this.optimizer.dispose();
|
|
32
|
-
const
|
|
30
|
+
const e = new f(
|
|
33
31
|
t.learningRateFactor * this.learningRate,
|
|
34
32
|
t.beta1,
|
|
35
33
|
t.beta2,
|
|
@@ -41,68 +39,78 @@ class G {
|
|
|
41
39
|
weightDecay: 0
|
|
42
40
|
}
|
|
43
41
|
);
|
|
44
|
-
this.optimizer =
|
|
42
|
+
this.optimizer = e;
|
|
45
43
|
}
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
44
|
+
maxGradNorm(t) {
|
|
45
|
+
let e = 0;
|
|
46
|
+
return Object.keys(t).forEach((s) => {
|
|
47
|
+
const a = t[s], r = S(a), i = r.dataSync()[0];
|
|
48
|
+
r.dispose(), i > e && (e = i);
|
|
49
|
+
}), e;
|
|
51
50
|
}
|
|
52
|
-
trainStep(t, s = !1,
|
|
53
|
-
return
|
|
51
|
+
trainStep(t, e, s = !1, a = !1) {
|
|
52
|
+
return y(() => {
|
|
54
53
|
this.model.getProfiler()?.startMemory();
|
|
55
|
-
const { xs:
|
|
56
|
-
const [
|
|
57
|
-
return
|
|
58
|
-
}, { value:
|
|
59
|
-
|
|
54
|
+
const { xs: r, ys: i } = e, d = () => {
|
|
55
|
+
const [n, h] = this.model.forward({ training: !0 }, r, i);
|
|
56
|
+
return n.dispose(), h;
|
|
57
|
+
}, { value: l, grads: o } = z(d);
|
|
58
|
+
if (s)
|
|
59
|
+
this.model.getProfiler()?.endMemory("Training");
|
|
60
|
+
else {
|
|
61
|
+
if (a) {
|
|
62
|
+
const n = this.maxGradNorm(o);
|
|
63
|
+
t.gradientNorm = n;
|
|
64
|
+
}
|
|
65
|
+
this.optimizer.applyGradients(o), this.model.getProfiler()?.endMemory("Training"), c(o);
|
|
66
|
+
}
|
|
67
|
+
return l;
|
|
60
68
|
});
|
|
61
69
|
}
|
|
62
70
|
dummyPass() {
|
|
63
|
-
const t = p([1, this.model.config.gpt.blockSize], "int32"),
|
|
71
|
+
const t = p([1, this.model.config.gpt.blockSize], "int32"), e = p([1, this.model.config.gpt.blockSize], "int32");
|
|
64
72
|
try {
|
|
65
|
-
const
|
|
66
|
-
|
|
67
|
-
} catch (
|
|
68
|
-
console.error("Error during dummy pass:",
|
|
73
|
+
const s = this.trainStep({}, { xs: t, ys: e }, !0);
|
|
74
|
+
s.dataSync(), s.dispose();
|
|
75
|
+
} catch (s) {
|
|
76
|
+
console.error("Error during dummy pass:", s);
|
|
69
77
|
} finally {
|
|
70
|
-
t.dispose(),
|
|
78
|
+
t.dispose(), e.dispose();
|
|
71
79
|
}
|
|
72
80
|
}
|
|
73
|
-
async trainBatch(t, s) {
|
|
81
|
+
async trainBatch(t, e, s = !1) {
|
|
74
82
|
try {
|
|
75
|
-
const
|
|
76
|
-
return
|
|
77
|
-
} catch (
|
|
78
|
-
throw console.error(`Error processing batch at step ${t.step}:`,
|
|
83
|
+
const a = this.trainStep(t, e, !1, s);
|
|
84
|
+
return e.xs.dispose(), e.ys.dispose(), t.step++, t.totalSteps++, a.array().then((r) => (t.lastLoss = r, t.losses.push(t.lastLoss), a.dispose(), t.lastLoss));
|
|
85
|
+
} catch (a) {
|
|
86
|
+
throw console.error(`Error processing batch at step ${t.step}:`, a), c(), a;
|
|
79
87
|
}
|
|
80
88
|
}
|
|
81
|
-
async createTrainValidationSplit(t,
|
|
82
|
-
const a = await
|
|
83
|
-
if (
|
|
84
|
-
const
|
|
85
|
-
for (;
|
|
86
|
-
const
|
|
87
|
-
|
|
89
|
+
async createTrainValidationSplit(t, e = 32, s = 0.1) {
|
|
90
|
+
const a = await m(t, this.tokenizer), r = /* @__PURE__ */ new Set();
|
|
91
|
+
if (s > 0) {
|
|
92
|
+
const l = Math.floor(a.length / (this.datasetBuilder.blockSize * u)), o = Math.max(1, Math.floor(l * s));
|
|
93
|
+
for (; r.size < o; ) {
|
|
94
|
+
const n = Math.floor(Math.random() * l);
|
|
95
|
+
r.add(n);
|
|
88
96
|
}
|
|
89
97
|
}
|
|
90
|
-
const
|
|
98
|
+
const i = await this.datasetBuilder.createTextDataset(a, e, r, !1), d = await this.datasetBuilder.createTextDataset(
|
|
91
99
|
a,
|
|
92
|
-
|
|
93
|
-
|
|
100
|
+
e,
|
|
101
|
+
r,
|
|
94
102
|
!0
|
|
95
103
|
);
|
|
96
|
-
return { trainDataset:
|
|
104
|
+
return { trainDataset: i, validationDataset: d };
|
|
97
105
|
}
|
|
98
|
-
async createDataset(t,
|
|
99
|
-
const
|
|
100
|
-
return await this.datasetBuilder.createTextDataset(
|
|
106
|
+
async createDataset(t, e = 32) {
|
|
107
|
+
const s = await m(t, this.tokenizer);
|
|
108
|
+
return await this.datasetBuilder.createTextDataset(s, e);
|
|
101
109
|
}
|
|
102
110
|
dispose() {
|
|
103
111
|
this.optimizer && this.optimizer.dispose();
|
|
104
112
|
}
|
|
105
113
|
}
|
|
106
114
|
export {
|
|
107
|
-
|
|
115
|
+
R as default
|
|
108
116
|
};
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import { gatherSub as L } from "../ops/gatherSub.js";
|
|
2
2
|
import { scatterSub as y } from "../ops/scatterSub.js";
|
|
3
|
-
import { e as u, c as i, z as S, t as f, s as G } from "../index-
|
|
4
|
-
import { s as v } from "../softmax-
|
|
5
|
-
import { m as z } from "../max-
|
|
6
|
-
import { l as k } from "../log_sum_exp-
|
|
3
|
+
import { e as u, c as i, z as S, t as f, s as G } from "../index-Du-bmOP8.js";
|
|
4
|
+
import { s as v } from "../softmax-DhWoBa7r.js";
|
|
5
|
+
import { m as z } from "../max-0Xnlpv8k.js";
|
|
6
|
+
import { l as k } from "../log_sum_exp-CxfBtUaG.js";
|
|
7
7
|
function F(a, s) {
|
|
8
8
|
return f(() => {
|
|
9
9
|
const e = a.shape[a.shape.length - 1], o = a.shape.slice(0, -1).reduce((d, c) => d * c, 1), p = a.shape.length > 2 ? a.reshape([o, e]) : a, n = s.shape.length > 1 ? s.reshape([o]).cast("int32") : s.cast("int32"), t = z(p, -1, !0), r = G(p, t), h = k(r, -1);
|
package/dist/utilities/dummy.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import "../index-
|
|
2
|
-
import { z as n } from "../zeros-
|
|
1
|
+
import "../index-Du-bmOP8.js";
|
|
2
|
+
import { z as n } from "../zeros-BaHhQTWf.js";
|
|
3
3
|
async function c(s) {
|
|
4
4
|
const i = n([1, s.config.gpt.blockSize], "int32"), [t, o] = s.forward({ training: !1 }, i);
|
|
5
5
|
await t.data(), t.dispose(), o && o.dispose(), i.dispose();
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { t as y } from "../index-
|
|
2
|
-
import { t as x } from "../tensor2d-
|
|
3
|
-
import { c as f } from "../concat-
|
|
1
|
+
import { t as y } from "../index-Du-bmOP8.js";
|
|
2
|
+
import { t as x } from "../tensor2d-CRWjDyUe.js";
|
|
3
|
+
import { c as f } from "../concat-DdKPyAtw.js";
|
|
4
4
|
async function A(o, r, a, c, T) {
|
|
5
5
|
if (c <= 0)
|
|
6
6
|
throw new Error("Length must be a positive integer");
|
package/dist/utilities/load.js
CHANGED
|
@@ -3,7 +3,7 @@ import { importWeights as b } from "./weights.js";
|
|
|
3
3
|
import u from "../tokeniser/CharTokeniser.js";
|
|
4
4
|
import F from "../NanoGPTModel.js";
|
|
5
5
|
import { dummyPassAsync as j } from "./dummy.js";
|
|
6
|
-
import { d as T } from "../index-
|
|
6
|
+
import { d as T } from "../index-Du-bmOP8.js";
|
|
7
7
|
import E from "../tokeniser/bpe.js";
|
|
8
8
|
async function A(t) {
|
|
9
9
|
const o = await fetch(t);
|