@genai-fi/nanogpt 0.5.4 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/dist/Generator.js +5 -5
  2. package/dist/NanoGPTModel.d.ts +2 -0
  3. package/dist/NanoGPTModel.js +8 -8
  4. package/dist/{Reshape-Bt_t7RNz.js → Reshape-Biok_3X1.js} +6 -6
  5. package/dist/TeachableLLM.js +1 -1
  6. package/dist/{TiedEmbedding-DORsPlNL.js → TiedEmbedding-8S8xn8e6.js} +5 -5
  7. package/dist/Trainer.d.ts +1 -0
  8. package/dist/Trainer.js +8 -7
  9. package/dist/{axis_util-CVbf1vmL.js → axis_util-BczFISHz.js} +1 -1
  10. package/dist/{broadcast_to-BBoMQXbL.js → broadcast_to-B7NGsBSh.js} +2 -2
  11. package/dist/{concat-BRRtq4S2.js → concat-DdKPyAtw.js} +1 -1
  12. package/dist/{dataset-ZHEPJmED.js → dataset-iqT4Otvb.js} +7 -7
  13. package/dist/{dropout-lQm_YyX3.js → dropout-B09InSJS.js} +1 -1
  14. package/dist/{gather-BWyutxwi.js → gather-D6MsdXqc.js} +1 -1
  15. package/dist/{gpgpu_math-Df7gzJWH.js → gpgpu_math-BFbOyvk4.js} +1 -1
  16. package/dist/{index-CnHyhpKc.js → index-Du-bmOP8.js} +98 -98
  17. package/dist/{kernel_funcs_utils-Dqo82NH4.js → kernel_funcs_utils-DShm7-0k.js} +33 -33
  18. package/dist/layers/BaseLayer.js +2 -2
  19. package/dist/layers/CausalSelfAttention.js +6 -6
  20. package/dist/layers/MLP.js +5 -5
  21. package/dist/layers/RMSNorm.js +3 -3
  22. package/dist/layers/RoPECache.js +3 -3
  23. package/dist/layers/TiedEmbedding.js +6 -6
  24. package/dist/layers/TransformerBlock.js +1 -1
  25. package/dist/{log_sum_exp-CRH7Np9v.js → log_sum_exp-CxfBtUaG.js} +5 -5
  26. package/dist/main.js +1 -1
  27. package/dist/{mat_mul-DeGU1U_C.js → mat_mul-CbiqIe2d.js} +1 -1
  28. package/dist/{max-CcnEArWK.js → max-0Xnlpv8k.js} +1 -1
  29. package/dist/{norm-BpWsOapl.js → norm-01kY9I2B.js} +5 -5
  30. package/dist/{ones-CDWGzVnm.js → ones-CrutWGas.js} +2 -2
  31. package/dist/ops/appendCache.js +3 -3
  32. package/dist/ops/attentionMask.js +1 -1
  33. package/dist/ops/cpu/appendCache.js +2 -2
  34. package/dist/ops/cpu/attentionMask.js +5 -5
  35. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  36. package/dist/ops/cpu/gatherSub.js +3 -3
  37. package/dist/ops/cpu/gelu.js +1 -1
  38. package/dist/ops/cpu/matMulGelu.js +1 -1
  39. package/dist/ops/cpu/matMulMul.js +1 -1
  40. package/dist/ops/cpu/mulDropout.js +1 -1
  41. package/dist/ops/cpu/normRMS.js +1 -1
  42. package/dist/ops/cpu/qkv.js +3 -3
  43. package/dist/ops/cpu/rope.js +5 -5
  44. package/dist/ops/cpu/scatterSub.js +4 -4
  45. package/dist/ops/fusedSoftmax.js +1 -1
  46. package/dist/ops/gatherSub.js +1 -1
  47. package/dist/ops/gelu.js +1 -1
  48. package/dist/ops/grads/attentionMask.js +1 -1
  49. package/dist/ops/grads/fusedSoftmax.js +2 -2
  50. package/dist/ops/grads/gelu.js +1 -1
  51. package/dist/ops/grads/matMulGelu.js +1 -1
  52. package/dist/ops/grads/normRMS.js +1 -1
  53. package/dist/ops/grads/qkv.js +1 -1
  54. package/dist/ops/grads/rope.js +1 -1
  55. package/dist/ops/matMulGelu.js +1 -1
  56. package/dist/ops/matMulMul.js +1 -1
  57. package/dist/ops/mulDrop.js +1 -1
  58. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  59. package/dist/ops/normRMS.js +1 -1
  60. package/dist/ops/qkv.js +1 -1
  61. package/dist/ops/scatterSub.js +1 -1
  62. package/dist/ops/webgl/appendCache.js +1 -1
  63. package/dist/ops/webgl/attentionMask.js +1 -1
  64. package/dist/ops/webgl/fusedSoftmax.js +96 -96
  65. package/dist/ops/webgl/gatherSub.js +1 -1
  66. package/dist/ops/webgl/gelu.js +2 -2
  67. package/dist/ops/webgl/matMulGelu.js +4 -4
  68. package/dist/ops/webgl/matMulMul.js +1 -1
  69. package/dist/ops/webgl/mulDropout.js +1 -1
  70. package/dist/ops/webgl/normRMS.js +2 -2
  71. package/dist/ops/webgl/qkv.js +1 -1
  72. package/dist/ops/webgl/rope.js +1 -1
  73. package/dist/ops/webgl/scatterSub.js +1 -1
  74. package/dist/{ops-DzQTmLIl.js → ops-CJNniCAV.js} +13 -13
  75. package/dist/{random_width-DI2h9CMs.js → random_width-C-v-35bY.js} +1324 -1279
  76. package/dist/{range-CkOJ7090.js → range-Bvs1hidm.js} +1 -1
  77. package/dist/{reshape-CTIbqjwm.js → reshape-BH7eBpwq.js} +1 -1
  78. package/dist/{sin-HzioENy_.js → sin-CPAZXNjH.js} +1 -1
  79. package/dist/{slice_util-n4wHKmex.js → slice_util-DskXqRZa.js} +1 -1
  80. package/dist/{softmax-DX6qXAbm.js → softmax-DhWoBa7r.js} +1 -1
  81. package/dist/{split-CVwhL8Oe.js → split-BCUhuU7B.js} +1 -1
  82. package/dist/{stack-S2-D2JAQ.js → stack-BV1v7l3S.js} +1 -1
  83. package/dist/{sum-UdfvaNhB.js → sum-Cvq06317.js} +1 -1
  84. package/dist/{tensor-IZex6Bwp.js → tensor-DgTOPY6h.js} +1 -1
  85. package/dist/{tensor2d-CqtBzOKq.js → tensor2d-CRWjDyUe.js} +1 -1
  86. package/dist/{tfjs_backend-DX9yVvwk.js → tfjs_backend-D9Ytje0G.js} +39 -39
  87. package/dist/training/AdamExt.js +1 -1
  88. package/dist/training/DatasetBuilder.js +2 -2
  89. package/dist/training/FullTrainer.js +36 -32
  90. package/dist/training/Trainer.d.ts +7 -4
  91. package/dist/training/Trainer.js +58 -50
  92. package/dist/training/sparseCrossEntropy.js +4 -4
  93. package/dist/utilities/dummy.js +2 -2
  94. package/dist/utilities/generate.js +3 -3
  95. package/dist/utilities/load.js +1 -1
  96. package/dist/utilities/profile.d.ts +1 -0
  97. package/dist/utilities/profile.js +6 -3
  98. package/dist/utilities/weights.js +2 -2
  99. package/dist/{variable-BGvK-VN3.js → variable-DZ3fF0R2.js} +1 -1
  100. package/dist/{zeros-CYMicyqz.js → zeros-BaHhQTWf.js} +1 -1
  101. package/package.json +1 -1
  102. package/dist/moments-DLTE6-1p.js +0 -53
@@ -1,4 +1,4 @@
1
- import { E as e, R as f } from "./index-CnHyhpKc.js";
1
+ import { E as e, R as f } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o, j as t, E as a, w as p } from "./index-CnHyhpKc.js";
1
+ import { o, j as t, E as a, w as p } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o, j as t, E as c, _ as a, $ as e } from "./index-CnHyhpKc.js";
1
+ import { o, j as t, E as c, _ as a, $ as e } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { l as s } from "./index-CnHyhpKc.js";
1
+ import { l as s } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2021 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as r, j as f, E as e, S as i } from "./index-CnHyhpKc.js";
1
+ import { o as r, j as f, E as e, S as i } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as p, j as i, E as a, x as c } from "./index-CnHyhpKc.js";
1
+ import { o as p, j as i, E as a, x as c } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as e, k as c, l as n, E as k, P as i } from "./index-CnHyhpKc.js";
1
+ import { o as e, k as c, l as n, E as k, P as i } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as e, j as u, D as c, E as l, F as m } from "./index-CnHyhpKc.js";
1
+ import { o as e, j as u, D as c, E as l, F as m } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { J as t, K as a } from "./index-CnHyhpKc.js";
1
+ import { J as t, K as a } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { I as t, J as s, K as a } from "./index-CnHyhpKc.js";
1
+ import { I as t, J as s, K as a } from "./index-Du-bmOP8.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,11 +1,11 @@
1
- import { o as h, j as f, E as $, ap as Te, l as _, g as Ee, aq as xe, ar as Ie, as as Le, at as be, au as Ne, av as Ce, aw as Pe, b as H, ax as Fe, a9 as U, u as ae, q as ie, Q as le, c as fe, ay as he, aj as pe, az as je, t as S, D as $e, am as Me, a4 as Be } from "./index-CnHyhpKc.js";
2
- import { s as C, t as Ke, a as Ue, b as ve } from "./ops-DzQTmLIl.js";
3
- import { r as Re, d as Ve } from "./dropout-lQm_YyX3.js";
4
- import { r as u } from "./reshape-CTIbqjwm.js";
5
- import { g as qe } from "./gather-BWyutxwi.js";
6
- import { s as Ge } from "./sum-UdfvaNhB.js";
7
- import { m as A } from "./mat_mul-DeGU1U_C.js";
8
- import { c as M } from "./concat-BRRtq4S2.js";
1
+ import { o as h, j as f, E as $, ao as Te, l as _, g as Ee, ap as xe, aq as Ie, ar as Le, as as be, at as Ne, au as Ce, av as Pe, b as H, aw as Fe, a8 as U, u as ae, q as ie, Q as le, c as fe, ax as he, ai as pe, ay as je, t as S, D as $e, al as Me, a2 as Be } from "./index-Du-bmOP8.js";
2
+ import { s as C, t as Ke, a as Ue, b as ve } from "./ops-CJNniCAV.js";
3
+ import { r as Re, d as Ve } from "./dropout-B09InSJS.js";
4
+ import { r as u } from "./reshape-BH7eBpwq.js";
5
+ import { g as qe } from "./gather-D6MsdXqc.js";
6
+ import { s as Ge } from "./sum-Cvq06317.js";
7
+ import { m as A } from "./mat_mul-CbiqIe2d.js";
8
+ import { c as M } from "./concat-DdKPyAtw.js";
9
9
  /**
10
10
  * @license
11
11
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -213,11 +213,11 @@ const X = /* @__PURE__ */ h({ slice1d_: dn });
213
213
  * limitations under the License.
214
214
  * =============================================================================
215
215
  */
216
- function mn(e, n, t) {
216
+ function gn(e, n, t) {
217
217
  const r = f(e, "x", "slice2d");
218
218
  return _(r.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${r.rank} tensor`), C(r, n, t);
219
219
  }
220
- const we = /* @__PURE__ */ h({ slice2d_: mn });
220
+ const we = /* @__PURE__ */ h({ slice2d_: gn });
221
221
  /**
222
222
  * @license
223
223
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -234,11 +234,11 @@ const we = /* @__PURE__ */ h({ slice2d_: mn });
234
234
  * limitations under the License.
235
235
  * =============================================================================
236
236
  */
237
- function gn(e, n, t) {
237
+ function mn(e, n, t) {
238
238
  const r = f(e, "x", "slice3d");
239
239
  return _(r.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${r.rank} tensor`), C(r, n, t);
240
240
  }
241
- const z = /* @__PURE__ */ h({ slice3d_: gn });
241
+ const z = /* @__PURE__ */ h({ slice3d_: mn });
242
242
  /**
243
243
  * @license
244
244
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -350,9 +350,9 @@ function _n({ a: e, b: n, transposeA: t = !1, transposeB: r = !1, bias: s, activ
350
350
  }
351
351
  let c = f(e, "a", "fused matMul"), a = f(n, "b", "fused matMul");
352
352
  [c, a] = ae(c, a);
353
- const k = t ? c.shape[c.rank - 2] : c.shape[c.rank - 1], g = r ? a.shape[a.rank - 1] : a.shape[a.rank - 2], E = t ? c.shape[c.rank - 1] : c.shape[c.rank - 2], d = r ? a.shape[a.rank - 2] : a.shape[a.rank - 1], ne = c.shape.slice(0, -2), x = a.shape.slice(0, -2), te = ie(ne), re = ie(x);
354
- _(k === g, () => `Error in fused matMul: inner shapes (${k}) and (${g}) of Tensors with shapes ${c.shape} and ${a.shape} and transposeA=${t} and transposeB=${r} must match.`);
355
- const R = le(c.shape.slice(0, -2), a.shape.slice(0, -2)).concat([E, d]), V = t ? u(c, [te, k, E]) : u(c, [te, E, k]), q = r ? u(a, [re, d, g]) : u(a, [re, g, d]);
353
+ const k = t ? c.shape[c.rank - 2] : c.shape[c.rank - 1], m = r ? a.shape[a.rank - 1] : a.shape[a.rank - 2], E = t ? c.shape[c.rank - 1] : c.shape[c.rank - 2], d = r ? a.shape[a.rank - 2] : a.shape[a.rank - 1], ne = c.shape.slice(0, -2), x = a.shape.slice(0, -2), te = ie(ne), re = ie(x);
354
+ _(k === m, () => `Error in fused matMul: inner shapes (${k}) and (${m}) of Tensors with shapes ${c.shape} and ${a.shape} and transposeA=${t} and transposeB=${r} must match.`);
355
+ const R = le(c.shape.slice(0, -2), a.shape.slice(0, -2)).concat([E, d]), V = t ? u(c, [te, k, E]) : u(c, [te, E, k]), q = r ? u(a, [re, d, m]) : u(a, [re, m, d]);
356
356
  let I;
357
357
  s != null && (I = f(s, "bias", "fused matMul"), [I] = ae(I, c), le(R, I.shape));
358
358
  let se;
@@ -450,7 +450,7 @@ function Jn(e, n) {
450
450
  return t.fill(e), t;
451
451
  }
452
452
  }
453
- function me(e, n) {
453
+ function ge(e, n) {
454
454
  if (!e)
455
455
  throw new ee(n);
456
456
  }
@@ -473,7 +473,7 @@ function Qn(e) {
473
473
  function Hn(e) {
474
474
  return e.length <= 1 || e.indexOf("_") === -1 ? e : e.replace(/[_]+(\w|$)/g, (n, t) => t.toUpperCase());
475
475
  }
476
- let m = {};
476
+ let g = {};
477
477
  function Xn(e) {
478
478
  if (e == null)
479
479
  return null;
@@ -498,8 +498,8 @@ function zn(e, n = {}, t = {}, r = "object", s = !1) {
498
498
  let i;
499
499
  if (o in t)
500
500
  i = t[o];
501
- else if (o in m)
502
- i = m[o];
501
+ else if (o in g)
502
+ i = g[o];
503
503
  else if (i = n[o], i == null)
504
504
  throw new l(`Unknown ${r}: ${e}. This may be due to one of the following reasons:
505
505
  1. The ${r} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
@@ -512,30 +512,30 @@ function zn(e, n = {}, t = {}, r = "object", s = !1) {
512
512
  'className' and 'config' must set.`);
513
513
  const i = o.className;
514
514
  let p, c;
515
- if (i in t ? [p, c] = t[i] : i in m ? [p, c] = m.className : i in n && ([p, c] = n[i]), p == null)
515
+ if (i in t ? [p, c] = t[i] : i in g ? [p, c] = g.className : i in n && ([p, c] = n[i]), p == null)
516
516
  throw new l(`Unknown ${r}: ${i}. This may be due to one of the following reasons:
517
517
  1. The ${r} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
518
518
  2. The custom ${r} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`);
519
519
  if (c != null) {
520
520
  const a = {};
521
- for (const d of Object.keys(m))
522
- a[d] = m[d];
521
+ for (const d of Object.keys(g))
522
+ a[d] = g[d];
523
523
  for (const d of Object.keys(t))
524
524
  a[d] = t[d];
525
525
  const k = o.config;
526
526
  k.customObjects = a;
527
- const g = Object.assign({}, m);
527
+ const m = Object.assign({}, g);
528
528
  for (const d of Object.keys(t))
529
- m[d] = t[d];
529
+ g[d] = t[d];
530
530
  W(o.config);
531
531
  const E = c(p, o.config, t, s);
532
- return m = Object.assign({}, g), E;
532
+ return g = Object.assign({}, m), E;
533
533
  } else {
534
- const a = Object.assign({}, m);
535
- for (const g of Object.keys(t))
536
- m[g] = t[g];
534
+ const a = Object.assign({}, g);
535
+ for (const m of Object.keys(t))
536
+ g[m] = t[m];
537
537
  const k = new p(o.config);
538
- return m = Object.assign({}, a), k;
538
+ return g = Object.assign({}, a), k;
539
539
  }
540
540
  }
541
541
  }
@@ -566,7 +566,7 @@ function v(e, n, t) {
566
566
  throw new l(`${t} is not a valid ${n}. Valid values are ${e} or null/undefined.`);
567
567
  }
568
568
  function rt(e, n, t = 0, r = 1 / 0) {
569
- return me(t >= 0), me(r >= t), Array.isArray(e) && e.length >= t && e.length <= r && e.every((s) => typeof s === n);
569
+ return ge(t >= 0), ge(r >= t), Array.isArray(e) && e.length >= t && e.length <= r && e.every((s) => typeof s === n);
570
570
  }
571
571
  function Ln(e, n) {
572
572
  Array.isArray(e) ? (_(e.length > 0, () => `${n} is unexpectedly an empty array.`), e.forEach((t, r) => Ln(t, `element ${r + 1} of ${n}`))) : _(Number.isInteger(e) && e > 0, () => `Expected ${n} to be a positive integer, but got ${ye(e)}.`);
@@ -606,7 +606,7 @@ function ct(e) {
606
606
  function at(e) {
607
607
  v(xn, "PoolMode", e);
608
608
  }
609
- const F = [], ge = "/";
609
+ const F = [], me = "/";
610
610
  function it(e, n) {
611
611
  F.push(e);
612
612
  try {
@@ -617,7 +617,7 @@ function it(e, n) {
617
617
  }
618
618
  }
619
619
  function Nn() {
620
- return F.length === 0 ? "" : F.join(ge) + ge;
620
+ return F.length === 0 ? "" : F.join(me) + me;
621
621
  }
622
622
  function lt(e) {
623
623
  if (!Oe(e))
@@ -678,7 +678,7 @@ function dt(e) {
678
678
  }
679
679
  return n;
680
680
  }
681
- function mt(e, n) {
681
+ function gt(e, n) {
682
682
  if (n < e)
683
683
  throw new l(`end (${n}) < begin (${e}) is forbidden.`);
684
684
  const t = [];
@@ -696,7 +696,7 @@ function mt(e, n) {
696
696
  * =============================================================================
697
697
  */
698
698
  let G;
699
- function gt() {
699
+ function mt() {
700
700
  return G == null && (G = je().epsilon()), G;
701
701
  }
702
702
  function Y() {
@@ -876,7 +876,7 @@ function Dt(e, n, t, r) {
876
876
  e = u(e, [-1, o]);
877
877
  const i = n.shape.slice(), p = i.pop(), c = i.pop(), a = [...i, p], k = Array.from({ length: n.rank }, (ne, x) => x === 0 ? n.rank - 2 : x <= n.rank - 2 ? x - 1 : x);
878
878
  n = u(ve(n, k), [c, -1]);
879
- const g = [...s, ...a];
879
+ const m = [...s, ...a];
880
880
  return u(de({
881
881
  a: e,
882
882
  b: n,
@@ -884,7 +884,7 @@ function Dt(e, n, t, r) {
884
884
  transposeB: !1,
885
885
  bias: r ? Q(e.rank, r, Y()) : null,
886
886
  activation: t
887
- }), g);
887
+ }), m);
888
888
  }
889
889
  }
890
890
  function Tt(e, n, t) {
@@ -951,7 +951,7 @@ export {
951
951
  J as H,
952
952
  Pn as I,
953
953
  Tt as J,
954
- mt as K,
954
+ gt as K,
955
955
  Zn as L,
956
956
  It as M,
957
957
  j as N,
@@ -1001,10 +1001,10 @@ export {
1001
1001
  _t as r,
1002
1002
  On as s,
1003
1003
  Qn as t,
1004
- gt as u,
1004
+ mt as u,
1005
1005
  fn as v,
1006
1006
  st as w,
1007
1007
  wt as x,
1008
1008
  Et as y,
1009
- me as z
1009
+ ge as z
1010
1010
  };
@@ -1,4 +1,4 @@
1
- import { A as r, b as c, f as h, s as g, e as o } from "../index-CnHyhpKc.js";
1
+ import { A as r, b as c, f as h, s as g, e as o } from "../index-Du-bmOP8.js";
2
2
  class u extends r {
3
3
  constructor(t, e, s, a, i) {
4
4
  super(t, e, s, a), this.config = i, this.startLearningRate = t;
@@ -1,5 +1,5 @@
1
- import { t as u } from "../index-CnHyhpKc.js";
2
- import { d as z, i as f } from "../dataset-ZHEPJmED.js";
1
+ import { t as u } from "../index-Du-bmOP8.js";
2
+ import { d as z, i as f } from "../dataset-iqT4Otvb.js";
3
3
  import "../index-Tf7vU29b.js";
4
4
  /**
5
5
  * @license
@@ -1,21 +1,22 @@
1
- import { generateText as v } from "../utilities/generate.js";
1
+ import { generateText as T } from "../utilities/generate.js";
2
2
  import L from "./Trainer.js";
3
3
  import x from "./Evaluator.js";
4
- import { a as h } from "../index-CnHyhpKc.js";
4
+ import { a as h } from "../index-Du-bmOP8.js";
5
+ import y from "../utilities/profile.js";
5
6
  const D = {
6
7
  desiredLoss: 0.01,
7
8
  logInterval: 1,
8
9
  maxSteps: 1e3
9
10
  };
10
- class E extends L {
11
- constructor(r, i, o = 3e-4) {
12
- super(r, i, o);
11
+ class I extends L {
12
+ constructor(i, e, o = 3e-4) {
13
+ super(i, e, o);
13
14
  }
14
15
  // Train for multiple epochs using Dataset API - FIXED memory leaks
15
- async trainOnDataset(r, i, o) {
16
- const { desiredLoss: u, logInterval: d, onStep: l, prompt: c, maxSteps: g } = {
16
+ async trainOnDataset(i, e, o) {
17
+ const { desiredLoss: p, logInterval: g, onStep: l, prompt: c, maxSteps: u } = {
17
18
  ...D,
18
- ...i
19
+ ...e
19
20
  }, n = Date.now(), t = {
20
21
  step: 0,
21
22
  lastLoss: 1e6,
@@ -26,52 +27,55 @@ class E extends L {
26
27
  trainingDuration: 0,
27
28
  ...this.lastState || {}
28
29
  };
29
- this.lastState = t, this.dummyPass(), this.model.trainable = !0, this.running = !0, t.logStartTime = n;
30
- const m = o ? new x(this.model, o) : void 0, S = await r.iterator();
30
+ this.lastState = t, this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, t.logStartTime = n;
31
+ const m = o ? new x(this.model, o) : void 0, f = await i.iterator();
31
32
  try {
32
- for (; this.running && !(t.lastLoss < u); ) {
33
- const a = await S.next();
34
- if (a.done) break;
35
- const p = a.value, f = this.trainBatch(t, p), s = {
33
+ for (; this.running && !(t.lastLoss < p); ) {
34
+ const r = await f.next();
35
+ if (r.done) break;
36
+ const d = r.value, v = this.trainBatch(t, d, e.advancedMetrics || !1), s = {
36
37
  loss: t.lastLoss,
37
38
  step: t.step,
38
39
  time: Date.now() - n,
39
- batchSize: p.xs.shape[0]
40
+ batchSize: d.xs.shape[0],
41
+ learningRate: e?.advancedMetrics ? this.optimizer.lr : void 0,
42
+ gradientNorm: e?.advancedMetrics ? t.gradientNorm : void 0
40
43
  };
41
- if (this.model.log.push(s), t.step % d === 0) {
42
- await f;
43
- const w = Date.now();
44
- if (t.trainingDuration += w - t.logStartTime, m)
44
+ if (this.model.log.push(s), t.step % g === 0) {
45
+ await v;
46
+ const S = Date.now();
47
+ if (t.trainingDuration += S - t.logStartTime, m)
45
48
  try {
46
- const e = await m.evaluate(5);
47
- t.validationLosses.push(e), s.valLoss = e;
48
- } catch (e) {
49
- console.error("Validation error:", e);
49
+ const a = await m.evaluate(5);
50
+ t.validationLosses.push(a), s.valLoss = a;
51
+ } catch (a) {
52
+ console.error("Validation error:", a);
50
53
  }
51
54
  if (l) {
52
55
  if (c) {
53
- const T = await v(this.tokenizer, this.model, c, 100, {
56
+ const w = await T(this.tokenizer, this.model, c, 100, {
54
57
  temperature: 0.8
55
58
  });
56
- s.example = T;
59
+ s.example = w;
57
60
  }
58
- const e = {
61
+ const a = {
59
62
  duration: t.trainingDuration,
60
63
  totalSamples: t.totalSteps * s.batchSize,
61
- samplesPerSecond: t.totalSteps * s.batchSize / (t.trainingDuration / 1e3)
64
+ samplesPerSecond: t.totalSteps * s.batchSize / (t.trainingDuration / 1e3),
65
+ memory: e.advancedMetrics ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
62
66
  };
63
- await l(s, e);
67
+ await l(s, a);
64
68
  }
65
69
  t.logStartTime = Date.now();
66
70
  }
67
- t.step >= g && this.stop();
71
+ t.step >= u && this.stop();
68
72
  }
69
- } catch (a) {
70
- throw console.error("Training error:", a), h(), a;
73
+ } catch (r) {
74
+ throw console.error("Training error:", r), h(), r;
71
75
  }
72
76
  return h(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
73
77
  }
74
78
  }
75
79
  export {
76
- E as default
80
+ I as default
77
81
  };
@@ -11,11 +11,13 @@ export interface TrainingState {
11
11
  totalSteps: number;
12
12
  losses: number[];
13
13
  validationLosses: number[];
14
+ gradientNorm?: number;
14
15
  }
15
16
  export interface TrainingProgress {
16
17
  duration: number;
17
18
  totalSamples: number;
18
19
  samplesPerSecond: number;
20
+ memory?: number;
19
21
  }
20
22
  export interface AdamConfig {
21
23
  learningRateFactor: number;
@@ -28,6 +30,7 @@ export interface TrainingOptions {
28
30
  logInterval: number;
29
31
  prompt?: string;
30
32
  maxSteps: number;
33
+ advancedMetrics?: boolean;
31
34
  onStep?: (log: TrainingLogEntry, progress: TrainingProgress) => Promise<void> | void;
32
35
  }
33
36
  export default abstract class GPTTrainer {
@@ -44,16 +47,16 @@ export default abstract class GPTTrainer {
44
47
  stop(): void;
45
48
  getOptimizer(): AdamExt;
46
49
  resetOptimizer(config?: AdamConfig): void;
47
- private printGradients;
48
- protected trainStep(batch: {
50
+ private maxGradNorm;
51
+ protected trainStep(state: Partial<TrainingState>, batch: {
49
52
  xs: Tensor;
50
53
  ys: Tensor;
51
- }, dummy?: boolean, print?: boolean): Scalar;
54
+ }, dummy?: boolean, calcNorm?: boolean): Scalar;
52
55
  protected dummyPass(): void;
53
56
  protected trainBatch(state: TrainingState, batch: {
54
57
  xs: Tensor;
55
58
  ys: Tensor;
56
- }): Promise<number>;
59
+ }, calcNorm?: boolean): Promise<number>;
57
60
  abstract trainOnDataset(dataset: Dataset<{
58
61
  xs: Tensor;
59
62
  ys: Tensor;
@@ -1,13 +1,11 @@
1
- import { DatasetBuilder as h, flattenTokens as d, PAGE_FACTOR as g } from "./DatasetBuilder.js";
2
- import u from "./AdamExt.js";
3
- import { t as f, v as y, a as m } from "../index-CnHyhpKc.js";
4
- import { m as S, n as z } from "../norm-BpWsOapl.js";
5
- import { m as w, a as T } from "../moments-DLTE6-1p.js";
6
- import { m as x } from "../max-CcnEArWK.js";
7
- import { z as p } from "../zeros-CYMicyqz.js";
8
- class G {
9
- constructor(t, s, e = 1e-3) {
10
- this.tokenizer = s, this.model = t, this.learningRate = e, this.resetOptimizer(), this.datasetBuilder = new h(s, t.config.gpt.blockSize);
1
+ import { DatasetBuilder as g, flattenTokens as m, PAGE_FACTOR as u } from "./DatasetBuilder.js";
2
+ import f from "./AdamExt.js";
3
+ import { t as y, v as z, a as c } from "../index-Du-bmOP8.js";
4
+ import { n as S } from "../norm-01kY9I2B.js";
5
+ import { z as p } from "../zeros-BaHhQTWf.js";
6
+ class R {
7
+ constructor(t, e, s = 1e-3) {
8
+ this.tokenizer = e, this.model = t, this.learningRate = s, this.resetOptimizer(), this.datasetBuilder = new g(e, t.config.gpt.blockSize);
11
9
  }
12
10
  model;
13
11
  optimizer;
@@ -29,7 +27,7 @@ class G {
29
27
  }
30
28
  resetOptimizer(t = { learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 }) {
31
29
  this.optimizer && this.optimizer.dispose();
32
- const s = new u(
30
+ const e = new f(
33
31
  t.learningRateFactor * this.learningRate,
34
32
  t.beta1,
35
33
  t.beta2,
@@ -41,68 +39,78 @@ class G {
41
39
  weightDecay: 0
42
40
  }
43
41
  );
44
- this.optimizer = s;
42
+ this.optimizer = e;
45
43
  }
46
- printGradients(t) {
47
- Object.keys(t).forEach((s) => {
48
- const e = t[s];
49
- console.log(`${s}:`), console.log(` Shape: ${e.shape}`), console.log(` Mean: ${w(e).dataSync()[0]}`), console.log(` Std: ${T(e).variance.sqrt().dataSync()[0]}`), console.log(` Min: ${S(e).dataSync()[0]}`), console.log(` Max: ${x(e).dataSync()[0]}`), console.log(` Norm: ${z(e).dataSync()[0]}`);
50
- });
44
+ maxGradNorm(t) {
45
+ let e = 0;
46
+ return Object.keys(t).forEach((s) => {
47
+ const a = t[s], r = S(a), i = r.dataSync()[0];
48
+ r.dispose(), i > e && (e = i);
49
+ }), e;
51
50
  }
52
- trainStep(t, s = !1, e = !1) {
53
- return f(() => {
51
+ trainStep(t, e, s = !1, a = !1) {
52
+ return y(() => {
54
53
  this.model.getProfiler()?.startMemory();
55
- const { xs: a, ys: i } = t, o = () => {
56
- const [l, c] = this.model.forward({ training: !0 }, a, i);
57
- return l.dispose(), c;
58
- }, { value: n, grads: r } = y(o);
59
- return s ? this.model.getProfiler()?.endMemory("Training") : (e && (console.log("-------"), this.printGradients(r), console.log("-------")), this.optimizer.applyGradients(r), this.model.getProfiler()?.endMemory("Training"), m(r)), n;
54
+ const { xs: r, ys: i } = e, d = () => {
55
+ const [n, h] = this.model.forward({ training: !0 }, r, i);
56
+ return n.dispose(), h;
57
+ }, { value: l, grads: o } = z(d);
58
+ if (s)
59
+ this.model.getProfiler()?.endMemory("Training");
60
+ else {
61
+ if (a) {
62
+ const n = this.maxGradNorm(o);
63
+ t.gradientNorm = n;
64
+ }
65
+ this.optimizer.applyGradients(o), this.model.getProfiler()?.endMemory("Training"), c(o);
66
+ }
67
+ return l;
60
68
  });
61
69
  }
62
70
  dummyPass() {
63
- const t = p([1, this.model.config.gpt.blockSize], "int32"), s = p([1, this.model.config.gpt.blockSize], "int32");
71
+ const t = p([1, this.model.config.gpt.blockSize], "int32"), e = p([1, this.model.config.gpt.blockSize], "int32");
64
72
  try {
65
- const e = this.trainStep({ xs: t, ys: s }, !0);
66
- e.dataSync(), e.dispose();
67
- } catch (e) {
68
- console.error("Error during dummy pass:", e);
73
+ const s = this.trainStep({}, { xs: t, ys: e }, !0);
74
+ s.dataSync(), s.dispose();
75
+ } catch (s) {
76
+ console.error("Error during dummy pass:", s);
69
77
  } finally {
70
- t.dispose(), s.dispose();
78
+ t.dispose(), e.dispose();
71
79
  }
72
80
  }
73
- async trainBatch(t, s) {
81
+ async trainBatch(t, e, s = !1) {
74
82
  try {
75
- const e = this.trainStep(s, !1, !1);
76
- return s.xs.dispose(), s.ys.dispose(), t.step++, t.totalSteps++, e.array().then((a) => (t.lastLoss = a, t.losses.push(t.lastLoss), e.dispose(), t.lastLoss));
77
- } catch (e) {
78
- throw console.error(`Error processing batch at step ${t.step}:`, e), m(), e;
83
+ const a = this.trainStep(t, e, !1, s);
84
+ return e.xs.dispose(), e.ys.dispose(), t.step++, t.totalSteps++, a.array().then((r) => (t.lastLoss = r, t.losses.push(t.lastLoss), a.dispose(), t.lastLoss));
85
+ } catch (a) {
86
+ throw console.error(`Error processing batch at step ${t.step}:`, a), c(), a;
79
87
  }
80
88
  }
81
- async createTrainValidationSplit(t, s = 32, e = 0.1) {
82
- const a = await d(t, this.tokenizer), i = /* @__PURE__ */ new Set();
83
- if (e > 0) {
84
- const r = Math.floor(a.length / (this.datasetBuilder.blockSize * g)), l = Math.max(1, Math.floor(r * e));
85
- for (; i.size < l; ) {
86
- const c = Math.floor(Math.random() * r);
87
- i.add(c);
89
+ async createTrainValidationSplit(t, e = 32, s = 0.1) {
90
+ const a = await m(t, this.tokenizer), r = /* @__PURE__ */ new Set();
91
+ if (s > 0) {
92
+ const l = Math.floor(a.length / (this.datasetBuilder.blockSize * u)), o = Math.max(1, Math.floor(l * s));
93
+ for (; r.size < o; ) {
94
+ const n = Math.floor(Math.random() * l);
95
+ r.add(n);
88
96
  }
89
97
  }
90
- const o = await this.datasetBuilder.createTextDataset(a, s, i, !1), n = await this.datasetBuilder.createTextDataset(
98
+ const i = await this.datasetBuilder.createTextDataset(a, e, r, !1), d = await this.datasetBuilder.createTextDataset(
91
99
  a,
92
- s,
93
- i,
100
+ e,
101
+ r,
94
102
  !0
95
103
  );
96
- return { trainDataset: o, validationDataset: n };
104
+ return { trainDataset: i, validationDataset: d };
97
105
  }
98
- async createDataset(t, s = 32) {
99
- const e = await d(t, this.tokenizer);
100
- return await this.datasetBuilder.createTextDataset(e, s);
106
+ async createDataset(t, e = 32) {
107
+ const s = await m(t, this.tokenizer);
108
+ return await this.datasetBuilder.createTextDataset(s, e);
101
109
  }
102
110
  dispose() {
103
111
  this.optimizer && this.optimizer.dispose();
104
112
  }
105
113
  }
106
114
  export {
107
- G as default
115
+ R as default
108
116
  };
@@ -1,9 +1,9 @@
1
1
  import { gatherSub as L } from "../ops/gatherSub.js";
2
2
  import { scatterSub as y } from "../ops/scatterSub.js";
3
- import { e as u, c as i, z as S, t as f, s as G } from "../index-CnHyhpKc.js";
4
- import { s as v } from "../softmax-DX6qXAbm.js";
5
- import { m as z } from "../max-CcnEArWK.js";
6
- import { l as k } from "../log_sum_exp-CRH7Np9v.js";
3
+ import { e as u, c as i, z as S, t as f, s as G } from "../index-Du-bmOP8.js";
4
+ import { s as v } from "../softmax-DhWoBa7r.js";
5
+ import { m as z } from "../max-0Xnlpv8k.js";
6
+ import { l as k } from "../log_sum_exp-CxfBtUaG.js";
7
7
  function F(a, s) {
8
8
  return f(() => {
9
9
  const e = a.shape[a.shape.length - 1], o = a.shape.slice(0, -1).reduce((d, c) => d * c, 1), p = a.shape.length > 2 ? a.reshape([o, e]) : a, n = s.shape.length > 1 ? s.reshape([o]).cast("int32") : s.cast("int32"), t = z(p, -1, !0), r = G(p, t), h = k(r, -1);
@@ -1,5 +1,5 @@
1
- import "../index-CnHyhpKc.js";
2
- import { z as n } from "../zeros-CYMicyqz.js";
1
+ import "../index-Du-bmOP8.js";
2
+ import { z as n } from "../zeros-BaHhQTWf.js";
3
3
  async function c(s) {
4
4
  const i = n([1, s.config.gpt.blockSize], "int32"), [t, o] = s.forward({ training: !1 }, i);
5
5
  await t.data(), t.dispose(), o && o.dispose(), i.dispose();
@@ -1,6 +1,6 @@
1
- import { t as y } from "../index-CnHyhpKc.js";
2
- import { t as x } from "../tensor2d-CqtBzOKq.js";
3
- import { c as f } from "../concat-BRRtq4S2.js";
1
+ import { t as y } from "../index-Du-bmOP8.js";
2
+ import { t as x } from "../tensor2d-CRWjDyUe.js";
3
+ import { c as f } from "../concat-DdKPyAtw.js";
4
4
  async function A(o, r, a, c, T) {
5
5
  if (c <= 0)
6
6
  throw new Error("Length must be a positive integer");
@@ -3,7 +3,7 @@ import { importWeights as b } from "./weights.js";
3
3
  import u from "../tokeniser/CharTokeniser.js";
4
4
  import F from "../NanoGPTModel.js";
5
5
  import { dummyPassAsync as j } from "./dummy.js";
6
- import { d as T } from "../index-CnHyhpKc.js";
6
+ import { d as T } from "../index-Du-bmOP8.js";
7
7
  import E from "../tokeniser/bpe.js";
8
8
  async function A(t) {
9
9
  const o = await fetch(t);