@genai-fi/nanogpt 0.15.13 → 0.15.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/Trainer.js CHANGED
@@ -1,7 +1,8 @@
1
1
  import { E as g } from "./index-DvYrXKkX.js";
2
2
  import o from "./training/PreTrainer.js";
3
- import { createTrainValidationSplit as p } from "./training/validation.js";
3
+ import { createTrainValidationSplit as d } from "./training/validation.js";
4
4
  import h from "./training/SFTTrainer.js";
5
+ import p from "./training/tasks/splitter.js";
5
6
  class l extends g {
6
7
  trainer;
7
8
  trainingType = "pretraining";
@@ -81,7 +82,7 @@ class l extends g {
81
82
  async prepare(t = []) {
82
83
  const i = this.options;
83
84
  if (this.trainingType === "pretraining" && this.trainer instanceof o) {
84
- const { trainDataset: e, validationDataset: a, size: r, trainState: n } = await p(
85
+ const { trainDataset: e, validationDataset: a, size: r, trainState: n } = await d(
85
86
  t,
86
87
  this.trainer.tokenizer,
87
88
  this.trainer.datasetBuilder,
@@ -92,12 +93,16 @@ class l extends g {
92
93
  } else if (this.trainingType === "sft" && this.trainer instanceof h) {
93
94
  if (t instanceof Uint16Array)
94
95
  throw new Error("SFT training requires Task[] input");
95
- const e = await this.trainer.datasetBuilder.createSFTDataset(
96
- t,
96
+ const e = p(t, i?.validationSplit || 0.1), a = await this.trainer.datasetBuilder.createSFTDataset(
97
+ [e.training],
98
+ i?.batchSize || 32,
99
+ -100
100
+ ), r = await this.trainer.datasetBuilder.createSFTDataset(
101
+ [e.validation],
97
102
  i?.batchSize || 32,
98
103
  -100
99
104
  );
100
- this.trainDataset = e, this.totalSamples = t.reduce((a, r) => a + r.length, 0), this.options.epochSteps = Math.ceil(this.totalSamples / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
105
+ this.validationDataset = r, this.trainDataset = a, this.totalSamples = t.reduce((n, s) => n + s.length, 0), this.options.epochSteps = Math.ceil(this.totalSamples / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
101
106
  }
102
107
  }
103
108
  configureModel(t) {
@@ -1,14 +1,14 @@
1
1
  import { p as u } from "../papaparse.min-C0cScC2i.js";
2
- import { loadParquet as d } from "./parquet.js";
3
- import { loadPDF as f } from "./pdf.js";
2
+ import { loadParquet as f } from "./parquet.js";
3
+ import { loadPDF as d } from "./pdf.js";
4
4
  import { loadDOCX as m } from "./docx.js";
5
5
  import { z as x } from "../jszip.min-BZhlzntC.js";
6
- function w(t, n) {
7
- const r = t.findIndex((i) => i.toLowerCase() === n.toLowerCase());
8
- return r === -1 ? 0 : r;
6
+ function y(t, r) {
7
+ const a = t.findIndex((i) => i.toLowerCase() === r.toLowerCase());
8
+ return a === -1 ? 0 : a;
9
9
  }
10
- function y(t) {
11
- return t.every((n) => n.length < 64);
10
+ function w(t) {
11
+ return t.every((r) => r.length < 64);
12
12
  }
13
13
  function h(t) {
14
14
  return t.split(".").pop() || "";
@@ -35,66 +35,72 @@ function g(t) {
35
35
  return "unknown";
36
36
  }
37
37
  }
38
- async function z(t, n) {
39
- const r = t.type !== "" ? t.type : g(t.name);
40
- if (r === "application/parquet")
41
- return d(t, n?.maxSize, n?.column);
42
- if (r === "application/pdf")
43
- return f(t, n?.maxSize);
44
- if (r === "application/vnd.openxmlformats-officedocument.wordprocessingml.document")
38
+ function j(t) {
39
+ if (!Array.isArray(t)) return !1;
40
+ const r = t[0];
41
+ return typeof r == "object" && r !== null && "role" in r && "content" in r && typeof r.role == "string" && typeof r.content == "string";
42
+ }
43
+ async function z(t, r) {
44
+ const a = t.type !== "" ? t.type : g(t.name);
45
+ if (a === "application/parquet")
46
+ return f(t, r?.maxSize, r?.column);
47
+ if (a === "application/pdf")
48
+ return d(t, r?.maxSize);
49
+ if (a === "application/vnd.openxmlformats-officedocument.wordprocessingml.document")
45
50
  return m(t);
46
- if (r === "application/json") {
47
- const i = await t.text(), a = JSON.parse(i);
48
- if (Array.isArray(a))
49
- return a.map(
51
+ if (a === "application/json") {
52
+ const i = await t.text(), o = JSON.parse(i);
53
+ if (Array.isArray(o))
54
+ return o.map(
50
55
  (e) => typeof e == "string" ? e : "text" in e ? e.text : JSON.stringify(e)
51
56
  );
52
57
  throw new Error("Expected JSON array");
53
58
  }
54
- if (r === "application/jsonl")
59
+ if (a === "application/jsonl")
55
60
  return (await t.text()).split(`
56
- `).filter((a) => a.trim() !== "").map((a) => {
61
+ `).filter((o) => o.trim() !== "").map((o) => {
57
62
  try {
58
- const e = JSON.parse(a);
59
- return typeof e == "string" ? e : "text" in e ? e.text : JSON.stringify(e);
63
+ const e = JSON.parse(o);
64
+ return j(e) ? e.map((n) => `${n.content}`).join(`
65
+ `) : typeof e == "string" ? e : "text" in e ? e.text : JSON.stringify(e);
60
66
  } catch {
61
- return a;
67
+ return o;
62
68
  }
63
69
  });
64
- if (r === "application/zip") {
65
- const i = await x.loadAsync(t), a = [];
70
+ if (a === "application/zip") {
71
+ const i = await x.loadAsync(t), o = [];
66
72
  for (const e of Object.keys(i.files)) {
67
- const o = i.file(e);
68
- if (o) {
69
- const c = await o.async("blob"), p = await z(new File([c], e), n);
70
- a.push(...p);
73
+ const n = i.file(e);
74
+ if (n) {
75
+ const s = await n.async("blob"), c = await z(new File([s], e), r);
76
+ o.push(...c);
71
77
  }
72
78
  }
73
- return a;
79
+ return o;
74
80
  }
75
- if (r === "text/csv") {
81
+ if (a === "text/csv") {
76
82
  const i = await t.text();
77
- return new Promise((a, e) => {
83
+ return new Promise((o, e) => {
78
84
  u.parse(i, {
79
85
  header: !1,
80
86
  skipEmptyLines: !0,
81
87
  delimiter: ",",
82
- complete: (o) => {
83
- if (o.errors.length > 0)
84
- console.error(o.errors), e(new Error("Error parsing file"));
88
+ complete: (n) => {
89
+ if (n.errors.length > 0)
90
+ console.error(n.errors), e(new Error("Error parsing file"));
85
91
  else {
86
- const c = w(o.data[0], n?.column || "text"), s = n?.hasHeader ?? y(o.data[0]) ? o.data.slice(1) : o.data;
87
- a(s.map((l) => l[c]));
92
+ const s = y(n.data[0], r?.column || "text"), p = r?.hasHeader ?? w(n.data[0]) ? n.data.slice(1) : n.data;
93
+ o(p.map((l) => l[s]));
88
94
  }
89
95
  },
90
- error: (o) => {
91
- e(o);
96
+ error: (n) => {
97
+ e(n);
92
98
  }
93
99
  });
94
100
  });
95
- } else if (r === "text/plain")
101
+ } else if (a === "text/plain")
96
102
  return [await t.text()];
97
- throw new Error(`Unsupported file type: ${r}`);
103
+ throw new Error(`Unsupported file type: ${a}`);
98
104
  }
99
105
  export {
100
106
  z as default
@@ -1,8 +1,8 @@
1
1
  import u from "./Evaluator.js";
2
- import { t as L, v as P, k as h, d as g, a as y } from "../index-CUXkjxiT.js";
2
+ import { t as z, v as P, k as g, d as p, a as y } from "../index-CUXkjxiT.js";
3
3
  import S from "../utilities/profile.js";
4
- import { createTensorStatistics as x } from "../checks/weights.js";
5
- import { calculateLoss as k, calculateAccuracy as T } from "./loss.js";
4
+ import { createTensorStatistics as k } from "../checks/weights.js";
5
+ import { calculateLoss as x, calculateAccuracy as T } from "./loss.js";
6
6
  import { AdamWOptimizer as N } from "./AdamW.js";
7
7
  import { z as w } from "../zeros-DvZpK8s6.js";
8
8
  const v = {
@@ -23,11 +23,11 @@ const v = {
23
23
  lossScaling: 1
24
24
  };
25
25
  class G {
26
- constructor(e, i, o, c) {
27
- this.tokenizer = i, this.model = e, this.optimizerConfig = {
26
+ constructor(s, i, o, c) {
27
+ this.tokenizer = i, this.model = s, this.optimizerConfig = {
28
28
  ...b,
29
29
  ...o,
30
- lossScaling: e.lossScaling
30
+ lossScaling: s.lossScaling
31
31
  };
32
32
  const l = c || new N(this.optimizerConfig);
33
33
  c && c.updateConfig(this.optimizerConfig), this.optimizer = l;
@@ -44,26 +44,26 @@ class G {
44
44
  _labelSmoothing = 0;
45
45
  _layerDrop = 0;
46
46
  _dropout = 0;
47
- setGradientCheckpointing(e) {
48
- this._gradientCheckpointing = e;
47
+ setGradientCheckpointing(s) {
48
+ this._gradientCheckpointing = s;
49
49
  }
50
- setMixedPrecision(e) {
51
- this._mixedPrecision = e;
50
+ setMixedPrecision(s) {
51
+ this._mixedPrecision = s;
52
52
  }
53
- setLabelSmoothing(e) {
54
- this._labelSmoothing = e;
53
+ setLabelSmoothing(s) {
54
+ this._labelSmoothing = s;
55
55
  }
56
- setDropout(e) {
57
- this._dropout = e;
56
+ setDropout(s) {
57
+ this._dropout = s;
58
58
  }
59
- setLayerDrop(e) {
60
- this._layerDrop = e;
59
+ setLayerDrop(s) {
60
+ this._layerDrop = s;
61
61
  }
62
- setLearningRate(e) {
63
- this.optimizerConfig.learningRate = e, this.updateOptimizer();
62
+ setLearningRate(s) {
63
+ this.optimizerConfig.learningRate = s, this.updateOptimizer();
64
64
  }
65
- setMetrics(e) {
66
- this.metrics = new Set(e);
65
+ setMetrics(s) {
66
+ this.metrics = new Set(s);
67
67
  }
68
68
  reset() {
69
69
  this.lastState = void 0, this.running = !1;
@@ -77,12 +77,12 @@ class G {
77
77
  getOptimizer() {
78
78
  return this.optimizer;
79
79
  }
80
- updateOptimizer(e) {
81
- e && (this.optimizerConfig = { ...this.optimizerConfig, ...e }), this.optimizer.updateConfig(this.optimizerConfig);
80
+ updateOptimizer(s) {
81
+ s && (this.optimizerConfig = { ...this.optimizerConfig, ...s }), this.optimizer.updateConfig(this.optimizerConfig);
82
82
  }
83
83
  // A single forward pass, backward pass, and optimizer step
84
- trainStep(e, i, o = !1, c = !1) {
85
- return L(() => {
84
+ trainStep(s, i, o = !1, c = !1) {
85
+ return z(() => {
86
86
  this.model.getProfiler()?.startMemory();
87
87
  const { xs: l, ys: a } = i, d = () => {
88
88
  const r = this.model.forward(
@@ -94,31 +94,31 @@ class G {
94
94
  layerDrop: this._layerDrop
95
95
  },
96
96
  l
97
- ), s = k(r, a, this.maskedLoss, !1, this._labelSmoothing);
98
- this.metrics.has("accuracy") && (e.accuracy = T(r, a), h(e.accuracy)), r.dispose();
99
- const m = s.mul(y(this.optimizerConfig.lossScaling));
100
- return s.dispose(), m;
97
+ ), e = x(r, a, this.maskedLoss, !1, this._labelSmoothing);
98
+ this.metrics.has("accuracy") && (s.accuracy = T(r, a), g(s.accuracy)), r.dispose();
99
+ const m = e.mul(y(this.optimizerConfig.lossScaling));
100
+ return e.dispose(), m;
101
101
  }, { value: t, grads: n } = P(d);
102
102
  if (o)
103
103
  this.model.getProfiler()?.endMemory("Training");
104
104
  else {
105
105
  const r = this.optimizer.applyGradients(n);
106
- this.metrics.has("gradientNorm") ? (e.gradientNorm = r, h(r)) : (e.gradientNorm = void 0, r.dispose());
107
- const s = Object.keys(n);
108
- this.model.weightStore.touchVariables(s), this.model.getProfiler()?.endMemory("Training"), c ? (e.gradients = n, Object.values(n).forEach((m) => h(m))) : g(n);
106
+ this.metrics.has("gradientNorm") ? (s.gradientNorm = r, g(r)) : (s.gradientNorm = void 0, r.dispose());
107
+ const e = Object.keys(n);
108
+ this.model.weightStore.touchVariables(e), this.model.getProfiler()?.endMemory("Training"), c ? (s.gradients = n, Object.values(n).forEach((m) => g(m))) : p(n);
109
109
  }
110
110
  return t.mul(y(1 / this.optimizerConfig.lossScaling));
111
111
  });
112
112
  }
113
113
  async dummyPass() {
114
- const e = w([1, this.model.config.blockSize], "int32"), i = w([1, this.model.config.blockSize], "int32");
114
+ const s = w([1, this.model.config.blockSize], "int32"), i = w([1, this.model.config.blockSize], "int32");
115
115
  try {
116
- const o = this.trainStep({}, { xs: e, ys: i }, !0);
116
+ const o = this.trainStep({}, { xs: s, ys: i }, !0);
117
117
  await o.data(), o.dispose();
118
118
  } catch (o) {
119
119
  console.error("Error during dummy pass:", o);
120
120
  } finally {
121
- e.dispose(), i.dispose();
121
+ s.dispose(), i.dispose();
122
122
  }
123
123
  }
124
124
  dispose() {
@@ -136,7 +136,7 @@ class G {
136
136
  ...this.lastState || {}
137
137
  };
138
138
  }
139
- async stepDataset(e, i, o) {
139
+ async stepDataset(s, i, o) {
140
140
  const { logInterval: c = 10 } = {
141
141
  ...v,
142
142
  ...i
@@ -144,21 +144,21 @@ class G {
144
144
  i.metrics && this.setMetrics(i.metrics);
145
145
  const l = Date.now(), a = this.createEmptyState();
146
146
  this.lastState = a, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new S())), this.running = !0, a.logStartTime = l;
147
- const d = o ? new u(this.model, o) : void 0, t = await e.iterator();
147
+ const d = o ? new u(this.model, o, void 0, this.maskedLoss) : void 0, t = await s.iterator();
148
148
  try {
149
149
  for (; this.running; ) {
150
150
  const n = await t.next();
151
151
  if (n.done) break;
152
- const r = n.value, s = this.trainStep(a, r, !1);
153
- r.xs.dispose(), r.ys.dispose(), a.step++, a.totalSteps++, a.step % c === 0 ? await this.performLogging(s, r.xs.shape[0], i, d) : (a.gradientNorm && (a.gradientNorm.dispose(), a.gradientNorm = void 0), a.accuracy && (a.accuracy.dispose(), a.accuracy = void 0)), s.dispose();
152
+ const r = n.value, e = this.trainStep(a, r, !1);
153
+ r.xs.dispose(), r.ys.dispose(), a.step++, a.totalSteps++, a.step % c === 0 ? await this.performLogging(e, r.xs.shape[0], i, d) : (a.gradientNorm && (a.gradientNorm.dispose(), a.gradientNorm = void 0), a.accuracy && (a.accuracy.dispose(), a.accuracy = void 0)), e.dispose();
154
154
  }
155
155
  } catch (n) {
156
- throw console.error("Training error:", n), g(), n;
156
+ throw console.error("Training error:", n), p(), n;
157
157
  }
158
- throw g(), this.running = !1, new Error("No log returned before training stopped.");
158
+ throw p(), this.running = !1, new Error("No log returned before training stopped.");
159
159
  }
160
- async performLogging(e, i, o, c) {
161
- const l = o?.onStep, a = this.metrics.has("gradientStatistics"), d = (await e.data())[0], t = this.lastState;
160
+ async performLogging(s, i, o, c) {
161
+ const l = o?.onStep, a = this.metrics.has("gradientStatistics"), d = (await s.data())[0], t = this.lastState;
162
162
  t.lastLoss = d;
163
163
  const n = Date.now();
164
164
  t.trainingDuration += n - t.logStartTime;
@@ -184,25 +184,25 @@ class G {
184
184
  batchSize: i,
185
185
  loss: t.lastLoss
186
186
  }, a && t.gradients) {
187
- const s = /* @__PURE__ */ new Map();
188
- for (const [m, p] of Object.entries(t.gradients))
189
- s.set(m, await x(p)), p.dispose();
190
- r.gradientMetrics = s;
187
+ const e = /* @__PURE__ */ new Map();
188
+ for (const [m, h] of Object.entries(t.gradients))
189
+ e.set(m, await k(h)), h.dispose();
190
+ r.gradientMetrics = e;
191
191
  }
192
192
  if (c)
193
193
  try {
194
- const s = await c.evaluate(5);
195
- Array.isArray(s) ? r.validationMetrics = { loss: s[0].loss, accuracy: s[0].accuracy } : (t.validationLosses.push(s.loss), r.validationMetrics = {
196
- accuracy: s.accuracy,
197
- loss: s.loss,
198
- perplexity: this.metrics.has("perplexity") ? Math.exp(s.loss) : void 0
194
+ const e = await c.evaluate(5);
195
+ Array.isArray(e) ? r.validationMetrics = { loss: e[0].loss, accuracy: e[0].accuracy } : (t.validationLosses.push(e.loss), r.validationMetrics = {
196
+ accuracy: e.accuracy,
197
+ loss: e.loss,
198
+ perplexity: this.metrics.has("perplexity") ? Math.exp(e.loss) : void 0
199
199
  });
200
- } catch (s) {
201
- console.error("Validation error:", s);
200
+ } catch (e) {
201
+ console.error("Validation error:", e);
202
202
  }
203
203
  l && await l(r), t.logStartTime = Date.now();
204
204
  }
205
- async trainOnDataset(e, i, o) {
205
+ async trainOnDataset(s, i, o) {
206
206
  const { logInterval: c = 10, maxEpochs: l = 1 / 0 } = {
207
207
  ...v,
208
208
  ...i
@@ -210,18 +210,18 @@ class G {
210
210
  i.metrics && this.setMetrics(i.metrics);
211
211
  const d = Date.now(), t = this.createEmptyState();
212
212
  this.lastState = t, await this.dummyPass(), i?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new S())), this.running = !0, t.logStartTime = d;
213
- const n = o ? new u(this.model, o) : void 0, r = await e.iterator();
213
+ const n = o ? new u(this.model, o, void 0, this.maskedLoss) : void 0, r = await s.iterator();
214
214
  try {
215
215
  for (; this.running; ) {
216
- const s = await r.next();
217
- if (s.done) break;
218
- const m = s.value, p = t.step % c === 0, z = (i?.metrics?.includes("gradientStatistics") || !1) && p, f = this.trainStep(t, m, !1, z);
219
- m.xs.dispose(), m.ys.dispose(), t.step++, t.totalSteps++, p ? await this.performLogging(f, m.xs.shape[0], i, n) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= a && this.stop();
216
+ const e = await r.next();
217
+ if (e.done) break;
218
+ const m = e.value, h = t.step % c === 0, L = (i?.metrics?.includes("gradientStatistics") || !1) && h, f = this.trainStep(t, m, !1, L);
219
+ m.xs.dispose(), m.ys.dispose(), t.step++, t.totalSteps++, h ? await this.performLogging(f, m.xs.shape[0], i, n) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= a && this.stop();
220
220
  }
221
- } catch (s) {
222
- throw console.error("Training error:", s), g(), s;
221
+ } catch (e) {
222
+ throw console.error("Training error:", e), p(), e;
223
223
  }
224
- return g(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
224
+ return p(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
225
225
  }
226
226
  }
227
227
  export {
@@ -11,7 +11,8 @@ export default class Evaluator {
11
11
  private iterator?;
12
12
  private xs?;
13
13
  private ys?;
14
- constructor(model: Model<ModelForwardAttributes>, dataset: Dataset<TensorContainer> | Conversation[][], tokeniser?: ITokeniser);
14
+ private masked;
15
+ constructor(model: Model<ModelForwardAttributes>, dataset: Dataset<TensorContainer> | Conversation[][], tokeniser?: ITokeniser, masked?: boolean);
15
16
  dispose(): void;
16
17
  private calculateBatchLoss;
17
18
  evaluate(maxBatches?: number): Promise<Result | Result[]>;
@@ -2,12 +2,12 @@ import { t as p } from "../index-CUXkjxiT.js";
2
2
  import { calculateLoss as d, calculateAccuracy as m } from "./loss.js";
3
3
  import { buildSFTExample as x } from "./SFTDatasetBuilder.js";
4
4
  import { t as h } from "../tensor-BWFldCso.js";
5
- class g {
6
- constructor(c, t, o) {
7
- if (this.model = c, Array.isArray(t)) {
5
+ class k {
6
+ constructor(i, t, o, c) {
7
+ if (this.model = i, this.masked = !!c, Array.isArray(t)) {
8
8
  if (!o)
9
9
  throw new Error("Tokeniser is required when dataset is an array of conversations");
10
- const a = t.map((s) => x(s, -100, o, c.config.blockSize)).filter((s) => s !== null);
10
+ const a = t.map((s) => x(s, -100, o, i.config.blockSize)).filter((s) => s !== null);
11
11
  if (a.length === 0)
12
12
  return;
13
13
  this.xs = h(a.map((s) => s.xs)), this.ys = h(a.map((s) => s.ys));
@@ -17,32 +17,33 @@ class g {
17
17
  iterator;
18
18
  xs;
19
19
  ys;
20
+ masked = !1;
20
21
  dispose() {
21
22
  this.xs && this.xs.dispose(), this.ys && this.ys.dispose();
22
23
  }
23
- async calculateBatchLoss(c, t, o, a) {
24
- const [s, e] = p(() => {
25
- const r = this.model.forward({ training: !1 }, c), y = d(r, t, a, o), f = m(r, t);
24
+ async calculateBatchLoss(i, t, o, c) {
25
+ const [a, s] = p(() => {
26
+ const r = this.model.forward({ training: !1 }, i), y = d(r, t, c, o), f = m(r, t);
26
27
  return r.dispose(), [y, f];
27
- }), n = await s.array(), u = await e.array(), i = n, l = u;
28
- return e.dispose(), s.dispose(), Array.isArray(i) ? i.map((r) => ({ loss: r, accuracy: l })) : { loss: i, accuracy: l };
28
+ }), n = await a.array(), u = await s.array(), e = n, l = u;
29
+ return s.dispose(), a.dispose(), Array.isArray(e) ? e.map((r) => ({ loss: r, accuracy: l })) : { loss: e, accuracy: l };
29
30
  }
30
- async evaluate(c = 100) {
31
- let t = 0, o = 0, a = 0;
31
+ async evaluate(i = 100) {
32
+ let t = 0, o = 0, c = 0;
32
33
  if (this.iterator) {
33
- const s = await this.iterator;
34
- for (let e = 0; e < c; e++) {
35
- const n = await s.next();
34
+ const a = await this.iterator;
35
+ for (let s = 0; s < i; s++) {
36
+ const n = await a.next();
36
37
  if (n.done) break;
37
- const u = n.value, { xs: i, ys: l } = u, r = await this.calculateBatchLoss(i, l, !1, !1);
38
- i.dispose(), l.dispose(), t += r.loss, o += r.accuracy, a++;
38
+ const u = n.value, { xs: e, ys: l } = u, r = await this.calculateBatchLoss(e, l, !1, this.masked);
39
+ e.dispose(), l.dispose(), t += r.loss, o += r.accuracy, c++;
39
40
  }
40
- return { loss: t / a, accuracy: o / a };
41
+ return { loss: t / c, accuracy: o / c };
41
42
  } else if (this.xs && this.ys)
42
43
  return this.calculateBatchLoss(this.xs, this.ys, !0, !0);
43
44
  throw new Error("No data available for evaluation");
44
45
  }
45
46
  }
46
47
  export {
47
- g as default
48
+ k as default
48
49
  };
@@ -1,50 +1,50 @@
1
- import { t as x } from "../index-CUXkjxiT.js";
1
+ import { t as y } from "../index-CUXkjxiT.js";
2
2
  import "../dataset-CGGp1z9P.js";
3
3
  import { g as I } from "../readers-iz5u3HBo.js";
4
4
  import "../index-Cp39cXWe.js";
5
- function w(u, a, t, r) {
6
- const s = [t.bosToken], n = [!1], f = {
5
+ function w(p, o, t, l) {
6
+ const s = [t.bosToken], a = [!1], u = {
7
7
  user: t.getSpecialTokenIndex("<|user_start|>"),
8
8
  assistant: t.getSpecialTokenIndex("<|assistant_start|>"),
9
9
  system: t.getSpecialTokenIndex("<|system_start|>")
10
- }, i = {
10
+ }, c = {
11
11
  user: t.getSpecialTokenIndex("<|user_end|>"),
12
12
  assistant: t.getSpecialTokenIndex("<|assistant_end|>"),
13
13
  system: t.getSpecialTokenIndex("<|system_end|>")
14
14
  };
15
- for (const e of u) {
16
- const c = f[e.role], h = i[e.role];
17
- if (!c || !h)
15
+ for (const e of p) {
16
+ const r = u[e.role], h = c[e.role];
17
+ if (!r || !h)
18
18
  throw new Error(`Missing special tokens for role: ${e.role}`);
19
- s.push(c), n.push(!1);
19
+ s.push(r), a.push(!1);
20
20
  const m = e.role === "assistant", S = t.encode(e.content);
21
21
  for (const T of S) {
22
22
  s.push(T);
23
- const y = t.isSpecialToken(T);
24
- n.push(m && !y);
23
+ const x = t.isSpecialToken(T);
24
+ a.push(m && !x);
25
25
  }
26
- s.push(h), n.push(m);
26
+ s.push(h), a.push(m);
27
27
  }
28
- s.push(t.eosToken), n.push(!1);
29
- const o = r + 1;
30
- if (s.length < o) {
31
- const e = o - s.length, c = t.getSpecialTokenIndex("<pad>");
28
+ s.push(t.eosToken), a.push(!1);
29
+ const n = l + 1;
30
+ if (s.length < n) {
31
+ const e = n - s.length, r = t.getSpecialTokenIndex("<pad>");
32
32
  for (let h = 0; h < e; h++)
33
- s.push(c), n.push(!1);
34
- } else s.length > o && (s.length = o, n.length = o);
35
- const p = new Int32Array(s.slice(0, r)), l = s.slice(1, r + 1), k = n.slice(1, r + 1), d = new Int32Array(l.length);
33
+ s.push(r), a.push(!1);
34
+ } else s.length > n && (s.length = n, a.length = n);
35
+ const f = new Int32Array(s.slice(0, l)), i = s.slice(1, l + 1), k = a.slice(1, l + 1), d = new Int32Array(i.length);
36
36
  let g = !1;
37
- for (let e = 0; e < l.length; e++) {
38
- const c = k[e] ? l[e] : a;
39
- d[e] = c, c !== a && (g = !0);
37
+ for (let e = 0; e < i.length; e++) {
38
+ const r = k[e] ? i[e] : o;
39
+ d[e] = r, r !== o && (g = !0);
40
40
  }
41
- return g ? { xs: p, ys: d } : null;
41
+ return g ? { xs: f, ys: d } : null;
42
42
  }
43
- class A {
43
+ class D {
44
44
  tokenizer;
45
45
  blockSize;
46
- constructor(a, t = 128) {
47
- this.tokenizer = a, this.blockSize = t;
46
+ constructor(o, t = 128) {
47
+ this.tokenizer = o, this.blockSize = t;
48
48
  }
49
49
  /**
50
50
  * Create SFT dataset from structured conversations.
@@ -52,25 +52,32 @@ class A {
52
52
  * - Pads with eosToken and masks padding.
53
53
  * - Masks non-assistant tokens in labels with ignoreIndex (default -100).
54
54
  */
55
- async createSFTDataset(a, t = 32, r = -100) {
56
- if (!a.length)
55
+ async createSFTDataset(o, t = 32, l = -100) {
56
+ if (!o.length)
57
57
  throw new Error("No conversations provided.");
58
- const s = this.tokenizer, n = this.blockSize;
58
+ const s = this.tokenizer, a = this.blockSize;
59
+ for (const c of o)
60
+ c.shuffle();
59
61
  return I(function* () {
60
62
  for (; ; ) {
61
- const i = Math.floor(Math.random() * a.length), p = a[i].getRandomConversation(), l = w(p, r, s, n);
62
- l && (yield l);
63
+ const c = Math.floor(Math.random() * o.length), n = o[c], f = n.nextConversation();
64
+ if (!f) {
65
+ n.shuffle();
66
+ continue;
67
+ }
68
+ const i = w(f, l, s, a);
69
+ i && (yield i);
63
70
  }
64
- }).batch(t).map((i) => {
65
- const o = i;
66
- return x(() => ({
67
- xs: o.xs.cast("int32"),
68
- ys: o.ys.cast("int32")
71
+ }).batch(t).map((c) => {
72
+ const n = c;
73
+ return y(() => ({
74
+ xs: n.xs.cast("int32"),
75
+ ys: n.ys.cast("int32")
69
76
  }));
70
77
  }).prefetch(2);
71
78
  }
72
79
  }
73
80
  export {
74
- A as SFTDatasetBuilder,
81
+ D as SFTDatasetBuilder,
75
82
  w as buildSFTExample
76
83
  };
@@ -2,13 +2,13 @@ import { Conversation, ITokeniser } from '../../main';
2
2
  import { Task } from './Task';
3
3
  export default class ConversationTask extends Task {
4
4
  private rawConvo;
5
+ private shuffledIndices;
5
6
  private index;
6
7
  get length(): number;
7
8
  constructor(conversations: Conversation[][]);
8
9
  hasMoreConversations(): boolean;
9
10
  nextConversation(): Conversation[] | null;
10
11
  nextTokens(tokeniser: ITokeniser): number[] | null;
11
- getRandomConversation(): Conversation[];
12
- getRandomTokens(tokeniser: ITokeniser): number[];
12
+ shuffle(): void;
13
13
  estimateTokens(tokeniser: ITokeniser): Promise<number>;
14
14
  }
@@ -1,6 +1,8 @@
1
1
  import { Task as t } from "./Task.js";
2
+ import { shuffle as s } from "../DatasetBuilder.js";
2
3
  class a extends t {
3
4
  rawConvo;
5
+ shuffledIndices = null;
4
6
  index = 0;
5
7
  get length() {
6
8
  return this.rawConvo.length;
@@ -14,20 +16,20 @@ class a extends t {
14
16
  nextConversation() {
15
17
  if (this.index >= this.rawConvo.length)
16
18
  return null;
17
- const n = this.rawConvo[this.index];
19
+ const n = this.rawConvo[this.shuffledIndices ? this.shuffledIndices[this.index] : this.index];
18
20
  return this.index++, n;
19
21
  }
20
22
  nextTokens(n) {
21
- const o = this.nextConversation();
22
- return o ? n.encodeConversation(o) : null;
23
- }
24
- getRandomConversation() {
25
- const n = Math.floor(Math.random() * this.rawConvo.length);
26
- return this.rawConvo[n];
27
- }
28
- getRandomTokens(n) {
29
- const o = Math.floor(Math.random() * this.rawConvo.length);
30
- return n.encodeConversation(this.rawConvo[o]);
23
+ const e = this.nextConversation();
24
+ return e ? n.encodeConversation(e) : null;
25
+ }
26
+ shuffle() {
27
+ if (!this.shuffledIndices) {
28
+ this.shuffledIndices = new Uint32Array(this.rawConvo.length);
29
+ for (let n = 0; n < this.rawConvo.length; n++)
30
+ this.shuffledIndices[n] = n;
31
+ }
32
+ s(this.shuffledIndices), this.index = 0;
31
33
  }
32
34
  async estimateTokens(n) {
33
35
  return (await n.encodeConversation(this.rawConvo[0])).length * this.length;
@@ -8,7 +8,6 @@ export default class PretrainingTask extends Task {
8
8
  hasMoreConversations(): boolean;
9
9
  nextConversation(): Conversation[] | null;
10
10
  nextTokens(tokeniser: ITokeniser): number[] | null;
11
- getRandomConversation(): Conversation[];
12
- getRandomTokens(tokeniser: ITokeniser): number[];
11
+ shuffle(): void;
13
12
  estimateTokens(tokeniser: ITokeniser): Promise<number>;
14
13
  }
@@ -1,5 +1,5 @@
1
1
  import { Task as n } from "./Task.js";
2
- class i extends n {
2
+ class r extends n {
3
3
  rawText;
4
4
  index = 0;
5
5
  get length() {
@@ -26,18 +26,8 @@ class i extends n {
26
26
  const e = t.encodeSequence(this.rawText[this.index]);
27
27
  return this.index++, e;
28
28
  }
29
- getRandomConversation() {
30
- const t = Math.floor(Math.random() * this.rawText.length);
31
- return [
32
- {
33
- role: "assistant",
34
- content: this.rawText[t]
35
- }
36
- ];
37
- }
38
- getRandomTokens(t) {
39
- const e = Math.floor(Math.random() * this.rawText.length);
40
- return t.encodeSequence(this.rawText[e]);
29
+ shuffle() {
30
+ this.index = 0;
41
31
  }
42
32
  async estimateTokens(t) {
43
33
  return (await t.encodeConversation([
@@ -49,5 +39,5 @@ class i extends n {
49
39
  }
50
40
  }
51
41
  export {
52
- i as default
42
+ r as default
53
43
  };
@@ -8,8 +8,7 @@ export default class StartSentenceTask extends Task {
8
8
  hasMoreConversations(): boolean;
9
9
  nextConversation(): Conversation[] | null;
10
10
  nextTokens(tokeniser: ITokeniser): number[] | null;
11
- getRandomConversation(): Conversation[];
12
- getRandomTokens(tokeniser: ITokeniser): number[];
11
+ shuffle(): void;
13
12
  private conversationFromString;
14
13
  estimateTokens(tokeniser: ITokeniser): Promise<number>;
15
14
  }
@@ -21,13 +21,8 @@ class a extends e {
21
21
  const n = this.nextConversation();
22
22
  return n ? t.encodeConversation(n) : null;
23
23
  }
24
- getRandomConversation() {
25
- const t = Math.floor(Math.random() * this.rawText.length);
26
- return this.conversationFromString(this.rawText[t]);
27
- }
28
- getRandomTokens(t) {
29
- const n = this.getRandomConversation();
30
- return t.encodeConversation(n);
24
+ shuffle() {
25
+ this.index = 0;
31
26
  }
32
27
  conversationFromString(t) {
33
28
  const n = t.indexOf(".");
@@ -5,7 +5,6 @@ export declare abstract class Task {
5
5
  abstract nextConversation(): Conversation[] | null;
6
6
  abstract nextTokens(tokeniser: ITokeniser): number[] | null;
7
7
  abstract estimateTokens(tokeniser: ITokeniser): Promise<number>;
8
- abstract getRandomConversation(): Conversation[];
9
- abstract getRandomTokens(tokeniser: ITokeniser): number[];
8
+ abstract shuffle(): void;
10
9
  }
11
10
  export declare function tokensFromTasks(tasks: Task[], tokenizer: ITokeniser, cb?: (tokens: number) => void): Promise<Uint16Array>;
@@ -0,0 +1,5 @@
1
+ import { Task } from './Task';
2
+ export default function splitValidation(tasks: Task[], validationSplit: number): {
3
+ training: Task;
4
+ validation: Task;
5
+ };
@@ -0,0 +1,21 @@
1
+ import s from "./ConversationTask.js";
2
+ function f(e, o) {
3
+ if (o <= 0 || o >= 1)
4
+ throw new Error("validationSplit must be between 0 and 1");
5
+ e.forEach((n) => n.shuffle());
6
+ const r = [], a = [];
7
+ for (const n of e)
8
+ for (; n.hasMoreConversations(); ) {
9
+ const t = n.nextConversation();
10
+ if (!t)
11
+ break;
12
+ Math.random() < o ? a.push(t) : r.push(t);
13
+ }
14
+ return {
15
+ training: new s(r),
16
+ validation: new s(a)
17
+ };
18
+ }
19
+ export {
20
+ f as default
21
+ };
@@ -39,8 +39,8 @@ import "../ops/webgl/adamAdjust.js";
39
39
  import "../ops/cpu/adamMoments.js";
40
40
  import "../ops/webgl/adamMoments.js";
41
41
  import { PAGE_FACTOR as m, shuffle as h } from "./DatasetBuilder.js";
42
- import "../papaparse.min-C0cScC2i.js";
43
42
  import { tokensFromTasks as k } from "./tasks/Task.js";
43
+ import "../papaparse.min-C0cScC2i.js";
44
44
  import "../ops/cpu/matMulGelu.js";
45
45
  import "../matMulGelu-JNLZqKQp.js";
46
46
  import "../ops/grads/matMulGelu.js";
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.15.13",
3
+ "version": "0.15.14",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",