@genai-fi/nanogpt 0.7.2 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. package/dist/Generator.d.ts +36 -4
  2. package/dist/Generator.js +183 -69
  3. package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-N8TpOMYv.js} +14 -14
  4. package/dist/{Reshape-DvudQDvJ.js → Reshape-B-lWQRnF.js} +1 -1
  5. package/dist/{Reshape-DH5srBP0.js → Reshape-Bo8HzP8V.js} +5 -5
  6. package/dist/TeachableLLM.d.ts +6 -6
  7. package/dist/TeachableLLM.js +51 -50
  8. package/dist/Trainer.d.ts +19 -3
  9. package/dist/Trainer.js +71 -28
  10. package/dist/{axis_util-BzbKo31C.js → axis_util-DubwyOhW.js} +3 -3
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-TE7aTPhZ.js → backend_util-BJ-_jSeK.js} +46 -46
  13. package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-BYfCp5iL.js} +2 -2
  14. package/dist/{concat-CsxrgovM.js → concat-BmDqqFsa.js} +1 -1
  15. package/dist/{dataset-CtdBYwjo.js → dataset-CJmEGu6D.js} +5 -5
  16. package/dist/{dropout-DYs5QFGQ.js → dropout-sx0sjVAT.js} +8 -8
  17. package/dist/exports_initializers-DAKM8UO9.js +16 -0
  18. package/dist/{gather-CMMy2KEG.js → gather-C1siEkdp.js} +1 -1
  19. package/dist/{gelu-C-dPj6Ku.js → gelu-Bd3UBBxg.js} +1 -1
  20. package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-TFLxaLkw.js} +26 -26
  21. package/dist/{index-CLthM0TO.js → index-BaPo_0H8.js} +185 -185
  22. package/dist/{index-BoWRt-10.js → index-CUQrfsw_.js} +266 -265
  23. package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-P9aFa232.js} +9 -9
  24. package/dist/layers/BaseLayer.d.ts +8 -13
  25. package/dist/layers/BaseLayer.js +25 -13
  26. package/dist/layers/CausalSelfAttention.d.ts +3 -2
  27. package/dist/layers/CausalSelfAttention.js +28 -28
  28. package/dist/layers/MLP.d.ts +3 -2
  29. package/dist/layers/MLP.js +16 -20
  30. package/dist/layers/PositionEmbedding.d.ts +9 -0
  31. package/dist/layers/PositionEmbedding.js +45 -0
  32. package/dist/layers/RMSNorm.d.ts +3 -2
  33. package/dist/layers/RMSNorm.js +6 -6
  34. package/dist/layers/RoPECache.d.ts +1 -1
  35. package/dist/layers/RoPECache.js +4 -4
  36. package/dist/layers/TiedEmbedding.d.ts +3 -2
  37. package/dist/layers/TiedEmbedding.js +29 -7
  38. package/dist/layers/TransformerBlock.d.ts +3 -2
  39. package/dist/layers/TransformerBlock.js +1 -1
  40. package/dist/loader/load.d.ts +2 -2
  41. package/dist/loader/loadHF.d.ts +2 -2
  42. package/dist/loader/loadTransformers.d.ts +4 -2
  43. package/dist/loader/loadTransformers.js +10 -9
  44. package/dist/loader/newZipLoad.d.ts +2 -2
  45. package/dist/loader/oldZipLoad.d.ts +2 -2
  46. package/dist/loader/oldZipLoad.js +42 -51
  47. package/dist/loader/save.d.ts +8 -0
  48. package/dist/loader/save.js +62 -0
  49. package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C142qZqY.js} +14 -14
  50. package/dist/main.d.ts +5 -4
  51. package/dist/main.js +22 -18
  52. package/dist/{mat_mul-8m8pfdcx.js → mat_mul-DMkduNJu.js} +1 -1
  53. package/dist/{max-Ddnnb5xe.js → max-B3JOcNGb.js} +1 -1
  54. package/dist/mod-uUuj4gSb.js +27 -0
  55. package/dist/models/NanoGPTV1.d.ts +15 -0
  56. package/dist/models/NanoGPTV1.js +71 -0
  57. package/dist/{config.d.ts → models/config.d.ts} +1 -0
  58. package/dist/{config.js → models/config.js} +1 -0
  59. package/dist/models/factory.d.ts +3 -0
  60. package/dist/models/factory.js +14 -0
  61. package/dist/models/model.d.ts +26 -0
  62. package/dist/models/model.js +68 -0
  63. package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-Cm2gw-c8.js} +1 -1
  64. package/dist/{ones-Dj0SDhHf.js → ones-ZdgQGBCP.js} +2 -2
  65. package/dist/ops/adamAdjust.js +1 -1
  66. package/dist/ops/adamMoments.js +1 -1
  67. package/dist/ops/appendCache.js +3 -3
  68. package/dist/ops/attentionMask.js +1 -1
  69. package/dist/ops/cpu/adamAdjust.js +9 -9
  70. package/dist/ops/cpu/adamMoments.js +2 -2
  71. package/dist/ops/cpu/appendCache.js +2 -2
  72. package/dist/ops/cpu/attentionMask.js +5 -5
  73. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  74. package/dist/ops/cpu/gatherSub.js +3 -3
  75. package/dist/ops/cpu/gelu.js +1 -1
  76. package/dist/ops/cpu/matMulGelu.js +2 -2
  77. package/dist/ops/cpu/matMulMul.js +1 -1
  78. package/dist/ops/cpu/mulDropout.js +1 -1
  79. package/dist/ops/cpu/normRMS.js +1 -1
  80. package/dist/ops/cpu/qkv.js +3 -3
  81. package/dist/ops/cpu/rope.js +5 -5
  82. package/dist/ops/cpu/scatterSub.js +11 -11
  83. package/dist/ops/fusedSoftmax.js +1 -1
  84. package/dist/ops/gatherSub.js +1 -1
  85. package/dist/ops/gelu.js +2 -2
  86. package/dist/ops/grads/attentionMask.js +1 -1
  87. package/dist/ops/grads/fusedSoftmax.js +2 -2
  88. package/dist/ops/grads/gelu.js +2 -2
  89. package/dist/ops/grads/matMulGelu.js +1 -1
  90. package/dist/ops/grads/normRMS.js +1 -1
  91. package/dist/ops/grads/qkv.js +1 -1
  92. package/dist/ops/grads/rope.js +1 -1
  93. package/dist/ops/matMulGelu.js +1 -1
  94. package/dist/ops/matMulMul.js +1 -1
  95. package/dist/ops/mulDrop.js +1 -1
  96. package/dist/ops/normRMS.js +1 -1
  97. package/dist/ops/qkv.js +1 -1
  98. package/dist/ops/rope.js +4 -4
  99. package/dist/ops/scatterSub.js +1 -1
  100. package/dist/ops/webgl/adamAdjust.js +2 -2
  101. package/dist/ops/webgl/adamMoments.js +1 -1
  102. package/dist/ops/webgl/appendCache.js +1 -1
  103. package/dist/ops/webgl/attentionMask.js +1 -1
  104. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  105. package/dist/ops/webgl/gatherSub.js +1 -1
  106. package/dist/ops/webgl/gelu.js +2 -2
  107. package/dist/ops/webgl/log.js +3 -3
  108. package/dist/ops/webgl/matMulGelu.js +10 -10
  109. package/dist/ops/webgl/matMulMul.js +1 -1
  110. package/dist/ops/webgl/mulDropout.js +1 -1
  111. package/dist/ops/webgl/normRMS.js +2 -2
  112. package/dist/ops/webgl/qkv.js +1 -1
  113. package/dist/ops/webgl/rope.js +1 -1
  114. package/dist/ops/webgl/scatterSub.js +1 -1
  115. package/dist/ops/webgpu/adamAdjust.js +3 -3
  116. package/dist/ops/webgpu/adamMoments.js +3 -3
  117. package/dist/ops/webgpu/appendCache.js +3 -3
  118. package/dist/ops/webgpu/attentionMask.js +3 -3
  119. package/dist/ops/webgpu/gatherSub.js +3 -3
  120. package/dist/ops/webgpu/gelu.js +3 -3
  121. package/dist/ops/webgpu/normRMS.js +2 -2
  122. package/dist/ops/webgpu/normRMSGrad.js +5 -5
  123. package/dist/ops/webgpu/qkv.js +3 -3
  124. package/dist/ops/webgpu/rope.js +3 -3
  125. package/dist/ops/webgpu/scatterSub.js +3 -3
  126. package/dist/ops/webgpu/utils/reductions.js +4 -4
  127. package/dist/{ops-BFGCx8Ri.js → ops-C_1K_-35.js} +103 -103
  128. package/dist/{random_width-sZORGo5k.js → random_width-D8Pwy_na.js} +136 -136
  129. package/dist/{range-CRuAh-gd.js → range-LVHrSLdi.js} +1 -1
  130. package/dist/{reciprocal-BvGAyKyu.js → reciprocal-CaR9e67G.js} +1 -1
  131. package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-DUshvVWP.js} +2026 -2049
  132. package/dist/{reshape-CdBq1WJ6.js → reshape-DEfQGSin.js} +1 -1
  133. package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-CUPPNLaA.js} +1 -1
  134. package/dist/{selu_util-BJEXVvjX.js → selu_util-8vv5JxQV.js} +3 -3
  135. package/dist/{shared-B8ztnyEk.js → shared-CkNorDcU.js} +83 -83
  136. package/dist/{shared-wS99K7_n.js → shared-D1elLckx.js} +1 -1
  137. package/dist/{sin-BeA3tsEd.js → sin-D2CKKmyR.js} +1 -1
  138. package/dist/{slice-BiOsknYS.js → slice-BnyE-M_7.js} +1 -1
  139. package/dist/{softmax-Bv_6lyMX.js → softmax-DLoZWYBx.js} +1 -1
  140. package/dist/{split-B-dikLRw.js → split-By_n4TKP.js} +1 -1
  141. package/dist/{stack-B17UN2nn.js → stack-DkdFLq37.js} +1 -1
  142. package/dist/{sum-66ew2byf.js → sum-l_0SqM4h.js} +3 -3
  143. package/dist/{tensor-JwS7ZYY6.js → tensor-BAQdLqoU.js} +1 -1
  144. package/dist/{tensor2d-wxPAnDQy.js → tensor2d-BHy261cI.js} +1 -1
  145. package/dist/training/Adam.js +2 -2
  146. package/dist/training/AdamExt.js +1 -1
  147. package/dist/training/DatasetBuilder.js +2 -2
  148. package/dist/training/Evaluator.d.ts +2 -2
  149. package/dist/training/FullTrainer.d.ts +16 -3
  150. package/dist/training/FullTrainer.js +91 -53
  151. package/dist/training/Trainer.d.ts +25 -3
  152. package/dist/training/Trainer.js +39 -47
  153. package/dist/training/sparseCrossEntropy.js +9 -9
  154. package/dist/utilities/dummy.d.ts +4 -4
  155. package/dist/utilities/dummy.js +13 -13
  156. package/dist/utilities/multinomialCPU.js +2 -2
  157. package/dist/utilities/parameters.d.ts +1 -1
  158. package/dist/utilities/performance.js +1 -1
  159. package/dist/utilities/profile.js +1 -1
  160. package/dist/utilities/safetensors.js +2 -2
  161. package/dist/utilities/weights.js +2 -2
  162. package/dist/{variable-BuddVFLa.js → variable-C9hihzDB.js} +1 -1
  163. package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-dFEVbDPL.js} +1 -1
  164. package/dist/{webgpu_util-D____QpY.js → webgpu_util-DLImlSc6.js} +27 -27
  165. package/dist/{zeros--BdLQ3oG.js → zeros-VZ72lWXM.js} +1 -1
  166. package/package.json +2 -3
  167. package/dist/NanoGPTModel.d.ts +0 -52
  168. package/dist/NanoGPTModel.js +0 -203
  169. package/dist/TiedEmbedding-BxOerUmB.js +0 -43
  170. package/dist/utilities/generate.d.ts +0 -3
  171. package/dist/utilities/generate.js +0 -22
  172. package/dist/utilities/save.d.ts +0 -9
  173. package/dist/utilities/save.js +0 -61
@@ -1,30 +1,21 @@
1
- import { defaultConfig as _ } from "./config.js";
2
- import f from "./NanoGPTModel.js";
3
- import { saveModel as u } from "./utilities/save.js";
4
- import { loadModel as d } from "./loader/load.js";
5
- import l from "./Generator.js";
6
- import p from "./Trainer.js";
7
- import { E as g } from "./index-Dwqa6Zy2.js";
1
+ import { defaultConfig as d } from "./models/config.js";
2
+ import { saveModel as l } from "./loader/save.js";
3
+ import { loadModel as _ } from "./loader/load.js";
4
+ import u from "./Generator.js";
5
+ import f from "./Trainer.js";
6
+ import { E as p } from "./index-Dwqa6Zy2.js";
8
7
  import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
9
- import c from "./tokeniser/CharTokeniser.js";
10
- import k from "./tokeniser/bpe.js";
11
- import "./papaparse.min-C8l2Kvo1.js";
12
- import "./index-Tf7vU29b.js";
13
- import "./jszip.min-CjP2V1VV.js";
14
- import "./index-BoWRt-10.js";
15
- import "./ops/cpu/scatterSub.js";
16
- import "./ops/webgl/scatterSub.js";
17
- import "./ops/cpu/gatherSub.js";
18
- import "./ops/webgl/gatherSub.js";
8
+ import "./index-CUQrfsw_.js";
19
9
  import "./ops/cpu/attentionMask.js";
20
10
  import "./ops/webgl/attentionMask.js";
21
11
  import "./ops/grads/attentionMask.js";
22
12
  import "./ops/cpu/qkv.js";
23
13
  import "./ops/webgl/qkv.js";
24
14
  import "./ops/grads/qkv.js";
25
- import "./random_width-sZORGo5k.js";
26
- import "./register_all_kernels-BwDSRN-f.js";
27
- import "./dataset-CtdBYwjo.js";
15
+ import "./random_width-D8Pwy_na.js";
16
+ import "./register_all_kernels-DUshvVWP.js";
17
+ import "./index-Tf7vU29b.js";
18
+ import "./dataset-CJmEGu6D.js";
28
19
  import "./ops/cpu/rope.js";
29
20
  import "./ops/webgl/rope.js";
30
21
  import "./ops/grads/rope.js";
@@ -36,20 +27,29 @@ import "./ops/grads/fusedSoftmax.js";
36
27
  import "./ops/cpu/matMulGelu.js";
37
28
  import "./ops/webgl/matMulGelu.js";
38
29
  import "./ops/grads/matMulGelu.js";
39
- import "./ops/cpu/gelu.js";
40
- import "./ops/webgl/gelu.js";
41
- import "./gelu-C-dPj6Ku.js";
42
30
  import "./ops/cpu/normRMS.js";
43
31
  import "./ops/webgl/normRMS.js";
44
32
  import "./ops/grads/normRMS.js";
33
+ import "./ops/cpu/gatherSub.js";
34
+ import "./ops/webgl/gatherSub.js";
35
+ import "./ops/cpu/scatterSub.js";
36
+ import "./ops/webgl/scatterSub.js";
37
+ import c from "./tokeniser/CharTokeniser.js";
38
+ import g from "./tokeniser/bpe.js";
39
+ import "./papaparse.min-C8l2Kvo1.js";
40
+ import "./jszip.min-CjP2V1VV.js";
41
+ import "./ops/cpu/gelu.js";
42
+ import "./ops/webgl/gelu.js";
43
+ import "./gelu-Bd3UBBxg.js";
45
44
  import "./ops/webgl/log.js";
46
45
  import "./ops/cpu/adamMoments.js";
47
46
  import "./ops/webgl/adamMoments.js";
48
47
  import "./ops/cpu/adamAdjust.js";
49
48
  import "./ops/webgl/adamAdjust.js";
50
- import w from "./utilities/profile.js";
49
+ import k from "./utilities/profile.js";
50
+ import w from "./models/factory.js";
51
51
  class a {
52
- ee = new g();
52
+ ee = new p();
53
53
  _config;
54
54
  _model;
55
55
  _tokeniser;
@@ -69,7 +69,7 @@ class a {
69
69
  get config() {
70
70
  if (!this._config)
71
71
  throw new Error("configuration_not_initialized.");
72
- return this._config.gpt;
72
+ return this._config;
73
73
  }
74
74
  get model() {
75
75
  if (!this._model)
@@ -92,8 +92,8 @@ class a {
92
92
  return this._status === "busy" || this._status === "training";
93
93
  }
94
94
  estimateTrainingMemoryUsage(t) {
95
- const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 }, i = e.perBatch * t, o = e.gradients;
96
- return i * 0.66 + o * 4;
95
+ const e = this._memoryRequirements ?? { perBatch: 0, gradients: 0 }, r = e.perBatch * t, o = e.gradients;
96
+ return r * 0.66 + o * 4;
97
97
  }
98
98
  setStatus(t) {
99
99
  this._status !== t && (this._status = t, this.ee.emit("status", t));
@@ -101,32 +101,32 @@ class a {
101
101
  saveModel(t) {
102
102
  if (!this._model || !this._tokeniser)
103
103
  throw new Error("model_or_tokeniser_not_initialized.");
104
- return u(this._model, this._tokeniser, {
104
+ return l(this._model, this._tokeniser, {
105
105
  ...t,
106
106
  name: t?.name || this.meta.name
107
107
  });
108
108
  }
109
109
  static loadModel(t) {
110
110
  const e = new a();
111
- return d(t).then(({ model: i, tokeniser: o, name: s }) => {
112
- e._model = i, e._tokeniser = o, e._config = i.config, s && (e.meta.name = s), e.setStatus("warmup"), m(i).then((r) => {
113
- e._memoryRequirements = r, e.setStatus("ready"), e.ee.emit("loaded");
114
- }).catch((r) => {
115
- e.setStatus("error"), e.ee.emit("error", r);
111
+ return _(t).then(({ model: r, tokeniser: o, name: s }) => {
112
+ e._model = r, e._tokeniser = o, e._config = r.config, s && (e.meta.name = s), e.setStatus("warmup"), m(r).then((i) => {
113
+ e._memoryRequirements = i, e.setStatus("ready"), e.ee.emit("loaded");
114
+ }).catch((i) => {
115
+ e.setStatus("error"), e.ee.emit("error", i);
116
116
  });
117
- }).catch((i) => {
118
- e.setStatus("error"), e.ee.emit("error", i);
117
+ }).catch((r) => {
118
+ e.setStatus("error"), e.ee.emit("error", r);
119
119
  }), e;
120
120
  }
121
121
  static create(t, e = {}) {
122
- const i = { ..._, ...e }, o = t === "char" ? new c(i.vocabSize) : new k(i.vocabSize), s = new f(i), r = new a(o, s);
123
- return r.setStatus("warmup"), m(s).then((n) => {
124
- r._memoryRequirements = n, r.tokeniser.trained ? (r.setStatus("ready"), r.ee.emit("loaded")) : (r.setStatus("awaitingTokens"), r.ee.emit("loaded"), r.tokeniser.once("trainStatus", (h) => {
125
- h === "trained" && r.setStatus("ready");
122
+ const r = { ...d, ...e }, o = t === "char" ? new c(r.vocabSize) : new g(r.vocabSize), s = w(r), i = new a(o, s);
123
+ return i.setStatus("warmup"), m(s).then((n) => {
124
+ i._memoryRequirements = n, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (h) => {
125
+ h === "trained" && i.setStatus("ready");
126
126
  }));
127
127
  }).catch((n) => {
128
- r.setStatus("error"), r.ee.emit("error", n);
129
- }), r;
128
+ i.setStatus("error"), i.ee.emit("error", n);
129
+ }), i;
130
130
  }
131
131
  getProfiler() {
132
132
  return this._model?.getProfiler();
@@ -138,9 +138,9 @@ class a {
138
138
  if (t) {
139
139
  if (!this._config)
140
140
  return;
141
- this._config.layerConfig.profiler || (this._config.layerConfig.profiler = new w());
141
+ this.model.getProfiler() || this.model.setProfiler(new k());
142
142
  } else
143
- this._config?.layerConfig.profiler && (this._config.layerConfig.profiler = void 0);
143
+ this.model.getProfiler() && this.model.setProfiler(null);
144
144
  }
145
145
  getNumParams() {
146
146
  return this._model ? this._model.getNumParams() : 0;
@@ -148,15 +148,16 @@ class a {
148
148
  trainer() {
149
149
  if (!this._model || !this._tokeniser)
150
150
  throw new Error("model_or_tokeniser_not_initialized.");
151
- const t = new p(this._model, this._tokeniser);
152
- return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, i) => {
151
+ const t = new f(this._model, this._tokeniser);
152
+ return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
153
153
  const o = this.ee.listeners("trainStep");
154
154
  for (const s of o)
155
- await s(e, i);
155
+ await s(e, r);
156
156
  }), t;
157
157
  }
158
- train(t, e) {
159
- return this.trainer().train(t, e);
158
+ async train(t, e) {
159
+ const r = this.trainer();
160
+ await r.prepare(t, e), await r.train(e);
160
161
  }
161
162
  async trainTokeniser(t) {
162
163
  if (!this._tokeniser)
@@ -167,7 +168,7 @@ class a {
167
168
  generator() {
168
169
  if (!this._model || !this._tokeniser)
169
170
  throw new Error("model_or_tokeniser_not_initialized.");
170
- const t = new l(this._model, this._tokeniser);
171
+ const t = new u(this._model, this._tokeniser);
171
172
  return t.on("start", () => {
172
173
  this.status === "ready" && this.setStatus("busy");
173
174
  }), t.on("stop", () => {
package/dist/Trainer.d.ts CHANGED
@@ -1,6 +1,7 @@
1
- import { default as NanoGPT } from './NanoGPTModel';
2
1
  import { ITokeniser } from './tokeniser/type';
3
2
  import { default as EE } from 'eventemitter3';
3
+ import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
4
+ import { default as Model, ModelForwardAttributes } from './models/model';
4
5
  export interface ITrainerOptions {
5
6
  batchSize?: number;
6
7
  learningRate?: number;
@@ -10,12 +11,27 @@ export interface ITrainerOptions {
10
11
  prompt?: string;
11
12
  validationSplit?: number;
12
13
  advancedMetrics?: boolean;
14
+ gradientCheckpointing?: boolean;
15
+ }
16
+ interface ExtendedTrainingProgress extends TrainingProgress {
17
+ progress: number;
18
+ remaining: number;
13
19
  }
14
20
  export default class Trainer extends EE<'start' | 'stop' | 'log'> {
15
21
  private trainer;
16
22
  private hasTrained;
17
- constructor(model: NanoGPT, tokeniser: ITokeniser);
23
+ private trainDataset?;
24
+ private validationDataset?;
25
+ private totalSamples;
26
+ private log;
27
+ private progress;
28
+ constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
18
29
  stop(): void;
19
30
  reset(): void;
20
- train(text: string[], options?: ITrainerOptions): Promise<void>;
31
+ prepare(text: string[], options?: ITrainerOptions): Promise<void>;
32
+ train(options?: ITrainerOptions): Promise<void>;
33
+ step(options?: ITrainerOptions): Promise<void>;
34
+ getLog(): TrainingLogEntry[];
35
+ getProgress(): ExtendedTrainingProgress | null;
21
36
  }
37
+ export {};
package/dist/Trainer.js CHANGED
@@ -1,48 +1,91 @@
1
- import { E as h } from "./index-Dwqa6Zy2.js";
2
- import m from "./training/FullTrainer.js";
3
- class p extends h {
1
+ import { E as l } from "./index-Dwqa6Zy2.js";
2
+ import h from "./training/FullTrainer.js";
3
+ class m extends l {
4
4
  trainer;
5
5
  hasTrained = !1;
6
- constructor(e, t) {
7
- super(), this.trainer = new m(e, t, 1e-3);
6
+ trainDataset;
7
+ validationDataset;
8
+ totalSamples = 0;
9
+ log = [];
10
+ progress = null;
11
+ constructor(t, e) {
12
+ super(), this.trainer = new h(t, e, 1e-3);
8
13
  }
9
14
  stop() {
10
15
  this.trainer.stop();
11
16
  }
12
17
  reset() {
13
- this.hasTrained = !1, this.trainer.reset();
14
- }
15
- async train(e, t) {
16
- const { trainDataset: s, validationDataset: n } = await this.trainer.createTrainValidationSplit(
17
- e,
18
- t?.batchSize || 32,
19
- t?.validationSplit || 0.1
20
- ), r = e.reduce((i, a) => i + a.length, 0) * (1 - (t?.validationSplit || 0));
21
- this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), await this.trainer.trainOnDataset(
22
- s,
18
+ this.hasTrained = !1, this.log = [], this.trainer.reset();
19
+ }
20
+ async prepare(t, e) {
21
+ const { trainDataset: a, validationDataset: s } = await this.trainer.createTrainValidationSplit(
22
+ t,
23
+ e?.batchSize || 32,
24
+ e?.validationSplit || 0.1
25
+ ), i = t.reduce((r, n) => r + n.length, 0) * (1 - (e?.validationSplit || 0));
26
+ this.trainDataset = a, this.validationDataset = s, this.totalSamples = i;
27
+ }
28
+ async train(t) {
29
+ if (!this.trainDataset || !this.validationDataset)
30
+ throw new Error("Datasets not prepared");
31
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), await this.trainer.trainOnDataset(
32
+ this.trainDataset,
23
33
  {
24
34
  prompt: t?.prompt,
25
35
  logInterval: t?.logInterval || 10,
26
36
  desiredLoss: t?.desiredLoss || 0.01,
27
37
  maxSteps: t?.maxSteps || 1e3,
28
38
  advancedMetrics: t?.advancedMetrics || !1,
29
- onStep: async (i, a) => {
30
- const l = this.listeners("log");
31
- for (const d of l)
32
- await d(i, {
33
- ...a,
34
- progress: a.totalSamples / r,
35
- remaining: Math.max(
36
- 0,
37
- (r - a.totalSamples) / a.totalSamples * a.duration
38
- )
39
- });
39
+ onStep: async (e, a) => {
40
+ this.log.push(e), this.progress = {
41
+ ...a,
42
+ progress: a.totalSamples / this.totalSamples,
43
+ remaining: Math.max(
44
+ 0,
45
+ (this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
46
+ )
47
+ };
48
+ const s = this.listeners("log");
49
+ for (const i of s)
50
+ await i(e, this.progress);
40
51
  }
41
52
  },
42
- n
53
+ this.validationDataset
43
54
  ), this.emit("stop");
44
55
  }
56
+ async step(t) {
57
+ if (!this.trainDataset || !this.validationDataset)
58
+ throw new Error("Datasets not prepared");
59
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
60
+ const { log: e, progress: a } = await this.trainer.stepDataset(
61
+ this.trainDataset,
62
+ {
63
+ prompt: t?.prompt,
64
+ logInterval: t?.logInterval || 10,
65
+ desiredLoss: t?.desiredLoss || 0.01,
66
+ maxSteps: t?.maxSteps || 1e3,
67
+ advancedMetrics: t?.advancedMetrics || !1
68
+ },
69
+ this.validationDataset
70
+ ), s = this.listeners("log");
71
+ for (const i of s)
72
+ await i(e, {
73
+ ...a,
74
+ progress: a.totalSamples / this.totalSamples,
75
+ remaining: Math.max(
76
+ 0,
77
+ (this.totalSamples - a.totalSamples) / a.totalSamples * a.duration
78
+ )
79
+ });
80
+ this.emit("stop");
81
+ }
82
+ getLog() {
83
+ return this.log;
84
+ }
85
+ getProgress() {
86
+ return this.progress;
87
+ }
45
88
  }
46
89
  export {
47
- p as default
90
+ m as default
48
91
  };
@@ -1,4 +1,4 @@
1
- import { l as c } from "./index-BoWRt-10.js";
1
+ import { n as c } from "./index-CUQrfsw_.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2017 Google LLC. All Rights Reserved.
@@ -28,7 +28,7 @@ function a(e, n, t) {
28
28
  t.indexOf(u) === -1 ? s.push(e[o++]) : s.push(n[f++]);
29
29
  return s;
30
30
  }
31
- function p(e, n) {
31
+ function l(e, n) {
32
32
  const t = [], r = e.length;
33
33
  for (let o = 0; o < r; o++)
34
34
  n.indexOf(o) === -1 && t.push(e[o]);
@@ -62,7 +62,7 @@ function x(e, n) {
62
62
  export {
63
63
  x as a,
64
64
  m as b,
65
- p as c,
65
+ l as c,
66
66
  i as d,
67
67
  h as e,
68
68
  a as f,
package/dist/backend.js CHANGED
@@ -1,6 +1,6 @@
1
- import { g as a, s as i, r as o } from "./index-BoWRt-10.js";
1
+ import { g as a, s as i, r as o } from "./index-CUQrfsw_.js";
2
2
  async function e(t) {
3
- a() !== t && (t === "webgpu" && (await import("./index-CLthM0TO.js"), await import("./ops/webgpu/index.js")), await i(t), await o(), console.log(`Backend set to ${t}`));
3
+ a() !== t && (t === "webgpu" && (await import("./index-BaPo_0H8.js"), await import("./ops/webgpu/index.js")), await i(t), await o(), console.log(`Backend set to ${t}`));
4
4
  }
5
5
  export {
6
6
  e as selectBackend
@@ -1,7 +1,7 @@
1
- import { j as m, a1 as O, l as g, aK as $, aL as R, aM as M, k as _, aa as y, aw as D, aN as T, u as b, aO as F } from "./index-BoWRt-10.js";
2
- import { b as L, d as W, f as v, c as N, e as x, g as P, a as C, h as z } from "./axis_util-BzbKo31C.js";
3
- import { S as U, a as B, b as V, c as j, d as k, e as G, f as H, g as q, h as Z, i as K, j as X, k as J, l as Y, m as Q, s as ee, n as te, o as ne, t as se } from "./selu_util-BJEXVvjX.js";
4
- import { c as re, v as oe, a as ae } from "./scatter_nd_util-DUstGbU1.js";
1
+ import { j as m, a2 as O, n as g, aM as $, aN as R, aO as M, l as _, ad as y, ay as D, aP as T, u as b, aQ as F } from "./index-CUQrfsw_.js";
2
+ import { b as L, d as W, f as v, c as N, e as x, g as P, a as C, h as z } from "./axis_util-DubwyOhW.js";
3
+ import { S as U, a as B, b as V, c as j, d as G, e as H, f as k, g as q, h as Z, i as X, j as J, k as K, l as Q, m as Y, s as ee, n as te, o as ne, t as se } from "./selu_util-8vv5JxQV.js";
4
+ import { c as re, v as oe, a as ae } from "./scatter_nd_util-CUPPNLaA.js";
5
5
  function ie(e, n) {
6
6
  const r = e.shape.length, t = n.shape.length;
7
7
  if (r < 1)
@@ -233,7 +233,7 @@ function Ie(e, n) {
233
233
  r.push(e[t][0]);
234
234
  return r;
235
235
  }
236
- function we(e, n, r) {
236
+ function Se(e, n, r) {
237
237
  const t = e.slice(0, 1);
238
238
  for (let s = 0; s < r; ++s)
239
239
  t.push(e[s + 1] - n[s][0] - n[s][1]);
@@ -255,7 +255,7 @@ function we(e, n, r) {
255
255
  * limitations under the License.
256
256
  * =============================================================================
257
257
  */
258
- const Se = 0.3275911, Ae = 0.254829592, Oe = -0.284496736, Re = 1.421413741, Me = -1.453152027, _e = 1.061405429;
258
+ const we = 0.3275911, Ae = 0.254829592, Oe = -0.284496736, Re = 1.421413741, Me = -1.453152027, _e = 1.061405429;
259
259
  /**
260
260
  * @license
261
261
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -333,7 +333,7 @@ function ve(e, n, r) {
333
333
  * limitations under the License.
334
334
  * =============================================================================
335
335
  */
336
- const E = "->", Ne = /->/g, w = ",", S = "...";
336
+ const E = "->", Ne = /->/g, S = ",", w = "...";
337
337
  function xe(e, n) {
338
338
  e = e.replace(/\s/g, "");
339
339
  const r = (e.length - e.replace(Ne, "").length) / E.length;
@@ -342,8 +342,8 @@ function xe(e, n) {
342
342
  if (r > 1)
343
343
  throw new Error(`Equation must contain exactly one arrow ("${E}").`);
344
344
  const [t, s] = e.split(E);
345
- g(t.indexOf(S) === -1, () => `The ellipsis notation ("${S}") is not supported yet.`);
346
- const o = t.split(w), a = o.length;
345
+ g(t.indexOf(w) === -1, () => `The ellipsis notation ("${w}") is not supported yet.`);
346
+ const o = t.split(S), a = o.length;
347
347
  if (n !== a)
348
348
  throw new Error(`Expected ${a} input tensors, received ${n}`);
349
349
  if (a > 2)
@@ -357,7 +357,7 @@ function xe(e, n) {
357
357
  }
358
358
  for (let l = 0; l < t.length; ++l) {
359
359
  const f = t[l];
360
- u.indexOf(f) === -1 && f !== w && u.push(f);
360
+ u.indexOf(f) === -1 && f !== S && u.push(f);
361
361
  }
362
362
  const c = new Array(o.length);
363
363
  for (let l = 0; l < a; ++l) {
@@ -449,10 +449,10 @@ function je(e) {
449
449
  return `Received SparseTensor with denseShape[0] = 0 but
450
450
  indices.shape[0] = ${e}`;
451
451
  }
452
- function ke(e, n) {
452
+ function Ge(e, n) {
453
453
  return `indices(${e}, 0) is invalid: ${n} < 0`;
454
454
  }
455
- function Ge(e, n, r) {
455
+ function He(e, n, r) {
456
456
  return `indices(${e}, 0) is invalid: ${n} >= ${r}`;
457
457
  }
458
458
  /**
@@ -471,7 +471,7 @@ function Ge(e, n, r) {
471
471
  * limitations under the License.
472
472
  * =============================================================================
473
473
  */
474
- function He(e, n) {
474
+ function ke(e, n) {
475
475
  return `only one output dimension may be -1, not both ${e} and ${n}`;
476
476
  }
477
477
  function qe(e, n) {
@@ -480,12 +480,12 @@ function qe(e, n) {
480
480
  function Ze() {
481
481
  return "reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero";
482
482
  }
483
- function Ke(e, n) {
483
+ function Xe(e, n) {
484
484
  const r = m(e), t = m(n);
485
485
  return `Input to reshape is a SparseTensor with ${r}
486
486
  dense values, but the requested shape requires a multiple of ${t}. inputShape=${e} outputShape= ${n}`;
487
487
  }
488
- function Xe(e, n) {
488
+ function Je(e, n) {
489
489
  const r = m(e), t = m(n);
490
490
  return `Input to reshape is a tensor with ${r} dense values, but the requested shape has ${t}. inputShape=${e} outputShape=${n}`;
491
491
  }
@@ -505,13 +505,13 @@ function Xe(e, n) {
505
505
  * limitations under the License.
506
506
  * =============================================================================
507
507
  */
508
- function Je() {
508
+ function Ke() {
509
509
  return "segment ids must be >= 0";
510
510
  }
511
- function Ye() {
511
+ function Qe() {
512
512
  return "segment ids are not increasing";
513
513
  }
514
- function Qe(e, n) {
514
+ function Ye(e, n) {
515
515
  return `Segment id ${e} out of range [0, ${n}), possibly because segmentIds input is not sorted.`;
516
516
  }
517
517
  function et(e, n, r) {
@@ -608,7 +608,7 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
608
608
  ERF_A3: Re,
609
609
  ERF_A4: Me,
610
610
  ERF_A5: _e,
611
- ERF_P: Se,
611
+ ERF_P: we,
612
612
  PARALLELIZE_THRESHOLD: I,
613
613
  get RowPartitionType() {
614
614
  return p;
@@ -628,18 +628,18 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
628
628
  combineRaggedTensorToTensorShapes: ce,
629
629
  complexWithEvenIndex: Te,
630
630
  complexWithOddIndex: be,
631
- computeConv2DInfo: k,
632
- computeConv3DInfo: G,
633
- computeDefaultPad: H,
631
+ computeConv2DInfo: G,
632
+ computeConv3DInfo: H,
633
+ computeDefaultPad: k,
634
634
  computeDilation2DInfo: q,
635
635
  computeOptimalWindowSize: ge,
636
636
  computeOutAndReduceShapes: N,
637
637
  computeOutShape: le,
638
638
  computePool2DInfo: Z,
639
- computePool3DInfo: K,
640
- convertConv2DDataFormat: X,
639
+ computePool3DInfo: X,
640
+ convertConv2DDataFormat: J,
641
641
  decodeEinsumEquation: xe,
642
- eitherStridesOrDilationsAreOne: J,
642
+ eitherStridesOrDilationsAreOne: K,
643
643
  expandShapeToKeepDim: x,
644
644
  exponent: ve,
645
645
  exponents: We,
@@ -650,8 +650,8 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
650
650
  getComplexWithIndex: Fe,
651
651
  getEinsumComputePath: ze,
652
652
  getEinsumPermutation: Pe,
653
- getFusedBiasGradient: Y,
654
- getFusedDyActivation: Q,
653
+ getFusedBiasGradient: Q,
654
+ getFusedDyActivation: Y,
655
655
  getImageCenter: de,
656
656
  getInnerMostAxes: C,
657
657
  getPermuted: Ee,
@@ -661,19 +661,19 @@ const ht = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
661
661
  getReshapedPermuted: $e,
662
662
  getRowPartitionTypesHelper: he,
663
663
  getSliceBeginCoords: Ie,
664
- getSliceSize: we,
664
+ getSliceSize: Se,
665
665
  getSparseFillEmptyRowsIndicesDenseShapeMismatch: je,
666
- getSparseFillEmptyRowsNegativeIndexErrorMessage: ke,
667
- getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: Ge,
666
+ getSparseFillEmptyRowsNegativeIndexErrorMessage: Ge,
667
+ getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: He,
668
668
  getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: Ze,
669
- getSparseReshapeInputOutputMismatchErrorMessage: Xe,
670
- getSparseReshapeInputOutputMultipleErrorMessage: Ke,
671
- getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: He,
669
+ getSparseReshapeInputOutputMismatchErrorMessage: Je,
670
+ getSparseReshapeInputOutputMultipleErrorMessage: Xe,
671
+ getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: ke,
672
672
  getSparseReshapeNegativeOutputDimErrorMessage: qe,
673
673
  getSparseSegmentReductionIndicesOutOfRangeErrorMessage: et,
674
- getSparseSegmentReductionNegativeSegmentIdsErrorMessage: Je,
675
- getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: Ye,
676
- getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: Qe,
674
+ getSparseSegmentReductionNegativeSegmentIdsErrorMessage: Ke,
675
+ getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: Qe,
676
+ getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: Ye,
677
677
  getUndoAxesPermutation: z,
678
678
  isIdentityPermutation: Ue,
679
679
  log: T,
@@ -697,8 +697,8 @@ export {
697
697
  Ee as B,
698
698
  $e as C,
699
699
  Ie as D,
700
- Se as E,
701
- we as F,
700
+ we as E,
701
+ Se as F,
702
702
  le as G,
703
703
  ue as H,
704
704
  xe as I,
@@ -728,17 +728,17 @@ export {
728
728
  ot as f,
729
729
  he as g,
730
730
  je as h,
731
- ke as i,
732
- Ge as j,
733
- He as k,
731
+ Ge as i,
732
+ He as j,
733
+ ke as k,
734
734
  qe as l,
735
735
  ye as m,
736
736
  Ze as n,
737
- Ke as o,
738
- Xe as p,
739
- Je as q,
740
- Ye as r,
741
- Qe as s,
737
+ Xe as o,
738
+ Je as p,
739
+ Ke as q,
740
+ Qe as r,
741
+ Ye as s,
742
742
  et as t,
743
743
  Ae as u,
744
744
  pe as v,
@@ -1,5 +1,5 @@
1
- import { B as h, C as f, F as p, M as g, E as u, N as b } from "./index-BoWRt-10.js";
2
- import { r as T } from "./reshape-CdBq1WJ6.js";
1
+ import { B as h, C as f, L as p, F as g, E as u, W as b } from "./index-CUQrfsw_.js";
2
+ import { r as T } from "./reshape-DEfQGSin.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,4 +1,4 @@
1
- import { B as s, l as a, D as p, M as i, E as l, Q as f } from "./index-BoWRt-10.js";
1
+ import { B as s, n as a, D as p, F as i, E as l, H as f } from "./index-CUQrfsw_.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -1,7 +1,7 @@
1
- import { ag as S, T as h, ac as N, d as v, ah as o, ai as p, aj as g, l as k, t as y } from "./index-BoWRt-10.js";
1
+ import { ai as S, T as h, af as k, d as v, aj as o, ak as p, al as g, n as N, t as y } from "./index-CUQrfsw_.js";
2
2
  import { s as R } from "./index-C4L8Cm77.js";
3
- import { s as $ } from "./stack-B17UN2nn.js";
4
- import { t as B } from "./tensor-JwS7ZYY6.js";
3
+ import { s as $ } from "./stack-DkdFLq37.js";
4
+ import { t as B } from "./tensor-BAQdLqoU.js";
5
5
  /**
6
6
  * @license
7
7
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -75,7 +75,7 @@ function I(s) {
75
75
  }
76
76
  function c(s) {
77
77
  let t = !1;
78
- if (N().get("IS_BROWSER"))
78
+ if (k().get("IS_BROWSER"))
79
79
  t = s instanceof TextDecoder;
80
80
  else {
81
81
  const { StringDecoder: e } = require("string_decoder");
@@ -930,7 +930,7 @@ class T {
930
930
  */
931
931
  batch(t, e = !0) {
932
932
  const r = this;
933
- k(t > 0, () => `batchSize needs to be positive, but it is
933
+ N(t > 0, () => `batchSize needs to be positive, but it is
934
934
  ${t}`);
935
935
  let n;
936
936
  return this.size === 1 / 0 || this.size == null ? n = this.size : e ? n = Math.ceil(this.size / t) : n = Math.floor(this.size / t), u(async () => (await r.iterator()).columnMajorBatch(t, e, st), n);