@genai-fi/nanogpt 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. package/LICENSE +7 -0
  2. package/README.md +20 -0
  3. package/dist/Generator.d.ts +14 -0
  4. package/dist/Generator.js +39 -0
  5. package/dist/NanoGPTModel.d.ts +35 -0
  6. package/dist/NanoGPTModel.js +129 -0
  7. package/dist/TeachableLLM.d.ts +21 -0
  8. package/dist/TeachableLLM.js +47 -0
  9. package/dist/Trainer.d.ts +19 -0
  10. package/dist/Trainer.js +34 -0
  11. package/dist/_commonjsHelpers-DaMA6jEr.js +8 -0
  12. package/dist/assets/worker-BYeSPNkq.js +1 -0
  13. package/dist/config.d.ts +11 -0
  14. package/dist/config.js +19 -0
  15. package/dist/index-B8nyc6IR.js +3899 -0
  16. package/dist/index-SOhdqzHq.js +113 -0
  17. package/dist/jszip.min-BLbRbbKt.js +2324 -0
  18. package/dist/layers/CausalSelfAttention.d.ts +22 -0
  19. package/dist/layers/CausalSelfAttention.js +75 -0
  20. package/dist/layers/LayerNorm.d.ts +12 -0
  21. package/dist/layers/LayerNorm.js +30 -0
  22. package/dist/layers/MLP.d.ts +17 -0
  23. package/dist/layers/MLP.js +57 -0
  24. package/dist/layers/TiedEmbedding.d.ts +22 -0
  25. package/dist/layers/TiedEmbedding.js +532 -0
  26. package/dist/layers/TransformerBlock.d.ts +19 -0
  27. package/dist/layers/TransformerBlock.js +47 -0
  28. package/dist/main.d.ts +6 -0
  29. package/dist/main.js +8 -0
  30. package/dist/tokeniser/CharTokeniser.d.ts +20 -0
  31. package/dist/tokeniser/CharTokeniser.js +52 -0
  32. package/dist/tokeniser/NodeTokeniser.d.ts +19 -0
  33. package/dist/tokeniser/NodeTokeniser.js +46 -0
  34. package/dist/tokeniser/WebTokeniser.d.ts +18 -0
  35. package/dist/tokeniser/WebTokeniser.js +96 -0
  36. package/dist/tokeniser/bpe.d.ts +14 -0
  37. package/dist/tokeniser/bpe.js +102 -0
  38. package/dist/tokeniser/messages.d.ts +61 -0
  39. package/dist/tokeniser/messages.js +1 -0
  40. package/dist/tokeniser/type.d.ts +14 -0
  41. package/dist/tokeniser/type.js +1 -0
  42. package/dist/tokeniser/worker.d.ts +1 -0
  43. package/dist/tokeniser/worker.js +53 -0
  44. package/dist/training/AdamExt.d.ts +23 -0
  45. package/dist/training/AdamExt.js +43 -0
  46. package/dist/training/DatasetBuilder.d.ts +12 -0
  47. package/dist/training/DatasetBuilder.js +27 -0
  48. package/dist/training/FullTrainer.d.ts +17 -0
  49. package/dist/training/FullTrainer.js +75 -0
  50. package/dist/training/LayerTrainer.d.ts +28 -0
  51. package/dist/training/LayerTrainer.js +108 -0
  52. package/dist/training/Trainer.d.ts +73 -0
  53. package/dist/training/Trainer.js +87 -0
  54. package/dist/training/lwSchedule.d.ts +7 -0
  55. package/dist/training/lwSchedule.js +162 -0
  56. package/dist/utilities/generate.d.ts +3 -0
  57. package/dist/utilities/generate.js +22 -0
  58. package/dist/utilities/load.d.ts +7 -0
  59. package/dist/utilities/load.js +47 -0
  60. package/dist/utilities/save.d.ts +3 -0
  61. package/dist/utilities/save.js +21 -0
  62. package/dist/utilities/textLoader.d.ts +1 -0
  63. package/dist/utilities/textLoader.js +438 -0
  64. package/dist/utilities/tokenParse.d.ts +1 -0
  65. package/dist/utilities/tokenParse.js +66 -0
  66. package/dist/utilities/weights.d.ts +12 -0
  67. package/dist/utilities/weights.js +43 -0
  68. package/package.json +59 -0
@@ -0,0 +1,19 @@
1
+ import { default as EE } from 'eventemitter3';
2
+ import { ITokeniser } from './type';
3
+ export default class NodeTokeniser extends EE<'trainStatus'> implements ITokeniser {
4
+ vocabSize: number;
5
+ eosToken: number;
6
+ private bpe;
7
+ constructor(vocab?: string[], merges?: [string, string][]);
8
+ get trained(): boolean;
9
+ destroy(): void;
10
+ train(text: string[], vocabSize: number): Promise<number>;
11
+ tokenise(text: string[], numeric: true): Promise<number[][]>;
12
+ tokenise(text: string[]): Promise<string[][]>;
13
+ detokenise(tokens: number[][]): Promise<string[]>;
14
+ encode(text: string): Promise<number[]>;
15
+ decode(tokens: number[]): Promise<string>;
16
+ getVocab(): string[];
17
+ getMerges(): Promise<[string, string][]>;
18
+ createTrainingData(text: string[], windowSize?: number): Promise<[number[], number[]]>;
19
+ }
@@ -0,0 +1,46 @@
1
+ import { E as a } from "../index-SOhdqzHq.js";
2
+ import o from "./bpe.js";
3
+ class b extends a {
4
+ vocabSize = 0;
5
+ eosToken = 0;
6
+ bpe = new o();
7
+ constructor(e, t) {
8
+ super(), e && (this.bpe = new o(e, t), this.vocabSize = e.length);
9
+ }
10
+ get trained() {
11
+ return this.vocabSize > 0;
12
+ }
13
+ destroy() {
14
+ }
15
+ async train(e, t) {
16
+ return this.bpe.train(e, t), this.vocabSize = this.bpe.getVocab().length, this.vocabSize;
17
+ }
18
+ async tokenise(e, t) {
19
+ return t ? this.bpe.tokenise(e, !0) : this.bpe.tokenise(e);
20
+ }
21
+ async detokenise(e) {
22
+ const t = this.bpe.getVocab();
23
+ return e.map((n) => n.map((i) => t[i]).join(""));
24
+ }
25
+ async encode(e) {
26
+ return (await this.tokenise([e], !0))[0];
27
+ }
28
+ async decode(e) {
29
+ return (await this.detokenise([e]))[0];
30
+ }
31
+ getVocab() {
32
+ return this.bpe.getVocab();
33
+ }
34
+ async getMerges() {
35
+ return this.bpe.getMerges();
36
+ }
37
+ async createTrainingData(e, t = 5) {
38
+ const s = this.bpe.tokenise(e, !0), n = [], i = [];
39
+ for (let r = 0; r < s.length - t; r++)
40
+ n.push(...s[r].slice(0, t)), i.push(s[r + 1][0]);
41
+ return [n, i];
42
+ }
43
+ }
44
+ export {
45
+ b as default
46
+ };
@@ -0,0 +1,18 @@
1
+ import { default as EE } from 'eventemitter3';
2
+ import { ITokeniser } from './type';
3
+ export default class WebTokeniser extends EE<'trainStatus'> implements ITokeniser {
4
+ private id;
5
+ vocabSize: number;
6
+ private handler?;
7
+ constructor();
8
+ destroy(): void;
9
+ private post;
10
+ train(text: string[], vocabSize: number): Promise<number>;
11
+ tokenise(text: string[], numeric: true): Promise<number[][]>;
12
+ tokenise(text: string[]): Promise<string[][]>;
13
+ detokenise(tokens: number[][]): Promise<string[]>;
14
+ encode(text: string): Promise<number[]>;
15
+ decode(tokens: number[]): Promise<string>;
16
+ getVocab(): Promise<string[]>;
17
+ createTrainingData(text: string[], windowSize?: number): Promise<[number[], number[]]>;
18
+ }
@@ -0,0 +1,96 @@
1
+ import { E as d } from "../index-SOhdqzHq.js";
2
+ const t = new Worker(new URL(
3
+ /* @vite-ignore */
4
+ "/assets/worker-BYeSPNkq.js",
5
+ import.meta.url
6
+ ), {
7
+ type: "module"
8
+ });
9
+ let r = 0;
10
+ class m extends d {
11
+ id;
12
+ vocabSize = 0;
13
+ handler;
14
+ constructor() {
15
+ super(), this.id = r++, this.handler = (e) => {
16
+ e.data.type === "trainStatus" && e.data.id === this.id && (this.vocabSize = e.data.vocabSize, this.emit("trainStatus", e.data.progress, e.data.vocabSize));
17
+ }, t.addEventListener("message", this.handler);
18
+ }
19
+ destroy() {
20
+ this.handler && (t.removeEventListener("message", this.handler), this.handler = void 0);
21
+ }
22
+ post(e) {
23
+ t.postMessage(e);
24
+ }
25
+ async train(e, n) {
26
+ return new Promise((s) => {
27
+ const i = (a) => {
28
+ a.data.type === "trainResponse" && a.data.id === this.id && (t.removeEventListener("message", i), this.vocabSize = a.data.vocabSize, s(this.vocabSize));
29
+ };
30
+ t.addEventListener("message", i), this.post({
31
+ type: "train",
32
+ id: this.id,
33
+ text: e,
34
+ vocabSize: n
35
+ });
36
+ });
37
+ }
38
+ async tokenise(e, n) {
39
+ return new Promise((s) => {
40
+ const i = (a) => {
41
+ a.data.type === "tokeniseResponse" && a.data.id === this.id && (t.removeEventListener("message", i), s(a.data.tokens));
42
+ };
43
+ t.addEventListener("message", i), this.post({
44
+ type: "tokenise",
45
+ id: this.id,
46
+ text: e,
47
+ numeric: n
48
+ });
49
+ });
50
+ }
51
+ async detokenise(e) {
52
+ return new Promise((n) => {
53
+ const s = (i) => {
54
+ i.data.type === "detokeniseResponse" && i.data.id === this.id && (t.removeEventListener("message", s), n(i.data.text));
55
+ };
56
+ t.addEventListener("message", s), this.post({
57
+ type: "detokenise",
58
+ id: this.id,
59
+ tokens: e
60
+ });
61
+ });
62
+ }
63
+ async encode(e) {
64
+ return (await this.tokenise([e], !0))[0];
65
+ }
66
+ async decode(e) {
67
+ return (await this.detokenise([e]))[0];
68
+ }
69
+ async getVocab() {
70
+ return new Promise((e) => {
71
+ const n = (s) => {
72
+ s.data.type === "tokensResponse" && s.data.id === this.id && (t.removeEventListener("message", n), e(s.data.tokens));
73
+ };
74
+ t.addEventListener("message", n), this.post({
75
+ type: "tokens",
76
+ id: this.id
77
+ });
78
+ });
79
+ }
80
+ async createTrainingData(e, n = 5) {
81
+ return new Promise((s) => {
82
+ const i = (a) => {
83
+ a.data.type === "buildTrainingDataResponse" && a.data.id === this.id && (t.removeEventListener("message", i), s(a.data.trainingData));
84
+ };
85
+ t.addEventListener("message", i), this.post({
86
+ type: "buildTrainingData",
87
+ id: this.id,
88
+ text: e,
89
+ windowSize: n
90
+ });
91
+ });
92
+ }
93
+ }
94
+ export {
95
+ m as default
96
+ };
@@ -0,0 +1,14 @@
1
+ export default class BPE {
2
+ private vocab;
3
+ private vocabIndex;
4
+ private merges;
5
+ private pretokenMap;
6
+ constructor(vocab?: string[], merges?: [string, string][]);
7
+ train(text: string[], vocabSize: number, onUpdate?: (progress: number, vocabSize: number) => void): void;
8
+ getVocab(): string[];
9
+ getMerges(): [string, string][];
10
+ private tokeniseWord;
11
+ private tokeniseStrings;
12
+ tokenise(text: string[], numeric: true): number[][];
13
+ tokenise(text: string[]): string[][];
14
+ }
@@ -0,0 +1,102 @@
1
+ import f from "../utilities/tokenParse.js";
2
+ function b(r) {
3
+ const s = /* @__PURE__ */ new Map();
4
+ for (let e = 0; e < r.length; e++) {
5
+ const t = r[e];
6
+ for (let n = 0; n < t.length - 1; n++) {
7
+ const o = `${t[n]}${t[n + 1]}`, i = s.get(o) || {
8
+ a: t[n],
9
+ b: t[n + 1],
10
+ count: 0,
11
+ instances: /* @__PURE__ */ new Set()
12
+ };
13
+ i.count += 1, i.instances.add(e), s.set(o, i);
14
+ }
15
+ }
16
+ return { pairs: s, tokens: r };
17
+ }
18
+ function h(r, s, e, t, n) {
19
+ const o = `${s}${e}`;
20
+ if (r.pairs.has(o)) {
21
+ const i = r.pairs.get(o);
22
+ i.count += n, i.instances.add(t);
23
+ } else
24
+ r.pairs.set(o, { a: s, b: e, count: n, instances: /* @__PURE__ */ new Set([t]) });
25
+ }
26
+ function g(r) {
27
+ let s = null, e = 0;
28
+ for (const t of r.pairs.values())
29
+ t.count > e && (e = t.count, s = t);
30
+ return s;
31
+ }
32
+ function m(r, s) {
33
+ return r.map((e) => {
34
+ const t = [];
35
+ for (let n = 0; n < e.length; n++)
36
+ n < e.length - 1 && e[n] === s[0] && e[n + 1] === s[1] ? (t.push(s[0] + s[1]), n++) : t.push(e[n]);
37
+ return t;
38
+ });
39
+ }
40
+ function d(r, s) {
41
+ s.instances.forEach((e) => {
42
+ const t = r.tokens[e], n = [];
43
+ for (let o = 0; o < t.length; o++)
44
+ if (o < t.length - 1 && t[o] === s.a && t[o + 1] === s.b) {
45
+ const i = s.a + s.b;
46
+ n.push(i), o > 0 && (h(r, t[o - 1], s.a, e, -1), h(r, t[o - 1], i, e, 1)), o++, o < t.length - 1 && (h(r, s.b, t[o + 1], e, -1), h(r, i, t[o + 1], e, 1));
47
+ } else
48
+ n.push(t[o]);
49
+ r.tokens[e] = n;
50
+ }), r.pairs.delete(`${s.a}${s.b}`);
51
+ }
52
+ class w {
53
+ vocab = /* @__PURE__ */ new Set();
54
+ vocabIndex = /* @__PURE__ */ new Map();
55
+ merges = [];
56
+ pretokenMap = /* @__PURE__ */ new Map();
57
+ constructor(s, e) {
58
+ s && s.forEach((t, n) => {
59
+ this.vocab.add(t), this.vocabIndex.set(t, n);
60
+ }), e && (this.merges = e);
61
+ }
62
+ train(s, e, t) {
63
+ const n = s.map((a) => f(a, !0)).flat(1), o = new Set(n);
64
+ this.vocab = /* @__PURE__ */ new Set(), this.pretokenMap.clear(), this.merges = [], this.vocab.add("<eos>");
65
+ const i = Array.from(o), u = i.map((a) => a.split("").map((c) => (this.vocab.add(c), c))), p = b(u);
66
+ for (; this.vocab.size < e && this.merges.length < e; ) {
67
+ const a = g(p);
68
+ if (!a)
69
+ break;
70
+ this.merges.push([a.a, a.b]), this.vocab.add(a.a + a.b), d(p, a), t && this.vocab.size % 100 === 0 && t(this.vocab.size / e, this.vocab.size);
71
+ }
72
+ i.forEach((a, l) => {
73
+ const c = u[l];
74
+ this.pretokenMap.set(a, c);
75
+ }), this.vocabIndex.clear();
76
+ let k = 0;
77
+ for (const a of this.vocab.keys())
78
+ this.vocabIndex.set(a, k++);
79
+ }
80
+ getVocab() {
81
+ return Array.from(this.vocab);
82
+ }
83
+ getMerges() {
84
+ return this.merges;
85
+ }
86
+ tokeniseWord(s) {
87
+ let e = s.split("");
88
+ return this.merges.forEach((t) => {
89
+ e = m([e], t)[0];
90
+ }), this.pretokenMap.set(s, e), e;
91
+ }
92
+ tokeniseStrings(s) {
93
+ return s.map((e) => f(e, !0).map((o) => this.pretokenMap.has(o) ? this.pretokenMap.get(o) : this.tokeniseWord(o)).flat(1));
94
+ }
95
+ tokenise(s, e) {
96
+ const t = this.tokeniseStrings(s);
97
+ return e ? t.map((n) => n.map((o) => this.vocabIndex.get(o) ?? -1)) : t;
98
+ }
99
+ }
100
+ export {
101
+ w as default
102
+ };
@@ -0,0 +1,61 @@
1
+ interface TrainMessage {
2
+ type: 'train';
3
+ id: number;
4
+ text: string[];
5
+ vocabSize: number;
6
+ }
7
+ interface TrainResponse {
8
+ type: 'trainResponse';
9
+ id: number;
10
+ vocabSize: number;
11
+ }
12
+ interface TrainStatusMessage {
13
+ type: 'trainStatus';
14
+ id: number;
15
+ progress: number;
16
+ vocabSize: number;
17
+ }
18
+ interface TokeniseMessage {
19
+ type: 'tokenise';
20
+ id: number;
21
+ numeric?: boolean;
22
+ text: string[];
23
+ }
24
+ interface TokeniseResponse {
25
+ type: 'tokeniseResponse';
26
+ id: number;
27
+ numeric: boolean;
28
+ tokens: string[][] | number[][];
29
+ }
30
+ interface DetokeniseMessage {
31
+ type: 'detokenise';
32
+ id: number;
33
+ tokens: number[][];
34
+ }
35
+ interface DetokeniseResponse {
36
+ type: 'detokeniseResponse';
37
+ id: number;
38
+ text: string[];
39
+ }
40
+ interface TokensMessage {
41
+ type: 'tokens';
42
+ id: number;
43
+ }
44
+ interface TokensResponse {
45
+ type: 'tokensResponse';
46
+ id: number;
47
+ tokens: string[];
48
+ }
49
+ interface BuildTrainingDataMessage {
50
+ type: 'buildTrainingData';
51
+ id: number;
52
+ text: string[];
53
+ windowSize: number;
54
+ }
55
+ interface BuildTrainingDataResponse {
56
+ type: 'buildTrainingDataResponse';
57
+ id: number;
58
+ trainingData: [number[], number[]];
59
+ }
60
+ export type TokeniserMessage = TrainMessage | TrainResponse | TrainStatusMessage | TokeniseMessage | DetokeniseMessage | TokeniseResponse | DetokeniseResponse | TokensMessage | TokensResponse | BuildTrainingDataMessage | BuildTrainingDataResponse;
61
+ export {};
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,14 @@
1
+ import { default as EE } from 'eventemitter3';
2
+ export interface ITokeniser extends EE<'trainStatus'> {
3
+ train(text: string[], vocabSize: number): Promise<number>;
4
+ tokenise(text: string[], numeric?: boolean): Promise<string[][] | number[][]>;
5
+ detokenise(tokens: string[][] | number[][]): Promise<string[]>;
6
+ getVocab(): string[];
7
+ getMerges(): Promise<[string, string][]>;
8
+ destroy(): void;
9
+ encode(text: string): Promise<number[]>;
10
+ decode(tokens: number[]): Promise<string>;
11
+ vocabSize: number;
12
+ eosToken: number;
13
+ trained: boolean;
14
+ }
@@ -0,0 +1 @@
1
+
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,53 @@
1
+ import d from "./bpe.js";
2
+ let e = new d();
3
+ onmessage = async (s) => {
4
+ if (s.data.type === "tokenise")
5
+ if (s.data.numeric) {
6
+ const t = e.tokenise(s.data.text, !0), a = {
7
+ type: "tokeniseResponse",
8
+ id: s.data.id,
9
+ tokens: t,
10
+ numeric: !0
11
+ };
12
+ postMessage(a);
13
+ } else {
14
+ const t = e.tokenise(s.data.text), a = {
15
+ type: "tokeniseResponse",
16
+ id: s.data.id,
17
+ tokens: t,
18
+ numeric: !1
19
+ };
20
+ postMessage(a);
21
+ }
22
+ else if (s.data.type === "detokenise") {
23
+ const t = e.getVocab(), a = s.data.tokens.map((i) => i.map((n) => t[n]).join("")), o = {
24
+ type: "detokeniseResponse",
25
+ id: s.data.id,
26
+ text: a
27
+ };
28
+ postMessage(o);
29
+ } else if (s.data.type === "train") {
30
+ e = new d(), e.train(s.data.text, s.data.vocabSize ?? 100, (a, o) => {
31
+ const i = {
32
+ type: "trainStatus",
33
+ id: s.data.id,
34
+ progress: a,
35
+ vocabSize: o
36
+ };
37
+ postMessage(i);
38
+ });
39
+ const t = {
40
+ type: "trainResponse",
41
+ id: s.data.id,
42
+ vocabSize: e.getVocab().length
43
+ };
44
+ postMessage(t);
45
+ } else if (s.data.type === "tokens") {
46
+ const t = e.getVocab(), a = {
47
+ type: "tokensResponse",
48
+ id: s.data.id,
49
+ tokens: t
50
+ };
51
+ postMessage(a);
52
+ }
53
+ };
@@ -0,0 +1,23 @@
1
+ import { AdamOptimizer } from '@tensorflow/tfjs-core';
2
+ import { NamedTensor, NamedVariableMap } from '@tensorflow/tfjs-core/dist/tensor_types';
3
+ interface AdamExtConfig {
4
+ warmupSteps: number;
5
+ decaySteps: number;
6
+ minLearningRate: number;
7
+ weightDecay?: number;
8
+ }
9
+ /**
10
+ * Extended Adam optimizer with warmup, cosine decay, and optional weight decay.
11
+ */
12
+ export default class AdamExt extends AdamOptimizer {
13
+ private config;
14
+ private step;
15
+ private startLearningRate;
16
+ constructor(learningRate: number, beta1: number, beta2: number, epsilon: number, config: AdamExtConfig);
17
+ get lr(): number;
18
+ private getAdjustedLearningRate;
19
+ applyGradients(gradientsAndVariables: NamedVariableMap | NamedTensor[]): void;
20
+ private decayVariable;
21
+ private applyWeightDecay;
22
+ }
23
+ export {};
@@ -0,0 +1,43 @@
1
+ import { A as r, m as c, s as h, a as g, e as o } from "../index-B8nyc6IR.js";
2
+ class u extends r {
3
+ constructor(t, e, s, a, i) {
4
+ super(t, e, s, a), this.config = i, this.startLearningRate = t;
5
+ }
6
+ step = 0;
7
+ startLearningRate;
8
+ get lr() {
9
+ return this.learningRate;
10
+ }
11
+ getAdjustedLearningRate() {
12
+ if (this.step++, this.step < this.config.warmupSteps) {
13
+ const s = Math.min(1, (this.step + 1) / (this.config.warmupSteps + 1));
14
+ return this.startLearningRate * s;
15
+ }
16
+ if (this.step > this.config.decaySteps)
17
+ return this.config.minLearningRate;
18
+ const t = (this.step - this.config.warmupSteps) / (this.config.decaySteps - this.config.warmupSteps), e = 0.5 * (1 + Math.cos(Math.PI * t));
19
+ return this.config.minLearningRate + e * (this.startLearningRate - this.config.minLearningRate);
20
+ }
21
+ applyGradients(t) {
22
+ this.learningRate = this.getAdjustedLearningRate(), super.applyGradients(t), this.config.weightDecay && this.config.weightDecay > 0 && this.applyWeightDecay(t);
23
+ }
24
+ decayVariable(t, e, s) {
25
+ if (t && t.shape.length >= 2) {
26
+ const a = c(t, h(s * e));
27
+ t.assign(g(t, a)), a.dispose();
28
+ }
29
+ }
30
+ applyWeightDecay(t) {
31
+ const e = this.config.weightDecay, s = this.learningRate, a = o().registeredVariables;
32
+ Array.isArray(t) ? t.forEach(({ name: i }) => {
33
+ const n = a[i];
34
+ this.decayVariable(n, e, s);
35
+ }) : Object.keys(t).forEach((i) => {
36
+ const n = a[i];
37
+ this.decayVariable(n, e, s);
38
+ });
39
+ }
40
+ }
41
+ export {
42
+ u as default
43
+ };
@@ -0,0 +1,12 @@
1
+ import { ITokeniser } from '../tokeniser/type';
2
+ import { default as TF } from '@tensorflow/tfjs';
3
+ export declare class DatasetBuilder {
4
+ tokenizer: ITokeniser;
5
+ blockSize: number;
6
+ private tf;
7
+ constructor(tf: typeof TF, tokenizer: ITokeniser, blockSize?: number);
8
+ createTextDataset(textData: string[], batchSize?: number): Promise<TF.data.Dataset<{
9
+ xs: TF.Tensor;
10
+ ys: TF.Tensor;
11
+ }>>;
12
+ }
@@ -0,0 +1,27 @@
1
+ class l {
2
+ tokenizer;
3
+ blockSize;
4
+ tf;
5
+ constructor(s, i, o = 128) {
6
+ this.tokenizer = i, this.blockSize = o, this.tf = s;
7
+ }
8
+ // Create dataset from text files
9
+ async createTextDataset(s, i = 32) {
10
+ const o = await Promise.all(s.map((t) => this.tokenizer.encode(t))), a = this.tokenizer.eosToken >= 0, n = o.map((t) => a ? [...t, this.tokenizer.eosToken] : t).flat(), c = (function* () {
11
+ for (; ; ) {
12
+ const t = Math.floor(Math.random() * (n.length - this.blockSize - 1)), e = n.slice(t, t + this.blockSize), r = n.slice(t + 1, t + this.blockSize + 1);
13
+ yield { xs: e, ys: r };
14
+ }
15
+ }).bind(this);
16
+ return this.tf.data.generator(c).batch(i).map((t) => {
17
+ const e = t;
18
+ return this.tf.tidy(() => ({
19
+ xs: e.xs.cast("int32"),
20
+ ys: this.tf.oneHot(e.ys.cast("int32"), this.tokenizer.vocabSize)
21
+ }));
22
+ }).prefetch(2);
23
+ }
24
+ }
25
+ export {
26
+ l as DatasetBuilder
27
+ };
@@ -0,0 +1,17 @@
1
+ import { ITokeniser } from '../tokeniser/type';
2
+ import { default as NanoGPT } from '../NanoGPTModel';
3
+ import { default as TF } from '@tensorflow/tfjs';
4
+ import { default as GPTTrainer, TrainingOptions } from './Trainer';
5
+ export default class FullTrainer extends GPTTrainer {
6
+ constructor(tf: typeof TF, model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
7
+ trainOnDataset(dataset: TF.data.Dataset<{
8
+ xs: TF.Tensor;
9
+ ys: TF.Tensor;
10
+ }>, options: Partial<TrainingOptions>, validationDataset?: TF.data.Dataset<{
11
+ xs: TF.Tensor;
12
+ ys: TF.Tensor;
13
+ }>): Promise<{
14
+ losses: number[];
15
+ validationLosses: number[];
16
+ }>;
17
+ }
@@ -0,0 +1,75 @@
1
+ import { generateText as g } from "../utilities/generate.js";
2
+ import T from "./Trainer.js";
3
+ const b = {
4
+ epochs: 1,
5
+ stepsPerEpoch: 1e6,
6
+ desiredLoss: 0.01,
7
+ logInterval: 1
8
+ };
9
+ class S extends T {
10
+ constructor(a, r, t, i = 3e-4) {
11
+ super(a, r, t, i);
12
+ }
13
+ // Train for multiple epochs using Dataset API - FIXED memory leaks
14
+ async trainOnDataset(a, r, t) {
15
+ const { epochs: i, stepsPerEpoch: n, desiredLoss: c, logInterval: L, onStep: h, onEpoch: o, prompt: l } = {
16
+ ...b,
17
+ ...r
18
+ }, s = {
19
+ epoch: 0,
20
+ pass: 0,
21
+ depth: 1,
22
+ step: 0,
23
+ stepSinceDepthChange: 0,
24
+ lastLoss: 1e6,
25
+ epochLoss: 0,
26
+ totalSteps: 0,
27
+ losses: [],
28
+ validationLosses: []
29
+ };
30
+ this.dummyPass(), this.model.trainable = !0;
31
+ const m = Date.now();
32
+ for (s.epoch = 0; s.epoch < i; s.epoch++) {
33
+ s.step = 0, s.epochLoss = 0, s.pass = 0, s.depth = 1, s.stepSinceDepthChange = 0;
34
+ const u = await a.iterator();
35
+ try {
36
+ for (; !(n && s.step >= n || s.lastLoss < c); ) {
37
+ const e = await u.next();
38
+ if (e.done) break;
39
+ const f = e.value, w = this.trainBatch(s, f), p = {
40
+ epoch: s.epoch,
41
+ loss: s.lastLoss,
42
+ step: s.step,
43
+ time: Date.now() - m,
44
+ batchSize: f.xs.shape[0]
45
+ };
46
+ if (this.model.log.push(p), s.step % L === 0 && (await w, h)) {
47
+ if (l) {
48
+ const v = await g(this.tokenizer, this.model, l, 100, 0.8);
49
+ p.example = v;
50
+ }
51
+ await h(p);
52
+ }
53
+ }
54
+ } catch (e) {
55
+ throw console.error("Training error:", e), this.tf.dispose(), e;
56
+ }
57
+ const d = s.epochLoss / s.step;
58
+ if (t)
59
+ try {
60
+ const e = await this.evaluateOnDataset(t, 5);
61
+ s.validationLosses.push(e), o && await o(s.epoch, d, e);
62
+ } catch (e) {
63
+ console.error("Validation error:", e);
64
+ }
65
+ else
66
+ o && o(s.epoch, d);
67
+ if (this.tf.dispose(), s.lastLoss < c)
68
+ break;
69
+ }
70
+ return { losses: s.losses, validationLosses: s.validationLosses };
71
+ }
72
+ }
73
+ export {
74
+ S as default
75
+ };
@@ -0,0 +1,28 @@
1
+ import { ITokeniser } from '../tokeniser/type';
2
+ import { default as NanoGPT } from '../NanoGPTModel';
3
+ import { default as TF } from '@tensorflow/tfjs';
4
+ import { default as GPTTrainer, TrainingOptions } from './Trainer';
5
+ interface LayerTrainingOptions extends TrainingOptions {
6
+ stepsPerLayer: number;
7
+ maxPasses: number;
8
+ onLayerChange?: (layer: number, pass: number, valLoss?: number) => Promise<void> | void;
9
+ onPassComplete?: (pass: number) => Promise<void> | void;
10
+ }
11
+ export default class LayerTrainer extends GPTTrainer {
12
+ private trainingPattern;
13
+ private startPass;
14
+ private startLayer;
15
+ constructor(tf: typeof TF, model: NanoGPT, tokenizer: ITokeniser, learningRate?: number);
16
+ private applyTrainingPattern;
17
+ trainOnDataset(dataset: TF.data.Dataset<{
18
+ xs: TF.Tensor;
19
+ ys: TF.Tensor;
20
+ }>, options: Partial<LayerTrainingOptions>, validationDataset?: TF.data.Dataset<{
21
+ xs: TF.Tensor;
22
+ ys: TF.Tensor;
23
+ }>): Promise<{
24
+ losses: number[];
25
+ validationLosses: number[];
26
+ }>;
27
+ }
28
+ export {};