@genai-fi/nanogpt 0.9.0 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,15 @@
1
- import y from "./Trainer.js";
2
- import v from "./Evaluator.js";
1
+ import L from "./Trainer.js";
2
+ import w from "./Evaluator.js";
3
3
  import { d as S } from "../index-BzFyqcy-.js";
4
- import w from "../utilities/profile.js";
5
- const f = {
4
+ import f from "../utilities/profile.js";
5
+ const y = {
6
6
  desiredLoss: 0.01,
7
7
  logInterval: 1,
8
8
  maxSteps: 1e3
9
9
  };
10
- class b extends y {
11
- constructor(s, t, a = 3e-4) {
12
- super(s, t, a);
10
+ class x extends L {
11
+ constructor(s, e, a = 3e-4) {
12
+ super(s, e, a);
13
13
  }
14
14
  createEmptyState() {
15
15
  return {
@@ -23,52 +23,55 @@ class b extends y {
23
23
  ...this.lastState || {}
24
24
  };
25
25
  }
26
- createLogEntry(s, t, a, n) {
26
+ createLogEntry(s, e, a, l) {
27
27
  return {
28
28
  loss: s.lastLoss,
29
29
  step: s.step,
30
- time: Date.now() - t,
30
+ time: Date.now() - e,
31
31
  batchSize: a,
32
- learningRate: n ? this.optimizer.lr : void 0
32
+ learningRate: l ? this.optimizer.lr : void 0
33
33
  };
34
34
  }
35
- createProgress(s, t, a) {
35
+ createProgress(s, e, a) {
36
36
  return {
37
37
  duration: s.trainingDuration,
38
- totalSamples: s.totalSteps * t.batchSize,
39
- samplesPerSecond: s.totalSteps * t.batchSize / (s.trainingDuration / 1e3),
38
+ totalSamples: s.totalSteps * e.batchSize,
39
+ samplesPerSecond: s.totalSteps * e.batchSize / (s.trainingDuration / 1e3),
40
40
  memory: a ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
41
41
  };
42
42
  }
43
- async stepDataset(s, t, a) {
44
- const { logInterval: n } = {
45
- ...f,
46
- ...t
47
- }, l = Date.now(), r = this.createEmptyState();
48
- this.lastState = r, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new w())), this.running = !0, r.logStartTime = l;
49
- const m = a ? new v(this.model, a) : void 0, e = await s.iterator();
43
+ async stepDataset(s, e, a) {
44
+ const { logInterval: l } = {
45
+ ...y,
46
+ ...e
47
+ }, c = Date.now(), r = this.createEmptyState();
48
+ this.lastState = r, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new f())), this.running = !0, r.logStartTime = c;
49
+ const d = a ? new w(this.model, a) : void 0, t = await s.iterator();
50
50
  try {
51
51
  for (; this.running; ) {
52
- const i = await e.next();
52
+ const i = await t.next();
53
53
  if (i.done) break;
54
- const g = i.value, o = this.trainBatch(r, g), c = this.createLogEntry(r, l, g.xs.shape[0], t?.advancedMetrics);
55
- if (this.model.trainingState = {
56
- steps: r.totalSteps,
57
- learningRate: this.optimizer.lr,
58
- batchSize: g.xs.shape[0],
59
- loss: r.lastLoss
60
- }, r.step % n === 0) {
61
- await o.data();
54
+ const m = i.value, o = this.trainBatch(r, m);
55
+ if (r.step % l === 0) {
56
+ const g = (await o.data())[0];
57
+ r.lastLoss = g;
62
58
  const u = Date.now();
63
- if (r.trainingDuration += u - r.logStartTime, m)
59
+ r.trainingDuration += u - r.logStartTime;
60
+ const p = this.createLogEntry(r, c, m.xs.shape[0], e?.advancedMetrics);
61
+ if (this.model.trainingState = {
62
+ steps: r.totalSteps,
63
+ learningRate: this.optimizer.lr,
64
+ batchSize: m.xs.shape[0],
65
+ loss: r.lastLoss
66
+ }, d)
64
67
  try {
65
- const h = await m.evaluate(5);
66
- r.validationLosses.push(h), c.valLoss = h;
67
- } catch (h) {
68
- console.error("Validation error:", h);
68
+ const n = await d.evaluate(5);
69
+ r.validationLosses.push(n), p.valLoss = n;
70
+ } catch (n) {
71
+ console.error("Validation error:", n);
69
72
  }
70
- const p = this.createProgress(r, c, t?.advancedMetrics);
71
- return o.dispose(), this.stop(), { log: c, progress: p };
73
+ const v = this.createProgress(r, p, e?.advancedMetrics);
74
+ return o.dispose(), this.stop(), { log: p, progress: v };
72
75
  }
73
76
  o.dispose();
74
77
  }
@@ -78,42 +81,50 @@ class b extends y {
78
81
  throw S(), this.running = !1, new Error("No log returned before training stopped.");
79
82
  }
80
83
  // Train for multiple epochs using Dataset API - FIXED memory leaks
81
- async trainOnDataset(s, t, a) {
82
- const { logInterval: n, onStep: l, maxSteps: r } = {
83
- ...f,
84
- ...t
85
- }, m = Date.now(), e = this.createEmptyState();
86
- this.lastState = e, await this.dummyPass(), this.model.trainable = !0, t?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new w())), this.running = !0, e.logStartTime = m;
87
- const i = a ? new v(this.model, a) : void 0, g = await s.iterator();
84
+ async trainOnDataset(s, e, a) {
85
+ const { logInterval: l, onStep: c, maxSteps: r } = {
86
+ ...y,
87
+ ...e
88
+ }, d = Date.now(), t = this.createEmptyState();
89
+ this.lastState = t, await this.dummyPass(), this.model.trainable = !0, e?.advancedMetrics && (this.model.getProfiler() || this.model.setProfiler(new f())), this.running = !0, t.logStartTime = d;
90
+ const i = a ? new w(this.model, a) : void 0, m = await s.iterator();
88
91
  try {
89
92
  for (; this.running; ) {
90
- const o = await g.next();
93
+ const o = await m.next();
91
94
  if (o.done) break;
92
- const c = o.value, u = this.trainBatch(e, c), p = this.createLogEntry(e, m, c.xs.shape[0], t?.advancedMetrics);
93
- if (e.step % n === 0) {
94
- await u.data();
95
- const h = Date.now();
96
- if (e.trainingDuration += h - e.logStartTime, i)
95
+ const g = o.value, u = this.trainBatch(t, g);
96
+ if (t.step % l === 0) {
97
+ const p = (await u.data())[0];
98
+ t.lastLoss = p;
99
+ const v = Date.now();
100
+ t.trainingDuration += v - t.logStartTime;
101
+ const n = this.createLogEntry(t, d, g.xs.shape[0], e?.advancedMetrics);
102
+ if (this.model.trainingState = {
103
+ steps: t.totalSteps,
104
+ learningRate: this.optimizer.lr,
105
+ batchSize: g.xs.shape[0],
106
+ loss: t.lastLoss
107
+ }, i)
97
108
  try {
98
- const d = await i.evaluate(5);
99
- e.validationLosses.push(d), p.valLoss = d;
100
- } catch (d) {
101
- console.error("Validation error:", d);
109
+ const h = await i.evaluate(5);
110
+ t.validationLosses.push(h), n.valLoss = h;
111
+ } catch (h) {
112
+ console.error("Validation error:", h);
102
113
  }
103
- if (l) {
104
- const d = this.createProgress(e, p, t?.advancedMetrics);
105
- await l(p, d);
114
+ if (c) {
115
+ const h = this.createProgress(t, n, e?.advancedMetrics);
116
+ await c(n, h);
106
117
  }
107
- e.logStartTime = Date.now();
118
+ t.logStartTime = Date.now();
108
119
  }
109
- u.dispose(), e.step >= r && this.stop();
120
+ u.dispose(), t.step >= r && this.stop();
110
121
  }
111
122
  } catch (o) {
112
123
  throw console.error("Training error:", o), S(), o;
113
124
  }
114
- return S(), this.running = !1, { losses: e.losses, validationLosses: e.validationLosses };
125
+ return S(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
115
126
  }
116
127
  }
117
128
  export {
118
- b as default
129
+ x as default
119
130
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.9.0",
3
+ "version": "0.9.1",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",