@genai-fi/nanogpt 0.5.6 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. package/dist/Generator.js +8 -7
  2. package/dist/NanoGPTModel.js +8 -8
  3. package/dist/{Reshape-Biok_3X1.js → Reshape-CLOrdpve.js} +2 -2
  4. package/dist/TeachableLLM.js +16 -15
  5. package/dist/{TiedEmbedding-8S8xn8e6.js → TiedEmbedding-BhxWO8QR.js} +5 -5
  6. package/dist/{axis_util-BczFISHz.js → axis_util-D17qZRQm.js} +1 -1
  7. package/dist/{broadcast_to-B7NGsBSh.js → broadcast_to-BMQLjvt_.js} +2 -2
  8. package/dist/{concat-DdKPyAtw.js → concat-DhZfF1GY.js} +1 -1
  9. package/dist/{dataset-iqT4Otvb.js → dataset-oilnemHf.js} +3 -3
  10. package/dist/{dropout-B09InSJS.js → dropout-CrMQPCeG.js} +1 -1
  11. package/dist/{gather-D6MsdXqc.js → gather-DZCMHZuN.js} +1 -1
  12. package/dist/{gpgpu_math-BFbOyvk4.js → gpgpu_math-Ctc31slO.js} +1 -1
  13. package/dist/{index-Du-bmOP8.js → index-bMBtI-WR.js} +50 -50
  14. package/dist/{kernel_funcs_utils-DShm7-0k.js → kernel_funcs_utils-CNmjLWnB.js} +26 -24
  15. package/dist/layers/BaseLayer.js +2 -2
  16. package/dist/layers/CausalSelfAttention.js +6 -6
  17. package/dist/layers/MLP.js +5 -5
  18. package/dist/layers/RMSNorm.js +3 -3
  19. package/dist/layers/RoPECache.js +3 -3
  20. package/dist/layers/TiedEmbedding.js +6 -6
  21. package/dist/layers/TransformerBlock.js +1 -1
  22. package/dist/{log_sum_exp-CxfBtUaG.js → log_sum_exp-BHdkCb4s.js} +5 -5
  23. package/dist/main.js +20 -19
  24. package/dist/{mat_mul-CbiqIe2d.js → mat_mul-BsrLfy81.js} +1 -1
  25. package/dist/{max-0Xnlpv8k.js → max-DechV4Bc.js} +1 -1
  26. package/dist/{norm-01kY9I2B.js → norm-B9hWHZH1.js} +5 -5
  27. package/dist/{ones-CrutWGas.js → ones-g0K8jVwm.js} +2 -2
  28. package/dist/ops/appendCache.js +3 -3
  29. package/dist/ops/attentionMask.js +1 -1
  30. package/dist/ops/cpu/appendCache.js +2 -2
  31. package/dist/ops/cpu/attentionMask.js +5 -5
  32. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  33. package/dist/ops/cpu/gatherSub.js +3 -3
  34. package/dist/ops/cpu/gelu.js +1 -1
  35. package/dist/ops/cpu/matMulGelu.js +1 -1
  36. package/dist/ops/cpu/matMulMul.js +1 -1
  37. package/dist/ops/cpu/mulDropout.js +1 -1
  38. package/dist/ops/cpu/normRMS.js +1 -1
  39. package/dist/ops/cpu/qkv.js +3 -3
  40. package/dist/ops/cpu/rope.js +5 -5
  41. package/dist/ops/cpu/scatterSub.js +4 -4
  42. package/dist/ops/fusedSoftmax.js +1 -1
  43. package/dist/ops/gatherSub.js +1 -1
  44. package/dist/ops/gelu.js +1 -1
  45. package/dist/ops/grads/attentionMask.js +15 -11
  46. package/dist/ops/grads/fusedSoftmax.js +12 -10
  47. package/dist/ops/grads/gelu.js +1 -1
  48. package/dist/ops/grads/matMulGelu.js +1 -1
  49. package/dist/ops/grads/normRMS.js +1 -1
  50. package/dist/ops/grads/qkv.js +1 -1
  51. package/dist/ops/grads/rope.js +1 -1
  52. package/dist/ops/log.d.ts +0 -0
  53. package/dist/ops/log.js +1 -0
  54. package/dist/ops/matMulGelu.js +1 -1
  55. package/dist/ops/matMulMul.js +1 -1
  56. package/dist/ops/mulDrop.js +1 -1
  57. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  58. package/dist/ops/normRMS.js +1 -1
  59. package/dist/ops/qkv.js +1 -1
  60. package/dist/ops/scatterSub.js +1 -1
  61. package/dist/ops/webgl/appendCache.js +1 -1
  62. package/dist/ops/webgl/attentionMask.js +1 -1
  63. package/dist/ops/webgl/fusedSoftmax.js +205 -3022
  64. package/dist/ops/webgl/gatherSub.js +1 -1
  65. package/dist/ops/webgl/gelu.js +2 -2
  66. package/dist/ops/webgl/log.d.ts +17 -0
  67. package/dist/ops/webgl/log.js +39 -0
  68. package/dist/ops/webgl/matMulGelu.js +4 -4
  69. package/dist/ops/webgl/matMulMul.js +1 -1
  70. package/dist/ops/webgl/mulDropout.js +1 -1
  71. package/dist/ops/webgl/normRMS.js +2 -2
  72. package/dist/ops/webgl/qkv.js +1 -1
  73. package/dist/ops/webgl/rope.js +1 -1
  74. package/dist/ops/webgl/scatterSub.js +1 -1
  75. package/dist/{ops-CJNniCAV.js → ops-Mv7Ta72x.js} +13 -13
  76. package/dist/{random_width-C-v-35bY.js → random_width-BBAWzDym.js} +23 -23
  77. package/dist/{range-Bvs1hidm.js → range-DMaG9A3G.js} +1 -1
  78. package/dist/{reshape-BH7eBpwq.js → reshape-T4yDEqoF.js} +1 -1
  79. package/dist/shared-XNAoXhOa.js +2826 -0
  80. package/dist/{sin-CPAZXNjH.js → sin-EEhbrRO_.js} +1 -1
  81. package/dist/{slice_util-DskXqRZa.js → slice_util-Ddk0uxGJ.js} +1 -1
  82. package/dist/{softmax-DhWoBa7r.js → softmax-B2_IKPDR.js} +1 -1
  83. package/dist/{split-BCUhuU7B.js → split-dcks18H1.js} +1 -1
  84. package/dist/{stack-BV1v7l3S.js → stack-lpJ5kYvE.js} +1 -1
  85. package/dist/{sum-Cvq06317.js → sum-CutF5lj2.js} +1 -1
  86. package/dist/{tensor-DgTOPY6h.js → tensor-C15NA2LA.js} +1 -1
  87. package/dist/{tensor2d-CRWjDyUe.js → tensor2d-DZ_e5eKM.js} +1 -1
  88. package/dist/{tfjs_backend-D9Ytje0G.js → tfjs_backend-BDb8r9qx.js} +28 -28
  89. package/dist/training/AdamExt.js +1 -1
  90. package/dist/training/DatasetBuilder.js +2 -2
  91. package/dist/training/FullTrainer.js +1 -1
  92. package/dist/training/Trainer.js +3 -3
  93. package/dist/training/sparseCrossEntropy.js +4 -4
  94. package/dist/utilities/dummy.js +2 -2
  95. package/dist/utilities/generate.js +3 -3
  96. package/dist/utilities/load.d.ts +25 -0
  97. package/dist/utilities/load.js +89 -37
  98. package/dist/utilities/profile.js +4 -4
  99. package/dist/utilities/safetensors.d.ts +3 -0
  100. package/dist/utilities/safetensors.js +83 -0
  101. package/dist/utilities/save.js +47 -29
  102. package/dist/utilities/weights.js +2 -2
  103. package/dist/{variable-DZ3fF0R2.js → variable-CdRKKp8x.js} +1 -1
  104. package/dist/{zeros-BaHhQTWf.js → zeros-CAbHfODe.js} +1 -1
  105. package/package.json +1 -1
@@ -1,4 +1,4 @@
1
- import { o, j as t, E as c, _ as a, $ as e } from "./index-Du-bmOP8.js";
1
+ import { o, j as t, E as c, _ as a, $ as e } from "./index-bMBtI-WR.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { l as s } from "./index-Du-bmOP8.js";
1
+ import { l as s } from "./index-bMBtI-WR.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2021 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as r, j as f, E as e, S as i } from "./index-Du-bmOP8.js";
1
+ import { o as r, j as f, E as e, S as i } from "./index-bMBtI-WR.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as p, j as i, E as a, x as c } from "./index-Du-bmOP8.js";
1
+ import { o as p, j as i, E as a, x as c } from "./index-bMBtI-WR.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as e, k as c, l as n, E as k, P as i } from "./index-Du-bmOP8.js";
1
+ import { o as e, k as c, l as n, E as k, P as i } from "./index-bMBtI-WR.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { o as e, j as u, D as c, E as l, F as m } from "./index-Du-bmOP8.js";
1
+ import { o as e, j as u, D as c, E as l, F as m } from "./index-bMBtI-WR.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { J as t, K as a } from "./index-Du-bmOP8.js";
1
+ import { J as t, K as a } from "./index-bMBtI-WR.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { I as t, J as s, K as a } from "./index-Du-bmOP8.js";
1
+ import { I as t, J as s, K as a } from "./index-bMBtI-WR.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,11 +1,11 @@
1
- import { o as h, j as f, E as $, ao as Te, l as _, g as Ee, ap as xe, aq as Ie, ar as Le, as as be, at as Ne, au as Ce, av as Pe, b as H, aw as Fe, a8 as U, u as ae, q as ie, Q as le, c as fe, ax as he, ai as pe, ay as je, t as S, D as $e, al as Me, a2 as Be } from "./index-Du-bmOP8.js";
2
- import { s as C, t as Ke, a as Ue, b as ve } from "./ops-CJNniCAV.js";
3
- import { r as Re, d as Ve } from "./dropout-B09InSJS.js";
4
- import { r as u } from "./reshape-BH7eBpwq.js";
5
- import { g as qe } from "./gather-D6MsdXqc.js";
6
- import { s as Ge } from "./sum-Cvq06317.js";
7
- import { m as A } from "./mat_mul-CbiqIe2d.js";
8
- import { c as M } from "./concat-DdKPyAtw.js";
1
+ import { o as h, j as f, E as $, ar as Te, l as _, g as Ee, as as xe, at as Ie, au as Le, av as be, aw as Ne, ax as Ce, ay as Pe, b as H, az as Fe, a8 as U, u as ae, q as ie, Q as le, c as fe, aA as he, ai as pe, aB as je, t as S, D as $e, al as Be, a2 as Me } from "./index-bMBtI-WR.js";
2
+ import { s as C, t as Ke, a as Ue, b as ve } from "./ops-Mv7Ta72x.js";
3
+ import { r as Re, d as Ve } from "./dropout-CrMQPCeG.js";
4
+ import { r as u } from "./reshape-T4yDEqoF.js";
5
+ import { g as Ge } from "./gather-DZCMHZuN.js";
6
+ import { s as qe } from "./sum-CutF5lj2.js";
7
+ import { m as A } from "./mat_mul-BsrLfy81.js";
8
+ import { c as B } from "./concat-DhZfF1GY.js";
9
9
  /**
10
10
  * @license
11
11
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -52,7 +52,7 @@ function We(e, n, t) {
52
52
  }
53
53
  const Ye = /* @__PURE__ */ h({ clipByValue_: We });
54
54
  function Qe(e) {
55
- return M(
55
+ return B(
56
56
  e,
57
57
  0
58
58
  /* axis */
@@ -60,15 +60,15 @@ function Qe(e) {
60
60
  }
61
61
  const He = /* @__PURE__ */ h({ concat1d_: Qe });
62
62
  function Xe(e, n) {
63
- return M(e, n);
63
+ return B(e, n);
64
64
  }
65
65
  const ze = /* @__PURE__ */ h({ concat2d_: Xe });
66
66
  function en(e, n) {
67
- return M(e, n);
67
+ return B(e, n);
68
68
  }
69
69
  const nn = /* @__PURE__ */ h({ concat3d_: en });
70
70
  function tn(e, n) {
71
- return M(e, n);
71
+ return B(e, n);
72
72
  }
73
73
  const rn = /* @__PURE__ */ h({ concat4d_: tn });
74
74
  /**
@@ -307,7 +307,7 @@ function An(e, n, t) {
307
307
  function Sn(e, n) {
308
308
  let t = n;
309
309
  const r = Fe(e.shape, n.shape);
310
- return r.length > 0 && (t = Ge(t, r)), u(t, e.shape);
310
+ return r.length > 0 && (t = qe(t, r)), u(t, e.shape);
311
311
  }
312
312
  function yn(e, n, t, r) {
313
313
  if (n === "linear")
@@ -352,22 +352,22 @@ function _n({ a: e, b: n, transposeA: t = !1, transposeB: r = !1, bias: s, activ
352
352
  [c, a] = ae(c, a);
353
353
  const k = t ? c.shape[c.rank - 2] : c.shape[c.rank - 1], m = r ? a.shape[a.rank - 1] : a.shape[a.rank - 2], E = t ? c.shape[c.rank - 1] : c.shape[c.rank - 2], d = r ? a.shape[a.rank - 2] : a.shape[a.rank - 1], ne = c.shape.slice(0, -2), x = a.shape.slice(0, -2), te = ie(ne), re = ie(x);
354
354
  _(k === m, () => `Error in fused matMul: inner shapes (${k}) and (${m}) of Tensors with shapes ${c.shape} and ${a.shape} and transposeA=${t} and transposeB=${r} must match.`);
355
- const R = le(c.shape.slice(0, -2), a.shape.slice(0, -2)).concat([E, d]), V = t ? u(c, [te, k, E]) : u(c, [te, E, k]), q = r ? u(a, [re, d, m]) : u(a, [re, m, d]);
355
+ const R = le(c.shape.slice(0, -2), a.shape.slice(0, -2)).concat([E, d]), V = t ? u(c, [te, k, E]) : u(c, [te, E, k]), G = r ? u(a, [re, d, m]) : u(a, [re, m, d]);
356
356
  let I;
357
357
  s != null && (I = f(s, "bias", "fused matMul"), [I] = ae(I, c), le(R, I.shape));
358
358
  let se;
359
359
  i != null && (se = f(i, "prelu weights", "fused matMul"));
360
360
  const oe = (D, P) => {
361
- const [y, O, T, B] = P, w = An(u(D, T.shape), T, o);
361
+ const [y, O, T, M] = P, w = An(u(D, T.shape), T, o);
362
362
  let L, b;
363
363
  if (!t && !r ? (L = A(w, O, !1, !0), b = A(y, w, !0, !1)) : !t && r ? (L = A(w, O, !1, !1), b = A(w, y, !0, !1)) : t && !r ? (L = A(O, w, !1, !0), b = A(y, w, !1, !1)) : (L = A(O, w, !0, !0), b = A(w, y, !0, !0)), s != null) {
364
- const De = Sn(B, w);
364
+ const De = Sn(M, w);
365
365
  return [L, b, De];
366
366
  } else
367
367
  return [L, b];
368
368
  }, ue = {
369
369
  a: V,
370
- b: q,
370
+ b: G,
371
371
  bias: I,
372
372
  preluActivationWeights: se
373
373
  }, ce = { transposeA: t, transposeB: r, activation: o, leakyreluAlpha: p };
@@ -377,13 +377,13 @@ function _n({ a: e, b: n, transposeA: t = !1, transposeB: r = !1, bias: s, activ
377
377
  $.runKernel(he, ue, ce)
378
378
  );
379
379
  return O([P, y, T]), { value: u(T, R), gradFunc: oe };
380
- })(V, q) : fe((P, y, O, T) => {
381
- const B = (
380
+ })(V, G) : fe((P, y, O, T) => {
381
+ const M = (
382
382
  // tslint:disable-next-line: no-unnecessary-type-assertion
383
383
  $.runKernel(he, ue, ce)
384
384
  );
385
- return T([P, y, B, O]), { value: u(B, R), gradFunc: oe };
386
- })(V, q, I);
385
+ return T([P, y, M, O]), { value: u(M, R), gradFunc: oe };
386
+ })(V, G, I);
387
387
  }
388
388
  const de = /* @__PURE__ */ h({ fusedMatMul_: _n });
389
389
  /**
@@ -395,7 +395,7 @@ const de = /* @__PURE__ */ h({ fusedMatMul_: _n });
395
395
  * https://opensource.org/licenses/MIT.
396
396
  * =============================================================================
397
397
  */
398
- const Dn = ["channelsFirst", "channelsLast"], Tn = ["nearest", "bilinear"], En = ["valid", "same", "causal"], xn = ["max", "avg"], Gn = ["sum", "mul", "concat", "ave"];
398
+ const Dn = ["channelsFirst", "channelsLast"], Tn = ["nearest", "bilinear"], En = ["valid", "same", "causal"], xn = ["max", "avg"], qn = ["sum", "mul", "concat", "ave"];
399
399
  /**
400
400
  * @license
401
401
  * Copyright 2018 Google LLC
@@ -695,9 +695,9 @@ function gt(e, n) {
695
695
  * https://opensource.org/licenses/MIT.
696
696
  * =============================================================================
697
697
  */
698
- let G;
698
+ let q;
699
699
  function mt() {
700
- return G == null && (G = je().epsilon()), G;
700
+ return q == null && (q = je().epsilon()), q;
701
701
  }
702
702
  function Y() {
703
703
  return "channelsLast";
@@ -830,7 +830,7 @@ function St(e, n, t, r) {
830
830
  }
831
831
  function yt(e, n = -1) {
832
832
  let t;
833
- return n < 0 && (t = e[0].rank, t !== 0 ? n = t : n = 0), n === e[0].rank && (n = -1), M(e, n);
833
+ return n < 0 && (t = e[0].rank, t !== 0 ? n = t : n = 0), n === e[0].rank && (n = -1), B(e, n);
834
834
  }
835
835
  function Ot(e, n) {
836
836
  switch (e.rank) {
@@ -888,7 +888,7 @@ function Dt(e, n, t, r) {
888
888
  }
889
889
  }
890
890
  function Tt(e, n, t) {
891
- return S(() => (Array.isArray(n) ? n = Ke(n, "int32") : n = $e(n, "int32"), qe(e, n, t)));
891
+ return S(() => (Array.isArray(n) ? n = Ke(n, "int32") : n = $e(n, "int32"), Ge(e, n, t)));
892
892
  }
893
893
  function Et(e) {
894
894
  return H(e, e);
@@ -925,7 +925,7 @@ function It(e, n = 1) {
925
925
  return ke(e);
926
926
  }
927
927
  function Lt(e) {
928
- return S(() => Me(e, U(Be(e), 1)));
928
+ return S(() => Be(e, U(Me(e), 1)));
929
929
  }
930
930
  function bt(e, n, t, r) {
931
931
  return S(() => Ve(e, n, t, r));
@@ -981,7 +981,7 @@ export {
981
981
  At as a9,
982
982
  kt as aa,
983
983
  at as ab,
984
- Gn as ac,
984
+ qn as ac,
985
985
  Sn as b,
986
986
  v as c,
987
987
  Dt as d,
@@ -1,4 +1,4 @@
1
- import { A as r, b as c, f as h, s as g, e as o } from "../index-Du-bmOP8.js";
1
+ import { A as r, b as c, f as h, s as g, e as o } from "../index-bMBtI-WR.js";
2
2
  class u extends r {
3
3
  constructor(t, e, s, a, i) {
4
4
  super(t, e, s, a), this.config = i, this.startLearningRate = t;
@@ -1,5 +1,5 @@
1
- import { t as u } from "../index-Du-bmOP8.js";
2
- import { d as z, i as f } from "../dataset-iqT4Otvb.js";
1
+ import { t as u } from "../index-bMBtI-WR.js";
2
+ import { d as z, i as f } from "../dataset-oilnemHf.js";
3
3
  import "../index-Tf7vU29b.js";
4
4
  /**
5
5
  * @license
@@ -1,7 +1,7 @@
1
1
  import { generateText as T } from "../utilities/generate.js";
2
2
  import L from "./Trainer.js";
3
3
  import x from "./Evaluator.js";
4
- import { a as h } from "../index-Du-bmOP8.js";
4
+ import { a as h } from "../index-bMBtI-WR.js";
5
5
  import y from "../utilities/profile.js";
6
6
  const D = {
7
7
  desiredLoss: 0.01,
@@ -1,8 +1,8 @@
1
1
  import { DatasetBuilder as g, flattenTokens as m, PAGE_FACTOR as u } from "./DatasetBuilder.js";
2
2
  import f from "./AdamExt.js";
3
- import { t as y, v as z, a as c } from "../index-Du-bmOP8.js";
4
- import { n as S } from "../norm-01kY9I2B.js";
5
- import { z as p } from "../zeros-BaHhQTWf.js";
3
+ import { t as y, v as z, a as c } from "../index-bMBtI-WR.js";
4
+ import { n as S } from "../norm-B9hWHZH1.js";
5
+ import { z as p } from "../zeros-CAbHfODe.js";
6
6
  class R {
7
7
  constructor(t, e, s = 1e-3) {
8
8
  this.tokenizer = e, this.model = t, this.learningRate = s, this.resetOptimizer(), this.datasetBuilder = new g(e, t.config.gpt.blockSize);
@@ -1,9 +1,9 @@
1
1
  import { gatherSub as L } from "../ops/gatherSub.js";
2
2
  import { scatterSub as y } from "../ops/scatterSub.js";
3
- import { e as u, c as i, z as S, t as f, s as G } from "../index-Du-bmOP8.js";
4
- import { s as v } from "../softmax-DhWoBa7r.js";
5
- import { m as z } from "../max-0Xnlpv8k.js";
6
- import { l as k } from "../log_sum_exp-CxfBtUaG.js";
3
+ import { e as u, c as i, z as S, t as f, s as G } from "../index-bMBtI-WR.js";
4
+ import { s as v } from "../softmax-B2_IKPDR.js";
5
+ import { m as z } from "../max-DechV4Bc.js";
6
+ import { l as k } from "../log_sum_exp-BHdkCb4s.js";
7
7
  function F(a, s) {
8
8
  return f(() => {
9
9
  const e = a.shape[a.shape.length - 1], o = a.shape.slice(0, -1).reduce((d, c) => d * c, 1), p = a.shape.length > 2 ? a.reshape([o, e]) : a, n = s.shape.length > 1 ? s.reshape([o]).cast("int32") : s.cast("int32"), t = z(p, -1, !0), r = G(p, t), h = k(r, -1);
@@ -1,5 +1,5 @@
1
- import "../index-Du-bmOP8.js";
2
- import { z as n } from "../zeros-BaHhQTWf.js";
1
+ import "../index-bMBtI-WR.js";
2
+ import { z as n } from "../zeros-CAbHfODe.js";
3
3
  async function c(s) {
4
4
  const i = n([1, s.config.gpt.blockSize], "int32"), [t, o] = s.forward({ training: !1 }, i);
5
5
  await t.data(), t.dispose(), o && o.dispose(), i.dispose();
@@ -1,6 +1,6 @@
1
- import { t as y } from "../index-Du-bmOP8.js";
2
- import { t as x } from "../tensor2d-CRWjDyUe.js";
3
- import { c as f } from "../concat-DdKPyAtw.js";
1
+ import { t as y } from "../index-bMBtI-WR.js";
2
+ import { t as x } from "../tensor2d-DZ_e5eKM.js";
3
+ import { c as f } from "../concat-DhZfF1GY.js";
4
4
  async function A(o, r, a, c, T) {
5
5
  if (c <= 0)
6
6
  throw new Error("Length must be a positive integer");
@@ -1,6 +1,31 @@
1
+ import { default as zip } from 'jszip';
1
2
  import { default as NanoGPT } from '../NanoGPTModel';
2
3
  import { ITokeniser } from '../tokeniser/type';
4
+ export declare const VERSION = 2;
5
+ export interface TransformersConfig {
6
+ model_type: string;
7
+ vocab_size: number;
8
+ hidden_size: number;
9
+ num_hidden_layers: number;
10
+ num_attention_heads: number;
11
+ block_size: number;
12
+ dropout: number;
13
+ biasInLinear: boolean;
14
+ biasInLayerNorm: boolean;
15
+ mlpFactor: number;
16
+ useRope: boolean;
17
+ }
18
+ export interface Metadata {
19
+ version: string;
20
+ application: string;
21
+ name?: string;
22
+ }
23
+ export declare function loadOldModel(zipFile: zip): Promise<{
24
+ model: NanoGPT;
25
+ tokeniser: ITokeniser;
26
+ }>;
3
27
  export declare function loadModel(data: Blob | Buffer | string): Promise<{
4
28
  model: NanoGPT;
5
29
  tokeniser: ITokeniser;
30
+ name?: string;
6
31
  }>;
@@ -1,47 +1,99 @@
1
- import { j as k } from "../jszip.min-CjP2V1VV.js";
2
- import { importWeights as b } from "./weights.js";
3
- import u from "../tokeniser/CharTokeniser.js";
4
- import F from "../NanoGPTModel.js";
5
- import { dummyPassAsync as j } from "./dummy.js";
6
- import { d as T } from "../index-Du-bmOP8.js";
7
- import E from "../tokeniser/bpe.js";
8
- async function A(t) {
9
- const o = await fetch(t);
10
- if (!o.ok)
11
- throw new Error(`Failed to fetch ${t}: ${o.statusText}`);
12
- return o.arrayBuffer();
1
+ import { j as v } from "../jszip.min-CjP2V1VV.js";
2
+ import { importWeights as F } from "./weights.js";
3
+ import h from "../tokeniser/CharTokeniser.js";
4
+ import b from "../NanoGPTModel.js";
5
+ import { dummyPassAsync as u } from "./dummy.js";
6
+ import { d as k } from "../index-bMBtI-WR.js";
7
+ import j from "../tokeniser/bpe.js";
8
+ import { load_safetensors as N } from "./safetensors.js";
9
+ const I = 2;
10
+ async function O(t) {
11
+ const s = await fetch(t);
12
+ if (!s.ok)
13
+ throw new Error(`Failed to fetch ${t}: ${s.statusText}`);
14
+ return s.arrayBuffer();
13
15
  }
14
- async function P(t) {
15
- const o = typeof t == "string" ? await A(t) : t, n = await k.loadAsync(o), i = /* @__PURE__ */ new Map(), f = await n.file("manifest.json")?.async("string");
16
- if (!f)
16
+ async function S(t) {
17
+ const s = /* @__PURE__ */ new Map(), r = await t.file("manifest.json")?.async("string");
18
+ if (!r)
17
19
  throw new Error("Manifest file not found in the zip archive");
18
- const l = JSON.parse(f);
19
- for (const [e, r] of Object.entries(l.weightSpec))
20
- i.set(e, { spec: r, data: new Float32Array() });
21
- const p = await n.file("tokeniser.json")?.async("string");
22
- if (!p)
20
+ const p = JSON.parse(r);
21
+ for (const [o, a] of Object.entries(p.weightSpec))
22
+ s.set(o, { spec: a, data: new Float32Array() });
23
+ const e = await t.file("tokeniser.json")?.async("string");
24
+ if (!e)
23
25
  throw new Error("Tokeniser file not found in the zip archive");
24
- const s = JSON.parse(p), y = (s.type ?? "char") === "char" ? new u(s.vocab) : new E(s.vocab, s.merges), w = /* @__PURE__ */ new Map();
25
- for (const e of Object.keys(n.files))
26
- if (e.endsWith(".bin")) {
27
- const r = e.replace(".bin", ""), g = await n.file(e).async("arraybuffer"), h = new Float32Array(g), c = i.get(r) || { spec: [], data: new Float32Array() };
28
- c.data = h, i.set(r, c);
29
- const d = await b(c);
30
- w.set(r, d);
26
+ const i = JSON.parse(e), c = (i.type ?? "char") === "char" ? new h(i.vocab) : new j(i.vocab, i.merges), d = /* @__PURE__ */ new Map();
27
+ for (const o of Object.keys(t.files))
28
+ if (o.endsWith(".bin")) {
29
+ const a = o.replace(".bin", ""), w = await t.file(o).async("arraybuffer"), g = new Float32Array(w), l = s.get(a) || { spec: [], data: new Float32Array() };
30
+ l.data = g, s.set(a, l);
31
+ const n = await F(l);
32
+ d.set(a, n);
31
33
  }
32
- T();
33
- const a = new F(l.config);
34
- await j(a), a.loadWeights(w);
35
- const m = await n.file("log.json")?.async("string");
34
+ k();
35
+ const f = new b(p.config);
36
+ await u(f), f.loadWeights(d);
37
+ const m = await t.file("log.json")?.async("string");
36
38
  if (m)
37
39
  try {
38
- const e = JSON.parse(m);
39
- a.log = e;
40
- } catch (e) {
41
- throw console.error("Error parsing training log:", e), new Error(`Failed to parse training log: ${e}`);
40
+ const o = JSON.parse(m);
41
+ f.log = o;
42
+ } catch (o) {
43
+ throw console.error("Error parsing training log:", o), new Error(`Failed to parse training log: ${o}`);
42
44
  }
43
- return { model: a, tokeniser: y };
45
+ return { model: f, tokeniser: c };
46
+ }
47
+ async function R(t) {
48
+ const s = typeof t == "string" ? await O(t) : t, r = await v.loadAsync(s);
49
+ if (r.file("manifest.json"))
50
+ return S(r);
51
+ {
52
+ const p = await r.file("config.json")?.async("string");
53
+ if (!p)
54
+ throw new Error("Config file not found in the zip archive");
55
+ const e = JSON.parse(p), i = {
56
+ vocabSize: e.vocab_size,
57
+ blockSize: e.block_size,
58
+ nLayer: e.num_hidden_layers,
59
+ nHead: e.num_attention_heads,
60
+ nEmbed: e.hidden_size,
61
+ dropout: e.dropout,
62
+ biasInLinear: e.biasInLinear,
63
+ biasInLayerNorm: e.biasInLayerNorm,
64
+ mlpFactor: e.mlpFactor,
65
+ useRope: e.useRope
66
+ }, y = await r.file("tokeniser.json")?.async("string");
67
+ if (!y)
68
+ throw new Error("Tokeniser file not found in the zip archive");
69
+ const c = JSON.parse(y), f = (c.type ?? "char") === "char" ? new h(c.vocab) : new j(c.vocab, c.merges), m = await N(await r.file("model.safetensors").async("arraybuffer")), o = /* @__PURE__ */ new Map();
70
+ for (const [n, E] of Object.entries(m))
71
+ o.set(n, [E]);
72
+ k();
73
+ const a = new b(i);
74
+ await u(a), a.loadWeights(o);
75
+ const w = await r.file("meta.json")?.async("string");
76
+ let g;
77
+ if (w)
78
+ try {
79
+ const n = JSON.parse(w);
80
+ n.name && (g = n.name);
81
+ } catch (n) {
82
+ console.error("Error parsing meta file:", n);
83
+ }
84
+ const l = await r.file("log.json")?.async("string");
85
+ if (l)
86
+ try {
87
+ const n = JSON.parse(l);
88
+ a.log = n;
89
+ } catch (n) {
90
+ throw console.error("Error parsing training log:", n), new Error(`Failed to parse training log: ${n}`);
91
+ }
92
+ return { model: a, tokeniser: f, name: g };
93
+ }
44
94
  }
45
95
  export {
46
- P as loadModel
96
+ I as VERSION,
97
+ R as loadModel,
98
+ S as loadOldModel
47
99
  };
@@ -1,6 +1,6 @@
1
- import { m as s } from "../index-Du-bmOP8.js";
1
+ import { m as s } from "../index-bMBtI-WR.js";
2
2
  const m = 1024 * 1024;
3
- class M {
3
+ class l {
4
4
  log = /* @__PURE__ */ new Map();
5
5
  maxMemory = 0;
6
6
  maxLabel;
@@ -18,7 +18,7 @@ class M {
18
18
  return;
19
19
  }
20
20
  const o = s(), t = o.numBytes - (this.lastMemInfo.pop()?.numBytes || 0);
21
- this.log.set(e, Math.max(this.log.get(e) || 0, t)), t > this.maxMemory && (this.maxMemory = t, this.maxLabel = e), this.peakMemory = Math.max(this.peakMemory, o.numBytes);
21
+ this.log.set(e, Math.max(this.log.get(e) || 0, t)), t > this.maxMemory && (this.maxMemory = t, this.maxLabel = e), this.peakMemory = Math.max(this.peakMemory, o.numBytesInGPUAllocated || o.numBytes);
22
22
  }
23
23
  printSummary() {
24
24
  console.log("Memory Usage Summary:");
@@ -28,5 +28,5 @@ class M {
28
28
  }
29
29
  }
30
30
  export {
31
- M as default
31
+ l as default
32
32
  };
@@ -0,0 +1,3 @@
1
+ import { Tensor } from '@tensorflow/tfjs-core';
2
+ export declare function save_safetensors(tensors: Record<string, Tensor>): Promise<ArrayBuffer>;
3
+ export declare function load_safetensors(buffer: ArrayBuffer): Promise<Record<string, Tensor>>;
@@ -0,0 +1,83 @@
1
+ import "../index-bMBtI-WR.js";
2
+ import { t as y } from "../tensor-C15NA2LA.js";
3
+ function l(t) {
4
+ if (t === "float32") return "F32";
5
+ if (t === "int32") return "I32";
6
+ throw new Error(`Unsupported dtype: ${t}`);
7
+ }
8
+ function h(t) {
9
+ if (t === "F32") return "float32";
10
+ if (t === "I32") return "int32";
11
+ throw new Error(`Unsupported dtype: ${t}`);
12
+ }
13
+ async function _(t) {
14
+ const c = {};
15
+ let a = 0;
16
+ for (const [n, s] of Object.entries(t))
17
+ c[n] = {
18
+ dtype: l(s.dtype),
19
+ shape: s.shape,
20
+ data_offsets: [a, a + s.size * 4]
21
+ }, a += s.size * 4;
22
+ const p = JSON.stringify(c);
23
+ let r = new TextEncoder().encode(p);
24
+ if (r.length % 4 !== 0) {
25
+ const n = 4 - r.length % 4, s = new Uint8Array(r.length + n);
26
+ s.set(r);
27
+ for (let w = r.length; w < s.length; w++)
28
+ s[w] = 32;
29
+ r = s;
30
+ }
31
+ const o = r.length, f = 8 + o + a, e = new ArrayBuffer(f);
32
+ new DataView(e).setUint32(0, o, !0), new Uint8Array(e, 8, o).set(r);
33
+ let d = 8 + o;
34
+ for (const n of Object.values(t)) {
35
+ if (n.size === 0) continue;
36
+ const s = await n.data();
37
+ if (n.dtype === "float32")
38
+ new Float32Array(e, d, n.size).set(s), d += n.size * 4;
39
+ else if (n.dtype === "int32")
40
+ new Int32Array(e, d, n.size).set(s), d += n.size * 4;
41
+ else
42
+ throw new Error(`Unsupported dtype: ${n.dtype}`);
43
+ }
44
+ return e;
45
+ }
46
+ async function U(t) {
47
+ const a = new DataView(t).getUint32(0, !0), p = new Uint8Array(t, 8, a), r = JSON.parse(new TextDecoder().decode(p)), o = {};
48
+ for (const [f, e] of Object.entries(r)) {
49
+ if (e.data_offsets[0] === e.data_offsets[1]) {
50
+ o[f] = y([], e.shape, h(e.dtype));
51
+ continue;
52
+ }
53
+ if (e.dtype === "F32") {
54
+ const i = y(
55
+ new Float32Array(
56
+ t,
57
+ e.data_offsets[0] + 8 + a,
58
+ (e.data_offsets[1] - e.data_offsets[0]) / 4
59
+ ),
60
+ e.shape,
61
+ h(e.dtype)
62
+ );
63
+ o[f] = i;
64
+ } else if (e.dtype === "I32") {
65
+ const i = y(
66
+ new Int32Array(
67
+ t,
68
+ e.data_offsets[0] + 8 + a,
69
+ (e.data_offsets[1] - e.data_offsets[0]) / 4
70
+ ),
71
+ e.shape,
72
+ h(e.dtype)
73
+ );
74
+ o[f] = i;
75
+ } else
76
+ throw new Error(`Unsupported dtype: ${e.dtype}`);
77
+ }
78
+ return o;
79
+ }
80
+ export {
81
+ U as load_safetensors,
82
+ _ as save_safetensors
83
+ };