@genai-fi/nanogpt 0.7.1 → 0.7.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. package/dist/Generator.d.ts +11 -2
  2. package/dist/Generator.js +81 -68
  3. package/dist/NanoGPTModel.js +8 -8
  4. package/dist/{RealDiv-CVYNbZxu.js → RealDiv-Dy0p8Bvo.js} +7 -7
  5. package/dist/{Reshape-CEsEp0AI.js → Reshape-DH5srBP0.js} +2 -2
  6. package/dist/{Reshape-Do18N3gO.js → Reshape-DvudQDvJ.js} +1 -1
  7. package/dist/TeachableLLM.js +33 -32
  8. package/dist/{TiedEmbedding-ccLBFiZi.js → TiedEmbedding-BxOerUmB.js} +4 -4
  9. package/dist/Trainer.d.ts +6 -1
  10. package/dist/Trainer.js +53 -19
  11. package/dist/{axis_util-5DTW2tFV.js → axis_util-BzbKo31C.js} +1 -1
  12. package/dist/backend.js +2 -2
  13. package/dist/{backend_util-C9Ut8n0Q.js → backend_util-TE7aTPhZ.js} +4 -4
  14. package/dist/{broadcast_to-Ba9h_8DO.js → broadcast_to-CdbwV-Dj.js} +2 -2
  15. package/dist/{concat-CbXTetof.js → concat-CsxrgovM.js} +1 -1
  16. package/dist/{dataset-U3PrjwgU.js → dataset-CtdBYwjo.js} +3 -3
  17. package/dist/{dropout-DPfPgWWe.js → dropout-DYs5QFGQ.js} +1 -1
  18. package/dist/{gather-Bbh8DHhM.js → gather-CMMy2KEG.js} +1 -1
  19. package/dist/{gelu-BFwVnd1r.js → gelu-C-dPj6Ku.js} +1 -1
  20. package/dist/{gpgpu_math-DffelNS-.js → gpgpu_math-DGNLNL4I.js} +2 -2
  21. package/dist/{index-UdZhlibC.js → index-BoWRt-10.js} +4 -4
  22. package/dist/{index-DYD_yPa-.js → index-CLthM0TO.js} +10 -10
  23. package/dist/{kernel_funcs_utils-CXDy3EN7.js → kernel_funcs_utils-BYKWV8Aa.js} +3 -3
  24. package/dist/layers/BaseLayer.js +2 -2
  25. package/dist/layers/CausalSelfAttention.js +6 -6
  26. package/dist/layers/MLP.js +5 -5
  27. package/dist/layers/RMSNorm.js +3 -3
  28. package/dist/layers/RoPECache.js +4 -4
  29. package/dist/layers/TiedEmbedding.js +5 -5
  30. package/dist/layers/TransformerBlock.js +1 -1
  31. package/dist/loader/loadTransformers.js +1 -1
  32. package/dist/loader/oldZipLoad.js +5 -5
  33. package/dist/{log_sum_exp-BnmCkHWl.js → log_sum_exp-DbjkV734.js} +5 -5
  34. package/dist/main.js +5 -5
  35. package/dist/{mat_mul-dwmZz69e.js → mat_mul-8m8pfdcx.js} +1 -1
  36. package/dist/{max-ByjEGoFx.js → max-Ddnnb5xe.js} +1 -1
  37. package/dist/{mulmat_packed_gpu-IGPBp6h9.js → mulmat_packed_gpu-VSekgsNv.js} +1 -1
  38. package/dist/{ones-C8Mfln6-.js → ones-Dj0SDhHf.js} +2 -2
  39. package/dist/ops/adamAdjust.js +1 -1
  40. package/dist/ops/adamMoments.js +1 -1
  41. package/dist/ops/appendCache.js +3 -3
  42. package/dist/ops/attentionMask.js +1 -1
  43. package/dist/ops/cpu/adamAdjust.js +1 -1
  44. package/dist/ops/cpu/adamMoments.js +2 -2
  45. package/dist/ops/cpu/appendCache.js +2 -2
  46. package/dist/ops/cpu/attentionMask.js +5 -5
  47. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  48. package/dist/ops/cpu/gatherSub.js +3 -3
  49. package/dist/ops/cpu/gelu.js +1 -1
  50. package/dist/ops/cpu/matMulGelu.js +2 -2
  51. package/dist/ops/cpu/matMulMul.js +1 -1
  52. package/dist/ops/cpu/mulDropout.js +1 -1
  53. package/dist/ops/cpu/normRMS.js +1 -1
  54. package/dist/ops/cpu/qkv.js +3 -3
  55. package/dist/ops/cpu/rope.js +5 -5
  56. package/dist/ops/cpu/scatterSub.js +5 -5
  57. package/dist/ops/fusedSoftmax.js +1 -1
  58. package/dist/ops/gatherSub.js +1 -1
  59. package/dist/ops/gelu.js +2 -2
  60. package/dist/ops/grads/attentionMask.js +1 -1
  61. package/dist/ops/grads/fusedSoftmax.js +2 -2
  62. package/dist/ops/grads/gelu.js +2 -2
  63. package/dist/ops/grads/matMulGelu.js +1 -1
  64. package/dist/ops/grads/normRMS.js +1 -1
  65. package/dist/ops/grads/qkv.js +1 -1
  66. package/dist/ops/grads/rope.js +1 -1
  67. package/dist/ops/matMulGelu.js +1 -1
  68. package/dist/ops/matMulMul.js +1 -1
  69. package/dist/ops/mulDrop.js +1 -1
  70. package/dist/ops/normRMS.js +1 -1
  71. package/dist/ops/qkv.js +1 -1
  72. package/dist/ops/rope.js +4 -4
  73. package/dist/ops/scatterSub.js +1 -1
  74. package/dist/ops/webgl/adamAdjust.js +2 -2
  75. package/dist/ops/webgl/adamMoments.js +7 -5
  76. package/dist/ops/webgl/appendCache.js +1 -1
  77. package/dist/ops/webgl/attentionMask.js +1 -1
  78. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  79. package/dist/ops/webgl/gatherSub.js +1 -1
  80. package/dist/ops/webgl/gelu.js +2 -2
  81. package/dist/ops/webgl/log.js +3 -3
  82. package/dist/ops/webgl/matMulGelu.js +4 -4
  83. package/dist/ops/webgl/matMulMul.js +1 -1
  84. package/dist/ops/webgl/mulDropout.js +1 -1
  85. package/dist/ops/webgl/normRMS.js +2 -2
  86. package/dist/ops/webgl/qkv.js +1 -1
  87. package/dist/ops/webgl/rope.js +1 -1
  88. package/dist/ops/webgl/scatterSub.js +1 -1
  89. package/dist/ops/webgpu/adamAdjust.js +15 -13
  90. package/dist/ops/webgpu/adamMoments.js +18 -11
  91. package/dist/ops/webgpu/appendCache.js +18 -15
  92. package/dist/ops/webgpu/attentionMask.js +24 -18
  93. package/dist/ops/webgpu/gatherSub.js +17 -30
  94. package/dist/ops/webgpu/gelu.js +3 -3
  95. package/dist/ops/webgpu/normRMS.js +16 -8
  96. package/dist/ops/webgpu/normRMSGrad.js +25 -20
  97. package/dist/ops/webgpu/qkv.js +23 -19
  98. package/dist/ops/webgpu/rope.js +37 -24
  99. package/dist/ops/webgpu/scatterSub.js +16 -14
  100. package/dist/ops/webgpu/utils/reductions.js +4 -4
  101. package/dist/{ops-aRTXR2Sr.js → ops-BFGCx8Ri.js} +15 -15
  102. package/dist/{random_width-DbSpgl4o.js → random_width-sZORGo5k.js} +22 -22
  103. package/dist/{range-D9CZhVlR.js → range-CRuAh-gd.js} +1 -1
  104. package/dist/{reciprocal-CGB48wZB.js → reciprocal-BvGAyKyu.js} +1 -1
  105. package/dist/{register_all_kernels-DnbAyBXt.js → register_all_kernels-BwDSRN-f.js} +30 -30
  106. package/dist/{reshape-BR0eoLYN.js → reshape-CdBq1WJ6.js} +1 -1
  107. package/dist/{scatter_nd_util-OjyAxku2.js → scatter_nd_util-DUstGbU1.js} +1 -1
  108. package/dist/{selu_util-Ce6pu9IM.js → selu_util-BJEXVvjX.js} +3 -3
  109. package/dist/{shared-Czipaeb6.js → shared-B8ztnyEk.js} +6 -6
  110. package/dist/{shared-DS5waSIY.js → shared-wS99K7_n.js} +1 -1
  111. package/dist/{sin-CiBxrDqX.js → sin-BeA3tsEd.js} +1 -1
  112. package/dist/{slice-BHbDHObE.js → slice-BiOsknYS.js} +1 -1
  113. package/dist/{softmax-JMEIUo2J.js → softmax-Bv_6lyMX.js} +1 -1
  114. package/dist/{split-CRU0PjVV.js → split-B-dikLRw.js} +1 -1
  115. package/dist/{stack-ikk2Y8_P.js → stack-B17UN2nn.js} +1 -1
  116. package/dist/{sum-NLYbiDag.js → sum-66ew2byf.js} +1 -1
  117. package/dist/{tensor-Do9PKbIE.js → tensor-JwS7ZYY6.js} +1 -1
  118. package/dist/{tensor2d-CWHxHpLh.js → tensor2d-wxPAnDQy.js} +1 -1
  119. package/dist/training/Adam.js +2 -2
  120. package/dist/training/AdamExt.js +1 -1
  121. package/dist/training/DatasetBuilder.js +35 -32
  122. package/dist/training/FullTrainer.d.ts +15 -2
  123. package/dist/training/FullTrainer.js +97 -51
  124. package/dist/training/Trainer.d.ts +10 -0
  125. package/dist/training/Trainer.js +2 -2
  126. package/dist/training/sparseCrossEntropy.js +4 -4
  127. package/dist/utilities/dummy.js +2 -2
  128. package/dist/utilities/generate.js +3 -3
  129. package/dist/utilities/multinomialCPU.js +2 -2
  130. package/dist/utilities/performance.js +1 -1
  131. package/dist/utilities/profile.js +1 -1
  132. package/dist/utilities/safetensors.js +2 -2
  133. package/dist/utilities/weights.js +2 -2
  134. package/dist/{variable-BTBkayv_.js → variable-BuddVFLa.js} +1 -1
  135. package/dist/{webgpu_program-WaoMq-WD.js → webgpu_program-PFzf1hAQ.js} +1 -1
  136. package/dist/{webgpu_util-DhSeP4b6.js → webgpu_util-D____QpY.js} +1 -1
  137. package/dist/{zeros-DnPT2nD4.js → zeros--BdLQ3oG.js} +1 -1
  138. package/package.json +1 -1
@@ -1,5 +1,5 @@
1
- import { t as u } from "../index-UdZhlibC.js";
2
- import { d as z, i as f } from "../dataset-U3PrjwgU.js";
1
+ import { t as g } from "../index-BoWRt-10.js";
2
+ import { d as u, i as d } from "../dataset-CtdBYwjo.js";
3
3
  import "../index-Tf7vU29b.js";
4
4
  /**
5
5
  * @license
@@ -18,57 +18,60 @@ import "../index-Tf7vU29b.js";
18
18
  *
19
19
  * =============================================================================
20
20
  */
21
- function S(c) {
22
- return z(async () => {
23
- const t = await c();
24
- return f(() => t.next());
21
+ function z(r) {
22
+ return u(async () => {
23
+ const t = await r();
24
+ return d(() => t.next());
25
25
  });
26
26
  }
27
- const p = 8;
28
- async function y(c, t) {
29
- const s = await Promise.all(c.map((n) => t.encode(n))), i = t.eosToken >= 0;
30
- return s.map((n) => i ? [...n, t.eosToken] : n).flat();
27
+ const S = 8;
28
+ async function y(r, t) {
29
+ const s = await Promise.all(r.map((e) => t.encode(e))), o = t.eosToken >= 0, a = s.map((e) => o ? [...e, t.eosToken] : e).flat();
30
+ for (const e of a)
31
+ if (e < 0 || e >= t.vocabSize)
32
+ throw new Error(`Invalid token index ${e} found in tokenised data`);
33
+ return a;
31
34
  }
32
35
  class w {
33
36
  tokenizer;
34
37
  blockSize;
35
38
  pageSize;
36
39
  constructor(t, s = 128) {
37
- this.tokenizer = t, this.blockSize = s, this.pageSize = s * p;
40
+ this.tokenizer = t, this.blockSize = s, this.pageSize = s * S;
38
41
  }
39
42
  // Create dataset from text files
40
- async createTextDataset(t, s = 32, i, r) {
43
+ async createTextDataset(t, s = 32, o, a) {
41
44
  if (t.length < this.blockSize + 1)
42
45
  throw new Error(`Not enough tokens (${t.length}) for block size ${this.blockSize}`);
43
- if (i && i.size > t.length / this.pageSize / 2)
46
+ if (o && o.size > t.length / this.pageSize / 2)
44
47
  throw new Error("Too many masked pages - would leave insufficient training data");
45
- const n = (function* () {
46
- if (i && r) {
47
- const e = Array.from(i);
48
+ const e = (function* () {
49
+ if (o && a) {
50
+ const i = Array.from(o);
48
51
  for (; ; ) {
49
- const a = Math.floor(Math.random() * e.length), l = Math.floor(Math.random() * this.pageSize), o = e[a] * this.pageSize + l;
50
- if (o + this.blockSize + 1 > t.length)
52
+ const c = Math.floor(Math.random() * i.length), l = Math.floor(Math.random() * this.pageSize), n = i[c] * this.pageSize + l;
53
+ if (n + this.blockSize + 1 > t.length)
51
54
  continue;
52
- const h = t.slice(o, o + this.blockSize), g = t.slice(o + 1, o + this.blockSize + 1);
53
- yield { xs: h, ys: g };
55
+ const h = t.slice(n, n + this.blockSize), f = t.slice(n + 1, n + this.blockSize + 1);
56
+ yield { xs: h, ys: f };
54
57
  }
55
58
  } else
56
59
  for (; ; ) {
57
- const e = Math.floor(Math.random() * (t.length - this.blockSize - 1));
58
- if (i) {
59
- const o = Math.floor(e / this.pageSize), h = i.has(o);
60
- if (h && !r || !h && r)
60
+ const i = Math.floor(Math.random() * (t.length - this.blockSize - 1));
61
+ if (o) {
62
+ const n = Math.floor(i / this.pageSize), h = o.has(n);
63
+ if (h && !a || !h && a)
61
64
  continue;
62
65
  }
63
- const a = t.slice(e, e + this.blockSize), l = t.slice(e + 1, e + this.blockSize + 1);
64
- yield { xs: a, ys: l };
66
+ const c = t.slice(i, i + this.blockSize), l = t.slice(i + 1, i + this.blockSize + 1);
67
+ yield { xs: c, ys: l };
65
68
  }
66
69
  }).bind(this);
67
- return S(n).batch(s).map((e) => {
68
- const a = e;
69
- return u(() => ({
70
- xs: a.xs.cast("int32"),
71
- ys: a.ys.cast("int32")
70
+ return z(e).batch(s).map((i) => {
71
+ const c = i;
72
+ return g(() => ({
73
+ xs: c.xs.cast("int32"),
74
+ ys: c.ys.cast("int32")
72
75
  // this.tf.oneHot(batchData.ys.cast('int32'), this.tokenizer.vocabSize),
73
76
  }));
74
77
  }).prefetch(2);
@@ -76,6 +79,6 @@ class w {
76
79
  }
77
80
  export {
78
81
  w as DatasetBuilder,
79
- p as PAGE_FACTOR,
82
+ S as PAGE_FACTOR,
80
83
  y as flattenTokens
81
84
  };
@@ -1,10 +1,23 @@
1
1
  import { ITokeniser } from '../tokeniser/type';
2
- import { default as NanoGPT } from '../NanoGPTModel';
3
- import { default as GPTTrainer, TrainingOptions } from './Trainer';
2
+ import { default as NanoGPT, TrainingLogEntry } from '../NanoGPTModel';
3
+ import { default as GPTTrainer, TrainingOptions, TrainingProgress } from './Trainer';
4
4
  import { Tensor } from '@tensorflow/tfjs-core';
5
5
  import { Dataset } from '@tensorflow/tfjs-data';
6
6
  export default class FullTrainer extends GPTTrainer {
7
7
  constructor(model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
8
+ private createEmptyState;
9
+ private createLogEntry;
10
+ private createProgress;
11
+ stepDataset(dataset: Dataset<{
12
+ xs: Tensor;
13
+ ys: Tensor;
14
+ }>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
15
+ xs: Tensor;
16
+ ys: Tensor;
17
+ }>): Promise<{
18
+ log: TrainingLogEntry;
19
+ progress: TrainingProgress;
20
+ }>;
8
21
  trainOnDataset(dataset: Dataset<{
9
22
  xs: Tensor;
10
23
  ys: Tensor;
@@ -1,81 +1,127 @@
1
- import { generateText as w } from "../utilities/generate.js";
2
- import T from "./Trainer.js";
3
- import L from "./Evaluator.js";
4
- import { d as h } from "../index-UdZhlibC.js";
5
- import x from "../utilities/profile.js";
6
- const y = {
1
+ import { generateText as v } from "../utilities/generate.js";
2
+ import x from "./Trainer.js";
3
+ import S from "./Evaluator.js";
4
+ import { d as w } from "../index-BoWRt-10.js";
5
+ import y from "../utilities/profile.js";
6
+ const T = {
7
7
  desiredLoss: 0.01,
8
8
  logInterval: 1,
9
9
  maxSteps: 1e3
10
10
  };
11
- class E extends T {
12
- constructor(i, e, r = 3e-4) {
13
- super(i, e, r);
11
+ class z extends x {
12
+ constructor(r, t, s = 3e-4) {
13
+ super(r, t, s);
14
14
  }
15
- // Train for multiple epochs using Dataset API - FIXED memory leaks
16
- async trainOnDataset(i, e, r) {
17
- const { logInterval: g, onStep: l, prompt: c, maxSteps: u } = {
18
- ...y,
19
- ...e
20
- }, n = Date.now(), t = {
15
+ createEmptyState() {
16
+ return {
21
17
  step: 0,
22
18
  lastLoss: 1e6,
23
19
  totalSteps: 0,
24
20
  losses: [],
25
21
  validationLosses: [],
26
- logStartTime: n,
22
+ logStartTime: 0,
27
23
  trainingDuration: 0,
28
24
  ...this.lastState || {}
29
25
  };
30
- this.lastState = t, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new x())), this.running = !0, t.logStartTime = n;
31
- const m = r ? new L(this.model, r) : void 0, f = await i.iterator();
26
+ }
27
+ createLogEntry(r, t, s, h) {
28
+ return {
29
+ loss: r.lastLoss,
30
+ step: r.step,
31
+ time: Date.now() - t,
32
+ batchSize: s,
33
+ learningRate: h ? this.optimizer.lr : void 0
34
+ };
35
+ }
36
+ createProgress(r, t, s) {
37
+ return {
38
+ duration: r.trainingDuration,
39
+ totalSamples: r.totalSteps * t.batchSize,
40
+ samplesPerSecond: r.totalSteps * t.batchSize / (r.trainingDuration / 1e3),
41
+ memory: s ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
42
+ };
43
+ }
44
+ async stepDataset(r, t, s) {
45
+ const { logInterval: h, prompt: m } = {
46
+ ...T,
47
+ ...t
48
+ }, g = Date.now(), a = this.createEmptyState();
49
+ this.lastState = a, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, a.logStartTime = g;
50
+ const p = s ? new S(this.model, s) : void 0, e = await r.iterator();
51
+ try {
52
+ for (; this.running; ) {
53
+ const i = await e.next();
54
+ if (i.done) break;
55
+ const u = i.value, o = this.trainBatch(a, u), n = this.createLogEntry(a, g, u.xs.shape[0], t?.advancedMetrics);
56
+ if (this.model.log.push(n), a.step % h === 0) {
57
+ await o.data();
58
+ const f = Date.now();
59
+ if (a.trainingDuration += f - a.logStartTime, p)
60
+ try {
61
+ const l = await p.evaluate(5);
62
+ a.validationLosses.push(l), n.valLoss = l;
63
+ } catch (l) {
64
+ console.error("Validation error:", l);
65
+ }
66
+ if (m) {
67
+ const l = await v(this.tokenizer, this.model, m, 100, {
68
+ temperature: 0.8
69
+ });
70
+ n.example = l;
71
+ }
72
+ const c = this.createProgress(a, n, t?.advancedMetrics);
73
+ return o.dispose(), this.stop(), { log: n, progress: c };
74
+ }
75
+ o.dispose();
76
+ }
77
+ } catch (i) {
78
+ throw console.error("Training error:", i), w(), i;
79
+ }
80
+ throw w(), this.running = !1, new Error("No log returned before training stopped.");
81
+ }
82
+ // Train for multiple epochs using Dataset API - FIXED memory leaks
83
+ async trainOnDataset(r, t, s) {
84
+ const { logInterval: h, onStep: m, prompt: g, maxSteps: a } = {
85
+ ...T,
86
+ ...t
87
+ }, p = Date.now(), e = this.createEmptyState();
88
+ this.lastState = e, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || (this.model.config.layerConfig.profiler = new y())), this.running = !0, e.logStartTime = p;
89
+ const i = s ? new S(this.model, s) : void 0, u = await r.iterator();
32
90
  try {
33
91
  for (; this.running; ) {
34
- const o = await f.next();
92
+ const o = await u.next();
35
93
  if (o.done) break;
36
- const d = o.value, p = this.trainBatch(t, d), s = {
37
- loss: t.lastLoss,
38
- step: t.step,
39
- time: Date.now() - n,
40
- batchSize: d.xs.shape[0],
41
- learningRate: e?.advancedMetrics ? this.optimizer.lr : void 0
42
- //gradientNorm: options?.advancedMetrics ? await state.gradientNorm : undefined,
43
- };
44
- if (this.model.log.push(s), t.step % g === 0) {
45
- await p.data();
46
- const S = Date.now();
47
- if (t.trainingDuration += S - t.logStartTime, m)
94
+ const n = o.value, f = this.trainBatch(e, n), c = this.createLogEntry(e, p, n.xs.shape[0], t?.advancedMetrics);
95
+ if (this.model.log.push(c), e.step % h === 0) {
96
+ await f.data();
97
+ const l = Date.now();
98
+ if (e.trainingDuration += l - e.logStartTime, i)
48
99
  try {
49
- const a = await m.evaluate(5);
50
- t.validationLosses.push(a), s.valLoss = a;
51
- } catch (a) {
52
- console.error("Validation error:", a);
100
+ const d = await i.evaluate(5);
101
+ e.validationLosses.push(d), c.valLoss = d;
102
+ } catch (d) {
103
+ console.error("Validation error:", d);
53
104
  }
54
- if (l) {
55
- if (c) {
56
- const v = await w(this.tokenizer, this.model, c, 100, {
105
+ if (m) {
106
+ if (g) {
107
+ const L = await v(this.tokenizer, this.model, g, 100, {
57
108
  temperature: 0.8
58
109
  });
59
- s.example = v;
110
+ c.example = L;
60
111
  }
61
- const a = {
62
- duration: t.trainingDuration,
63
- totalSamples: t.totalSteps * s.batchSize,
64
- samplesPerSecond: t.totalSteps * s.batchSize / (t.trainingDuration / 1e3),
65
- memory: e.advancedMetrics ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
66
- };
67
- await l(s, a);
112
+ const d = this.createProgress(e, c, t?.advancedMetrics);
113
+ await m(c, d);
68
114
  }
69
- t.logStartTime = Date.now();
115
+ e.logStartTime = Date.now();
70
116
  }
71
- p.dispose(), t.step >= u && this.stop();
117
+ f.dispose(), e.step >= a && this.stop();
72
118
  }
73
119
  } catch (o) {
74
- throw console.error("Training error:", o), h(), o;
120
+ throw console.error("Training error:", o), w(), o;
75
121
  }
76
- return h(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
122
+ return w(), this.running = !1, { losses: e.losses, validationLosses: e.validationLosses };
77
123
  }
78
124
  }
79
125
  export {
80
- E as default
126
+ z as default
81
127
  };
@@ -66,6 +66,16 @@ export default abstract class GPTTrainer {
66
66
  losses: number[];
67
67
  validationLosses: number[];
68
68
  }>;
69
+ abstract stepDataset(dataset: Dataset<{
70
+ xs: Tensor;
71
+ ys: Tensor;
72
+ }>, options: Partial<TrainingOptions>, validationDataset?: Dataset<{
73
+ xs: Tensor;
74
+ ys: Tensor;
75
+ }>): Promise<{
76
+ log: TrainingLogEntry;
77
+ progress: TrainingProgress;
78
+ }>;
69
79
  createTrainValidationSplit(textData: string[], batchSize?: number, validationSplit?: number): Promise<{
70
80
  trainDataset: Dataset<{
71
81
  xs: Tensor;
@@ -1,7 +1,7 @@
1
1
  import { DatasetBuilder as h, flattenTokens as p, PAGE_FACTOR as g } from "./DatasetBuilder.js";
2
2
  import u from "./AdamExt.js";
3
- import { t as f, v as y, d as c } from "../index-UdZhlibC.js";
4
- import { z as m } from "../zeros-DnPT2nD4.js";
3
+ import { t as f, v as y, d as c } from "../index-BoWRt-10.js";
4
+ import { z as m } from "../zeros--BdLQ3oG.js";
5
5
  class x {
6
6
  constructor(t, e, a = 1e-3) {
7
7
  this.tokenizer = e, this.model = t, this.learningRate = a, this.resetOptimizer(), this.datasetBuilder = new h(e, t.config.gpt.blockSize);
@@ -1,9 +1,9 @@
1
1
  import { gatherSub as x } from "../ops/gatherSub.js";
2
2
  import { scatterSub as L } from "../ops/scatterSub.js";
3
- import { y, t as u, z as C, c as E } from "../index-UdZhlibC.js";
4
- import { s as G } from "../softmax-JMEIUo2J.js";
5
- import { m as z } from "../max-ByjEGoFx.js";
6
- import { l as v } from "../log_sum_exp-BnmCkHWl.js";
3
+ import { y, t as u, z as C, c as E } from "../index-BoWRt-10.js";
4
+ import { s as G } from "../softmax-Bv_6lyMX.js";
5
+ import { m as z } from "../max-Ddnnb5xe.js";
6
+ import { l as v } from "../log_sum_exp-DbjkV734.js";
7
7
  function k(t, s) {
8
8
  return u(() => {
9
9
  const n = t.shape[t.shape.length - 1], c = t.shape.slice(0, -1).reduce((o, e) => o * e, 1), h = t.shape.length > 2 ? t.reshape([c, n]) : t, p = s.shape.length > 1 ? s.reshape([c]).cast("int32") : s.cast("int32"), r = z(h, -1, !0), a = E(h, r), m = v(a, -1);
@@ -1,5 +1,5 @@
1
- import { m as y, v as P, e as S } from "../index-UdZhlibC.js";
2
- import { z as i } from "../zeros-DnPT2nD4.js";
1
+ import { m as y, v as P, e as S } from "../index-BoWRt-10.js";
2
+ import { z as i } from "../zeros--BdLQ3oG.js";
3
3
  async function w(s) {
4
4
  const t = i([1, s.config.gpt.blockSize], "int32"), [e, n] = s.forward({ training: !1 }, t);
5
5
  await e.data(), e.dispose(), n && n.dispose(), t.dispose();
@@ -1,6 +1,6 @@
1
- import "../index-UdZhlibC.js";
2
- import { t as m } from "../tensor2d-CWHxHpLh.js";
3
- import { c as u } from "../concat-CbXTetof.js";
1
+ import "../index-BoWRt-10.js";
2
+ import { t as m } from "../tensor2d-wxPAnDQy.js";
3
+ import { c as u } from "../concat-CsxrgovM.js";
4
4
  async function v(o, r, a, c, f) {
5
5
  if (c <= 0)
6
6
  throw new Error("Length must be a positive integer");
@@ -1,5 +1,5 @@
1
- import "../index-UdZhlibC.js";
2
- import { t as e } from "../tensor2d-CWHxHpLh.js";
1
+ import "../index-BoWRt-10.js";
2
+ import { t as e } from "../tensor2d-wxPAnDQy.js";
3
3
  function l(n) {
4
4
  let r = 0;
5
5
  const i = Math.random();
@@ -1,4 +1,4 @@
1
- import { t as s } from "../index-UdZhlibC.js";
1
+ import { t as s } from "../index-BoWRt-10.js";
2
2
  async function f(e, o = 10, r = !1) {
3
3
  for (let t = 0; t < 100; t++) {
4
4
  const a = r ? await e() : s(e);
@@ -1,4 +1,4 @@
1
- import { m as a } from "../index-UdZhlibC.js";
1
+ import { m as a } from "../index-BoWRt-10.js";
2
2
  const s = 1024 * 1024;
3
3
  class l {
4
4
  log = /* @__PURE__ */ new Map();
@@ -1,5 +1,5 @@
1
- import "../index-UdZhlibC.js";
2
- import { t as y } from "../tensor-Do9PKbIE.js";
1
+ import "../index-BoWRt-10.js";
2
+ import { t as y } from "../tensor-JwS7ZYY6.js";
3
3
  function l(t) {
4
4
  if (t === "float32") return "F32";
5
5
  if (t === "int32") return "I32";
@@ -1,5 +1,5 @@
1
- import "../index-UdZhlibC.js";
2
- import { t as p } from "../tensor-Do9PKbIE.js";
1
+ import "../index-BoWRt-10.js";
2
+ import { t as p } from "../tensor-JwS7ZYY6.js";
3
3
  function h(n) {
4
4
  const e = n.reduce((s, o) => s + o.length, 0), a = new Float32Array(e);
5
5
  let t = 0;
@@ -1,4 +1,4 @@
1
- import { E as i } from "./index-UdZhlibC.js";
1
+ import { E as i } from "./index-BoWRt-10.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { aa as k, ab as z, ac as E, a1 as j, l as A } from "./index-UdZhlibC.js";
1
+ import { aa as k, ab as z, ac as E, a1 as j, l as A } from "./index-BoWRt-10.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2019 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { l as u } from "./index-UdZhlibC.js";
1
+ import { l as u } from "./index-BoWRt-10.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2019 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { B as m, C as r, a5 as l, E as c, a6 as i, F as p, a7 as u, j as f } from "./index-UdZhlibC.js";
1
+ import { B as m, C as r, a2 as l, E as c, a6 as i, F as p, a7 as u, j as f } from "./index-BoWRt-10.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.7.1",
3
+ "version": "0.7.3",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",