@genai-fi/nanogpt 0.7.1 → 0.7.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. package/dist/Generator.d.ts +11 -2
  2. package/dist/Generator.js +81 -68
  3. package/dist/NanoGPTModel.js +8 -8
  4. package/dist/{RealDiv-CVYNbZxu.js → RealDiv-Dy0p8Bvo.js} +7 -7
  5. package/dist/{Reshape-CEsEp0AI.js → Reshape-DH5srBP0.js} +2 -2
  6. package/dist/{Reshape-Do18N3gO.js → Reshape-DvudQDvJ.js} +1 -1
  7. package/dist/TeachableLLM.js +33 -32
  8. package/dist/{TiedEmbedding-ccLBFiZi.js → TiedEmbedding-BxOerUmB.js} +4 -4
  9. package/dist/Trainer.d.ts +6 -1
  10. package/dist/Trainer.js +53 -19
  11. package/dist/{axis_util-5DTW2tFV.js → axis_util-BzbKo31C.js} +1 -1
  12. package/dist/backend.js +2 -2
  13. package/dist/{backend_util-C9Ut8n0Q.js → backend_util-TE7aTPhZ.js} +4 -4
  14. package/dist/{broadcast_to-Ba9h_8DO.js → broadcast_to-CdbwV-Dj.js} +2 -2
  15. package/dist/{concat-CbXTetof.js → concat-CsxrgovM.js} +1 -1
  16. package/dist/{dataset-U3PrjwgU.js → dataset-CtdBYwjo.js} +3 -3
  17. package/dist/{dropout-DPfPgWWe.js → dropout-DYs5QFGQ.js} +1 -1
  18. package/dist/{gather-Bbh8DHhM.js → gather-CMMy2KEG.js} +1 -1
  19. package/dist/{gelu-BFwVnd1r.js → gelu-C-dPj6Ku.js} +1 -1
  20. package/dist/{gpgpu_math-DffelNS-.js → gpgpu_math-DGNLNL4I.js} +2 -2
  21. package/dist/{index-UdZhlibC.js → index-BoWRt-10.js} +4 -4
  22. package/dist/{index-DYD_yPa-.js → index-CLthM0TO.js} +10 -10
  23. package/dist/{kernel_funcs_utils-CXDy3EN7.js → kernel_funcs_utils-BYKWV8Aa.js} +3 -3
  24. package/dist/layers/BaseLayer.js +2 -2
  25. package/dist/layers/CausalSelfAttention.js +6 -6
  26. package/dist/layers/MLP.js +5 -5
  27. package/dist/layers/RMSNorm.js +3 -3
  28. package/dist/layers/RoPECache.js +4 -4
  29. package/dist/layers/TiedEmbedding.js +5 -5
  30. package/dist/layers/TransformerBlock.js +1 -1
  31. package/dist/loader/loadTransformers.js +1 -1
  32. package/dist/loader/oldZipLoad.js +5 -5
  33. package/dist/{log_sum_exp-BnmCkHWl.js → log_sum_exp-DbjkV734.js} +5 -5
  34. package/dist/main.js +5 -5
  35. package/dist/{mat_mul-dwmZz69e.js → mat_mul-8m8pfdcx.js} +1 -1
  36. package/dist/{max-ByjEGoFx.js → max-Ddnnb5xe.js} +1 -1
  37. package/dist/{mulmat_packed_gpu-IGPBp6h9.js → mulmat_packed_gpu-VSekgsNv.js} +1 -1
  38. package/dist/{ones-C8Mfln6-.js → ones-Dj0SDhHf.js} +2 -2
  39. package/dist/ops/adamAdjust.js +1 -1
  40. package/dist/ops/adamMoments.js +1 -1
  41. package/dist/ops/appendCache.js +3 -3
  42. package/dist/ops/attentionMask.js +1 -1
  43. package/dist/ops/cpu/adamAdjust.js +1 -1
  44. package/dist/ops/cpu/adamMoments.js +2 -2
  45. package/dist/ops/cpu/appendCache.js +2 -2
  46. package/dist/ops/cpu/attentionMask.js +5 -5
  47. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  48. package/dist/ops/cpu/gatherSub.js +3 -3
  49. package/dist/ops/cpu/gelu.js +1 -1
  50. package/dist/ops/cpu/matMulGelu.js +2 -2
  51. package/dist/ops/cpu/matMulMul.js +1 -1
  52. package/dist/ops/cpu/mulDropout.js +1 -1
  53. package/dist/ops/cpu/normRMS.js +1 -1
  54. package/dist/ops/cpu/qkv.js +3 -3
  55. package/dist/ops/cpu/rope.js +5 -5
  56. package/dist/ops/cpu/scatterSub.js +5 -5
  57. package/dist/ops/fusedSoftmax.js +1 -1
  58. package/dist/ops/gatherSub.js +1 -1
  59. package/dist/ops/gelu.js +2 -2
  60. package/dist/ops/grads/attentionMask.js +1 -1
  61. package/dist/ops/grads/fusedSoftmax.js +2 -2
  62. package/dist/ops/grads/gelu.js +2 -2
  63. package/dist/ops/grads/matMulGelu.js +1 -1
  64. package/dist/ops/grads/normRMS.js +1 -1
  65. package/dist/ops/grads/qkv.js +1 -1
  66. package/dist/ops/grads/rope.js +1 -1
  67. package/dist/ops/matMulGelu.js +1 -1
  68. package/dist/ops/matMulMul.js +1 -1
  69. package/dist/ops/mulDrop.js +1 -1
  70. package/dist/ops/normRMS.js +1 -1
  71. package/dist/ops/qkv.js +1 -1
  72. package/dist/ops/rope.js +4 -4
  73. package/dist/ops/scatterSub.js +1 -1
  74. package/dist/ops/webgl/adamAdjust.js +2 -2
  75. package/dist/ops/webgl/adamMoments.js +7 -5
  76. package/dist/ops/webgl/appendCache.js +1 -1
  77. package/dist/ops/webgl/attentionMask.js +1 -1
  78. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  79. package/dist/ops/webgl/gatherSub.js +1 -1
  80. package/dist/ops/webgl/gelu.js +2 -2
  81. package/dist/ops/webgl/log.js +3 -3
  82. package/dist/ops/webgl/matMulGelu.js +4 -4
  83. package/dist/ops/webgl/matMulMul.js +1 -1
  84. package/dist/ops/webgl/mulDropout.js +1 -1
  85. package/dist/ops/webgl/normRMS.js +2 -2
  86. package/dist/ops/webgl/qkv.js +1 -1
  87. package/dist/ops/webgl/rope.js +1 -1
  88. package/dist/ops/webgl/scatterSub.js +1 -1
  89. package/dist/ops/webgpu/adamAdjust.js +15 -13
  90. package/dist/ops/webgpu/adamMoments.js +18 -11
  91. package/dist/ops/webgpu/appendCache.js +18 -15
  92. package/dist/ops/webgpu/attentionMask.js +24 -18
  93. package/dist/ops/webgpu/gatherSub.js +17 -30
  94. package/dist/ops/webgpu/gelu.js +3 -3
  95. package/dist/ops/webgpu/normRMS.js +16 -8
  96. package/dist/ops/webgpu/normRMSGrad.js +25 -20
  97. package/dist/ops/webgpu/qkv.js +23 -19
  98. package/dist/ops/webgpu/rope.js +37 -24
  99. package/dist/ops/webgpu/scatterSub.js +16 -14
  100. package/dist/ops/webgpu/utils/reductions.js +4 -4
  101. package/dist/{ops-aRTXR2Sr.js → ops-BFGCx8Ri.js} +15 -15
  102. package/dist/{random_width-DbSpgl4o.js → random_width-sZORGo5k.js} +22 -22
  103. package/dist/{range-D9CZhVlR.js → range-CRuAh-gd.js} +1 -1
  104. package/dist/{reciprocal-CGB48wZB.js → reciprocal-BvGAyKyu.js} +1 -1
  105. package/dist/{register_all_kernels-DnbAyBXt.js → register_all_kernels-BwDSRN-f.js} +30 -30
  106. package/dist/{reshape-BR0eoLYN.js → reshape-CdBq1WJ6.js} +1 -1
  107. package/dist/{scatter_nd_util-OjyAxku2.js → scatter_nd_util-DUstGbU1.js} +1 -1
  108. package/dist/{selu_util-Ce6pu9IM.js → selu_util-BJEXVvjX.js} +3 -3
  109. package/dist/{shared-Czipaeb6.js → shared-B8ztnyEk.js} +6 -6
  110. package/dist/{shared-DS5waSIY.js → shared-wS99K7_n.js} +1 -1
  111. package/dist/{sin-CiBxrDqX.js → sin-BeA3tsEd.js} +1 -1
  112. package/dist/{slice-BHbDHObE.js → slice-BiOsknYS.js} +1 -1
  113. package/dist/{softmax-JMEIUo2J.js → softmax-Bv_6lyMX.js} +1 -1
  114. package/dist/{split-CRU0PjVV.js → split-B-dikLRw.js} +1 -1
  115. package/dist/{stack-ikk2Y8_P.js → stack-B17UN2nn.js} +1 -1
  116. package/dist/{sum-NLYbiDag.js → sum-66ew2byf.js} +1 -1
  117. package/dist/{tensor-Do9PKbIE.js → tensor-JwS7ZYY6.js} +1 -1
  118. package/dist/{tensor2d-CWHxHpLh.js → tensor2d-wxPAnDQy.js} +1 -1
  119. package/dist/training/Adam.js +2 -2
  120. package/dist/training/AdamExt.js +1 -1
  121. package/dist/training/DatasetBuilder.js +35 -32
  122. package/dist/training/FullTrainer.d.ts +15 -2
  123. package/dist/training/FullTrainer.js +97 -51
  124. package/dist/training/Trainer.d.ts +10 -0
  125. package/dist/training/Trainer.js +2 -2
  126. package/dist/training/sparseCrossEntropy.js +4 -4
  127. package/dist/utilities/dummy.js +2 -2
  128. package/dist/utilities/generate.js +3 -3
  129. package/dist/utilities/multinomialCPU.js +2 -2
  130. package/dist/utilities/performance.js +1 -1
  131. package/dist/utilities/profile.js +1 -1
  132. package/dist/utilities/safetensors.js +2 -2
  133. package/dist/utilities/weights.js +2 -2
  134. package/dist/{variable-BTBkayv_.js → variable-BuddVFLa.js} +1 -1
  135. package/dist/{webgpu_program-WaoMq-WD.js → webgpu_program-PFzf1hAQ.js} +1 -1
  136. package/dist/{webgpu_util-DhSeP4b6.js → webgpu_util-D____QpY.js} +1 -1
  137. package/dist/{zeros-DnPT2nD4.js → zeros--BdLQ3oG.js} +1 -1
  138. package/package.json +1 -1
@@ -9,11 +9,20 @@ export default class Generator extends EE<'start' | 'stop' | 'tokens'> {
9
9
  private readonly model;
10
10
  private readonly tokeniser;
11
11
  private active;
12
+ private cache;
13
+ private initialPrompt;
14
+ private outputText;
15
+ private actualTokeniser;
16
+ private lastToken;
12
17
  constructor(model: NanoGPT, tokeniser: ITokeniser);
13
18
  private tokenisePrompt;
14
- private generateNoCache;
15
19
  private processResponse;
16
- private generateCache;
20
+ private _generate;
21
+ reset(): void;
22
+ dispose(): void;
23
+ private initialise;
24
+ step(prompt?: string, options?: IGenerateOptions): Promise<string>;
17
25
  generate(prompt?: string, options?: IGenerateOptions): Promise<string>;
18
26
  stop(): void;
27
+ getText(): string;
19
28
  }
package/dist/Generator.js CHANGED
@@ -1,15 +1,15 @@
1
- import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-UdZhlibC.js";
1
+ import { E as l } from "./index-Dwqa6Zy2.js";
2
+ import "./index-BoWRt-10.js";
3
3
  import "./ops/cpu/attentionMask.js";
4
4
  import "./ops/webgl/attentionMask.js";
5
5
  import "./ops/grads/attentionMask.js";
6
6
  import "./ops/cpu/qkv.js";
7
7
  import "./ops/webgl/qkv.js";
8
8
  import "./ops/grads/qkv.js";
9
- import "./random_width-DbSpgl4o.js";
10
- import "./register_all_kernels-DnbAyBXt.js";
9
+ import "./random_width-sZORGo5k.js";
10
+ import "./register_all_kernels-BwDSRN-f.js";
11
11
  import "./index-Tf7vU29b.js";
12
- import "./dataset-U3PrjwgU.js";
12
+ import "./dataset-CtdBYwjo.js";
13
13
  import "./ops/cpu/rope.js";
14
14
  import "./ops/webgl/rope.js";
15
15
  import "./ops/grads/rope.js";
@@ -29,7 +29,7 @@ import "./ops/webgl/gatherSub.js";
29
29
  import "./ops/cpu/scatterSub.js";
30
30
  import "./ops/webgl/scatterSub.js";
31
31
  import "./jszip.min-CjP2V1VV.js";
32
- import f from "./tokeniser/CharTokeniser.js";
32
+ import u from "./tokeniser/CharTokeniser.js";
33
33
  import "./ops/cpu/adamAdjust.js";
34
34
  import "./ops/webgl/adamAdjust.js";
35
35
  import "./ops/cpu/adamMoments.js";
@@ -37,12 +37,12 @@ import "./ops/webgl/adamMoments.js";
37
37
  import "./papaparse.min-C8l2Kvo1.js";
38
38
  import "./ops/cpu/gelu.js";
39
39
  import "./ops/webgl/gelu.js";
40
- import "./gelu-BFwVnd1r.js";
40
+ import "./gelu-C-dPj6Ku.js";
41
41
  import "./ops/webgl/log.js";
42
- import { t as d } from "./tensor2d-CWHxHpLh.js";
43
- import { c as g } from "./concat-CbXTetof.js";
42
+ import { t as p } from "./tensor2d-wxPAnDQy.js";
43
+ import { c as f } from "./concat-CsxrgovM.js";
44
44
  const k = [
45
- ...Array.from({ length: 95 }, (a, t) => String.fromCharCode(t + 32)),
45
+ ...Array.from({ length: 95 }, (r, t) => String.fromCharCode(t + 32)),
46
46
  // ASCII
47
47
  // Spanish accented letters and punctuation
48
48
  ..."áéíóúüñ¿¡",
@@ -53,80 +53,93 @@ const k = [
53
53
  // Cyrillic letters
54
54
  ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
55
55
  ];
56
- function w(a, t) {
57
- return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
56
+ function d(r, t) {
57
+ return r.length === t ? r : r.length > t ? r.slice(0, t) : r.concat(Array(t - r.length).fill(""));
58
58
  }
59
- class pt extends u {
60
- constructor(t, o) {
61
- super(), this.model = t, this.tokeniser = o;
59
+ class nt extends l {
60
+ constructor(t, i) {
61
+ super(), this.model = t, this.tokeniser = i, this.actualTokeniser = i;
62
62
  }
63
63
  active = !1;
64
- async tokenisePrompt(t, o) {
65
- const r = o ? await t.tokenise([o], !0) : [[t.eosToken]];
66
- return d(r, [1, r[0].length], "int32");
64
+ cache = null;
65
+ initialPrompt = null;
66
+ outputText = "";
67
+ actualTokeniser;
68
+ lastToken = -1;
69
+ async tokenisePrompt(t, i) {
70
+ const e = i ? await t.tokenise([i], !0) : [[t.eosToken]];
71
+ return p(e, [1, e[0].length], "int32");
67
72
  }
68
- async generateNoCache(t, o, r) {
69
- let i = await this.tokenisePrompt(t, o), s = o || "";
70
- const n = r?.maxLength ?? 1e3;
71
- for (let m = 0; m < n && this.active; m++) {
72
- const {
73
- output: e,
74
- attention: p,
75
- probabilities: c
76
- } = await this.model.generate(i, void 0, r), h = i;
77
- i = g([i, e], 1), h.dispose();
78
- const l = await this.processResponse(t, e, p, c);
79
- if (e.dispose(), l === null)
80
- break;
81
- s += l;
82
- }
83
- return i.dispose(), s;
84
- }
85
- async processResponse(t, o, r, i) {
86
- const s = (await o.array())[0][0];
87
- if (s === this.tokeniser.eosToken)
73
+ async processResponse(t, i, e, o) {
74
+ const s = (await i.array())[0][0];
75
+ if (this.lastToken = s, s === this.tokeniser.eosToken)
88
76
  return null;
89
77
  const n = await t.decode([s]);
90
- let m;
91
- r && (m = await Promise.all(r.map((p) => p.array().then((c) => c))), r.forEach((p) => p.dispose()));
92
- let e;
93
- return i && (e = await i.array(), i.dispose()), this.emit("tokens", [s], n, m, e), n;
78
+ let c;
79
+ e && (c = await Promise.all(e.map((h) => h.array().then((m) => m))), e.forEach((h) => h.dispose()));
80
+ let a;
81
+ return o && (a = await o.array(), o.dispose()), this.emit("tokens", [s], n, c, a), n;
94
82
  }
95
- async generateCache(t, o, r) {
96
- let i = await this.tokenisePrompt(t, o), s = o || "";
97
- const n = new Array(this.model.config.gpt.nLayer);
98
- for (let e = 0; e < this.model.config.gpt.nLayer; e++)
99
- n[e] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
100
- const m = r?.maxLength ?? 1e3;
101
- for (let e = 0; e < m && this.active; e++) {
83
+ async _generate(t) {
84
+ let i = this.lastToken >= 0 && this.cache ? p([this.lastToken], [1, 1], "int32") : await this.tokenisePrompt(this.actualTokeniser, this.outputText);
85
+ const e = t?.maxLength ?? 1e3;
86
+ for (let o = 0; o < e && this.active; o++) {
102
87
  const {
103
- output: p,
104
- probabilities: c,
105
- attention: h
106
- } = await this.model.generate(i, n, {
107
- ...r,
108
- usePadding: !1
88
+ output: s,
89
+ probabilities: n,
90
+ attention: c
91
+ } = await this.model.generate(i, this.cache ? this.cache : void 0, {
92
+ ...t,
93
+ usePadding: !this.cache
109
94
  });
110
- i.dispose(), i = p;
111
- const l = await this.processResponse(t, p, h, c);
112
- if (l === null)
95
+ if (this.cache)
96
+ i.dispose(), i = s;
97
+ else {
98
+ const h = i;
99
+ i = f([i, s], 1), h.dispose();
100
+ }
101
+ const a = await this.processResponse(this.actualTokeniser, s, c, n);
102
+ if (this.cache || s.dispose(), a === null)
113
103
  break;
114
- s += l;
104
+ this.outputText += a;
105
+ }
106
+ return i.dispose(), this.outputText;
107
+ }
108
+ reset() {
109
+ this.cache && (this.cache.forEach((t) => {
110
+ t && (t.k && t.k.dispose(), t.v && t.v.dispose());
111
+ }), this.cache = null), this.outputText = "", this.initialPrompt = null, this.lastToken = -1;
112
+ }
113
+ dispose() {
114
+ this.reset();
115
+ }
116
+ initialise(t, i) {
117
+ const e = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t ?? null;
118
+ if (this.cache && i?.noCache && this.reset(), this.initialPrompt = e || null, this.lastToken === -1 && (this.outputText = this.initialPrompt || ""), !this.cache && !i?.noCache && this.model.config.gpt.useRope) {
119
+ const s = new Array(this.model.config.gpt.nLayer);
120
+ for (let n = 0; n < this.model.config.gpt.nLayer; n++)
121
+ s[n] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
122
+ this.cache = s, this.lastToken = -1;
115
123
  }
116
- return n.forEach((e) => {
117
- e && (e.k && e.k.dispose(), e.v && e.v.dispose());
118
- }), i.dispose(), s;
124
+ const o = this.tokeniser.trained ? this.tokeniser : new u(d(k, this.tokeniser.vocabSize));
125
+ this.actualTokeniser = o;
119
126
  }
120
- async generate(t, o) {
121
- const r = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t;
122
- this.active = !0, this.emit("start");
123
- const i = this.tokeniser.trained ? this.tokeniser : new f(w(k, this.tokeniser.vocabSize)), n = await (this.model.config.gpt.useRope && !o?.noCache ? this.generateCache(i, r, o) : this.generateNoCache(i, r, o));
124
- return this.active = !1, this.emit("stop"), n;
127
+ async step(t, i) {
128
+ const e = { ...i, maxLength: 1 };
129
+ return this.generate(t, e);
130
+ }
131
+ async generate(t, i) {
132
+ this.initialise(t, i), this.active = !0, this.emit("start");
133
+ const o = await this._generate(i);
134
+ return this.active = !1, this.emit("stop"), o;
125
135
  }
126
136
  stop() {
127
137
  this.active = !1;
128
138
  }
139
+ getText() {
140
+ return this.outputText;
141
+ }
129
142
  }
130
143
  export {
131
- pt as default
144
+ nt as default
132
145
  };
@@ -1,19 +1,19 @@
1
1
  import { defaultConfig as M } from "./config.js";
2
2
  import v from "./layers/TransformerBlock.js";
3
- import { T as x, r as T } from "./TiedEmbedding-ccLBFiZi.js";
3
+ import { T as x, r as T } from "./TiedEmbedding-BxOerUmB.js";
4
4
  import F from "./layers/RoPECache.js";
5
5
  import O from "./layers/RMSNorm.js";
6
6
  import { estimateParameterCount as _ } from "./utilities/parameters.js";
7
7
  import { createSoftmaxCrossEntropyWithGrad as D } from "./training/sparseCrossEntropy.js";
8
8
  import K from "./layers/BaseLayer.js";
9
- import { E as N, D as R, p as q } from "./random_width-DbSpgl4o.js";
10
- import { B as A, C as B, E as G, ad as V, t as C, o as j, b as z, w as U } from "./index-UdZhlibC.js";
9
+ import { E as N, D as R, p as q } from "./random_width-sZORGo5k.js";
10
+ import { B as A, C as B, E as G, ad as V, t as C, o as j, b as z, w as U } from "./index-BoWRt-10.js";
11
11
  import W from "./utilities/multinomialCPU.js";
12
- import { m as H, t as J } from "./register_all_kernels-DnbAyBXt.js";
13
- import { r as P } from "./reshape-BR0eoLYN.js";
14
- import { r as Q } from "./range-D9CZhVlR.js";
15
- import { s as $ } from "./softmax-JMEIUo2J.js";
16
- import { g as X } from "./gather-Bbh8DHhM.js";
12
+ import { m as H, t as J } from "./register_all_kernels-BwDSRN-f.js";
13
+ import { r as P } from "./reshape-CdBq1WJ6.js";
14
+ import { r as Q } from "./range-CRuAh-gd.js";
15
+ import { s as $ } from "./softmax-Bv_6lyMX.js";
16
+ import { g as X } from "./gather-CMMy2KEG.js";
17
17
  /**
18
18
  * @license
19
19
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,10 +1,10 @@
1
- import { aq as T, ac as E, p as O, j as V, ay as B, Y as F, U, az as j } from "./index-UdZhlibC.js";
2
- import { r as $ } from "./Reshape-CEsEp0AI.js";
3
- import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-5DTW2tFV.js";
4
- import { t as K, m as W } from "./shared-DS5waSIY.js";
5
- import { c as _ } from "./backend_util-C9Ut8n0Q.js";
6
- import { f as y } from "./gpgpu_math-DffelNS-.js";
7
- import { g as G, b as L } from "./kernel_funcs_utils-CXDy3EN7.js";
1
+ import { aq as T, ac as E, p as O, j as V, ay as B, Y as F, U, az as j } from "./index-BoWRt-10.js";
2
+ import { r as $ } from "./Reshape-DH5srBP0.js";
3
+ import { g as A, a as k, b as C, c as N, e as R } from "./axis_util-BzbKo31C.js";
4
+ import { t as K, m as W } from "./shared-wS99K7_n.js";
5
+ import { c as _ } from "./backend_util-TE7aTPhZ.js";
6
+ import { f as y } from "./gpgpu_math-DGNLNL4I.js";
7
+ import { g as G, b as L } from "./kernel_funcs_utils-BYKWV8Aa.js";
8
8
  /**
9
9
  * @license
10
10
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,5 +1,5 @@
1
- import { j as c, a2 as C, l as f, K as R } from "./index-UdZhlibC.js";
2
- import { u as g, g as I, a as x, b as F, c as $, d as u, e as l, i as m } from "./gpgpu_math-DffelNS-.js";
1
+ import { j as c, a3 as C, l as f, K as R } from "./index-BoWRt-10.js";
2
+ import { u as g, g as I, a as x, b as F, c as $, d as u, e as l, i as m } from "./gpgpu_math-DGNLNL4I.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { j as h, a2 as d, l as c, K as m } from "./index-UdZhlibC.js";
1
+ import { j as h, a3 as d, l as c, K as m } from "./index-BoWRt-10.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2021 Google LLC. All Rights Reserved.
@@ -1,17 +1,17 @@
1
1
  import { defaultConfig as _ } from "./config.js";
2
2
  import f from "./NanoGPTModel.js";
3
- import { saveModel as u } from "./utilities/save.js";
4
- import { loadModel as d } from "./loader/load.js";
5
- import l from "./Generator.js";
3
+ import { saveModel as d } from "./utilities/save.js";
4
+ import { loadModel as l } from "./loader/load.js";
5
+ import u from "./Generator.js";
6
6
  import p from "./Trainer.js";
7
- import { E as g } from "./index-Dwqa6Zy2.js";
7
+ import { E as c } from "./index-Dwqa6Zy2.js";
8
8
  import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
9
- import c from "./tokeniser/CharTokeniser.js";
9
+ import g from "./tokeniser/CharTokeniser.js";
10
10
  import k from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index-UdZhlibC.js";
14
+ import "./index-BoWRt-10.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";
@@ -22,9 +22,9 @@ import "./ops/grads/attentionMask.js";
22
22
  import "./ops/cpu/qkv.js";
23
23
  import "./ops/webgl/qkv.js";
24
24
  import "./ops/grads/qkv.js";
25
- import "./random_width-DbSpgl4o.js";
26
- import "./register_all_kernels-DnbAyBXt.js";
27
- import "./dataset-U3PrjwgU.js";
25
+ import "./random_width-sZORGo5k.js";
26
+ import "./register_all_kernels-BwDSRN-f.js";
27
+ import "./dataset-CtdBYwjo.js";
28
28
  import "./ops/cpu/rope.js";
29
29
  import "./ops/webgl/rope.js";
30
30
  import "./ops/grads/rope.js";
@@ -38,7 +38,7 @@ import "./ops/webgl/matMulGelu.js";
38
38
  import "./ops/grads/matMulGelu.js";
39
39
  import "./ops/cpu/gelu.js";
40
40
  import "./ops/webgl/gelu.js";
41
- import "./gelu-BFwVnd1r.js";
41
+ import "./gelu-C-dPj6Ku.js";
42
42
  import "./ops/cpu/normRMS.js";
43
43
  import "./ops/webgl/normRMS.js";
44
44
  import "./ops/grads/normRMS.js";
@@ -49,7 +49,7 @@ import "./ops/cpu/adamAdjust.js";
49
49
  import "./ops/webgl/adamAdjust.js";
50
50
  import w from "./utilities/profile.js";
51
51
  class a {
52
- ee = new g();
52
+ ee = new c();
53
53
  _config;
54
54
  _model;
55
55
  _tokeniser;
@@ -92,8 +92,8 @@ class a {
92
92
  return this._status === "busy" || this._status === "training";
93
93
  }
94
94
  estimateTrainingMemoryUsage(t) {
95
- const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 }, i = e.perBatch * t, o = e.gradients;
96
- return i * 0.66 + o * 4;
95
+ const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 }, r = e.perBatch * t, o = e.gradients;
96
+ return r * 0.66 + o * 4;
97
97
  }
98
98
  setStatus(t) {
99
99
  this._status !== t && (this._status = t, this.ee.emit("status", t));
@@ -101,32 +101,32 @@ class a {
101
101
  saveModel(t) {
102
102
  if (!this._model || !this._tokeniser)
103
103
  throw new Error("model_or_tokeniser_not_initialized.");
104
- return u(this._model, this._tokeniser, {
104
+ return d(this._model, this._tokeniser, {
105
105
  ...t,
106
106
  name: t?.name || this.meta.name
107
107
  });
108
108
  }
109
109
  static loadModel(t) {
110
110
  const e = new a();
111
- return d(t).then(({ model: i, tokeniser: o, name: s }) => {
112
- e._model = i, e._tokeniser = o, e._config = i.config, s && (e.meta.name = s), e.setStatus("warmup"), m(i).then((r) => {
113
- e._memoryRequirements = r, e.setStatus("ready"), e.ee.emit("loaded");
114
- }).catch((r) => {
115
- e.setStatus("error"), e.ee.emit("error", r);
111
+ return l(t).then(({ model: r, tokeniser: o, name: s }) => {
112
+ e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
113
+ e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
114
+ }).catch((i) => {
115
+ e.setStatus("error"), e.ee.emit("error", i);
116
116
  });
117
- }).catch((i) => {
118
- e.setStatus("error"), e.ee.emit("error", i);
117
+ }).catch((r) => {
118
+ e.setStatus("error"), e.ee.emit("error", r);
119
119
  }), e;
120
120
  }
121
121
  static create(t, e = {}) {
122
- const i = { ..._, ...e }, o = t === "char" ? new c(i.vocabSize) : new k(i.vocabSize), s = new f(i), r = new a(o, s);
123
- return r.setStatus("warmup"), m(s).then((n) => {
124
- r._memoryRequirements = n, r.tokeniser.trained ? (r.setStatus("ready"), r.ee.emit("loaded")) : (r.setStatus("awaitingTokens"), r.ee.emit("loaded"), r.tokeniser.once("trainStatus", (h) => {
125
- h === "trained" && r.setStatus("ready");
122
+ const r = { ..._, ...e }, o = t === "char" ? new g(r.vocabSize) : new k(r.vocabSize), s = new f(r), i = new a(o, s);
123
+ return i.setStatus("warmup"), m(s).then((n) => {
124
+ i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
125
+ h === "trained" && i.setStatus("ready");
126
126
  }));
127
127
  }).catch((n) => {
128
- r.setStatus("error"), r.ee.emit("error", n);
129
- }), r;
128
+ i.setStatus("error"), i.ee.emit("error", n);
129
+ }), i;
130
130
  }
131
131
  getProfiler() {
132
132
  return this._model?.getProfiler();
@@ -149,14 +149,15 @@ class a {
149
149
  if (!this._model || !this._tokeniser)
150
150
  throw new Error("model_or_tokeniser_not_initialized.");
151
151
  const t = new p(this._model, this._tokeniser);
152
- return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, i) => {
152
+ return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
153
153
  const o = this.ee.listeners("trainStep");
154
154
  for (const s of o)
155
- await s(e, i);
155
+ await s(e, r);
156
156
  }), t;
157
157
  }
158
- train(t, e) {
159
- return this.trainer().train(t, e);
158
+ async train(t, e) {
159
+ const r = this.trainer();
160
+ await r.prepare(t, e), await r.train(e);
160
161
  }
161
162
  async trainTokeniser(t) {
162
163
  if (!this._tokeniser)
@@ -167,7 +168,7 @@ class a {
167
168
  generator() {
168
169
  if (!this._model || !this._tokeniser)
169
170
  throw new Error("model_or_tokeniser_not_initialized.");
170
- const t = new l(this._model, this._tokeniser);
171
+ const t = new u(this._model, this._tokeniser);
171
172
  return t.on("start", () => {
172
173
  this.status === "ready" && this.setStatus("busy");
173
174
  }), t.on("stop", () => {
@@ -1,8 +1,8 @@
1
- import { R as a, d as s } from "./random_width-DbSpgl4o.js";
2
- import "./index-UdZhlibC.js";
1
+ import { R as a, d as s } from "./random_width-sZORGo5k.js";
2
+ import "./index-BoWRt-10.js";
3
3
  import o from "./layers/BaseLayer.js";
4
- import { v as m } from "./variable-BTBkayv_.js";
5
- import { g as d } from "./gather-Bbh8DHhM.js";
4
+ import { v as m } from "./variable-BuddVFLa.js";
5
+ import { g as d } from "./gather-CMMy2KEG.js";
6
6
  /**
7
7
  * @license
8
8
  * Copyright 2018 Google LLC
package/dist/Trainer.d.ts CHANGED
@@ -14,8 +14,13 @@ export interface ITrainerOptions {
14
14
  export default class Trainer extends EE<'start' | 'stop' | 'log'> {
15
15
  private trainer;
16
16
  private hasTrained;
17
+ private trainDataset?;
18
+ private validationDataset?;
19
+ private totalSamples;
17
20
  constructor(model: NanoGPT, tokeniser: ITokeniser);
18
21
  stop(): void;
19
22
  reset(): void;
20
- train(text: string[], options?: ITrainerOptions): Promise<void>;
23
+ prepare(text: string[], options?: ITrainerOptions): Promise<void>;
24
+ train(options?: ITrainerOptions): Promise<void>;
25
+ step(options?: ITrainerOptions): Promise<void>;
21
26
  }
package/dist/Trainer.js CHANGED
@@ -1,10 +1,13 @@
1
- import { E as h } from "./index-Dwqa6Zy2.js";
2
- import m from "./training/FullTrainer.js";
3
- class p extends h {
1
+ import { E as l } from "./index-Dwqa6Zy2.js";
2
+ import h from "./training/FullTrainer.js";
3
+ class p extends l {
4
4
  trainer;
5
5
  hasTrained = !1;
6
- constructor(e, t) {
7
- super(), this.trainer = new m(e, t, 1e-3);
6
+ trainDataset;
7
+ validationDataset;
8
+ totalSamples = 0;
9
+ constructor(t, e) {
10
+ super(), this.trainer = new h(t, e, 1e-3);
8
11
  }
9
12
  stop() {
10
13
  this.trainer.stop();
@@ -12,36 +15,67 @@ class p extends h {
12
15
  reset() {
13
16
  this.hasTrained = !1, this.trainer.reset();
14
17
  }
15
- async train(e, t) {
16
- const { trainDataset: s, validationDataset: n } = await this.trainer.createTrainValidationSplit(
17
- e,
18
- t?.batchSize || 32,
19
- t?.validationSplit || 0.1
20
- ), r = e.reduce((i, a) => i + a.length, 0) * (1 - (t?.validationSplit || 0));
18
+ async prepare(t, e) {
19
+ const { trainDataset: a, validationDataset: s } = await this.trainer.createTrainValidationSplit(
20
+ t,
21
+ e?.batchSize || 32,
22
+ e?.validationSplit || 0.1
23
+ ), i = t.reduce((r, n) => r + n.length, 0) * (1 - (e?.validationSplit || 0));
24
+ this.trainDataset = a, this.validationDataset = s, this.totalSamples = i;
25
+ }
26
+ async train(t) {
27
+ if (!this.trainDataset || !this.validationDataset)
28
+ throw new Error("Datasets not prepared");
21
29
  this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), await this.trainer.trainOnDataset(
22
- s,
30
+ this.trainDataset,
23
31
  {
24
32
  prompt: t?.prompt,
25
33
  logInterval: t?.logInterval || 10,
26
34
  desiredLoss: t?.desiredLoss || 0.01,
27
35
  maxSteps: t?.maxSteps || 1e3,
28
36
  advancedMetrics: t?.advancedMetrics || !1,
29
- onStep: async (i, a) => {
30
- const l = this.listeners("log");
31
- for (const d of l)
32
- await d(i, {
37
+ onStep: async (e, a) => {
38
+ const s = this.listeners("log");
39
+ for (const i of s)
40
+ await i(e, {
33
41
  ...a,
34
- progress: a.totalSamples / r,
42
+ progress: a.totalSamples / this.totalSamples,
35
43
  remaining: Math.max(
36
44
  0,
37
- (r - a.totalSamples) / a.totalSamples * a.duration
45
+ (this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
38
46
  )
39
47
  });
40
48
  }
41
49
  },
42
- n
50
+ this.validationDataset
43
51
  ), this.emit("stop");
44
52
  }
53
+ async step(t) {
54
+ if (!this.trainDataset || !this.validationDataset)
55
+ throw new Error("Datasets not prepared");
56
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
57
+ const { log: e, progress: a } = await this.trainer.stepDataset(
58
+ this.trainDataset,
59
+ {
60
+ prompt: t?.prompt,
61
+ logInterval: t?.logInterval || 10,
62
+ desiredLoss: t?.desiredLoss || 0.01,
63
+ maxSteps: t?.maxSteps || 1e3,
64
+ advancedMetrics: t?.advancedMetrics || !1
65
+ },
66
+ this.validationDataset
67
+ ), s = this.listeners("log");
68
+ for (const i of s)
69
+ await i(e, {
70
+ ...a,
71
+ progress: a.totalSamples / this.totalSamples,
72
+ remaining: Math.max(
73
+ 0,
74
+ (this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
75
+ )
76
+ });
77
+ this.emit("stop");
78
+ }
45
79
  }
46
80
  export {
47
81
  p as default
@@ -1,4 +1,4 @@
1
- import { l as c } from "./index-UdZhlibC.js";
1
+ import { l as c } from "./index-BoWRt-10.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2017 Google LLC. All Rights Reserved.
package/dist/backend.js CHANGED
@@ -1,6 +1,6 @@
1
- import { g as a, s as i, r as o } from "./index-UdZhlibC.js";
1
+ import { g as a, s as i, r as o } from "./index-BoWRt-10.js";
2
2
  async function e(t) {
3
- a() !== t && (t === "webgpu" && (await import("./index-DYD_yPa-.js"), await import("./ops/webgpu/index.js")), await i(t), await o(), console.log(`Backend set to ${t}`));
3
+ a() !== t && (t === "webgpu" && (await import("./index-CLthM0TO.js"), await import("./ops/webgpu/index.js")), await i(t), await o(), console.log(`Backend set to ${t}`));
4
4
  }
5
5
  export {
6
6
  e as selectBackend
@@ -1,7 +1,7 @@
1
- import { j as m, a1 as O, l as g, aK as $, aL as R, aM as M, k as _, aa as y, aw as D, aN as T, u as b, aO as F } from "./index-UdZhlibC.js";
2
- import { b as L, d as W, f as v, c as N, e as x, g as P, a as C, h as z } from "./axis_util-5DTW2tFV.js";
3
- import { S as U, a as B, b as V, c as j, d as k, e as G, f as H, g as q, h as Z, i as K, j as X, k as J, l as Y, m as Q, s as ee, n as te, o as ne, t as se } from "./selu_util-Ce6pu9IM.js";
4
- import { c as re, v as oe, a as ae } from "./scatter_nd_util-OjyAxku2.js";
1
+ import { j as m, a1 as O, l as g, aK as $, aL as R, aM as M, k as _, aa as y, aw as D, aN as T, u as b, aO as F } from "./index-BoWRt-10.js";
2
+ import { b as L, d as W, f as v, c as N, e as x, g as P, a as C, h as z } from "./axis_util-BzbKo31C.js";
3
+ import { S as U, a as B, b as V, c as j, d as k, e as G, f as H, g as q, h as Z, i as K, j as X, k as J, l as Y, m as Q, s as ee, n as te, o as ne, t as se } from "./selu_util-BJEXVvjX.js";
4
+ import { c as re, v as oe, a as ae } from "./scatter_nd_util-DUstGbU1.js";
5
5
  function ie(e, n) {
6
6
  const r = e.shape.length, t = n.shape.length;
7
7
  if (r < 1)
@@ -1,5 +1,5 @@
1
- import { B as h, C as f, F as p, M as g, E as u, N as b } from "./index-UdZhlibC.js";
2
- import { r as T } from "./reshape-BR0eoLYN.js";
1
+ import { B as h, C as f, F as p, M as g, E as u, N as b } from "./index-BoWRt-10.js";
2
+ import { r as T } from "./reshape-CdBq1WJ6.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { B as s, l as a, D as p, M as i, E as l, Q as f } from "./index-UdZhlibC.js";
1
+ import { B as s, l as a, D as p, M as i, E as l, Q as f } from "./index-BoWRt-10.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,7 +1,7 @@
1
- import { ag as S, T as h, ac as N, d as v, ah as o, ai as p, aj as g, l as k, t as y } from "./index-UdZhlibC.js";
1
+ import { ag as S, T as h, ac as N, d as v, ah as o, ai as p, aj as g, l as k, t as y } from "./index-BoWRt-10.js";
2
2
  import { s as R } from "./index-C4L8Cm77.js";
3
- import { s as $ } from "./stack-ikk2Y8_P.js";
4
- import { t as B } from "./tensor-Do9PKbIE.js";
3
+ import { s as $ } from "./stack-B17UN2nn.js";
4
+ import { t as B } from "./tensor-JwS7ZYY6.js";
5
5
  /**
6
6
  * @license
7
7
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { B as l, C as h, E as m, ak as p, F as c, al as d, ab as g, l as u, T as V, n as v, o as N, a as w } from "./index-UdZhlibC.js";
1
+ import { B as l, C as h, E as m, ak as p, F as c, al as d, ab as g, l as u, T as V, n as v, o as N, a as w } from "./index-BoWRt-10.js";
2
2
  import { s as f } from "./index-C4L8Cm77.js";
3
3
  /**
4
4
  * @license
@@ -1,4 +1,4 @@
1
- import { B as g, C as t, E as h, G as p } from "./index-UdZhlibC.js";
1
+ import { B as g, C as t, E as h, G as p } from "./index-BoWRt-10.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { i as t, e as n } from "./index-UdZhlibC.js";
1
+ import { i as t, e as n } from "./index-BoWRt-10.js";
2
2
  import "./ops/cpu/gelu.js";
3
3
  import "./ops/webgl/gelu.js";
4
4
  const a = {