@genai-fi/nanogpt 0.4.5 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/dist/BaseLayer-BhrMN8JO.js +135 -0
  2. package/dist/Generator.js +44 -41
  3. package/dist/NanoGPTModel.d.ts +12 -16
  4. package/dist/NanoGPTModel.js +128 -138
  5. package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
  6. package/dist/TeachableLLM.js +1 -1
  7. package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
  8. package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
  9. package/dist/broadcast_to-CMlkG8NS.js +44 -0
  10. package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
  11. package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
  12. package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
  13. package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
  14. package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
  15. package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
  16. package/dist/layers/BaseLayer.d.ts +28 -4
  17. package/dist/layers/BaseLayer.js +3 -16
  18. package/dist/layers/CausalSelfAttention.d.ts +22 -24
  19. package/dist/layers/CausalSelfAttention.js +73 -128
  20. package/dist/layers/MLP.d.ts +8 -15
  21. package/dist/layers/MLP.js +43 -81
  22. package/dist/layers/RMSNorm.d.ts +5 -10
  23. package/dist/layers/RMSNorm.js +13 -29
  24. package/dist/layers/RoPECache.js +14 -12
  25. package/dist/layers/TiedEmbedding.d.ts +6 -16
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.d.ts +12 -16
  28. package/dist/layers/TransformerBlock.js +20 -41
  29. package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
  30. package/dist/main.js +1 -1
  31. package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
  32. package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
  33. package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
  34. package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
  35. package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
  36. package/dist/ops/appendCache.js +4 -4
  37. package/dist/ops/attentionMask.d.ts +1 -1
  38. package/dist/ops/attentionMask.js +4 -4
  39. package/dist/ops/cpu/appendCache.js +2 -2
  40. package/dist/ops/cpu/attentionMask.js +14 -15
  41. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  42. package/dist/ops/cpu/gatherSub.js +5 -5
  43. package/dist/ops/cpu/gelu.js +1 -1
  44. package/dist/ops/cpu/matMulGelu.js +1 -1
  45. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  46. package/dist/ops/cpu/matMulMul.js +17 -0
  47. package/dist/ops/cpu/mulDropout.js +1 -1
  48. package/dist/ops/cpu/normRMS.js +1 -1
  49. package/dist/ops/cpu/qkv.js +3 -3
  50. package/dist/ops/cpu/rope.js +5 -5
  51. package/dist/ops/cpu/scatterSub.js +8 -8
  52. package/dist/ops/fusedSoftmax.js +1 -1
  53. package/dist/ops/gatherSub.js +1 -1
  54. package/dist/ops/gelu.js +1 -1
  55. package/dist/ops/grads/attentionMask.js +13 -9
  56. package/dist/ops/grads/fusedSoftmax.js +12 -9
  57. package/dist/ops/grads/gelu.js +1 -1
  58. package/dist/ops/grads/matMulGelu.js +1 -1
  59. package/dist/ops/grads/normRMS.js +1 -1
  60. package/dist/ops/grads/qkv.js +19 -9
  61. package/dist/ops/grads/rope.js +1 -1
  62. package/dist/ops/matMulGelu.js +1 -1
  63. package/dist/ops/matMulMul.d.ts +2 -0
  64. package/dist/ops/matMulMul.js +9 -0
  65. package/dist/ops/mulDrop.js +1 -1
  66. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  67. package/dist/ops/normRMS.js +1 -1
  68. package/dist/ops/qkv.js +1 -1
  69. package/dist/ops/scatterSub.js +1 -1
  70. package/dist/ops/webgl/appendCache.js +1 -1
  71. package/dist/ops/webgl/attentionMask.js +13 -12
  72. package/dist/ops/webgl/fusedSoftmax.js +43 -40
  73. package/dist/ops/webgl/gatherSub.js +1 -1
  74. package/dist/ops/webgl/gelu.js +2 -2
  75. package/dist/ops/webgl/matMulGelu.js +17 -17
  76. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  77. package/dist/ops/webgl/matMulMul.js +28 -0
  78. package/dist/ops/webgl/mulDropout.js +1 -1
  79. package/dist/ops/webgl/normRMS.js +29 -21
  80. package/dist/ops/webgl/qkv.js +1 -1
  81. package/dist/ops/webgl/rope.js +1 -1
  82. package/dist/ops/webgl/scatterSub.js +1 -1
  83. package/dist/ops-ObfXLHYQ.js +1269 -0
  84. package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
  85. package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
  86. package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
  87. package/dist/slice_util-D-kaD4ZV.js +49 -0
  88. package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
  89. package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
  90. package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
  91. package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
  92. package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
  93. package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
  94. package/dist/tfjs_backend-NucKez4s.js +1010 -0
  95. package/dist/training/AdamExt.js +1 -1
  96. package/dist/training/DatasetBuilder.js +44 -44
  97. package/dist/training/Evaluator.js +6 -6
  98. package/dist/training/FullTrainer.js +1 -1
  99. package/dist/training/Trainer.js +7 -7
  100. package/dist/training/sparseCrossEntropy.js +4 -4
  101. package/dist/utilities/dummy.js +10 -10
  102. package/dist/utilities/generate.js +3 -3
  103. package/dist/utilities/load.js +1 -1
  104. package/dist/utilities/profile.js +1 -1
  105. package/dist/utilities/save.js +10 -8
  106. package/dist/utilities/weights.js +2 -2
  107. package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
  108. package/package.json +1 -1
  109. package/dist/slice_util-BdhYwFY_.js +0 -90
  110. package/dist/tfjs_backend-DuKis_xG.js +0 -2271
  111. package/dist/variable-BJTZ3jOy.js +0 -23
@@ -1,4 +1,4 @@
1
- import { A as r, b as c, f as h, s as g, e as o } from "../index--6vO-cOz.js";
1
+ import { A as r, b as c, f as h, s as g, e as o } from "../index-iNhkcAEQ.js";
2
2
  class u extends r {
3
3
  constructor(t, e, s, a, i) {
4
4
  super(t, e, s, a), this.config = i, this.startLearningRate = t;
@@ -1,7 +1,7 @@
1
- import { aj as $, ah as d, L as M, a as R, ak as f, al as v, am as z, j as _, t as x } from "../index--6vO-cOz.js";
1
+ import { aj as $, V as d, K as M, a as R, ak as f, al as v, am as z, k as _, t as x } from "../index-iNhkcAEQ.js";
2
2
  import { s as E } from "../index-C4L8Cm77.js";
3
- import { s as P } from "../stack-CmqSdsfs.js";
4
- import { t as D } from "../tensor-BGYi41cj.js";
3
+ import { s as P } from "../stack--cqr9Dgc.js";
4
+ import { t as D } from "../tensor-CfiPXsW4.js";
5
5
  import "../index-Tf7vU29b.js";
6
6
  /**
7
7
  * @license
@@ -37,13 +37,13 @@ function I(s, t, e = /* @__PURE__ */ new Map(), r = /* @__PURE__ */ new Set()) {
37
37
  throw new Error("A deep map function may not return both a value and recurse=true.");
38
38
  if (n.recurse)
39
39
  if (p(s)) {
40
- const a = Array.isArray(s) ? [] : {};
40
+ const i = Array.isArray(s) ? [] : {};
41
41
  r.add(s);
42
42
  for (const l in s) {
43
43
  const h = s[l], c = I(h, t, e, r);
44
- a[l] = c;
44
+ i[l] = c;
45
45
  }
46
- return r.delete(s), s.__proto__ && (a.__proto__ = s.__proto__), a;
46
+ return r.delete(s), s.__proto__ && (i.__proto__ = s.__proto__), i;
47
47
  } else
48
48
  throw new Error(`Can't recurse into non-iterable type: ${s}`);
49
49
  else return e.set(s, n.value), n.value;
@@ -60,13 +60,13 @@ function A(s, t, e = /* @__PURE__ */ new Set()) {
60
60
  throw new Error("A deep zip function may not return both a value and recurse=true.");
61
61
  if (n.recurse)
62
62
  if (p(r)) {
63
- const a = Array.isArray(r) ? [] : {};
63
+ const i = Array.isArray(r) ? [] : {};
64
64
  e.add(r);
65
65
  for (const l in r) {
66
66
  const h = s.map((w) => w[l]), c = A(h, t, e);
67
- a[l] = c;
67
+ i[l] = c;
68
68
  }
69
- return e.delete(r), a;
69
+ return e.delete(r), i;
70
70
  } else
71
71
  throw new Error(`Can't recurse into non-iterable type: ${r}`);
72
72
  else return n.value;
@@ -303,15 +303,15 @@ y.INITIAL_CAPACITY = 32;
303
303
  * =============================================================================
304
304
  */
305
305
  function W(s) {
306
- return new Y(s);
306
+ return new V(s);
307
307
  }
308
308
  function k(s) {
309
- return new J(s);
309
+ return new Y(s);
310
310
  }
311
311
  function U(s, t) {
312
312
  return new F(s, t);
313
313
  }
314
- class i {
314
+ class a {
315
315
  /**
316
316
  * Collect all remaining elements of a bounded stream into an array.
317
317
  * Obviously this will succeed only for small streams that fit in memory.
@@ -477,7 +477,7 @@ class i {
477
477
  * of the original element type.
478
478
  */
479
479
  rowMajorBatch(t, e = !0) {
480
- return new K(this, t, e);
480
+ return new j(this, t, e);
481
481
  }
482
482
  /**
483
483
  * Groups elements into batches, represented in column-major form.
@@ -512,7 +512,7 @@ class i {
512
512
  * with collections at the leaves.
513
513
  */
514
514
  columnMajorBatch(t, e = !0, r = T) {
515
- return this.rowMajorBatch(t, e).map((a) => O(a, r));
515
+ return this.rowMajorBatch(t, e).map((i) => O(i, r));
516
516
  }
517
517
  /**
518
518
  * Concatenate this `LazyIterator` with another.
@@ -535,7 +535,7 @@ class i {
535
535
  * unaltered.
536
536
  */
537
537
  take(t) {
538
- return t < 0 || t == null ? this : new j(this, t);
538
+ return t < 0 || t == null ? this : new X(this, t);
539
539
  }
540
540
  /**
541
541
  * Skips the first `count` items in this stream.
@@ -544,7 +544,7 @@ class i {
544
544
  * value is given, the entire stream is returned unaltered.
545
545
  */
546
546
  skip(t) {
547
- return t < 0 || t == null ? this : new X(this, t);
547
+ return t < 0 || t == null ? this : new K(this, t);
548
548
  }
549
549
  /**
550
550
  * Prefetch the first `bufferSize` items in this stream.
@@ -575,10 +575,10 @@ class i {
575
575
  * prior one, so that they cannot execute concurrently.
576
576
  */
577
577
  serial() {
578
- return new V(this);
578
+ return new J(this);
579
579
  }
580
580
  }
581
- class Y extends i {
581
+ class V extends a {
582
582
  constructor(t) {
583
583
  super(), this.items = t, this.trav = 0;
584
584
  }
@@ -592,7 +592,7 @@ class Y extends i {
592
592
  return this.trav++, { value: Q(t), done: !1 };
593
593
  }
594
594
  }
595
- class J extends i {
595
+ class Y extends a {
596
596
  constructor(t) {
597
597
  super(), this.nextFn = t;
598
598
  }
@@ -607,7 +607,7 @@ class J extends i {
607
607
  }
608
608
  }
609
609
  }
610
- class V extends i {
610
+ class J extends a {
611
611
  constructor(t) {
612
612
  super(), this.upstream = t, this.lastRead = Promise.resolve({ value: null, done: !1 });
613
613
  }
@@ -621,7 +621,7 @@ class V extends i {
621
621
  return this.upstream.next();
622
622
  }
623
623
  }
624
- class X extends i {
624
+ class K extends a {
625
625
  constructor(t, e) {
626
626
  super(), this.upstream = t, this.maxCount = e, this.count = 0, this.lastRead = Promise.resolve({ value: null, done: !1 });
627
627
  }
@@ -641,7 +641,7 @@ class X extends i {
641
641
  return this.upstream.next();
642
642
  }
643
643
  }
644
- class j extends i {
644
+ class X extends a {
645
645
  constructor(t, e) {
646
646
  super(), this.upstream = t, this.maxCount = e, this.count = 0;
647
647
  }
@@ -652,7 +652,7 @@ class j extends i {
652
652
  return this.count++ >= this.maxCount ? { value: null, done: !0 } : this.upstream.next();
653
653
  }
654
654
  }
655
- class K extends i {
655
+ class j extends a {
656
656
  constructor(t, e, r = !0) {
657
657
  super(), this.upstream = t, this.batchSize = e, this.enableSmallLastBatch = r, this.lastRead = Promise.resolve({ value: null, done: !1 });
658
658
  }
@@ -673,7 +673,7 @@ class K extends i {
673
673
  return { value: t, done: !1 };
674
674
  }
675
675
  }
676
- class Z extends i {
676
+ class Z extends a {
677
677
  constructor(t, e) {
678
678
  super(), this.upstream = t, this.predicate = e, this.lastRead = Promise.resolve({ value: null, done: !1 });
679
679
  }
@@ -692,7 +692,7 @@ class Z extends i {
692
692
  }
693
693
  }
694
694
  }
695
- class tt extends i {
695
+ class tt extends a {
696
696
  constructor(t, e) {
697
697
  super(), this.upstream = t, this.transform = e;
698
698
  }
@@ -704,12 +704,12 @@ class tt extends i {
704
704
  if (t.done)
705
705
  return { value: null, done: !0 };
706
706
  const e = f(t.value), r = this.transform(t.value), n = f(r);
707
- for (const a of e)
708
- v(a, n) || a.dispose();
707
+ for (const i of e)
708
+ v(i, n) || i.dispose();
709
709
  return { value: r, done: !1 };
710
710
  }
711
711
  }
712
- class et extends i {
712
+ class et extends a {
713
713
  constructor(t, e) {
714
714
  super(), this.upstream = t, this.handler = e, this.count = 0, this.lastRead = Promise.resolve({ value: null, done: !1 });
715
715
  }
@@ -729,7 +729,7 @@ class et extends i {
729
729
  }
730
730
  }
731
731
  }
732
- class g extends i {
732
+ class g extends a {
733
733
  constructor(t, e) {
734
734
  super(), this.upstream = t, this.transform = e;
735
735
  }
@@ -741,12 +741,12 @@ class g extends i {
741
741
  if (t.done)
742
742
  return { value: null, done: !0 };
743
743
  const e = f(t.value), r = await this.transform(t.value), n = f(r);
744
- for (const a of e)
745
- v(a, n) || a.dispose();
744
+ for (const i of e)
745
+ v(i, n) || i.dispose();
746
746
  return { value: r, done: !1 };
747
747
  }
748
748
  }
749
- class st extends i {
749
+ class st extends a {
750
750
  constructor() {
751
751
  super(), this.outputQueue = new y(), this.lastRead = Promise.resolve({ value: null, done: !1 });
752
752
  }
@@ -773,12 +773,12 @@ class rt extends st {
773
773
  return !1;
774
774
  const e = f(t.value), r = this.transform(t.value), n = f(r);
775
775
  this.outputQueue.pushAll(r);
776
- for (const a of e)
777
- v(a, n) || a.dispose();
776
+ for (const i of e)
777
+ v(i, n) || i.dispose();
778
778
  return !0;
779
779
  }
780
780
  }
781
- class F extends i {
781
+ class F extends a {
782
782
  constructor(t, e) {
783
783
  super(), this.baseErrorHandler = e, this.lastRead = null, this.iterator = null, this.moreIterators = t;
784
784
  }
@@ -803,7 +803,7 @@ var b;
803
803
  (function(s) {
804
804
  s[s.FAIL = 0] = "FAIL", s[s.SHORTEST = 1] = "SHORTEST", s[s.LONGEST = 2] = "LONGEST";
805
805
  })(b || (b = {}));
806
- class S extends i {
806
+ class S extends a {
807
807
  constructor(t, e) {
808
808
  super(), this.upstream = t, this.bufferSize = e, this.buffer = new C(e);
809
809
  }
@@ -934,7 +934,7 @@ class N {
934
934
  _(t > 0, () => `batchSize needs to be positive, but it is
935
935
  ${t}`);
936
936
  let n;
937
- return this.size === 1 / 0 || this.size == null ? n = this.size : e ? n = Math.ceil(this.size / t) : n = Math.floor(this.size / t), o(async () => (await r.iterator()).columnMajorBatch(t, e, at), n);
937
+ return this.size === 1 / 0 || this.size == null ? n = this.size : e ? n = Math.ceil(this.size / t) : n = Math.floor(this.size / t), o(async () => (await r.iterator()).columnMajorBatch(t, e, it), n);
938
938
  }
939
939
  /**
940
940
  * Concatenates this `Dataset` with another.
@@ -1129,10 +1129,10 @@ class N {
1129
1129
  shuffle(t, e, r = !0) {
1130
1130
  if (t == null || t < 0)
1131
1131
  throw this.size == null ? new RangeError("`Dataset.shuffle()` requires bufferSize to be specified.") : new RangeError(`\`Dataset.shuffle()\` requires bufferSize to be specified. If your data fits in main memory (for regular JS objects), and/or GPU memory (for \`tf.Tensor\`s), consider setting bufferSize to the dataset size (${this.size} elements)`);
1132
- const n = this, a = E.alea(e || z().toString());
1132
+ const n = this, i = E.alea(e || z().toString());
1133
1133
  return o(async () => {
1134
- let l = a.int32();
1135
- return r && (l += a.int32()), (await n.iterator()).shuffle(t, l.toString());
1134
+ let l = i.int32();
1135
+ return r && (l += i.int32()), (await n.iterator()).shuffle(t, l.toString());
1136
1136
  }, this.size);
1137
1137
  }
1138
1138
  /**
@@ -1210,13 +1210,13 @@ function o(s, t = null) {
1210
1210
  }
1211
1211
  }();
1212
1212
  }
1213
- function at(s) {
1213
+ function it(s) {
1214
1214
  if (s === null)
1215
1215
  return null;
1216
1216
  const t = s[0];
1217
- return H(t) ? { value: it(s), recurse: !1 } : { value: null, recurse: !0 };
1217
+ return H(t) ? { value: at(s), recurse: !1 } : { value: null, recurse: !0 };
1218
1218
  }
1219
- function it(s) {
1219
+ function at(s) {
1220
1220
  if (s.length === 0)
1221
1221
  throw new Error("Can't make a batch of zero elements.");
1222
1222
  return s[0] instanceof d ? P(s) : D(s);
@@ -1252,7 +1252,7 @@ class mt {
1252
1252
  }
1253
1253
  // Create dataset from text files
1254
1254
  async createTextDataset(t, e = 32, r = 0, n = 1) {
1255
- const a = await Promise.all(t.map((u) => this.tokenizer.encode(u))), l = this.tokenizer.eosToken >= 0, h = a.map((u) => l ? [...u, this.tokenizer.eosToken] : u).flat(), c = h.slice(
1255
+ const i = await Promise.all(t.map((u) => this.tokenizer.encode(u))), l = this.tokenizer.eosToken >= 0, h = i.map((u) => l ? [...u, this.tokenizer.eosToken] : u).flat(), c = h.slice(
1256
1256
  Math.floor(r * h.length),
1257
1257
  n === 1 ? void 0 : Math.floor(n * h.length)
1258
1258
  ), w = (function* () {
@@ -5,14 +5,14 @@ class p {
5
5
  iterator;
6
6
  async evaluate(s = 100) {
7
7
  let t = 0, o = 0;
8
- const c = await this.iterator;
8
+ const n = await this.iterator;
9
9
  for (let a = 0; a < s; a++) {
10
- const e = await c.next();
10
+ const e = await n.next();
11
11
  if (e.done) break;
12
- const n = e.value, { xs: r, ys: l } = n, { loss: i, logits: u } = this.model.forward(r, l, !1, !1);
13
- u.dispose(), r.dispose(), l.dispose();
14
- const d = i.arraySync();
15
- i.dispose(), t += d, o++;
12
+ const c = e.value, { xs: r, ys: i } = c, [u, l] = this.model.forward({ training: !1 }, r, i);
13
+ u.dispose(), r.dispose(), i.dispose();
14
+ const d = l.arraySync();
15
+ l.dispose(), t += d, o++;
16
16
  }
17
17
  return t / o;
18
18
  }
@@ -1,7 +1,7 @@
1
1
  import { generateText as v } from "../utilities/generate.js";
2
2
  import L from "./Trainer.js";
3
3
  import x from "./Evaluator.js";
4
- import { a as h } from "../index--6vO-cOz.js";
4
+ import { a as h } from "../index-iNhkcAEQ.js";
5
5
  const D = {
6
6
  desiredLoss: 0.01,
7
7
  logInterval: 1,
@@ -1,10 +1,10 @@
1
1
  import { DatasetBuilder as d } from "./DatasetBuilder.js";
2
2
  import h from "./AdamExt.js";
3
- import { t as g, v as u, a as o } from "../index--6vO-cOz.js";
4
- import { m as y, n as f } from "../norm-DSva3hI3.js";
5
- import { m as S, a as z } from "../moments-DYOHXoRV.js";
6
- import { m as b } from "../max-BUShNgfh.js";
7
- import { z as n } from "../zeros-8xl-W2DC.js";
3
+ import { t as g, v as u, a as o } from "../index-iNhkcAEQ.js";
4
+ import { m as y, n as f } from "../norm-D3676xIo.js";
5
+ import { m as S, a as z } from "../moments-B06NlR_V.js";
6
+ import { m as b } from "../max-CYaAjEEp.js";
7
+ import { z as n } from "../zeros-NMYTayy7.js";
8
8
  class G {
9
9
  constructor(t, s, e = 1e-3) {
10
10
  this.tokenizer = s, this.model = t, this.learningRate = e, this.resetOptimizer(), this.datasetBuilder = new d(s, t.config.gpt.blockSize);
@@ -53,8 +53,8 @@ class G {
53
53
  return g(() => {
54
54
  this.model.getProfiler()?.startMemory();
55
55
  const { xs: a, ys: r } = t, l = () => {
56
- const { loss: m, logits: p } = this.model.forward(a, r, !0);
57
- return p.dispose(), m;
56
+ const [m, p] = this.model.forward({ training: !0 }, a, r);
57
+ return m.dispose(), p;
58
58
  }, { value: c, grads: i } = u(l);
59
59
  return s ? this.model.getProfiler()?.endMemory("Training") : (e && (console.log("-------"), this.printGradients(i), console.log("-------")), this.optimizer.applyGradients(i), this.model.getProfiler()?.endMemory("Training"), o(i)), c;
60
60
  });
@@ -1,9 +1,9 @@
1
1
  import { gatherSub as L } from "../ops/gatherSub.js";
2
2
  import { scatterSub as y } from "../ops/scatterSub.js";
3
- import { e as u, c as i, z as S, t as f, s as G } from "../index--6vO-cOz.js";
4
- import { s as v } from "../softmax-Dsxflvdl.js";
5
- import { m as z } from "../max-BUShNgfh.js";
6
- import { l as k } from "../log_sum_exp-CiEy1aUe.js";
3
+ import { e as u, c as i, z as S, t as f, s as G } from "../index-iNhkcAEQ.js";
4
+ import { s as v } from "../softmax-BjsptB07.js";
5
+ import { m as z } from "../max-CYaAjEEp.js";
6
+ import { l as k } from "../log_sum_exp-CkumwesB.js";
7
7
  function F(a, s) {
8
8
  return f(() => {
9
9
  const e = a.shape[a.shape.length - 1], o = a.shape.slice(0, -1).reduce((d, c) => d * c, 1), p = a.shape.length > 2 ? a.reshape([o, e]) : a, n = s.shape.length > 1 ? s.reshape([o]).cast("int32") : s.cast("int32"), t = z(p, -1, !0), r = G(p, t), h = k(r, -1);
@@ -1,14 +1,14 @@
1
- import "../index--6vO-cOz.js";
2
- import { z as n } from "../zeros-8xl-W2DC.js";
3
- async function a(s) {
4
- const o = n([1, s.config.gpt.blockSize], "int32"), { logits: t, loss: i } = s.forward(o, void 0, !1);
5
- await t.data(), t.dispose(), i && i.dispose(), o.dispose();
1
+ import "../index-iNhkcAEQ.js";
2
+ import { z as n } from "../zeros-NMYTayy7.js";
3
+ async function c(s) {
4
+ const i = n([1, s.config.gpt.blockSize], "int32"), [t, o] = s.forward({ training: !1 }, i);
5
+ await t.data(), t.dispose(), o && o.dispose(), i.dispose();
6
6
  }
7
- function c(s) {
8
- const o = n([1, s.config.gpt.blockSize], "int32"), { logits: t, loss: i } = s.forward(o, void 0, !1);
9
- t.dispose(), i && i.dispose(), o.dispose();
7
+ function d(s) {
8
+ const i = n([1, s.config.gpt.blockSize], "int32"), [t, o] = s.forward({ training: !1 }, i);
9
+ t.dispose(), o && o.dispose(), i.dispose();
10
10
  }
11
11
  export {
12
- c as dummyPass,
13
- a as dummyPassAsync
12
+ d as dummyPass,
13
+ c as dummyPassAsync
14
14
  };
@@ -1,6 +1,6 @@
1
- import { t as y } from "../index--6vO-cOz.js";
2
- import { t as x } from "../tensor2d-DUr_htjt.js";
3
- import { c as f } from "../concat-DvWM7HGZ.js";
1
+ import { t as y } from "../index-iNhkcAEQ.js";
2
+ import { t as x } from "../tensor2d-tSxWdFMH.js";
3
+ import { c as f } from "../concat-Cxbo2sOz.js";
4
4
  async function A(o, r, a, c, T) {
5
5
  if (c <= 0)
6
6
  throw new Error("Length must be a positive integer");
@@ -3,7 +3,7 @@ import { importWeights as b } from "./weights.js";
3
3
  import u from "../tokeniser/CharTokeniser.js";
4
4
  import F from "../NanoGPTModel.js";
5
5
  import { dummyPassAsync as j } from "./dummy.js";
6
- import { d as T } from "../index--6vO-cOz.js";
6
+ import { d as T } from "../index-iNhkcAEQ.js";
7
7
  import E from "../tokeniser/bpe.js";
8
8
  async function A(t) {
9
9
  const o = await fetch(t);
@@ -1,4 +1,4 @@
1
- import { m as s } from "../index--6vO-cOz.js";
1
+ import { m as s } from "../index-iNhkcAEQ.js";
2
2
  const m = 1024 * 1024;
3
3
  class i {
4
4
  log = /* @__PURE__ */ new Map();
@@ -1,19 +1,21 @@
1
1
  import { j as g } from "../jszip.min-CjP2V1VV.js";
2
2
  import { exportWeights as l } from "./weights.js";
3
3
  import b from "../tokeniser/CharTokeniser.js";
4
- const y = "1.0.0";
4
+ const p = "1.0.0";
5
5
  async function h(t, a, i) {
6
- const o = i?.includeLog ?? !0, c = t.saveWeights(), e = new g(), f = {};
7
- for (const [n, s] of c) {
8
- const r = await l(s);
9
- f[n] = r.spec, e.file(`${n}.bin`, r.data.buffer, { binary: !0 });
6
+ const c = i?.includeLog ?? !0, f = /* @__PURE__ */ new Map();
7
+ t.saveWeights(f);
8
+ const e = new g(), r = {};
9
+ for (const [n, s] of f) {
10
+ const o = await l(s);
11
+ r[n] = o.spec, e.file(`${n}.bin`, o.data.buffer, { binary: !0 });
10
12
  }
11
13
  if (e.file(
12
14
  "manifest.json",
13
15
  JSON.stringify({
14
- weightSpec: f,
16
+ weightSpec: r,
15
17
  config: t.config,
16
- version: y,
18
+ version: p,
17
19
  application: "@genai-fi/nanogpt",
18
20
  meta: i?.metadata,
19
21
  name: i?.name
@@ -31,7 +33,7 @@ async function h(t, a, i) {
31
33
  {
32
34
  binary: !1
33
35
  }
34
- ), o && e.file("log.json", JSON.stringify(t.log), { binary: !1 }), i?.files)
36
+ ), c && e.file("log.json", JSON.stringify(t.log), { binary: !1 }), i?.files)
35
37
  for (const [n, s] of Object.entries(i.files))
36
38
  e.file(n, JSON.stringify(s), { binary: !1 });
37
39
  return e.generateAsync({ type: "blob" });
@@ -1,5 +1,5 @@
1
- import "../index--6vO-cOz.js";
2
- import { t as p } from "../tensor-BGYi41cj.js";
1
+ import "../index-iNhkcAEQ.js";
2
+ import { t as p } from "../tensor-CfiPXsW4.js";
3
3
  function h(n) {
4
4
  const e = n.reduce((s, o) => s + o.length, 0), a = new Float32Array(e);
5
5
  let t = 0;
@@ -1,4 +1,4 @@
1
- import { o as m, h as r, X as l, E as c, Y as i, k as p, Z as u, n as f } from "./index--6vO-cOz.js";
1
+ import { o as l, i as r, X as m, E as c, Y as i, l as p, Z as u, p as f } from "./index-iNhkcAEQ.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -17,11 +17,11 @@ import { o as m, h as r, X as l, E as c, Y as i, k as p, Z as u, n as f } from "
17
17
  */
18
18
  function x(a, e) {
19
19
  const o = r(a, "real", "complex"), s = r(e, "imag", "complex");
20
- l(o.shape, s.shape, `real and imag shapes, ${o.shape} and ${s.shape}, must match in call to tf.complex().`);
20
+ m(o.shape, s.shape, `real and imag shapes, ${o.shape} and ${s.shape}, must match in call to tf.complex().`);
21
21
  const n = { real: o, imag: s };
22
22
  return c.runKernel(i, n);
23
23
  }
24
- const g = /* @__PURE__ */ m({ complex_: x });
24
+ const g = /* @__PURE__ */ l({ complex_: x });
25
25
  /**
26
26
  * @license
27
27
  * Copyright 2018 Google LLC. All Rights Reserved.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.4.5",
3
+ "version": "0.5.0",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",
@@ -1,90 +0,0 @@
1
- import { o as u, h as p, k as g, w as m, E as w, a4 as x, j as i } from "./index--6vO-cOz.js";
2
- import { r as y } from "./reshape-z51Eu-re.js";
3
- /**
4
- * @license
5
- * Copyright 2020 Google LLC. All Rights Reserved.
6
- * Licensed under the Apache License, Version 2.0 (the "License");
7
- * you may not use this file except in compliance with the License.
8
- * You may obtain a copy of the License at
9
- *
10
- * http://www.apache.org/licenses/LICENSE-2.0
11
- *
12
- * Unless required by applicable law or agreed to in writing, software
13
- * distributed under the License is distributed on an "AS IS" BASIS,
14
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- * See the License for the specific language governing permissions and
16
- * limitations under the License.
17
- * =============================================================================
18
- */
19
- function T(a, t) {
20
- let e = p(a, "broadcastTo", "x");
21
- const r = e.shape;
22
- if (g(t), t.length < e.rank)
23
- throw new Error(`broadcastTo(): shape.length=${t.length} < input.rank=${e.rank}.`);
24
- if (t.length > e.rank) {
25
- const l = e.shape.slice();
26
- for (; l.length < t.length; )
27
- l.unshift(1);
28
- e = y(e, l);
29
- }
30
- const n = e.shape, o = Array.from(t);
31
- for (let l = t.length - 1; l >= 0; l--)
32
- if (n[l] === t[l])
33
- o[l] = 1;
34
- else if (e.shape[l] !== 1)
35
- throw new Error(`broadcastTo(): [${r}] cannot be broadcast to [${t}].`);
36
- if (o.map((l, h) => l > 1 ? h : -1).filter((l) => l >= 0).length === 0)
37
- return m(e);
38
- const f = { x: e }, c = { reps: o };
39
- return w.runKernel(x, f, c);
40
- }
41
- const A = /* @__PURE__ */ u({ broadcastTo_: T });
42
- /**
43
- * @license
44
- * Copyright 2021 Google LLC. All Rights Reserved.
45
- * Licensed under the Apache License, Version 2.0 (the "License");
46
- * you may not use this file except in compliance with the License.
47
- * You may obtain a copy of the License at
48
- *
49
- * http://www.apache.org/licenses/LICENSE-2.0
50
- *
51
- * Unless required by applicable law or agreed to in writing, software
52
- * distributed under the License is distributed on an "AS IS" BASIS,
53
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
54
- * See the License for the specific language governing permissions and
55
- * limitations under the License.
56
- * =============================================================================
57
- */
58
- function b(a, t, e) {
59
- let r = e.length;
60
- for (let n = 0; n < e.length; n++)
61
- if (e[n] > 1) {
62
- r = n;
63
- break;
64
- }
65
- for (let n = r + 1; n < e.length; n++)
66
- if (t[n] > 0 || e[n] !== a[n])
67
- return !1;
68
- return !0;
69
- }
70
- function E(a, t) {
71
- let e = a.length > 0 ? a[a.length - 1] : 1;
72
- for (let r = 0; r < a.length - 1; r++)
73
- e += a[r] * t[r];
74
- return e;
75
- }
76
- function N(a, t, e) {
77
- let r;
78
- const n = a.shape.length;
79
- typeof t == "number" ? r = [t, ...new Array(n - 1).fill(0)] : t.length < n ? r = t.concat(new Array(n - t.length).fill(0)) : r = t.slice(), r.forEach((s) => {
80
- i(s !== -1, () => "slice() does not support negative begin indexing.");
81
- });
82
- let o;
83
- return e == null ? o = new Array(n).fill(-1) : typeof e == "number" ? o = [e, ...new Array(n - 1).fill(-1)] : e.length < n ? o = e.concat(new Array(n - e.length).fill(-1)) : o = e, o = o.map((s, f) => s >= 0 ? s : (i(s === -1, () => `Negative size values should be exactly -1 but got ${s} for the slice() size at index ${f}.`), a.shape[f] - r[f])), [r, o];
84
- }
85
- export {
86
- A as b,
87
- E as c,
88
- b as i,
89
- N as p
90
- };