@genai-fi/nanogpt 0.15.13 → 0.16.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  import { yieldIfNeeded as f } from "../utilities/yielder.js";
2
- import k from "../utilities/tokenParse.js";
3
- import z, { SPECIALS as m } from "./BaseTokeniser.js";
2
+ import m from "../utilities/tokenParse.js";
3
+ import z, { SPECIALS as k } from "./BaseTokeniser.js";
4
4
  function p(o, e) {
5
5
  return `${o}-::-${e}`;
6
6
  }
@@ -8,25 +8,25 @@ function w(o) {
8
8
  const e = /* @__PURE__ */ new Map();
9
9
  for (let s = 0; s < o.length; s++) {
10
10
  const t = o[s];
11
- for (let r = 0; r < t.length - 1; r++) {
12
- const n = p(t[r], t[r + 1]), a = e.get(n) || {
13
- a: t[r],
14
- b: t[r + 1],
11
+ for (let n = 0; n < t.length - 1; n++) {
12
+ const r = p(t[n], t[n + 1]), a = e.get(r) || {
13
+ a: t[n],
14
+ b: t[n + 1],
15
15
  count: 0,
16
16
  instances: /* @__PURE__ */ new Set()
17
17
  };
18
- a.count += 1, a.instances.add(s), e.set(n, a);
18
+ a.count += 1, a.instances.add(s), e.set(r, a);
19
19
  }
20
20
  }
21
21
  return { pairs: e, tokens: o };
22
22
  }
23
- function d(o, e, s, t, r) {
24
- const n = p(e, s);
25
- if (o.pairs.has(n)) {
26
- const a = o.pairs.get(n);
27
- a.count += r, r > 0 ? a.instances.add(t) : a.count <= 0 ? o.pairs.delete(n) : a.instances.delete(t);
23
+ function d(o, e, s, t, n) {
24
+ const r = p(e, s);
25
+ if (o.pairs.has(r)) {
26
+ const a = o.pairs.get(r);
27
+ a.count += n, n > 0 ? a.instances.add(t) : a.count <= 0 ? o.pairs.delete(r) : a.instances.delete(t);
28
28
  } else
29
- o.pairs.set(n, { a: e, b: s, count: r, instances: /* @__PURE__ */ new Set([t]) });
29
+ o.pairs.set(r, { a: e, b: s, count: n, instances: /* @__PURE__ */ new Set([t]) });
30
30
  }
31
31
  function T(o) {
32
32
  let e = null, s = 0;
@@ -37,21 +37,21 @@ function T(o) {
37
37
  function y(o, e) {
38
38
  return o.map((s) => {
39
39
  const t = [];
40
- for (let r = 0; r < s.length; r++)
41
- r < s.length - 1 && s[r] === e[0] && s[r + 1] === e[1] ? (t.push(e[0] + e[1]), r++) : t.push(s[r]);
40
+ for (let n = 0; n < s.length; n++)
41
+ n < s.length - 1 && s[n] === e[0] && s[n + 1] === e[1] ? (t.push(e[0] + e[1]), n++) : t.push(s[n]);
42
42
  return t;
43
43
  });
44
44
  }
45
45
  function I(o, e) {
46
46
  e.instances.forEach((s) => {
47
- const t = o.tokens[s], r = [];
48
- for (let n = 0; n < t.length; n++)
49
- if (n < t.length - 1 && t[n] === e.a && t[n + 1] === e.b) {
47
+ const t = o.tokens[s], n = [];
48
+ for (let r = 0; r < t.length; r++)
49
+ if (r < t.length - 1 && t[r] === e.a && t[r + 1] === e.b) {
50
50
  const a = e.a + e.b;
51
- r.push(a), n > 0 && (d(o, t[n - 1], e.a, s, -1), d(o, t[n - 1], a, s, 1)), n++, n < t.length - 1 && (d(o, e.b, t[n + 1], s, -1), d(o, a, t[n + 1], s, 1));
51
+ n.push(a), r > 0 && (d(o, t[r - 1], e.a, s, -1), d(o, t[r - 1], a, s, 1)), r++, r < t.length - 1 && (d(o, e.b, t[r + 1], s, -1), d(o, a, t[r + 1], s, 1));
52
52
  } else
53
- r.push(t[n]);
54
- o.tokens[s] = r;
53
+ n.push(t[r]);
54
+ o.tokens[s] = n;
55
55
  }), o.pairs.delete(p(e.a, e.b));
56
56
  }
57
57
  class E extends z {
@@ -61,11 +61,11 @@ class E extends z {
61
61
  merges = [];
62
62
  pretokenMap = /* @__PURE__ */ new Map();
63
63
  constructor(e, s) {
64
- super(), Array.isArray(e) ? (e.forEach((t, r) => {
65
- this.vocab.add(t), this.vocabIndex.set(t, r);
66
- }), s && (this.merges = s), this.targetSize = e.length, m.forEach((t) => {
67
- const r = e.indexOf(t);
68
- r !== -1 && this.addSpecialToken(t, r);
64
+ super(), Array.isArray(e) ? (e.forEach((t, n) => {
65
+ this.vocab.add(t), this.vocabIndex.set(t, n);
66
+ }), s && (this.merges = s), this.targetSize = e.length, k.forEach((t) => {
67
+ const n = e.indexOf(t);
68
+ n !== -1 && this.addSpecialToken(t, n);
69
69
  })) : (this.addSpecialTokens(), this.targetSize = e);
70
70
  }
71
71
  addToken(e, s) {
@@ -81,7 +81,7 @@ class E extends z {
81
81
  this.vocab.clear(), this.vocabIndex.clear(), this.merges = [], this.pretokenMap.clear();
82
82
  }
83
83
  get trained() {
84
- return this.vocab.size > m.length && this.vocab.size <= this.targetSize;
84
+ return this.vocab.size > k.length && this.vocab.size <= this.targetSize;
85
85
  }
86
86
  get vocabSize() {
87
87
  return this.vocab.size;
@@ -95,23 +95,23 @@ class E extends z {
95
95
  get unkToken() {
96
96
  return this.vocabIndex.get("") ?? 1;
97
97
  }
98
- async train(e, s) {
98
+ async train(e = [], s) {
99
99
  let t = performance.now();
100
- const r = e.map((i) => k(i)).flat(1);
100
+ const n = e.map((i) => i.map((h) => m(h.content))).flat(2);
101
101
  t = await f(t, s, this.vocab.size);
102
- const n = new Set(r);
102
+ const r = new Set(n);
103
103
  this.vocab = /* @__PURE__ */ new Set(), this.pretokenMap.clear(), this.merges = [], this.addSpecialTokens();
104
- const a = Array.from(n), b = a.map((i) => Array.from(i).map((h) => (this.vocab.add(h), h))), g = w(b);
104
+ const a = Array.from(r), b = a.map((i) => Array.from(i).map((l) => (this.vocab.add(l), l))), g = w(b);
105
105
  if (t = await f(t, s, this.vocab.size), this.vocab.size >= this.targetSize) {
106
106
  console.warn("Initial vocab size is greater than or equal to target size. No merges will be performed.");
107
107
  const i = /* @__PURE__ */ new Map();
108
- r.forEach((c) => {
108
+ n.forEach((c) => {
109
109
  Array.from(c).forEach((u) => {
110
110
  i.set(u, (i.get(u) || 0) + 1);
111
111
  });
112
112
  });
113
- const l = Array.from(i.entries()).sort((c, u) => u[1] - c[1]);
114
- this.vocab = /* @__PURE__ */ new Set(), this.addSpecialTokens(), l.slice(0, this.targetSize - this.vocab.size).map(([c]) => c).forEach((c) => this.vocab.add(c)), this.vocabIndex.clear();
113
+ const h = Array.from(i.entries()).sort((c, u) => u[1] - c[1]);
114
+ this.vocab = /* @__PURE__ */ new Set(), this.addSpecialTokens(), h.slice(0, this.targetSize - this.vocab.size).map(([c]) => c).forEach((c) => this.vocab.add(c)), this.vocabIndex.clear();
115
115
  let S = 0;
116
116
  for (const c of this.vocab.keys())
117
117
  this.vocabIndex.set(c, S++);
@@ -123,9 +123,9 @@ class E extends z {
123
123
  break;
124
124
  this.merges.push([i.a, i.b]), this.vocab.add(i.a + i.b), I(g, i), t = await f(t, s, this.vocab.size);
125
125
  }
126
- a.forEach((i, l) => {
127
- const h = b[l];
128
- this.pretokenMap.set(i, h);
126
+ a.forEach((i, h) => {
127
+ const l = b[h];
128
+ this.pretokenMap.set(i, l);
129
129
  }), this.vocabIndex.clear();
130
130
  let v = 0;
131
131
  for (const i of this.vocab.keys())
@@ -145,15 +145,15 @@ class E extends z {
145
145
  }), this.pretokenMap.set(e, s), s;
146
146
  }
147
147
  tokeniseStrings(e) {
148
- return e.map((s) => k(s).map((n) => this.pretokenMap.has(n) ? this.pretokenMap.get(n) : this.tokeniseWord(n)).flat(1));
148
+ return e.map((s) => m(s).map((r) => this.pretokenMap.has(r) ? this.pretokenMap.get(r) : this.tokeniseWord(r)).flat(1));
149
149
  }
150
150
  tokenise(e, s) {
151
151
  const t = this.tokeniseStrings(e);
152
- return s ? t.map((r) => r.map((n) => this.vocabIndex.get(n) ?? this.unkToken)) : t.map((r) => r.map((n) => this.vocab.has(n) ? n : ""));
152
+ return s ? t.map((n) => n.map((r) => this.vocabIndex.get(r) ?? this.unkToken)) : t.map((n) => n.map((r) => this.vocab.has(r) ? r : ""));
153
153
  }
154
154
  detokenise(e) {
155
155
  const s = this.getVocab();
156
- return e.map((r) => r.map((n) => s[n]).join(""));
156
+ return e.map((n) => n.map((r) => s[r]).join(""));
157
157
  }
158
158
  encode(e) {
159
159
  return this.tokenise([e], !0)[0];
@@ -1,11 +1,11 @@
1
1
  import { default as EE } from 'eventemitter3';
2
- export type Roles = 'user' | 'assistant' | 'system';
2
+ export type Roles = 'user' | 'assistant' | 'system' | 'text';
3
3
  export interface Conversation {
4
4
  role: Roles;
5
5
  content: string;
6
6
  }
7
7
  export interface ITokeniser extends EE<'trainStatus'> {
8
- train(text: string[], cb?: (vocab: number) => void): Promise<number>;
8
+ train(text: Conversation[][], cb?: (vocab: number) => void): Promise<number>;
9
9
  getVocab(): string[];
10
10
  getMerges(): [string, string][];
11
11
  destroy(): void;
@@ -1,8 +1,8 @@
1
1
  import u from "./Evaluator.js";
2
- import { t as L, v as P, k as h, d as g, a as y } from "../index-CUXkjxiT.js";
2
+ import { t as z, v as P, k as g, d as p, a as y } from "../index-CUXkjxiT.js";
3
3
  import S from "../utilities/profile.js";
4
- import { createTensorStatistics as x } from "../checks/weights.js";
5
- import { calculateLoss as k, calculateAccuracy as T } from "./loss.js";
4
+ import { createTensorStatistics as k } from "../checks/weights.js";
5
+ import { calculateLoss as x, calculateAccuracy as T } from "./loss.js";
6
6
  import { AdamWOptimizer as N } from "./AdamW.js";
7
7
  import { z as w } from "../zeros-DvZpK8s6.js";
8
8
  const v = {
@@ -23,11 +23,11 @@ const v = {
23
23
  lossScaling: 1
24
24
  };
25
25
  class G {
26
- constructor(e, i, o, c) {
27
- this.tokenizer = i, this.model = e, this.optimizerConfig = {
26
+ constructor(s, i, o, c) {
27
+ this.tokenizer = i, this.model = s, this.optimizerConfig = {
28
28
  ...b,
29
29
  ...o,
30
- lossScaling: e.lossScaling
30
+ lossScaling: s.lossScaling
31
31
  };
32
32
  const l = c || new N(this.optimizerConfig);
33
33
  c && c.updateConfig(this.optimizerConfig), this.optimizer = l;
@@ -44,26 +44,26 @@ class G {
44
44
  _labelSmoothing = 0;
45
45
  _layerDrop = 0;
46
46
  _dropout = 0;
47
- setGradientCheckpointing(e) {
48
- this._gradientCheckpointing = e;
47
+ setGradientCheckpointing(s) {
48
+ this._gradientCheckpointing = s;
49
49
  }
50
- setMixedPrecision(e) {
51
- this._mixedPrecision = e;
50
+ setMixedPrecision(s) {
51
+ this._mixedPrecision = s;
52
52
  }
53
- setLabelSmoothing(e) {
54
- this._labelSmoothing = e;
53
+ setLabelSmoothing(s) {
54
+ this._labelSmoothing = s;
55
55
  }
56
- setDropout(e) {
57
- this._dropout = e;
56
+ setDropout(s) {
57
+ this._dropout = s;
58
58
  }
59
- setLayerDrop(e) {
60
- this._layerDrop = e;
59
+ setLayerDrop(s) {
60
+ this._layerDrop = s;
61
61
  }
62
- setLearningRate(e) {
63
- this.optimizerConfig.learningRate = e, this.updateOptimizer();
62
+ setLearningRate(s) {
63
+ this.optimizerConfig.learningRate = s, this.updateOptimizer();
64
64
  }
65
- setMetrics(e) {
66
- this.metrics = new Set(e);
65
+ setMetrics(s) {
66
+ this.metrics = new Set(s);
67
67
  }
68
68
  reset() {
69
69
  this.lastState = void 0, this.running = !1;
@@ -77,12 +77,12 @@ class G {
77
77
  getOptimizer() {
78
78
  return this.optimizer;
79
79
  }
80
- updateOptimizer(e) {
81
- e && (this.optimizerConfig = { ...this.optimizerConfig, ...e }), this.optimizer.updateConfig(this.optimizerConfig);
80
+ updateOptimizer(s) {
81
+ s && (this.optimizerConfig = { ...this.optimizerConfig, ...s }), this.optimizer.updateConfig(this.optimizerConfig);
82
82
  }
83
83
  // A single forward pass, backward pass, and optimizer step
84
- trainStep(e, i, o = !1, c = !1) {
85
- return L(() => {
84
+ trainStep(s, i, o = !1, c = !1) {
85
+ return z(() => {
86
86
  this.model.getProfiler()?.startMemory();
87
87
  const { xs: l, ys: a } = i, d = () => {
88
88
  const r = this.model.forward(
@@ -94,31 +94,31 @@ class G {
94
94
  layerDrop: this._layerDrop
95
95
  },
96
96
  l
97
- ), s = k(r, a, this.maskedLoss, !1, this._labelSmoothing);
98
- this.metrics.has("accuracy") && (e.accuracy = T(r, a), h(e.accuracy)), r.dispose();
99
- const m = s.mul(y(this.optimizerConfig.lossScaling));
100
- return s.dispose(), m;
97
+ ), e = x(r, a, this.maskedLoss, !1, this._labelSmoothing);
98
+ this.metrics.has("accuracy") && (s.accuracy = T(r, a), g(s.accuracy)), r.dispose();
99
+ const m = e.mul(y(this.optimizerConfig.lossScaling));
100
+ return e.dispose(), m;
101
101
  }, { value: t, grads: n } = P(d);
102
102
  if (o)
103
103
  this.model.getProfiler()?.endMemory("Training");
104
104
  else {
105
105
  const r = this.optimizer.applyGradients(n);
106
- this.metrics.has("gradientNorm") ? (e.gradientNorm = r, h(r)) : (e.gradientNorm = void 0, r.dispose());
107
- const s = Object.keys(n);
108
- this.model.weightStore.touchVariables(s), this.model.getProfiler()?.endMemory("Training"), c ? (e.gradients = n, Object.values(n).forEach((m) => h(m))) : g(n);
106
+ this.metrics.has("gradientNorm") ? (s.gradientNorm = r, g(r)) : (s.gradientNorm = void 0, r.dispose());
107
+ const e = Object.keys(n);
108
+ this.model.weightStore.touchVariables(e), this.model.getProfiler()?.endMemory("Training"), c ? (s.gradients = n, Object.values(n).forEach((m) => g(m))) : p(n);
109
109
  }
110
110
  return t.mul(y(1 / this.optimizerConfig.lossScaling));
111
111
  });
112
112
  }
113
113
  async dummyPass() {
114
- const e = w([1, this.model.config.blockSize], "int32"), i = w([1, this.model.config.blockSize], "int32");
114
+ const s = w([1, this.model.config.blockSize], "int32"), i = w([1, this.model.config.blockSize], "int32");
115
115
  try {
116
- const o = this.trainStep({}, { xs: e, ys: i }, !0);
116
+ const o = this.trainStep({}, { xs: s, ys: i }, !0);
117
117
  await o.data(), o.dispose();
118
118
  } catch (o) {
119
119
  console.error("Error during dummy pass:", o);
120
120
  } finally {
121
- e.dispose(), i.dispose();
121
+ s.dispose(), i.dispose();
122
122
  }
123
123
  }
124
124
  dispose() {
@@ -136,7 +136,7 @@ class G {
136
136
  ...this.lastState || {}
137
137
  };
138
138
  }
139
- async stepDataset(e, i, o) {
139
+ async stepDataset(s, i, o) {
140
140
  const { logInterval: c = 10 } = {
141
141
  ...v,
142
142
  ...i
@@ -144,21 +144,21 @@ class G {
144
144
  i.metrics && this.setMetrics(i.metrics);
145
145
  const l = Date.now(), a = this.createEmptyState();
146
146
  this.lastState = a, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new S())), this.running = !0, a.logStartTime = l;
147
- const d = o ? new u(this.model, o) : void 0, t = await e.iterator();
147
+ const d = o ? new u(this.model, o, void 0, this.maskedLoss) : void 0, t = await s.iterator();
148
148
  try {
149
149
  for (; this.running; ) {
150
150
  const n = await t.next();
151
151
  if (n.done) break;
152
- const r = n.value, s = this.trainStep(a, r, !1);
153
- r.xs.dispose(), r.ys.dispose(), a.step++, a.totalSteps++, a.step % c === 0 ? await this.performLogging(s, r.xs.shape[0], i, d) : (a.gradientNorm && (a.gradientNorm.dispose(), a.gradientNorm = void 0), a.accuracy && (a.accuracy.dispose(), a.accuracy = void 0)), s.dispose();
152
+ const r = n.value, e = this.trainStep(a, r, !1);
153
+ r.xs.dispose(), r.ys.dispose(), a.step++, a.totalSteps++, a.step % c === 0 ? await this.performLogging(e, r.xs.shape[0], i, d) : (a.gradientNorm && (a.gradientNorm.dispose(), a.gradientNorm = void 0), a.accuracy && (a.accuracy.dispose(), a.accuracy = void 0)), e.dispose();
154
154
  }
155
155
  } catch (n) {
156
- throw console.error("Training error:", n), g(), n;
156
+ throw console.error("Training error:", n), p(), n;
157
157
  }
158
- throw g(), this.running = !1, new Error("No log returned before training stopped.");
158
+ throw p(), this.running = !1, new Error("No log returned before training stopped.");
159
159
  }
160
- async performLogging(e, i, o, c) {
161
- const l = o?.onStep, a = this.metrics.has("gradientStatistics"), d = (await e.data())[0], t = this.lastState;
160
+ async performLogging(s, i, o, c) {
161
+ const l = o?.onStep, a = this.metrics.has("gradientStatistics"), d = (await s.data())[0], t = this.lastState;
162
162
  t.lastLoss = d;
163
163
  const n = Date.now();
164
164
  t.trainingDuration += n - t.logStartTime;
@@ -184,25 +184,25 @@ class G {
184
184
  batchSize: i,
185
185
  loss: t.lastLoss
186
186
  }, a && t.gradients) {
187
- const s = /* @__PURE__ */ new Map();
188
- for (const [m, p] of Object.entries(t.gradients))
189
- s.set(m, await x(p)), p.dispose();
190
- r.gradientMetrics = s;
187
+ const e = /* @__PURE__ */ new Map();
188
+ for (const [m, h] of Object.entries(t.gradients))
189
+ e.set(m, await k(h)), h.dispose();
190
+ r.gradientMetrics = e;
191
191
  }
192
192
  if (c)
193
193
  try {
194
- const s = await c.evaluate(5);
195
- Array.isArray(s) ? r.validationMetrics = { loss: s[0].loss, accuracy: s[0].accuracy } : (t.validationLosses.push(s.loss), r.validationMetrics = {
196
- accuracy: s.accuracy,
197
- loss: s.loss,
198
- perplexity: this.metrics.has("perplexity") ? Math.exp(s.loss) : void 0
194
+ const e = await c.evaluate(5);
195
+ Array.isArray(e) ? r.validationMetrics = { loss: e[0].loss, accuracy: e[0].accuracy } : (t.validationLosses.push(e.loss), r.validationMetrics = {
196
+ accuracy: e.accuracy,
197
+ loss: e.loss,
198
+ perplexity: this.metrics.has("perplexity") ? Math.exp(e.loss) : void 0
199
199
  });
200
- } catch (s) {
201
- console.error("Validation error:", s);
200
+ } catch (e) {
201
+ console.error("Validation error:", e);
202
202
  }
203
203
  l && await l(r), t.logStartTime = Date.now();
204
204
  }
205
- async trainOnDataset(e, i, o) {
205
+ async trainOnDataset(s, i, o) {
206
206
  const { logInterval: c = 10, maxEpochs: l = 1 / 0 } = {
207
207
  ...v,
208
208
  ...i
@@ -210,18 +210,18 @@ class G {
210
210
  i.metrics && this.setMetrics(i.metrics);
211
211
  const d = Date.now(), t = this.createEmptyState();
212
212
  this.lastState = t, await this.dummyPass(), i?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new S())), this.running = !0, t.logStartTime = d;
213
- const n = o ? new u(this.model, o) : void 0, r = await e.iterator();
213
+ const n = o ? new u(this.model, o, void 0, this.maskedLoss) : void 0, r = await s.iterator();
214
214
  try {
215
215
  for (; this.running; ) {
216
- const s = await r.next();
217
- if (s.done) break;
218
- const m = s.value, p = t.step % c === 0, z = (i?.metrics?.includes("gradientStatistics") || !1) && p, f = this.trainStep(t, m, !1, z);
219
- m.xs.dispose(), m.ys.dispose(), t.step++, t.totalSteps++, p ? await this.performLogging(f, m.xs.shape[0], i, n) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= a && this.stop();
216
+ const e = await r.next();
217
+ if (e.done) break;
218
+ const m = e.value, h = t.step % c === 0, L = (i?.metrics?.includes("gradientStatistics") || !1) && h, f = this.trainStep(t, m, !1, L);
219
+ m.xs.dispose(), m.ys.dispose(), t.step++, t.totalSteps++, h ? await this.performLogging(f, m.xs.shape[0], i, n) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= a && this.stop();
220
220
  }
221
- } catch (s) {
222
- throw console.error("Training error:", s), g(), s;
221
+ } catch (e) {
222
+ throw console.error("Training error:", e), p(), e;
223
223
  }
224
- return g(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
224
+ return p(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
225
225
  }
226
226
  }
227
227
  export {
@@ -11,7 +11,8 @@ export default class Evaluator {
11
11
  private iterator?;
12
12
  private xs?;
13
13
  private ys?;
14
- constructor(model: Model<ModelForwardAttributes>, dataset: Dataset<TensorContainer> | Conversation[][], tokeniser?: ITokeniser);
14
+ private masked;
15
+ constructor(model: Model<ModelForwardAttributes>, dataset: Dataset<TensorContainer> | Conversation[][], tokeniser?: ITokeniser, masked?: boolean);
15
16
  dispose(): void;
16
17
  private calculateBatchLoss;
17
18
  evaluate(maxBatches?: number): Promise<Result | Result[]>;
@@ -2,12 +2,12 @@ import { t as p } from "../index-CUXkjxiT.js";
2
2
  import { calculateLoss as d, calculateAccuracy as m } from "./loss.js";
3
3
  import { buildSFTExample as x } from "./SFTDatasetBuilder.js";
4
4
  import { t as h } from "../tensor-BWFldCso.js";
5
- class g {
6
- constructor(c, t, o) {
7
- if (this.model = c, Array.isArray(t)) {
5
+ class k {
6
+ constructor(i, t, o, c) {
7
+ if (this.model = i, this.masked = !!c, Array.isArray(t)) {
8
8
  if (!o)
9
9
  throw new Error("Tokeniser is required when dataset is an array of conversations");
10
- const a = t.map((s) => x(s, -100, o, c.config.blockSize)).filter((s) => s !== null);
10
+ const a = t.map((s) => x(s, -100, o, i.config.blockSize)).filter((s) => s !== null);
11
11
  if (a.length === 0)
12
12
  return;
13
13
  this.xs = h(a.map((s) => s.xs)), this.ys = h(a.map((s) => s.ys));
@@ -17,32 +17,33 @@ class g {
17
17
  iterator;
18
18
  xs;
19
19
  ys;
20
+ masked = !1;
20
21
  dispose() {
21
22
  this.xs && this.xs.dispose(), this.ys && this.ys.dispose();
22
23
  }
23
- async calculateBatchLoss(c, t, o, a) {
24
- const [s, e] = p(() => {
25
- const r = this.model.forward({ training: !1 }, c), y = d(r, t, a, o), f = m(r, t);
24
+ async calculateBatchLoss(i, t, o, c) {
25
+ const [a, s] = p(() => {
26
+ const r = this.model.forward({ training: !1 }, i), y = d(r, t, c, o), f = m(r, t);
26
27
  return r.dispose(), [y, f];
27
- }), n = await s.array(), u = await e.array(), i = n, l = u;
28
- return e.dispose(), s.dispose(), Array.isArray(i) ? i.map((r) => ({ loss: r, accuracy: l })) : { loss: i, accuracy: l };
28
+ }), n = await a.array(), u = await s.array(), e = n, l = u;
29
+ return s.dispose(), a.dispose(), Array.isArray(e) ? e.map((r) => ({ loss: r, accuracy: l })) : { loss: e, accuracy: l };
29
30
  }
30
- async evaluate(c = 100) {
31
- let t = 0, o = 0, a = 0;
31
+ async evaluate(i = 100) {
32
+ let t = 0, o = 0, c = 0;
32
33
  if (this.iterator) {
33
- const s = await this.iterator;
34
- for (let e = 0; e < c; e++) {
35
- const n = await s.next();
34
+ const a = await this.iterator;
35
+ for (let s = 0; s < i; s++) {
36
+ const n = await a.next();
36
37
  if (n.done) break;
37
- const u = n.value, { xs: i, ys: l } = u, r = await this.calculateBatchLoss(i, l, !1, !1);
38
- i.dispose(), l.dispose(), t += r.loss, o += r.accuracy, a++;
38
+ const u = n.value, { xs: e, ys: l } = u, r = await this.calculateBatchLoss(e, l, !1, this.masked);
39
+ e.dispose(), l.dispose(), t += r.loss, o += r.accuracy, c++;
39
40
  }
40
- return { loss: t / a, accuracy: o / a };
41
+ return { loss: t / c, accuracy: o / c };
41
42
  } else if (this.xs && this.ys)
42
43
  return this.calculateBatchLoss(this.xs, this.ys, !0, !0);
43
44
  throw new Error("No data available for evaluation");
44
45
  }
45
46
  }
46
47
  export {
47
- g as default
48
+ k as default
48
49
  };
@@ -1,50 +1,52 @@
1
- import { t as x } from "../index-CUXkjxiT.js";
1
+ import { t as y } from "../index-CUXkjxiT.js";
2
2
  import "../dataset-CGGp1z9P.js";
3
3
  import { g as I } from "../readers-iz5u3HBo.js";
4
4
  import "../index-Cp39cXWe.js";
5
- function w(u, a, t, r) {
6
- const s = [t.bosToken], n = [!1], f = {
5
+ function w(p, o, t, l) {
6
+ const s = [t.bosToken], a = [!1], u = {
7
7
  user: t.getSpecialTokenIndex("<|user_start|>"),
8
8
  assistant: t.getSpecialTokenIndex("<|assistant_start|>"),
9
- system: t.getSpecialTokenIndex("<|system_start|>")
10
- }, i = {
9
+ system: t.getSpecialTokenIndex("<|system_start|>"),
10
+ text: void 0
11
+ }, c = {
11
12
  user: t.getSpecialTokenIndex("<|user_end|>"),
12
13
  assistant: t.getSpecialTokenIndex("<|assistant_end|>"),
13
- system: t.getSpecialTokenIndex("<|system_end|>")
14
+ system: t.getSpecialTokenIndex("<|system_end|>"),
15
+ text: void 0
14
16
  };
15
- for (const e of u) {
16
- const c = f[e.role], h = i[e.role];
17
- if (!c || !h)
17
+ for (const e of p) {
18
+ const r = u[e.role], h = c[e.role];
19
+ if (!r || !h)
18
20
  throw new Error(`Missing special tokens for role: ${e.role}`);
19
- s.push(c), n.push(!1);
20
- const m = e.role === "assistant", S = t.encode(e.content);
21
- for (const T of S) {
21
+ s.push(r), a.push(!1);
22
+ const m = e.role === "assistant", x = t.encode(e.content);
23
+ for (const T of x) {
22
24
  s.push(T);
23
- const y = t.isSpecialToken(T);
24
- n.push(m && !y);
25
+ const S = t.isSpecialToken(T);
26
+ a.push(m && !S);
25
27
  }
26
- s.push(h), n.push(m);
28
+ s.push(h), a.push(m);
27
29
  }
28
- s.push(t.eosToken), n.push(!1);
29
- const o = r + 1;
30
- if (s.length < o) {
31
- const e = o - s.length, c = t.getSpecialTokenIndex("<pad>");
30
+ s.push(t.eosToken), a.push(!1);
31
+ const n = l + 1;
32
+ if (s.length < n) {
33
+ const e = n - s.length, r = t.getSpecialTokenIndex("<pad>");
32
34
  for (let h = 0; h < e; h++)
33
- s.push(c), n.push(!1);
34
- } else s.length > o && (s.length = o, n.length = o);
35
- const p = new Int32Array(s.slice(0, r)), l = s.slice(1, r + 1), k = n.slice(1, r + 1), d = new Int32Array(l.length);
35
+ s.push(r), a.push(!1);
36
+ } else s.length > n && (s.length = n, a.length = n);
37
+ const f = new Int32Array(s.slice(0, l)), i = s.slice(1, l + 1), k = a.slice(1, l + 1), d = new Int32Array(i.length);
36
38
  let g = !1;
37
- for (let e = 0; e < l.length; e++) {
38
- const c = k[e] ? l[e] : a;
39
- d[e] = c, c !== a && (g = !0);
39
+ for (let e = 0; e < i.length; e++) {
40
+ const r = k[e] ? i[e] : o;
41
+ d[e] = r, r !== o && (g = !0);
40
42
  }
41
- return g ? { xs: p, ys: d } : null;
43
+ return g ? { xs: f, ys: d } : null;
42
44
  }
43
45
  class A {
44
46
  tokenizer;
45
47
  blockSize;
46
- constructor(a, t = 128) {
47
- this.tokenizer = a, this.blockSize = t;
48
+ constructor(o, t = 128) {
49
+ this.tokenizer = o, this.blockSize = t;
48
50
  }
49
51
  /**
50
52
  * Create SFT dataset from structured conversations.
@@ -52,20 +54,27 @@ class A {
52
54
  * - Pads with eosToken and masks padding.
53
55
  * - Masks non-assistant tokens in labels with ignoreIndex (default -100).
54
56
  */
55
- async createSFTDataset(a, t = 32, r = -100) {
56
- if (!a.length)
57
+ async createSFTDataset(o, t = 32, l = -100) {
58
+ if (!o.length)
57
59
  throw new Error("No conversations provided.");
58
- const s = this.tokenizer, n = this.blockSize;
60
+ const s = this.tokenizer, a = this.blockSize;
61
+ for (const c of o)
62
+ c.shuffle();
59
63
  return I(function* () {
60
64
  for (; ; ) {
61
- const i = Math.floor(Math.random() * a.length), p = a[i].getRandomConversation(), l = w(p, r, s, n);
62
- l && (yield l);
65
+ const c = Math.floor(Math.random() * o.length), n = o[c], f = n.nextConversation();
66
+ if (!f) {
67
+ n.shuffle();
68
+ continue;
69
+ }
70
+ const i = w(f, l, s, a);
71
+ i && (yield i);
63
72
  }
64
- }).batch(t).map((i) => {
65
- const o = i;
66
- return x(() => ({
67
- xs: o.xs.cast("int32"),
68
- ys: o.ys.cast("int32")
73
+ }).batch(t).map((c) => {
74
+ const n = c;
75
+ return y(() => ({
76
+ xs: n.xs.cast("int32"),
77
+ ys: n.ys.cast("int32")
69
78
  }));
70
79
  }).prefetch(2);
71
80
  }
@@ -2,13 +2,13 @@ import { Conversation, ITokeniser } from '../../main';
2
2
  import { Task } from './Task';
3
3
  export default class ConversationTask extends Task {
4
4
  private rawConvo;
5
+ private shuffledIndices;
5
6
  private index;
6
7
  get length(): number;
7
8
  constructor(conversations: Conversation[][]);
8
9
  hasMoreConversations(): boolean;
9
10
  nextConversation(): Conversation[] | null;
10
11
  nextTokens(tokeniser: ITokeniser): number[] | null;
11
- getRandomConversation(): Conversation[];
12
- getRandomTokens(tokeniser: ITokeniser): number[];
12
+ shuffle(): void;
13
13
  estimateTokens(tokeniser: ITokeniser): Promise<number>;
14
14
  }