@genai-fi/nanogpt 0.17.4 → 0.18.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. package/dist/Generator.d.ts +2 -15
  2. package/dist/Generator.js +45 -34
  3. package/dist/{RealDiv-CGwv0liw.js → RealDiv-ioj6Z-ox.js} +9 -9
  4. package/dist/{Reshape-BW__R4mZ.js → Reshape-BZC-ebeR.js} +7 -7
  5. package/dist/{Reshape-CPBkTIH2.js → Reshape-pwprEaej.js} +1 -1
  6. package/dist/TeachableLLM.d.ts +3 -8
  7. package/dist/TeachableLLM.js +61 -44
  8. package/dist/Trainer.d.ts +6 -4
  9. package/dist/Trainer.js +107 -92
  10. package/dist/{axis_util-GTVlo58H.js → axis_util-QWWgLjut.js} +1 -1
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-GaFarB78.js → backend_util-qwSFfxYx.js} +21 -21
  13. package/dist/{backend_webgpu-BqASlsbV.js → backend_webgpu-DI2wXEC2.js} +8 -8
  14. package/dist/{broadcast_to-eS93CCN_.js → broadcast_to-C_EJTVTZ.js} +2 -2
  15. package/dist/checks/appendCache.js +2 -2
  16. package/dist/checks/attentionMask.js +5 -5
  17. package/dist/checks/gelu.js +2 -2
  18. package/dist/checks/matMulGelu.js +2 -2
  19. package/dist/checks/normRMS.js +6 -6
  20. package/dist/checks/normRMSGrad.js +3 -3
  21. package/dist/checks/packUnpack.js +6 -6
  22. package/dist/checks/qkv.js +2 -2
  23. package/dist/checks/rope.js +2 -2
  24. package/dist/{clip_by_value-DDA7rrcT.js → clip_by_value-CLAD4h_I.js} +1 -1
  25. package/dist/complex-3DpPEG9B.js +11 -0
  26. package/dist/{concat-CAQpCret.js → concat-Dqk7Xk7h.js} +5 -5
  27. package/dist/{concat_util-D18dJ4fD.js → concat_util-C1Mxe27t.js} +1 -1
  28. package/dist/{dataset-CGGp1z9P.js → dataset-DlqAN81i.js} +3 -3
  29. package/dist/{dropout_util--NxWuYg2.js → dropout_util-N0z8Os-K.js} +1 -1
  30. package/dist/{expand_dims-Bkd1YD5x.js → expand_dims-D0rBtgT1.js} +4 -4
  31. package/dist/{exports_initializers-CYzKLjN7.js → exports_initializers-DIOZQt_L.js} +1 -1
  32. package/dist/{floor-BQtb-Azg.js → floor-CymuCmTO.js} +1 -1
  33. package/dist/{gather-qIqEqaGn.js → gather-DEyjXNb1.js} +1 -1
  34. package/dist/{gelu-B220X1Go.js → gelu-DpTCC3eB.js} +1 -1
  35. package/dist/{gpgpu_math-BwvV12df.js → gpgpu_math-3bCb5ooU.js} +25 -25
  36. package/dist/{index-CjOWnMXP.js → index-BQvB7LCC.js} +15 -15
  37. package/dist/{index-CUXkjxiT.js → index-DSGwv2Yx.js} +33 -33
  38. package/dist/inference/types.d.ts +16 -0
  39. package/dist/inference/types.js +1 -0
  40. package/dist/{kernel_funcs_utils-pq0CK9co.js → kernel_funcs_utils-DGqzNlHT.js} +6 -6
  41. package/dist/layers/BaseLayer.js +4 -4
  42. package/dist/layers/CausalSelfAttention.js +6 -6
  43. package/dist/layers/LoRA.js +4 -4
  44. package/dist/layers/MLP.js +4 -4
  45. package/dist/layers/PositionEmbedding.js +5 -5
  46. package/dist/layers/RMSNorm.js +3 -3
  47. package/dist/layers/RoPECache.js +4 -4
  48. package/dist/layers/TiedEmbedding.js +6 -6
  49. package/dist/layers/TransformerBlock.js +1 -1
  50. package/dist/layers/WeightStore.js +2 -2
  51. package/dist/loader/load.d.ts +2 -8
  52. package/dist/loader/loadTransformers.d.ts +2 -8
  53. package/dist/loader/loadTransformers.js +13 -11
  54. package/dist/loader/newZipLoad.d.ts +2 -8
  55. package/dist/loader/newZipLoad.js +25 -10
  56. package/dist/loader/oldZipLoad.js +13 -13
  57. package/dist/loader/save.d.ts +9 -2
  58. package/dist/loader/save.js +64 -55
  59. package/dist/loader/types.d.ts +29 -1
  60. package/dist/main.d.ts +2 -0
  61. package/dist/main.js +45 -43
  62. package/dist/{matMul16-BcVC_E62.js → matMul16-BIT70Vya.js} +3 -3
  63. package/dist/{matMulGelu-JNLZqKQp.js → matMulGelu-CsZnh18H.js} +18 -18
  64. package/dist/mat_mul-DP86qZtZ.js +11 -0
  65. package/dist/mod-BXjLYwvM.js +11 -0
  66. package/dist/models/NanoGPTV1.js +2 -2
  67. package/dist/models/NanoGPTV2.js +2 -2
  68. package/dist/models/model.d.ts +3 -2
  69. package/dist/models/model.js +13 -13
  70. package/dist/{not_equal-hurPF26l.js → not_equal-CkQKkKZy.js} +15 -15
  71. package/dist/{ones-BytntneX.js → ones-DbVB5N58.js} +3 -3
  72. package/dist/ops/adamAdjust.js +3 -3
  73. package/dist/ops/adamMoments.js +3 -3
  74. package/dist/ops/add16.js +1 -1
  75. package/dist/ops/appendCache.js +6 -6
  76. package/dist/ops/attentionMask.js +3 -3
  77. package/dist/ops/concat16.js +3 -3
  78. package/dist/ops/cpu/adamAdjust.js +9 -9
  79. package/dist/ops/cpu/adamMoments.js +5 -5
  80. package/dist/ops/cpu/appendCache.js +2 -2
  81. package/dist/ops/cpu/attentionMask.js +6 -6
  82. package/dist/ops/cpu/fusedSoftmax.js +4 -4
  83. package/dist/ops/cpu/gatherSub.js +5 -5
  84. package/dist/ops/cpu/gelu.js +4 -4
  85. package/dist/ops/cpu/matMul16.js +2 -2
  86. package/dist/ops/cpu/matMulGelu.js +7 -7
  87. package/dist/ops/cpu/matMulMul.js +2 -2
  88. package/dist/ops/cpu/mulDropout.js +5 -5
  89. package/dist/ops/cpu/normRMS.js +1 -1
  90. package/dist/ops/cpu/qkv.js +3 -3
  91. package/dist/ops/cpu/rope.js +5 -5
  92. package/dist/ops/cpu/scatterSub.js +5 -5
  93. package/dist/ops/dot16.js +2 -2
  94. package/dist/ops/dropout.js +6 -6
  95. package/dist/ops/dropout16.js +1 -1
  96. package/dist/ops/gatherSub.js +1 -1
  97. package/dist/ops/gelu.js +2 -2
  98. package/dist/ops/globalNorm.js +7 -7
  99. package/dist/ops/grads/add16.js +1 -1
  100. package/dist/ops/grads/attentionMask.js +2 -2
  101. package/dist/ops/grads/dropout16.js +1 -1
  102. package/dist/ops/grads/gelu.js +2 -2
  103. package/dist/ops/grads/matMul16.js +3 -3
  104. package/dist/ops/grads/matMulGelu.js +1 -1
  105. package/dist/ops/grads/mul16.js +1 -1
  106. package/dist/ops/grads/normRMS.js +7 -7
  107. package/dist/ops/grads/pack16.js +3 -3
  108. package/dist/ops/grads/qkv.js +11 -11
  109. package/dist/ops/grads/rope.js +2 -2
  110. package/dist/ops/grads/softmax16.js +1 -1
  111. package/dist/ops/grads/unpack16.js +2 -2
  112. package/dist/ops/matMul16.js +3 -3
  113. package/dist/ops/matMulGelu.js +6 -6
  114. package/dist/ops/matMulMul.js +3 -3
  115. package/dist/ops/mul16.js +1 -1
  116. package/dist/ops/mulDrop.js +3 -3
  117. package/dist/ops/normRMS.js +4 -4
  118. package/dist/ops/pack16.js +2 -2
  119. package/dist/ops/qkv.js +3 -3
  120. package/dist/ops/reshape16.js +6 -6
  121. package/dist/ops/rope.js +2 -2
  122. package/dist/ops/scatterSub.js +1 -1
  123. package/dist/ops/slice16.js +2 -2
  124. package/dist/ops/softmax16.js +1 -1
  125. package/dist/ops/sub16.js +1 -1
  126. package/dist/ops/sum16.js +6 -6
  127. package/dist/ops/transpose16.js +3 -3
  128. package/dist/ops/unpack16.js +2 -2
  129. package/dist/ops/webgl/adamAdjust.js +2 -2
  130. package/dist/ops/webgl/adamMoments.js +1 -1
  131. package/dist/ops/webgl/appendCache.js +1 -1
  132. package/dist/ops/webgl/attentionMask.js +1 -1
  133. package/dist/ops/webgl/dropout16.js +1 -1
  134. package/dist/ops/webgl/fusedSoftmax.js +7 -7
  135. package/dist/ops/webgl/gatherSub.js +3 -3
  136. package/dist/ops/webgl/gelu.js +2 -2
  137. package/dist/ops/webgl/log.js +3 -3
  138. package/dist/ops/webgl/matMul16.js +13 -13
  139. package/dist/ops/webgl/matMulGelu.js +4 -4
  140. package/dist/ops/webgl/matMulMul.js +2 -2
  141. package/dist/ops/webgl/mulDropout.js +1 -1
  142. package/dist/ops/webgl/normRMS.js +2 -2
  143. package/dist/ops/webgl/qkv.js +1 -1
  144. package/dist/ops/webgl/rope.js +1 -1
  145. package/dist/ops/webgl/scatterSub.js +2 -2
  146. package/dist/ops/webgpu/adamAdjust.js +3 -3
  147. package/dist/ops/webgpu/adamMoments.js +3 -3
  148. package/dist/ops/webgpu/add16.js +6 -6
  149. package/dist/ops/webgpu/appendCache.js +3 -3
  150. package/dist/ops/webgpu/attentionMask.js +2 -2
  151. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  152. package/dist/ops/webgpu/clipScale.js +7 -7
  153. package/dist/ops/webgpu/concat16.js +5 -5
  154. package/dist/ops/webgpu/dropout16.js +6 -6
  155. package/dist/ops/webgpu/gatherSub.js +3 -3
  156. package/dist/ops/webgpu/gelu.js +8 -8
  157. package/dist/ops/webgpu/matMul16.js +16 -16
  158. package/dist/ops/webgpu/matMul16_program.js +2 -2
  159. package/dist/ops/webgpu/mul16.js +5 -5
  160. package/dist/ops/webgpu/norm2.js +1 -1
  161. package/dist/ops/webgpu/normRMS.js +2 -2
  162. package/dist/ops/webgpu/normRMSGrad.js +4 -4
  163. package/dist/ops/webgpu/pack16.js +4 -4
  164. package/dist/ops/webgpu/pack16_program.js +2 -2
  165. package/dist/ops/webgpu/qkv.js +2 -2
  166. package/dist/ops/webgpu/rope.js +3 -3
  167. package/dist/ops/webgpu/scatterSub.js +3 -3
  168. package/dist/ops/webgpu/slice16.js +4 -4
  169. package/dist/ops/webgpu/softmax16.js +4 -4
  170. package/dist/ops/webgpu/softmax16_program.js +2 -2
  171. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  172. package/dist/ops/webgpu/softmax16grad.js +4 -4
  173. package/dist/ops/webgpu/sub16.js +6 -6
  174. package/dist/ops/webgpu/sum16.js +3 -3
  175. package/dist/ops/webgpu/transpose16.js +8 -8
  176. package/dist/ops/webgpu/transpose16_program.js +2 -2
  177. package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
  178. package/dist/ops/webgpu/unpack16.js +3 -3
  179. package/dist/ops/webgpu/utils/binary_op.js +3 -3
  180. package/dist/ops/webgpu/utils/reductions.js +5 -5
  181. package/dist/{ops-CsXeTq1P.js → ops-CURIZSVt.js} +100 -100
  182. package/dist/{pack16-bqltoUlR.js → pack16-WlOSOuZA.js} +2 -2
  183. package/dist/patches/webgpu_backend.js +6 -6
  184. package/dist/patches/webgpu_base.js +1 -1
  185. package/dist/patches/webgpu_program.js +2 -2
  186. package/dist/{random_normal-IBRrha8a.js → random_normal-CIm8lk2-.js} +1 -1
  187. package/dist/{random_width-DN5ZtQkM.js → random_width-B_fVXhGx.js} +131 -131
  188. package/dist/{range-C-CjF-LI.js → range-BDxO73mk.js} +1 -1
  189. package/dist/{readers-iz5u3HBo.js → readers-17HLdxVM.js} +2 -2
  190. package/dist/relu-DTvZKBsZ.js +9 -0
  191. package/dist/{reshape-BDOuCSNW.js → reshape-BIN71H3p.js} +1 -1
  192. package/dist/{resize_nearest_neighbor-BojqlfRe.js → resize_nearest_neighbor-C6_0dAnK.js} +41 -41
  193. package/dist/{rope-0j_f1TPm.js → rope-CC5RjmKU.js} +4 -4
  194. package/dist/{scatter_nd_util-ByNJaL6I.js → scatter_nd_util-C-x73Cj6.js} +1 -1
  195. package/dist/{segment_util-Dasb2Zaf.js → segment_util-4zuHV5IG.js} +2 -2
  196. package/dist/{selu_util-BLhIqRkw.js → selu_util-BXdhy_W6.js} +5 -5
  197. package/dist/{shared-CagdqkLh.js → shared-DRWDyk9w.js} +6 -6
  198. package/dist/{shared-3agzAqQ_.js → shared-zTaJ5siv.js} +1 -1
  199. package/dist/slice-BvItlgXu.js +12 -0
  200. package/dist/{slice_util-CC35pLmT.js → slice_util-DPY56GzQ.js} +5 -5
  201. package/dist/{softmax-D4q1LJN7.js → softmax-BLGJqdwx.js} +1 -1
  202. package/dist/split-BN9LkEgS.js +9 -0
  203. package/dist/{squeeze-ho4wLUek.js → squeeze-O_YWJpw_.js} +2 -2
  204. package/dist/{stack-DudVrtmG.js → stack-z6QE7kmP.js} +1 -1
  205. package/dist/{step-BTxPtq1r.js → step-DQY6_ABw.js} +4 -4
  206. package/dist/{sum-BpiwSWvg.js → sum-D39FeU5h.js} +3 -3
  207. package/dist/{tensor-BWFldCso.js → tensor-D8e0Gd7c.js} +1 -1
  208. package/dist/{tensor1d-LMGMIUlr.js → tensor1d-BMl0eZYV.js} +1 -1
  209. package/dist/{tensor2d-BnXMKScO.js → tensor2d-DTtQ1QcT.js} +1 -1
  210. package/dist/{tensor4d-C6UCG_u8.js → tensor4d-Dj4rDssL.js} +1 -1
  211. package/dist/{tfjs_backend-BGnG-ppu.js → tfjs_backend-Bk3PmK91.js} +65 -65
  212. package/dist/{tile-CFy-xTO6.js → tile-CsWlVKKz.js} +1 -1
  213. package/dist/tokeniser/BaseTokeniser.d.ts +4 -1
  214. package/dist/tokeniser/BaseTokeniser.js +21 -5
  215. package/dist/tokeniser/CharTokeniser.d.ts +1 -1
  216. package/dist/tokeniser/CharTokeniser.js +62 -50
  217. package/dist/tokeniser/bpe.d.ts +1 -1
  218. package/dist/tokeniser/bpe.js +41 -35
  219. package/dist/tokeniser/type.d.ts +3 -1
  220. package/dist/training/AdamW.d.ts +3 -0
  221. package/dist/training/AdamW.js +59 -30
  222. package/dist/training/BasicTrainer.d.ts +1 -0
  223. package/dist/training/BasicTrainer.js +112 -92
  224. package/dist/training/DatasetBuilder.js +3 -3
  225. package/dist/training/Evaluator.js +2 -2
  226. package/dist/training/LRScheduler.d.ts +1 -0
  227. package/dist/training/LRScheduler.js +18 -12
  228. package/dist/training/PreTrainer.js +3 -3
  229. package/dist/training/SFTDatasetBuilder.js +3 -3
  230. package/dist/training/SFTTrainer.js +1 -1
  231. package/dist/training/orthoGrad.js +1 -1
  232. package/dist/training/sparseCrossEntropy.js +30 -30
  233. package/dist/training/types.d.ts +5 -3
  234. package/dist/training/validation.js +13 -13
  235. package/dist/{transpose-9kRxIXWR.js → transpose-Qxz-4os3.js} +7 -7
  236. package/dist/{unsorted_segment_sum-DJvk5xnh.js → unsorted_segment_sum-BfFVV9Zm.js} +20 -20
  237. package/dist/utilities/datasetID.d.ts +2 -0
  238. package/dist/utilities/datasetID.js +21 -0
  239. package/dist/utilities/dummy.js +6 -6
  240. package/dist/utilities/multinomialCPU.js +2 -2
  241. package/dist/utilities/packed.js +1 -1
  242. package/dist/utilities/performance.js +1 -1
  243. package/dist/utilities/profile.js +1 -1
  244. package/dist/utilities/safetensors.js +2 -2
  245. package/dist/utilities/sentences.js +5 -5
  246. package/dist/utilities/weights.js +2 -2
  247. package/dist/{variable-Ck482e3n.js → variable-SSATClyt.js} +1 -1
  248. package/dist/{webgpu_program-B4HmApL1.js → webgpu_program-CbjdYLYk.js} +1 -1
  249. package/dist/{webgpu_util-DYlGSwOJ.js → webgpu_util-DuofJBMo.js} +7 -7
  250. package/dist/{zeros-DvZpK8s6.js → zeros-Bw0puq_w.js} +2 -2
  251. package/dist/{zeros_like-CWjDdwr-.js → zeros_like-rOHr54NY.js} +69 -69
  252. package/package.json +3 -3
  253. package/dist/complex-DI35Q-gW.js +0 -11
  254. package/dist/mat_mul-DhG0Newp.js +0 -11
  255. package/dist/mod-CSdCpRjf.js +0 -11
  256. package/dist/relu-J_X6MUzx.js +0 -9
  257. package/dist/slice-BzS11Qh0.js +0 -12
  258. package/dist/split-C2Sj255c.js +0 -9
@@ -1,15 +1,15 @@
1
- import { yieldIfNeeded as f } from "../utilities/yielder.js";
1
+ import { yieldIfNeeded as p } from "../utilities/yielder.js";
2
2
  import m from "../utilities/tokenParse.js";
3
- import z, { SPECIALS as k } from "./BaseTokeniser.js";
4
- function p(o, e) {
3
+ import T, { SPECIALS as S } from "./BaseTokeniser.js";
4
+ function g(o, e) {
5
5
  return `${o}-::-${e}`;
6
6
  }
7
- function w(o) {
7
+ function y(o) {
8
8
  const e = /* @__PURE__ */ new Map();
9
9
  for (let s = 0; s < o.length; s++) {
10
10
  const t = o[s];
11
11
  for (let n = 0; n < t.length - 1; n++) {
12
- const r = p(t[n], t[n + 1]), a = e.get(r) || {
12
+ const r = g(t[n], t[n + 1]), a = e.get(r) || {
13
13
  a: t[n],
14
14
  b: t[n + 1],
15
15
  count: 0,
@@ -20,21 +20,21 @@ function w(o) {
20
20
  }
21
21
  return { pairs: e, tokens: o };
22
22
  }
23
- function d(o, e, s, t, n) {
24
- const r = p(e, s);
23
+ function f(o, e, s, t, n) {
24
+ const r = g(e, s);
25
25
  if (o.pairs.has(r)) {
26
26
  const a = o.pairs.get(r);
27
27
  a.count += n, n > 0 ? a.instances.add(t) : a.count <= 0 ? o.pairs.delete(r) : a.instances.delete(t);
28
28
  } else
29
29
  o.pairs.set(r, { a: e, b: s, count: n, instances: /* @__PURE__ */ new Set([t]) });
30
30
  }
31
- function T(o) {
31
+ function I(o) {
32
32
  let e = null, s = 0;
33
33
  for (const t of o.pairs.values())
34
34
  t.count > s && (s = t.count, e = t);
35
35
  return e;
36
36
  }
37
- function y(o, e) {
37
+ function x(o, e) {
38
38
  return o.map((s) => {
39
39
  const t = [];
40
40
  for (let n = 0; n < s.length; n++)
@@ -42,19 +42,19 @@ function y(o, e) {
42
42
  return t;
43
43
  });
44
44
  }
45
- function I(o, e) {
45
+ function A(o, e) {
46
46
  e.instances.forEach((s) => {
47
47
  const t = o.tokens[s], n = [];
48
48
  for (let r = 0; r < t.length; r++)
49
49
  if (r < t.length - 1 && t[r] === e.a && t[r + 1] === e.b) {
50
50
  const a = e.a + e.b;
51
- n.push(a), r > 0 && (d(o, t[r - 1], e.a, s, -1), d(o, t[r - 1], a, s, 1)), r++, r < t.length - 1 && (d(o, e.b, t[r + 1], s, -1), d(o, a, t[r + 1], s, 1));
51
+ n.push(a), r > 0 && (f(o, t[r - 1], e.a, s, -1), f(o, t[r - 1], a, s, 1)), r++, r < t.length - 1 && (f(o, e.b, t[r + 1], s, -1), f(o, a, t[r + 1], s, 1));
52
52
  } else
53
53
  n.push(t[r]);
54
54
  o.tokens[s] = n;
55
- }), o.pairs.delete(p(e.a, e.b));
55
+ }), o.pairs.delete(g(e.a, e.b));
56
56
  }
57
- class E extends z {
57
+ class P extends T {
58
58
  targetSize;
59
59
  vocab = /* @__PURE__ */ new Set();
60
60
  vocabIndex = /* @__PURE__ */ new Map();
@@ -63,7 +63,7 @@ class E extends z {
63
63
  constructor(e, s) {
64
64
  super(), Array.isArray(e) ? (e.forEach((t, n) => {
65
65
  this.vocab.add(t), this.vocabIndex.set(t, n);
66
- }), s && (this.merges = s), this.targetSize = e.length, k.forEach((t) => {
66
+ }), s && (this.merges = s), this.targetSize = e.length, S.forEach((t) => {
67
67
  const n = e.indexOf(t);
68
68
  n !== -1 && this.addSpecialToken(t, n);
69
69
  })) : (this.addSpecialTokens(), this.targetSize = e);
@@ -81,7 +81,7 @@ class E extends z {
81
81
  this.vocab.clear(), this.vocabIndex.clear(), this.merges = [], this.pretokenMap.clear();
82
82
  }
83
83
  get trained() {
84
- return this.vocab.size > k.length && this.vocab.size <= this.targetSize;
84
+ return this.vocab.size > S.length && this.vocab.size <= this.targetSize;
85
85
  }
86
86
  get vocabSize() {
87
87
  return this.vocab.size;
@@ -95,42 +95,48 @@ class E extends z {
95
95
  get unkToken() {
96
96
  return this.vocabIndex.get("") ?? 1;
97
97
  }
98
- async train(e = [], s) {
99
- let t = performance.now();
100
- const n = e.map((i) => i.map((h) => m(h.content))).flat(2);
101
- t = await f(t, s, this.vocab.size);
102
- const r = new Set(n);
98
+ async train(e = [], s, t) {
99
+ this.datasetID = t;
100
+ let n = performance.now();
101
+ const r = new Array(e.length);
102
+ for (let i = 0; i < e.length; i++) {
103
+ const h = e[i], l = new Array(h.length);
104
+ for (let d = 0; d < h.length; d++)
105
+ l[d] = m(h[d].content);
106
+ n = await p(n, s, this.vocab.size), r[i] = l;
107
+ }
108
+ const a = r.flat(2), z = new Set(a);
103
109
  this.vocab = /* @__PURE__ */ new Set(), this.pretokenMap.clear(), this.merges = [], this.addSpecialTokens();
104
- const a = Array.from(r), b = a.map((i) => Array.from(i).map((l) => (this.vocab.add(l), l))), g = w(b);
105
- if (t = await f(t, s, this.vocab.size), this.vocab.size >= this.targetSize) {
110
+ const b = Array.from(z), v = b.map((i) => Array.from(i).map((l) => (this.vocab.add(l), l))), k = y(v);
111
+ if (n = await p(n, s, this.vocab.size), this.vocab.size >= this.targetSize) {
106
112
  console.warn("Initial vocab size is greater than or equal to target size. No merges will be performed.");
107
113
  const i = /* @__PURE__ */ new Map();
108
- n.forEach((c) => {
114
+ a.forEach((c) => {
109
115
  Array.from(c).forEach((u) => {
110
116
  i.set(u, (i.get(u) || 0) + 1);
111
117
  });
112
118
  });
113
119
  const h = Array.from(i.entries()).sort((c, u) => u[1] - c[1]);
114
120
  this.vocab = /* @__PURE__ */ new Set(), this.addSpecialTokens(), h.slice(0, this.targetSize - this.vocab.size).map(([c]) => c).forEach((c) => this.vocab.add(c)), this.vocabIndex.clear();
115
- let S = 0;
121
+ let d = 0;
116
122
  for (const c of this.vocab.keys())
117
- this.vocabIndex.set(c, S++);
118
- return this.emit("trainStatus", "trained"), this.vocab.size;
123
+ this.vocabIndex.set(c, d++);
124
+ return this.generateID(), this.emit("trainStatus", "trained"), this.vocab.size;
119
125
  }
120
126
  for (; this.vocab.size < this.targetSize && this.merges.length < this.targetSize; ) {
121
- const i = T(g);
127
+ const i = I(k);
122
128
  if (!i)
123
129
  break;
124
- this.merges.push([i.a, i.b]), this.vocab.add(i.a + i.b), I(g, i), t = await f(t, s, this.vocab.size);
130
+ this.merges.push([i.a, i.b]), this.vocab.add(i.a + i.b), A(k, i), n = await p(n, s, this.vocab.size);
125
131
  }
126
- a.forEach((i, h) => {
127
- const l = b[h];
132
+ b.forEach((i, h) => {
133
+ const l = v[h];
128
134
  this.pretokenMap.set(i, l);
129
135
  }), this.vocabIndex.clear();
130
- let v = 0;
136
+ let w = 0;
131
137
  for (const i of this.vocab.keys())
132
- this.vocabIndex.set(i, v++);
133
- return this.emit("trainStatus", "trained"), this.vocab.size;
138
+ this.vocabIndex.set(i, w++);
139
+ return this.generateID(), this.emit("trainStatus", "trained"), this.vocab.size;
134
140
  }
135
141
  getVocab() {
136
142
  return Array.from(this.vocab);
@@ -141,7 +147,7 @@ class E extends z {
141
147
  tokeniseWord(e) {
142
148
  let s = Array.from(e);
143
149
  return this.merges.forEach((t) => {
144
- s = y([s], t)[0];
150
+ s = x([s], t)[0];
145
151
  }), this.pretokenMap.set(e, s), s;
146
152
  }
147
153
  tokeniseStrings(e) {
@@ -163,5 +169,5 @@ class E extends z {
163
169
  }
164
170
  }
165
171
  export {
166
- E as default
172
+ P as default
167
173
  };
@@ -5,7 +5,9 @@ export interface Conversation {
5
5
  content: string;
6
6
  }
7
7
  export interface ITokeniser extends EE<'trainStatus'> {
8
- train(text: Conversation[][], cb?: (vocab: number) => void): Promise<number>;
8
+ id: string;
9
+ datasetID?: string;
10
+ train(text: Conversation[][], cb?: (vocab: number) => void, datasetID?: string): Promise<number>;
9
11
  getVocab(): string[];
10
12
  getMerges(): [string, string][];
11
13
  destroy(): void;
@@ -21,6 +21,9 @@ export declare class AdamWOptimizer extends Optimizer {
21
21
  protected orthGrad: boolean;
22
22
  constructor(config: AdamWOptimizerConfig);
23
23
  get lr(): number;
24
+ saveMoments(): Promise<ArrayBuffer>;
25
+ loadMoments(momentData: ArrayBuffer): Promise<void>;
26
+ serializeConfig(): AdamWOptimizerConfig;
24
27
  private orthogonalizeGradient;
25
28
  updateConfig(newConfig: Partial<AdamWOptimizerConfig>): void;
26
29
  applyGradients(variableGradients: NamedVariableMap | NamedTensor[]): Tensor;
@@ -1,12 +1,13 @@
1
- import { adamAdjust as N } from "../ops/adamAdjust.js";
2
- import { adamMoments as S } from "../ops/adamMoments.js";
3
- import { O as R, h as b, t as h, a as w, d as B } from "../index-CUXkjxiT.js";
4
- import M from "./LRScheduler.js";
5
- import { clipScale as A } from "../ops/globalNorm.js";
6
- import { z as O } from "../zeros-DvZpK8s6.js";
7
- class G extends R {
1
+ import { adamAdjust as B } from "../ops/adamAdjust.js";
2
+ import { adamMoments as N } from "../ops/adamMoments.js";
3
+ import { O as S, e as b, t as c, b as M, l as w } from "../index-DSGwv2Yx.js";
4
+ import R from "./LRScheduler.js";
5
+ import { clipScale as f } from "../ops/globalNorm.js";
6
+ import { save_safetensors as v, load_safetensors as A } from "../utilities/safetensors.js";
7
+ import { z as O } from "../zeros-Bw0puq_w.js";
8
+ class _ extends S {
8
9
  constructor(t) {
9
- super(), this.config = t, this.accBeta1 = t.beta1, this.accBeta2 = t.beta2, this.learningRate = t.learningRate, this.beta1 = t.beta1, this.beta2 = t.beta2, this.weightDecay = t.weightDecay, this.lossScaling = t.lossScaling, this.clipNorm = t.clipNorm, this.orthGrad = t.orthoGrad ?? !1, t.epsilon === null || t.epsilon === void 0 ? this.epsilon = b().backend.epsilon() : this.epsilon = t.epsilon, this.lrScheduler = new M(t.learningRate, t);
10
+ super(), this.config = t, this.accBeta1 = t.accBeta1 ?? t.beta1, this.accBeta2 = t.accBeta2 ?? t.beta2, this.learningRate = t.learningRate, this.beta1 = t.beta1, this.beta2 = t.beta2, this.weightDecay = t.weightDecay, this.lossScaling = t.lossScaling, this.clipNorm = t.clipNorm, this.orthGrad = t.orthoGrad ?? !1, t.epsilon === null || t.epsilon === void 0 ? this.epsilon = b().backend.epsilon() : this.epsilon = t.epsilon, this.lrScheduler = new R(t.learningRate, t);
10
11
  }
11
12
  className = "AdamW";
12
13
  accBeta1 = 0;
@@ -25,10 +26,38 @@ class G extends R {
25
26
  get lr() {
26
27
  return this.learningRate;
27
28
  }
29
+ saveMoments() {
30
+ const t = {};
31
+ return this.accumulatedMoments.forEach((e) => {
32
+ t[e.originalName] = e.variable;
33
+ }), v(t);
34
+ }
35
+ async loadMoments(t) {
36
+ const e = await A(t);
37
+ Object.entries(e).forEach(([a, s]) => {
38
+ const n = s.variable(!1);
39
+ this.accumulatedMoments.push({ originalName: a, variable: n });
40
+ });
41
+ }
42
+ serializeConfig() {
43
+ return {
44
+ learningRate: this.learningRate,
45
+ beta1: this.beta1,
46
+ beta2: this.beta2,
47
+ accBeta1: this.accBeta1,
48
+ accBeta2: this.accBeta2,
49
+ epsilon: this.epsilon ?? void 0,
50
+ weightDecay: this.weightDecay,
51
+ lossScaling: this.lossScaling,
52
+ clipNorm: this.clipNorm,
53
+ orthoGrad: this.orthGrad,
54
+ ...this.lrScheduler.serializeConfig()
55
+ };
56
+ }
28
57
  orthogonalizeGradient(t, e) {
29
- return h(() => {
30
- const a = t.reshape([-1]), s = e.reshape([-1]), l = a.mul(a).sum().add(this.orthGradEpsilon), c = a.mul(s).sum().div(l), n = s.sub(a.mul(c)), o = s.norm(), i = n.norm().add(this.orthGradEpsilon);
31
- return n.mul(o.div(i)).reshape(e.shape);
58
+ return c(() => {
59
+ const a = t.reshape([-1]), s = e.reshape([-1]), n = a.mul(a).sum().add(this.orthGradEpsilon), h = a.mul(s).sum().div(n), o = s.sub(a.mul(h)), l = s.norm(), i = o.norm().add(this.orthGradEpsilon);
60
+ return o.mul(l.div(i)).reshape(e.shape);
32
61
  });
33
62
  }
34
63
  updateConfig(t) {
@@ -38,42 +67,42 @@ class G extends R {
38
67
  applyGradients(t) {
39
68
  const e = this.lrScheduler.getNextLR();
40
69
  this.learningRate = e;
41
- const a = Array.isArray(t) ? t.map((l) => l.name) : Object.keys(t), s = h(() => {
42
- const l = 1 - this.accBeta1, c = 1 - this.accBeta2;
43
- let n;
70
+ const a = Array.isArray(t) ? t.map((n) => n.name) : Object.keys(t), s = c(() => {
71
+ const n = 1 - this.accBeta1, h = 1 - this.accBeta2;
72
+ let o;
44
73
  if (this.clipNorm !== void 0) {
45
- const o = a.map((i, r) => Array.isArray(t) ? t[r].tensor : t[i]);
46
- n = A(o, 1 / this.lossScaling, this.clipNorm);
74
+ const l = a.map((i, r) => Array.isArray(t) ? t[r].tensor : t[i]);
75
+ o = f(l, 1 / this.lossScaling, this.clipNorm);
47
76
  } else
48
- n = w(1 / this.lossScaling);
49
- return a.forEach((o, i) => {
50
- const r = b().registeredVariables[o], p = !1;
77
+ o = M(1 / this.lossScaling);
78
+ return a.forEach((l, i) => {
79
+ const r = b().registeredVariables[l], p = !1;
51
80
  this.accumulatedMoments[i] == null && (this.accumulatedMoments[i] = {
52
- originalName: `${o}/m`,
53
- variable: h(() => O([...r.shape, 2]).variable(p))
81
+ originalName: `${l}/m`,
82
+ variable: c(() => O([...r.shape, 2]).variable(p))
54
83
  });
55
- const m = Array.isArray(t) ? t[i].tensor : t[o];
84
+ const m = Array.isArray(t) ? t[i].tensor : t[l];
56
85
  if (m == null)
57
86
  return;
58
- const u = this.orthGrad ? this.orthogonalizeGradient(r, m) : m, d = this.accumulatedMoments[i].variable, g = S(d, u, this.beta1, this.beta2, n);
87
+ const u = this.orthGrad ? this.orthogonalizeGradient(r, m) : m, d = this.accumulatedMoments[i].variable, g = N(d, u, this.beta1, this.beta2, o);
59
88
  d.assign(g), this.orthGrad && u.dispose();
60
- const y = N(
89
+ const y = B(
61
90
  g,
62
91
  r,
63
- l,
64
- c,
92
+ n,
93
+ h,
65
94
  this.epsilon ?? 1e-8,
66
95
  this.learningRate,
67
96
  // Only apply weight decay if the variable is multi-dimensional (e.g. weights, not biases)
68
97
  r.shape.length > 1 ? this.weightDecay : 0
69
98
  );
70
99
  r.assign(y);
71
- }), this.accBeta1 = this.accBeta1 * this.beta1, this.accBeta2 = this.accBeta2 * this.beta2, n;
100
+ }), this.accBeta1 = this.accBeta1 * this.beta1, this.accBeta2 = this.accBeta2 * this.beta2, o;
72
101
  });
73
102
  return this.incrementIterations(), s;
74
103
  }
75
104
  dispose() {
76
- this.accumulatedMoments != null && B(this.accumulatedMoments.map((t) => t.variable));
105
+ this.accumulatedMoments != null && w(this.accumulatedMoments.map((t) => t.variable));
77
106
  }
78
107
  async getWeights() {
79
108
  const t = [...this.accumulatedMoments];
@@ -82,7 +111,7 @@ class G extends R {
82
111
  );
83
112
  }
84
113
  async setWeights(t) {
85
- t = await this.extractIterations(t), h(() => {
114
+ t = await this.extractIterations(t), c(() => {
86
115
  this.accBeta1 = Math.pow(this.beta1, this.iterations_ + 1), this.accBeta2 = Math.pow(this.beta2, this.iterations_ + 1);
87
116
  });
88
117
  const e = t.length / 2, a = !1;
@@ -105,5 +134,5 @@ class G extends R {
105
134
  }
106
135
  }
107
136
  export {
108
- G as AdamWOptimizer
137
+ _ as AdamWOptimizer
109
138
  };
@@ -31,6 +31,7 @@ export default class BasicTrainer {
31
31
  get isRunning(): boolean;
32
32
  getOptimizer(): AdamWOptimizer;
33
33
  updateOptimizer(config?: Partial<AdamWOptimizerConfig>): void;
34
+ resumeFromLog(log: TrainingLogEntry): void;
34
35
  protected trainStep(state: Partial<TrainingState>, batch: {
35
36
  xs: Tensor;
36
37
  ys: Tensor;
@@ -1,16 +1,16 @@
1
- import u from "./Evaluator.js";
2
- import { t as z, v as P, k as g, d as p, a as y } from "../index-CUXkjxiT.js";
3
- import S from "../utilities/profile.js";
4
- import { createTensorStatistics as k } from "../checks/weights.js";
5
- import { calculateLoss as x, calculateAccuracy as T } from "./loss.js";
6
- import { AdamWOptimizer as N } from "./AdamW.js";
7
- import { z as w } from "../zeros-DvZpK8s6.js";
8
- const v = {
1
+ import y from "./Evaluator.js";
2
+ import { t as L, Z as k, k as u, l as p, b as S } from "../index-DSGwv2Yx.js";
3
+ import w from "../utilities/profile.js";
4
+ import { createTensorStatistics as b } from "../checks/weights.js";
5
+ import { calculateLoss as x, calculateAccuracy as P } from "./loss.js";
6
+ import { AdamWOptimizer as T } from "./AdamW.js";
7
+ import { z as v } from "../zeros-Bw0puq_w.js";
8
+ const z = {
9
9
  logInterval: 1,
10
10
  maxEpochs: 100,
11
11
  sftMode: "full",
12
12
  batchSize: 32
13
- }, b = {
13
+ }, D = {
14
14
  learningRate: 3e-4,
15
15
  beta1: 0.9,
16
16
  beta2: 0.99,
@@ -23,14 +23,14 @@ const v = {
23
23
  lossScaling: 1
24
24
  };
25
25
  class G {
26
- constructor(s, i, o, c) {
27
- this.tokenizer = i, this.model = s, this.optimizerConfig = {
28
- ...b,
29
- ...o,
26
+ constructor(s, e, n, l) {
27
+ this.tokenizer = e, this.model = s, this.optimizerConfig = {
28
+ ...D,
29
+ ...n,
30
30
  lossScaling: s.lossScaling
31
31
  };
32
- const l = c || new N(this.optimizerConfig);
33
- c && c.updateConfig(this.optimizerConfig), this.optimizer = l;
32
+ const m = l || new T(this.optimizerConfig);
33
+ l && l.updateConfig(this.optimizerConfig), this.optimizer = m;
34
34
  }
35
35
  model;
36
36
  optimizer;
@@ -80,11 +80,22 @@ class G {
80
80
  updateOptimizer(s) {
81
81
  s && (this.optimizerConfig = { ...this.optimizerConfig, ...s }), this.optimizer.updateConfig(this.optimizerConfig);
82
82
  }
83
+ resumeFromLog(s) {
84
+ (!this.lastState || this.lastState.step === 0) && (this.lastState = {
85
+ losses: [],
86
+ validationLosses: [],
87
+ logStartTime: 0,
88
+ step: s.step,
89
+ lastLoss: s.trainingMetrics.loss,
90
+ totalSteps: s.step,
91
+ trainingDuration: s.duration
92
+ });
93
+ }
83
94
  // A single forward pass, backward pass, and optimizer step
84
- trainStep(s, i, o = !1, c = !1) {
85
- return z(() => {
95
+ trainStep(s, e, n = !1, l = !1) {
96
+ return L(() => {
86
97
  this.model.getProfiler()?.startMemory();
87
- const { xs: l, ys: a } = i, d = () => {
98
+ const { xs: m, ys: i } = e, d = () => {
88
99
  const r = this.model.forward(
89
100
  {
90
101
  training: !0,
@@ -93,32 +104,32 @@ class G {
93
104
  dropout: this._dropout,
94
105
  layerDrop: this._layerDrop
95
106
  },
96
- l
97
- ), e = x(r, a, this.maskedLoss, !1, this._labelSmoothing);
98
- this.metrics.has("accuracy") && (s.accuracy = T(r, a), g(s.accuracy)), r.dispose();
99
- const m = e.mul(y(this.optimizerConfig.lossScaling));
100
- return e.dispose(), m;
101
- }, { value: t, grads: n } = P(d);
102
- if (o)
107
+ m
108
+ ), o = x(r, i, this.maskedLoss, !1, this._labelSmoothing);
109
+ this.metrics.has("accuracy") && (s.accuracy = P(r, i), u(s.accuracy)), r.dispose();
110
+ const a = o.mul(S(this.optimizerConfig.lossScaling));
111
+ return o.dispose(), a;
112
+ }, { value: t, grads: c } = k(d);
113
+ if (n)
103
114
  this.model.getProfiler()?.endMemory("Training");
104
115
  else {
105
- const r = this.optimizer.applyGradients(n);
106
- this.metrics.has("gradientNorm") ? (s.gradientNorm = r, g(r)) : (s.gradientNorm = void 0, r.dispose());
107
- const e = Object.keys(n);
108
- this.model.weightStore.touchVariables(e), this.model.getProfiler()?.endMemory("Training"), c ? (s.gradients = n, Object.values(n).forEach((m) => g(m))) : p(n);
116
+ const r = this.optimizer.applyGradients(c);
117
+ this.metrics.has("gradientNorm") ? (s.gradientNorm = r, u(r)) : (s.gradientNorm = void 0, r.dispose());
118
+ const o = Object.keys(c);
119
+ this.model.weightStore.touchVariables(o), this.model.getProfiler()?.endMemory("Training"), l ? (s.gradients = c, Object.values(c).forEach((a) => u(a))) : p(c);
109
120
  }
110
- return t.mul(y(1 / this.optimizerConfig.lossScaling));
121
+ return t.mul(S(1 / this.optimizerConfig.lossScaling));
111
122
  });
112
123
  }
113
124
  async dummyPass() {
114
- const s = w([1, this.model.config.blockSize], "int32"), i = w([1, this.model.config.blockSize], "int32");
125
+ const s = v([1, this.model.config.blockSize], "int32"), e = v([1, this.model.config.blockSize], "int32");
115
126
  try {
116
- const o = this.trainStep({}, { xs: s, ys: i }, !0);
117
- await o.data(), o.dispose();
118
- } catch (o) {
119
- console.error("Error during dummy pass:", o);
127
+ const n = this.trainStep({}, { xs: s, ys: e }, !0);
128
+ await n.data(), n.dispose();
129
+ } catch (n) {
130
+ console.error("Error during dummy pass:", n);
120
131
  } finally {
121
- s.dispose(), i.dispose();
132
+ s.dispose(), e.dispose();
122
133
  }
123
134
  }
124
135
  dispose() {
@@ -136,33 +147,40 @@ class G {
136
147
  ...this.lastState || {}
137
148
  };
138
149
  }
139
- async stepDataset(s, i, o) {
140
- const { logInterval: c = 10 } = {
141
- ...v,
142
- ...i
150
+ async stepDataset(s, e, n) {
151
+ const { logInterval: l = 10 } = {
152
+ ...z,
153
+ ...e
143
154
  };
144
- i.metrics && this.setMetrics(i.metrics);
145
- const l = Date.now(), a = this.createEmptyState();
146
- this.lastState = a, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new S())), this.running = !0, a.logStartTime = l;
147
- const d = o ? new u(this.model, o, void 0, this.maskedLoss) : void 0, t = await s.iterator();
155
+ e.metrics && this.setMetrics(e.metrics);
156
+ const m = Date.now(), i = this.createEmptyState();
157
+ this.lastState = i, await this.dummyPass(), this.metrics.has("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new w())), this.running = !0, i.logStartTime = m;
158
+ const d = n ? new y(this.model, n, void 0, this.maskedLoss) : void 0, t = await s.iterator();
148
159
  try {
149
160
  for (; this.running; ) {
150
- const n = await t.next();
151
- if (n.done) break;
152
- const r = n.value, e = this.trainStep(a, r, !1);
153
- r.xs.dispose(), r.ys.dispose(), a.step++, a.totalSteps++, a.step % c === 0 ? await this.performLogging(e, r.xs.shape[0], i, d) : (a.gradientNorm && (a.gradientNorm.dispose(), a.gradientNorm = void 0), a.accuracy && (a.accuracy.dispose(), a.accuracy = void 0)), e.dispose();
161
+ const c = await t.next();
162
+ if (c.done) break;
163
+ const r = c.value, o = this.trainStep(i, r, !1);
164
+ r.xs.dispose(), r.ys.dispose(), i.step++, i.totalSteps++, i.step % l === 0 ? await this.performLogging(o, r.xs.shape[0], e, d) : (i.gradientNorm && (i.gradientNorm.dispose(), i.gradientNorm = void 0), i.accuracy && (i.accuracy.dispose(), i.accuracy = void 0)), o.dispose();
154
165
  }
155
- } catch (n) {
156
- throw console.error("Training error:", n), p(), n;
166
+ } catch (c) {
167
+ throw console.error("Training error:", c), c;
157
168
  }
158
- throw p(), this.running = !1, new Error("No log returned before training stopped.");
159
- }
160
- async performLogging(s, i, o, c) {
161
- const l = o?.onStep, a = this.metrics.has("gradientStatistics"), d = (await s.data())[0], t = this.lastState;
169
+ throw this.model.trainingState = {
170
+ steps: i.totalSteps,
171
+ learningRate: this.optimizer.lr,
172
+ batchSize: e.batchSize || 32,
173
+ loss: i.lastLoss,
174
+ tokensProcessed: i.totalSteps * (e.batchSize || 32) * this.model.config.blockSize,
175
+ duration: i.trainingDuration
176
+ }, p(), this.running = !1, new Error("No log returned before training stopped.");
177
+ }
178
+ async performLogging(s, e, n, l) {
179
+ const m = n?.onStep, i = this.metrics.has("gradientStatistics"), d = (await s.data())[0], t = this.lastState;
162
180
  t.lastLoss = d;
163
- const n = Date.now();
164
- t.trainingDuration += n - t.logStartTime;
165
- const r = {
181
+ const c = Date.now();
182
+ t.trainingDuration += c - t.logStartTime;
183
+ const r = t.totalSteps * e * this.model.config.blockSize, o = {
166
184
  trainingMetrics: {
167
185
  loss: t.lastLoss,
168
186
  perplexity: this.metrics.has("perplexity") ? Math.exp(t.lastLoss) : void 0,
@@ -171,55 +189,57 @@ class G {
171
189
  step: t.step,
172
190
  time: Date.now() - t.logStartTime,
173
191
  gradientNorm: t.gradientNorm ? (await t.gradientNorm.data())[1] : void 0,
174
- batchSize: i,
192
+ batchSize: e,
175
193
  learningRate: this.metrics.has("learningRate") ? this.optimizer.lr : void 0,
176
194
  duration: t.trainingDuration,
177
- totalSamples: t.totalSteps * i,
178
- samplesPerSecond: t.totalSteps * i / (t.trainingDuration / 1e3),
195
+ totalTokens: r,
196
+ tokensPerSecond: r / (t.trainingDuration / 1e3),
179
197
  memoryUsage: this.metrics.has("memoryUsage") ? this.model.getProfiler()?.getPeakMemory() || 0 : void 0
180
198
  };
181
- if (this.metrics.has("tokensPerSecond") && (r.tokensPerSecond = r.samplesPerSecond * this.model.config.blockSize), t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0), this.model.trainingState = {
199
+ if (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0), this.model.trainingState = {
182
200
  steps: t.totalSteps,
183
201
  learningRate: this.optimizer.lr,
184
- batchSize: i,
185
- loss: t.lastLoss
186
- }, a && t.gradients) {
187
- const e = /* @__PURE__ */ new Map();
188
- for (const [m, h] of Object.entries(t.gradients))
189
- e.set(m, await k(h)), h.dispose();
190
- r.gradientMetrics = e;
202
+ batchSize: e,
203
+ loss: t.lastLoss,
204
+ tokensProcessed: r,
205
+ duration: t.trainingDuration
206
+ }, i && t.gradients) {
207
+ const a = /* @__PURE__ */ new Map();
208
+ for (const [h, g] of Object.entries(t.gradients))
209
+ a.set(h, await b(g)), g.dispose();
210
+ o.gradientMetrics = a;
191
211
  }
192
- if (c)
212
+ if (l)
193
213
  try {
194
- const e = await c.evaluate(5);
195
- Array.isArray(e) ? r.validationMetrics = { loss: e[0].loss, accuracy: e[0].accuracy } : (t.validationLosses.push(e.loss), r.validationMetrics = {
196
- accuracy: e.accuracy,
197
- loss: e.loss,
198
- perplexity: this.metrics.has("perplexity") ? Math.exp(e.loss) : void 0
214
+ const a = await l.evaluate(5);
215
+ Array.isArray(a) ? o.validationMetrics = { loss: a[0].loss, accuracy: a[0].accuracy } : (t.validationLosses.push(a.loss), o.validationMetrics = {
216
+ accuracy: a.accuracy,
217
+ loss: a.loss,
218
+ perplexity: this.metrics.has("perplexity") ? Math.exp(a.loss) : void 0
199
219
  });
200
- } catch (e) {
201
- console.error("Validation error:", e);
220
+ } catch (a) {
221
+ console.error("Validation error:", a);
202
222
  }
203
- l && await l(r), t.logStartTime = Date.now();
204
- }
205
- async trainOnDataset(s, i, o) {
206
- const { logInterval: c = 10, maxEpochs: l = 1 / 0 } = {
207
- ...v,
208
- ...i
209
- }, a = l * (i?.epochSteps || 1e3);
210
- i.metrics && this.setMetrics(i.metrics);
223
+ m && await m(o), t.logStartTime = Date.now();
224
+ }
225
+ async trainOnDataset(s, e, n) {
226
+ const { logInterval: l = 10, maxEpochs: m = 1 / 0 } = {
227
+ ...z,
228
+ ...e
229
+ }, i = m * (e?.epochSteps || 1e3);
230
+ e.metrics && this.setMetrics(e.metrics);
211
231
  const d = Date.now(), t = this.createEmptyState();
212
- this.lastState = t, await this.dummyPass(), i?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new S())), this.running = !0, t.logStartTime = d;
213
- const n = o ? new u(this.model, o, void 0, this.maskedLoss) : void 0, r = await s.iterator();
232
+ this.lastState = t, await this.dummyPass(), e?.metrics?.includes("memoryUsage") && (this.model.getProfiler() || this.model.setProfiler(new w())), this.running = !0, t.logStartTime = d;
233
+ const c = n ? new y(this.model, n, void 0, this.maskedLoss) : void 0, r = await s.iterator();
214
234
  try {
215
235
  for (; this.running; ) {
216
- const e = await r.next();
217
- if (e.done) break;
218
- const m = e.value, h = t.step % c === 0, L = (i?.metrics?.includes("gradientStatistics") || !1) && h, f = this.trainStep(t, m, !1, L);
219
- m.xs.dispose(), m.ys.dispose(), t.step++, t.totalSteps++, h ? await this.performLogging(f, m.xs.shape[0], i, n) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= a && this.stop();
236
+ const o = await r.next();
237
+ if (o.done) break;
238
+ const a = o.value, h = t.step % l === 0, g = (e?.metrics?.includes("gradientStatistics") || !1) && h, f = this.trainStep(t, a, !1, g);
239
+ a.xs.dispose(), a.ys.dispose(), t.step++, t.totalSteps++, h ? await this.performLogging(f, a.xs.shape[0], e, c) : (t.gradientNorm && (t.gradientNorm.dispose(), t.gradientNorm = void 0), t.accuracy && (t.accuracy.dispose(), t.accuracy = void 0)), f.dispose(), t.step >= i && this.stop();
220
240
  }
221
- } catch (e) {
222
- throw console.error("Training error:", e), p(), e;
241
+ } catch (o) {
242
+ throw console.error("Training error:", o), p(), o;
223
243
  }
224
244
  return p(), this.running = !1, { losses: t.losses, validationLosses: t.validationLosses };
225
245
  }
@@ -1,6 +1,6 @@
1
- import { t as f } from "../index-CUXkjxiT.js";
2
- import "../dataset-CGGp1z9P.js";
3
- import { g as a } from "../readers-iz5u3HBo.js";
1
+ import { t as f } from "../index-DSGwv2Yx.js";
2
+ import "../dataset-DlqAN81i.js";
3
+ import { g as a } from "../readers-17HLdxVM.js";
4
4
  import "../index-Cp39cXWe.js";
5
5
  const g = 8;
6
6
  async function p(n, e) {
@@ -1,7 +1,7 @@
1
- import { t as p } from "../index-CUXkjxiT.js";
1
+ import { t as p } from "../index-DSGwv2Yx.js";
2
2
  import { calculateLoss as d, calculateAccuracy as m } from "./loss.js";
3
3
  import { buildSFTExample as x } from "./SFTDatasetBuilder.js";
4
- import { t as h } from "../tensor-BWFldCso.js";
4
+ import { t as h } from "../tensor-D8e0Gd7c.js";
5
5
  class k {
6
6
  constructor(i, t, o, c) {
7
7
  if (this.model = i, this.masked = !!c, Array.isArray(t)) {