@genai-fi/nanogpt 0.17.4 → 0.18.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. package/dist/Generator.d.ts +2 -15
  2. package/dist/Generator.js +45 -34
  3. package/dist/{RealDiv-CGwv0liw.js → RealDiv-ioj6Z-ox.js} +9 -9
  4. package/dist/{Reshape-BW__R4mZ.js → Reshape-BZC-ebeR.js} +7 -7
  5. package/dist/{Reshape-CPBkTIH2.js → Reshape-pwprEaej.js} +1 -1
  6. package/dist/TeachableLLM.d.ts +3 -8
  7. package/dist/TeachableLLM.js +61 -44
  8. package/dist/Trainer.d.ts +6 -4
  9. package/dist/Trainer.js +107 -92
  10. package/dist/{axis_util-GTVlo58H.js → axis_util-QWWgLjut.js} +1 -1
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-GaFarB78.js → backend_util-qwSFfxYx.js} +21 -21
  13. package/dist/{backend_webgpu-BqASlsbV.js → backend_webgpu-DI2wXEC2.js} +8 -8
  14. package/dist/{broadcast_to-eS93CCN_.js → broadcast_to-C_EJTVTZ.js} +2 -2
  15. package/dist/checks/appendCache.js +2 -2
  16. package/dist/checks/attentionMask.js +5 -5
  17. package/dist/checks/gelu.js +2 -2
  18. package/dist/checks/matMulGelu.js +2 -2
  19. package/dist/checks/normRMS.js +6 -6
  20. package/dist/checks/normRMSGrad.js +3 -3
  21. package/dist/checks/packUnpack.js +6 -6
  22. package/dist/checks/qkv.js +2 -2
  23. package/dist/checks/rope.js +2 -2
  24. package/dist/{clip_by_value-DDA7rrcT.js → clip_by_value-CLAD4h_I.js} +1 -1
  25. package/dist/complex-3DpPEG9B.js +11 -0
  26. package/dist/{concat-CAQpCret.js → concat-Dqk7Xk7h.js} +5 -5
  27. package/dist/{concat_util-D18dJ4fD.js → concat_util-C1Mxe27t.js} +1 -1
  28. package/dist/{dataset-CGGp1z9P.js → dataset-DlqAN81i.js} +3 -3
  29. package/dist/{dropout_util--NxWuYg2.js → dropout_util-N0z8Os-K.js} +1 -1
  30. package/dist/{expand_dims-Bkd1YD5x.js → expand_dims-D0rBtgT1.js} +4 -4
  31. package/dist/{exports_initializers-CYzKLjN7.js → exports_initializers-DIOZQt_L.js} +1 -1
  32. package/dist/{floor-BQtb-Azg.js → floor-CymuCmTO.js} +1 -1
  33. package/dist/{gather-qIqEqaGn.js → gather-DEyjXNb1.js} +1 -1
  34. package/dist/{gelu-B220X1Go.js → gelu-DpTCC3eB.js} +1 -1
  35. package/dist/{gpgpu_math-BwvV12df.js → gpgpu_math-3bCb5ooU.js} +25 -25
  36. package/dist/{index-CjOWnMXP.js → index-BQvB7LCC.js} +15 -15
  37. package/dist/{index-CUXkjxiT.js → index-DSGwv2Yx.js} +33 -33
  38. package/dist/inference/types.d.ts +16 -0
  39. package/dist/inference/types.js +1 -0
  40. package/dist/{kernel_funcs_utils-pq0CK9co.js → kernel_funcs_utils-DGqzNlHT.js} +6 -6
  41. package/dist/layers/BaseLayer.js +4 -4
  42. package/dist/layers/CausalSelfAttention.js +6 -6
  43. package/dist/layers/LoRA.js +4 -4
  44. package/dist/layers/MLP.js +4 -4
  45. package/dist/layers/PositionEmbedding.js +5 -5
  46. package/dist/layers/RMSNorm.js +3 -3
  47. package/dist/layers/RoPECache.js +4 -4
  48. package/dist/layers/TiedEmbedding.js +6 -6
  49. package/dist/layers/TransformerBlock.js +1 -1
  50. package/dist/layers/WeightStore.js +2 -2
  51. package/dist/loader/load.d.ts +2 -8
  52. package/dist/loader/loadTransformers.d.ts +2 -8
  53. package/dist/loader/loadTransformers.js +13 -11
  54. package/dist/loader/newZipLoad.d.ts +2 -8
  55. package/dist/loader/newZipLoad.js +25 -10
  56. package/dist/loader/oldZipLoad.js +13 -13
  57. package/dist/loader/save.d.ts +9 -2
  58. package/dist/loader/save.js +64 -55
  59. package/dist/loader/types.d.ts +29 -1
  60. package/dist/main.d.ts +2 -0
  61. package/dist/main.js +45 -43
  62. package/dist/{matMul16-BcVC_E62.js → matMul16-BIT70Vya.js} +3 -3
  63. package/dist/{matMulGelu-JNLZqKQp.js → matMulGelu-CsZnh18H.js} +18 -18
  64. package/dist/mat_mul-DP86qZtZ.js +11 -0
  65. package/dist/mod-BXjLYwvM.js +11 -0
  66. package/dist/models/NanoGPTV1.js +2 -2
  67. package/dist/models/NanoGPTV2.js +2 -2
  68. package/dist/models/model.d.ts +3 -2
  69. package/dist/models/model.js +13 -13
  70. package/dist/{not_equal-hurPF26l.js → not_equal-CkQKkKZy.js} +15 -15
  71. package/dist/{ones-BytntneX.js → ones-DbVB5N58.js} +3 -3
  72. package/dist/ops/adamAdjust.js +3 -3
  73. package/dist/ops/adamMoments.js +3 -3
  74. package/dist/ops/add16.js +1 -1
  75. package/dist/ops/appendCache.js +6 -6
  76. package/dist/ops/attentionMask.js +3 -3
  77. package/dist/ops/concat16.js +3 -3
  78. package/dist/ops/cpu/adamAdjust.js +9 -9
  79. package/dist/ops/cpu/adamMoments.js +5 -5
  80. package/dist/ops/cpu/appendCache.js +2 -2
  81. package/dist/ops/cpu/attentionMask.js +6 -6
  82. package/dist/ops/cpu/fusedSoftmax.js +4 -4
  83. package/dist/ops/cpu/gatherSub.js +5 -5
  84. package/dist/ops/cpu/gelu.js +4 -4
  85. package/dist/ops/cpu/matMul16.js +2 -2
  86. package/dist/ops/cpu/matMulGelu.js +7 -7
  87. package/dist/ops/cpu/matMulMul.js +2 -2
  88. package/dist/ops/cpu/mulDropout.js +5 -5
  89. package/dist/ops/cpu/normRMS.js +1 -1
  90. package/dist/ops/cpu/qkv.js +3 -3
  91. package/dist/ops/cpu/rope.js +5 -5
  92. package/dist/ops/cpu/scatterSub.js +5 -5
  93. package/dist/ops/dot16.js +2 -2
  94. package/dist/ops/dropout.js +6 -6
  95. package/dist/ops/dropout16.js +1 -1
  96. package/dist/ops/gatherSub.js +1 -1
  97. package/dist/ops/gelu.js +2 -2
  98. package/dist/ops/globalNorm.js +7 -7
  99. package/dist/ops/grads/add16.js +1 -1
  100. package/dist/ops/grads/attentionMask.js +2 -2
  101. package/dist/ops/grads/dropout16.js +1 -1
  102. package/dist/ops/grads/gelu.js +2 -2
  103. package/dist/ops/grads/matMul16.js +3 -3
  104. package/dist/ops/grads/matMulGelu.js +1 -1
  105. package/dist/ops/grads/mul16.js +1 -1
  106. package/dist/ops/grads/normRMS.js +7 -7
  107. package/dist/ops/grads/pack16.js +3 -3
  108. package/dist/ops/grads/qkv.js +11 -11
  109. package/dist/ops/grads/rope.js +2 -2
  110. package/dist/ops/grads/softmax16.js +1 -1
  111. package/dist/ops/grads/unpack16.js +2 -2
  112. package/dist/ops/matMul16.js +3 -3
  113. package/dist/ops/matMulGelu.js +6 -6
  114. package/dist/ops/matMulMul.js +3 -3
  115. package/dist/ops/mul16.js +1 -1
  116. package/dist/ops/mulDrop.js +3 -3
  117. package/dist/ops/normRMS.js +4 -4
  118. package/dist/ops/pack16.js +2 -2
  119. package/dist/ops/qkv.js +3 -3
  120. package/dist/ops/reshape16.js +6 -6
  121. package/dist/ops/rope.js +2 -2
  122. package/dist/ops/scatterSub.js +1 -1
  123. package/dist/ops/slice16.js +2 -2
  124. package/dist/ops/softmax16.js +1 -1
  125. package/dist/ops/sub16.js +1 -1
  126. package/dist/ops/sum16.js +6 -6
  127. package/dist/ops/transpose16.js +3 -3
  128. package/dist/ops/unpack16.js +2 -2
  129. package/dist/ops/webgl/adamAdjust.js +2 -2
  130. package/dist/ops/webgl/adamMoments.js +1 -1
  131. package/dist/ops/webgl/appendCache.js +1 -1
  132. package/dist/ops/webgl/attentionMask.js +1 -1
  133. package/dist/ops/webgl/dropout16.js +1 -1
  134. package/dist/ops/webgl/fusedSoftmax.js +7 -7
  135. package/dist/ops/webgl/gatherSub.js +3 -3
  136. package/dist/ops/webgl/gelu.js +2 -2
  137. package/dist/ops/webgl/log.js +3 -3
  138. package/dist/ops/webgl/matMul16.js +13 -13
  139. package/dist/ops/webgl/matMulGelu.js +4 -4
  140. package/dist/ops/webgl/matMulMul.js +2 -2
  141. package/dist/ops/webgl/mulDropout.js +1 -1
  142. package/dist/ops/webgl/normRMS.js +2 -2
  143. package/dist/ops/webgl/qkv.js +1 -1
  144. package/dist/ops/webgl/rope.js +1 -1
  145. package/dist/ops/webgl/scatterSub.js +2 -2
  146. package/dist/ops/webgpu/adamAdjust.js +3 -3
  147. package/dist/ops/webgpu/adamMoments.js +3 -3
  148. package/dist/ops/webgpu/add16.js +6 -6
  149. package/dist/ops/webgpu/appendCache.js +3 -3
  150. package/dist/ops/webgpu/attentionMask.js +2 -2
  151. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  152. package/dist/ops/webgpu/clipScale.js +7 -7
  153. package/dist/ops/webgpu/concat16.js +5 -5
  154. package/dist/ops/webgpu/dropout16.js +6 -6
  155. package/dist/ops/webgpu/gatherSub.js +3 -3
  156. package/dist/ops/webgpu/gelu.js +8 -8
  157. package/dist/ops/webgpu/matMul16.js +16 -16
  158. package/dist/ops/webgpu/matMul16_program.js +2 -2
  159. package/dist/ops/webgpu/mul16.js +5 -5
  160. package/dist/ops/webgpu/norm2.js +1 -1
  161. package/dist/ops/webgpu/normRMS.js +2 -2
  162. package/dist/ops/webgpu/normRMSGrad.js +4 -4
  163. package/dist/ops/webgpu/pack16.js +4 -4
  164. package/dist/ops/webgpu/pack16_program.js +2 -2
  165. package/dist/ops/webgpu/qkv.js +2 -2
  166. package/dist/ops/webgpu/rope.js +3 -3
  167. package/dist/ops/webgpu/scatterSub.js +3 -3
  168. package/dist/ops/webgpu/slice16.js +4 -4
  169. package/dist/ops/webgpu/softmax16.js +4 -4
  170. package/dist/ops/webgpu/softmax16_program.js +2 -2
  171. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  172. package/dist/ops/webgpu/softmax16grad.js +4 -4
  173. package/dist/ops/webgpu/sub16.js +6 -6
  174. package/dist/ops/webgpu/sum16.js +3 -3
  175. package/dist/ops/webgpu/transpose16.js +8 -8
  176. package/dist/ops/webgpu/transpose16_program.js +2 -2
  177. package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
  178. package/dist/ops/webgpu/unpack16.js +3 -3
  179. package/dist/ops/webgpu/utils/binary_op.js +3 -3
  180. package/dist/ops/webgpu/utils/reductions.js +5 -5
  181. package/dist/{ops-CsXeTq1P.js → ops-CURIZSVt.js} +100 -100
  182. package/dist/{pack16-bqltoUlR.js → pack16-WlOSOuZA.js} +2 -2
  183. package/dist/patches/webgpu_backend.js +6 -6
  184. package/dist/patches/webgpu_base.js +1 -1
  185. package/dist/patches/webgpu_program.js +2 -2
  186. package/dist/{random_normal-IBRrha8a.js → random_normal-CIm8lk2-.js} +1 -1
  187. package/dist/{random_width-DN5ZtQkM.js → random_width-B_fVXhGx.js} +131 -131
  188. package/dist/{range-C-CjF-LI.js → range-BDxO73mk.js} +1 -1
  189. package/dist/{readers-iz5u3HBo.js → readers-17HLdxVM.js} +2 -2
  190. package/dist/relu-DTvZKBsZ.js +9 -0
  191. package/dist/{reshape-BDOuCSNW.js → reshape-BIN71H3p.js} +1 -1
  192. package/dist/{resize_nearest_neighbor-BojqlfRe.js → resize_nearest_neighbor-C6_0dAnK.js} +41 -41
  193. package/dist/{rope-0j_f1TPm.js → rope-CC5RjmKU.js} +4 -4
  194. package/dist/{scatter_nd_util-ByNJaL6I.js → scatter_nd_util-C-x73Cj6.js} +1 -1
  195. package/dist/{segment_util-Dasb2Zaf.js → segment_util-4zuHV5IG.js} +2 -2
  196. package/dist/{selu_util-BLhIqRkw.js → selu_util-BXdhy_W6.js} +5 -5
  197. package/dist/{shared-CagdqkLh.js → shared-DRWDyk9w.js} +6 -6
  198. package/dist/{shared-3agzAqQ_.js → shared-zTaJ5siv.js} +1 -1
  199. package/dist/slice-BvItlgXu.js +12 -0
  200. package/dist/{slice_util-CC35pLmT.js → slice_util-DPY56GzQ.js} +5 -5
  201. package/dist/{softmax-D4q1LJN7.js → softmax-BLGJqdwx.js} +1 -1
  202. package/dist/split-BN9LkEgS.js +9 -0
  203. package/dist/{squeeze-ho4wLUek.js → squeeze-O_YWJpw_.js} +2 -2
  204. package/dist/{stack-DudVrtmG.js → stack-z6QE7kmP.js} +1 -1
  205. package/dist/{step-BTxPtq1r.js → step-DQY6_ABw.js} +4 -4
  206. package/dist/{sum-BpiwSWvg.js → sum-D39FeU5h.js} +3 -3
  207. package/dist/{tensor-BWFldCso.js → tensor-D8e0Gd7c.js} +1 -1
  208. package/dist/{tensor1d-LMGMIUlr.js → tensor1d-BMl0eZYV.js} +1 -1
  209. package/dist/{tensor2d-BnXMKScO.js → tensor2d-DTtQ1QcT.js} +1 -1
  210. package/dist/{tensor4d-C6UCG_u8.js → tensor4d-Dj4rDssL.js} +1 -1
  211. package/dist/{tfjs_backend-BGnG-ppu.js → tfjs_backend-Bk3PmK91.js} +65 -65
  212. package/dist/{tile-CFy-xTO6.js → tile-CsWlVKKz.js} +1 -1
  213. package/dist/tokeniser/BaseTokeniser.d.ts +4 -1
  214. package/dist/tokeniser/BaseTokeniser.js +21 -5
  215. package/dist/tokeniser/CharTokeniser.d.ts +1 -1
  216. package/dist/tokeniser/CharTokeniser.js +62 -50
  217. package/dist/tokeniser/bpe.d.ts +1 -1
  218. package/dist/tokeniser/bpe.js +41 -35
  219. package/dist/tokeniser/type.d.ts +3 -1
  220. package/dist/training/AdamW.d.ts +3 -0
  221. package/dist/training/AdamW.js +59 -30
  222. package/dist/training/BasicTrainer.d.ts +1 -0
  223. package/dist/training/BasicTrainer.js +112 -92
  224. package/dist/training/DatasetBuilder.js +3 -3
  225. package/dist/training/Evaluator.js +2 -2
  226. package/dist/training/LRScheduler.d.ts +1 -0
  227. package/dist/training/LRScheduler.js +18 -12
  228. package/dist/training/PreTrainer.js +3 -3
  229. package/dist/training/SFTDatasetBuilder.js +3 -3
  230. package/dist/training/SFTTrainer.js +1 -1
  231. package/dist/training/orthoGrad.js +1 -1
  232. package/dist/training/sparseCrossEntropy.js +30 -30
  233. package/dist/training/types.d.ts +5 -3
  234. package/dist/training/validation.js +13 -13
  235. package/dist/{transpose-9kRxIXWR.js → transpose-Qxz-4os3.js} +7 -7
  236. package/dist/{unsorted_segment_sum-DJvk5xnh.js → unsorted_segment_sum-BfFVV9Zm.js} +20 -20
  237. package/dist/utilities/datasetID.d.ts +2 -0
  238. package/dist/utilities/datasetID.js +21 -0
  239. package/dist/utilities/dummy.js +6 -6
  240. package/dist/utilities/multinomialCPU.js +2 -2
  241. package/dist/utilities/packed.js +1 -1
  242. package/dist/utilities/performance.js +1 -1
  243. package/dist/utilities/profile.js +1 -1
  244. package/dist/utilities/safetensors.js +2 -2
  245. package/dist/utilities/sentences.js +5 -5
  246. package/dist/utilities/weights.js +2 -2
  247. package/dist/{variable-Ck482e3n.js → variable-SSATClyt.js} +1 -1
  248. package/dist/{webgpu_program-B4HmApL1.js → webgpu_program-CbjdYLYk.js} +1 -1
  249. package/dist/{webgpu_util-DYlGSwOJ.js → webgpu_util-DuofJBMo.js} +7 -7
  250. package/dist/{zeros-DvZpK8s6.js → zeros-Bw0puq_w.js} +2 -2
  251. package/dist/{zeros_like-CWjDdwr-.js → zeros_like-rOHr54NY.js} +69 -69
  252. package/package.json +3 -3
  253. package/dist/complex-DI35Q-gW.js +0 -11
  254. package/dist/mat_mul-DhG0Newp.js +0 -11
  255. package/dist/mod-CSdCpRjf.js +0 -11
  256. package/dist/relu-J_X6MUzx.js +0 -9
  257. package/dist/slice-BzS11Qh0.js +0 -12
  258. package/dist/split-C2Sj255c.js +0 -9
package/dist/Trainer.js CHANGED
@@ -1,35 +1,36 @@
1
- import { E as f } from "./index-DvYrXKkX.js";
2
- import l from "./training/PreTrainer.js";
3
- import { createTrainValidationSplit as p } from "./training/validation.js";
4
- import g from "./training/SFTTrainer.js";
5
- import m from "./training/tasks/splitter.js";
6
- const n = [];
7
- for (let a = 0; a < 256; ++a)
8
- n.push((a + 256).toString(16).slice(1));
9
- function u(a, t = 0) {
10
- return (n[a[t + 0]] + n[a[t + 1]] + n[a[t + 2]] + n[a[t + 3]] + "-" + n[a[t + 4]] + n[a[t + 5]] + "-" + n[a[t + 6]] + n[a[t + 7]] + "-" + n[a[t + 8]] + n[a[t + 9]] + "-" + n[a[t + 10]] + n[a[t + 11]] + n[a[t + 12]] + n[a[t + 13]] + n[a[t + 14]] + n[a[t + 15]]).toLowerCase();
1
+ import { E as m } from "./index-DvYrXKkX.js";
2
+ import g from "./training/PreTrainer.js";
3
+ import { createTrainValidationSplit as u } from "./training/validation.js";
4
+ import c from "./training/SFTTrainer.js";
5
+ import p from "./training/tasks/splitter.js";
6
+ const r = [];
7
+ for (let n = 0; n < 256; ++n)
8
+ r.push((n + 256).toString(16).slice(1));
9
+ function w(n, t = 0) {
10
+ return (r[n[t + 0]] + r[n[t + 1]] + r[n[t + 2]] + r[n[t + 3]] + "-" + r[n[t + 4]] + r[n[t + 5]] + "-" + r[n[t + 6]] + r[n[t + 7]] + "-" + r[n[t + 8]] + r[n[t + 9]] + "-" + r[n[t + 10]] + r[n[t + 11]] + r[n[t + 12]] + r[n[t + 13]] + r[n[t + 14]] + r[n[t + 15]]).toLowerCase();
11
11
  }
12
- const w = new Uint8Array(16);
13
- function S() {
14
- return crypto.getRandomValues(w);
12
+ const T = new Uint8Array(16);
13
+ function D() {
14
+ return crypto.getRandomValues(T);
15
15
  }
16
- function c(a, t, i) {
17
- return crypto.randomUUID ? crypto.randomUUID() : D(a);
16
+ function d(n, t, a) {
17
+ return crypto.randomUUID ? crypto.randomUUID() : k(n);
18
18
  }
19
- function D(a, t, i) {
20
- a = a || {};
21
- const e = a.random ?? a.rng?.() ?? S();
22
- if (e.length < 16)
19
+ function k(n, t, a) {
20
+ n = n || {};
21
+ const i = n.random ?? n.rng?.() ?? D();
22
+ if (i.length < 16)
23
23
  throw new Error("Random bytes length must be >= 16");
24
- return e[6] = e[6] & 15 | 64, e[8] = e[8] & 63 | 128, u(e);
24
+ return i[6] = i[6] & 15 | 64, i[8] = i[8] & 63 | 128, w(i);
25
25
  }
26
- class d extends f {
26
+ class f extends m {
27
27
  trainer;
28
28
  trainingType = "pretraining";
29
29
  hasTrained = !1;
30
30
  trainDataset;
31
31
  validationDataset;
32
- totalSamples = 0;
32
+ totalTokens = 0;
33
+ tokensProcessed = 0;
33
34
  log = [];
34
35
  progress = null;
35
36
  options = {
@@ -38,21 +39,21 @@ class d extends f {
38
39
  logInterval: 10
39
40
  };
40
41
  tokenizer;
41
- constructor(t, i, e = "pretraining", r) {
42
- if (super(), t instanceof d) {
43
- const s = i || t.options, h = t.options;
44
- let o = !1;
45
- t.trainingType === "sft" && s.sftMode !== h.sftMode && (o = !0), e !== t.trainingType && (o = !0), o ? (t.trainingType === "sft" ? this.trainer = new g(t.model, t.tokenizer, s) : this.trainer = new l(t.model, t.tokenizer, s), this.trainingType = e, this.options = s, this.tokenizer = t.tokenizer) : (this.trainer = t.trainer, this.trainingType = e, this.options = s, this.trainer.updateOptimizer(this.options), this.log = t.log, this.progress = t.progress, this.totalSamples = t.totalSamples, this.tokenizer = t.tokenizer, s.batchSize === h.batchSize && (this.trainDataset = t.trainDataset, this.validationDataset = t.validationDataset));
42
+ constructor(t, a, i = "pretraining", e, s) {
43
+ if (super(), t instanceof f) {
44
+ const o = a || t.options, h = t.options;
45
+ let l = !1;
46
+ t.trainingType === "sft" && o.sftMode !== h.sftMode && (l = !0), i !== t.trainingType && (l = !0), l ? (t.trainingType === "sft" ? this.trainer = new c(t.model, t.tokenizer, o) : this.trainer = new g(t.model, t.tokenizer, o), this.trainingType = i, this.options = o, this.tokenizer = t.tokenizer) : (this.trainer = t.trainer, this.trainingType = i, this.options = o, this.trainer.updateOptimizer(this.options), this.log = t.log, this.progress = t.progress, this.totalTokens = t.totalTokens, this.tokenizer = t.tokenizer, o.batchSize === h.batchSize && (this.trainDataset = t.trainDataset, this.validationDataset = t.validationDataset));
46
47
  return;
47
48
  }
48
- if (!i)
49
+ if (!a)
49
50
  throw new Error("Tokeniser must be provided when initializing Trainer with a model");
50
51
  if (!t)
51
52
  throw new Error("Model must be provided when initializing Trainer");
52
- this.options = r || {
53
+ this.options = e || {
53
54
  batchSize: 32,
54
55
  sftMode: "full"
55
- }, e === "sft" ? this.trainer = new g(t, i, r) : this.trainer = new l(t, i, r), this.trainingType = e, this.tokenizer = i;
56
+ }, i === "sft" ? this.trainer = new c(t, a, e, s) : this.trainer = new g(t, a, e, s), this.trainingType = i, this.tokenizer = a;
56
57
  }
57
58
  get model() {
58
59
  return this.trainer.model;
@@ -69,61 +70,66 @@ class d extends f {
69
70
  dispose() {
70
71
  this.trainer.dispose(), this.removeAllListeners();
71
72
  }
72
- getTotalSamples() {
73
- return this.totalSamples;
73
+ getTotalTokens() {
74
+ return this.totalTokens;
74
75
  }
75
76
  setOptions(t) {
76
- const i = new Set(
77
+ const a = new Set(
77
78
  Object.keys(t).filter(
78
- (e) => t[e] !== this.options[e]
79
+ (i) => t[i] !== this.options[i]
79
80
  )
80
81
  );
81
82
  if (this.trainer.isRunning) {
82
- if (i.has("batchSize"))
83
+ if (a.has("batchSize"))
83
84
  throw new Error("Cannot change batch size during training");
84
- if (i.has("sftMode"))
85
+ if (a.has("sftMode"))
85
86
  throw new Error("Cannot change SFT mode during training");
86
- if (i.has("loraConfig"))
87
+ if (a.has("loraConfig"))
87
88
  throw new Error("Cannot change LoRA configuration during training");
88
- if (i.has("validationSplit"))
89
+ if (a.has("validationSplit"))
89
90
  throw new Error("Cannot change validation split during training");
90
- if (i.has("trainableWeights"))
91
+ if (a.has("trainableWeights"))
91
92
  throw new Error("Cannot change trainable weights during training");
92
- if (i.has("mixedPrecision"))
93
+ if (a.has("mixedPrecision"))
93
94
  throw new Error("Cannot change mixed precision setting during training");
94
- if (i.has("gradientCheckpointing"))
95
+ if (a.has("gradientCheckpointing"))
95
96
  throw new Error("Cannot change gradient checkpointing setting during training");
96
97
  }
97
98
  this.options = {
98
99
  ...this.options,
99
100
  ...t
100
- }, this.trainer.updateOptimizer(this.options), i.has("metrics") && this.trainer.setMetrics(t.metrics || []);
101
+ }, this.trainer.updateOptimizer(this.options), a.has("metrics") && this.trainer.setMetrics(t.metrics || []);
101
102
  }
102
- async prepare(t = []) {
103
+ async prepare(t = [], a) {
103
104
  const i = this.options;
104
- if (this.trainingType === "pretraining" && this.trainer instanceof l) {
105
- const { trainDataset: e, validationDataset: r, size: s, trainState: h } = await p(
105
+ if (a && (this.model.metaData.pretrainingData = a.map((e) => ({
106
+ id: e.id,
107
+ name: e.name
108
+ }))), this.trainingType === "pretraining" && this.trainer instanceof g) {
109
+ const { trainDataset: e, validationDataset: s, size: o } = await u(
106
110
  t,
107
111
  this.trainer.tokenizer,
108
112
  this.trainer.datasetBuilder,
109
113
  i?.batchSize || 32,
110
114
  i?.validationSplit || 0.1
111
- ), o = s * (1 - (i?.validationSplit || 0));
112
- this.trainDataset = e, this.validationDataset = r, this.totalSamples = o, this.options.epochSteps = Math.ceil(h.shuffledIndexes.length / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
113
- } else if (this.trainingType === "sft" && this.trainer instanceof g) {
115
+ ), h = o * (1 - (i?.validationSplit || 0));
116
+ this.trainDataset = e, this.validationDataset = s, this.totalTokens = h, this.options.epochSteps = Math.ceil(
117
+ this.totalTokens / ((i?.batchSize || 32) * this.model.config.blockSize)
118
+ ), this.trainer.updateOptimizer(this.options);
119
+ } else if (this.trainingType === "sft" && this.trainer instanceof c) {
114
120
  if (t instanceof Uint16Array)
115
121
  throw new Error("SFT training requires Task[] input");
116
122
  if (i?.validationSplit && i.validationSplit > 0) {
117
- const e = m(t, i?.validationSplit), r = await this.trainer.datasetBuilder.createSFTDataset(
123
+ const e = p(t, i?.validationSplit), s = await this.trainer.datasetBuilder.createSFTDataset(
118
124
  [e.training],
119
125
  i?.batchSize || 32,
120
126
  -100
121
- ), s = await this.trainer.datasetBuilder.createSFTDataset(
127
+ ), o = await this.trainer.datasetBuilder.createSFTDataset(
122
128
  [e.validation],
123
129
  i?.batchSize || 32,
124
130
  -100
125
131
  );
126
- this.validationDataset = s, this.trainDataset = r;
132
+ this.validationDataset = o, this.trainDataset = s;
127
133
  } else {
128
134
  const e = await this.trainer.datasetBuilder.createSFTDataset(
129
135
  t,
@@ -132,45 +138,47 @@ class d extends f {
132
138
  );
133
139
  this.trainDataset = e;
134
140
  }
135
- this.totalSamples = t.reduce((e, r) => e + r.length, 0), this.options.epochSteps = Math.ceil(this.totalSamples / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
141
+ this.totalTokens = t.reduce((e, s) => e + s.length, 0), this.options.epochSteps = Math.ceil(
142
+ this.totalTokens / ((i?.batchSize || 32) * this.model.config.blockSize)
143
+ ), this.trainer.updateOptimizer(this.options);
136
144
  }
137
145
  }
138
146
  configureModel(t) {
139
- const i = t?.sftMode || "full";
147
+ const a = t?.sftMode || "full";
140
148
  if (this.trainingType === "pretraining" && (this.trainer.model.hasLoRA() && this.trainer.model.detachLoRA(), this.trainer.model.weightStore.setTrainable(["*"])), this.trainingType === "sft") {
141
- if (i === "lora") {
142
- const e = this.trainer.model;
149
+ if (a === "lora") {
150
+ const i = this.trainer.model;
143
151
  if (t?.loraName)
144
- if (e.hasLoRA(t.loraName)) {
145
- if (e.attachLoRA(t.loraName), t.loraConfig) {
146
- const r = e.lora;
147
- (r.alpha !== t.loraConfig.alpha || r.rank !== t.loraConfig.rank) && (e.detachLoRA(), e.deleteLoRA(t.loraName), e.createLoRA(t.loraName, t.loraConfig), e.attachLoRA(t.loraName), console.warn("Resetting LoRA with new configuration."));
152
+ if (i.hasLoRA(t.loraName)) {
153
+ if (i.attachLoRA(t.loraName), t.loraConfig) {
154
+ const e = i.lora;
155
+ (e.alpha !== t.loraConfig.alpha || e.rank !== t.loraConfig.rank) && (i.detachLoRA(), i.deleteLoRA(t.loraName), i.createLoRA(t.loraName, t.loraConfig), i.attachLoRA(t.loraName), console.warn("Resetting LoRA with new configuration."));
148
156
  }
149
157
  } else if (t.loraConfig)
150
- e.createLoRA(t.loraName, t.loraConfig), e.attachLoRA(t.loraName);
158
+ i.createLoRA(t.loraName, t.loraConfig), i.attachLoRA(t.loraName);
151
159
  else
152
160
  throw new Error(
153
161
  `LoRA configuration must be provided to create LoRA with name ${t.loraName}`
154
162
  );
155
163
  else if (t?.loraConfig)
156
- if (e.hasLoRA()) {
157
- const r = e.lora;
158
- if (r.alpha !== t.loraConfig.alpha || r.rank !== t.loraConfig.rank) {
159
- e.detachLoRA();
160
- const s = t.loraName || c();
161
- e.createLoRA(s, t.loraConfig), e.attachLoRA(s);
164
+ if (i.hasLoRA()) {
165
+ const e = i.lora;
166
+ if (e.alpha !== t.loraConfig.alpha || e.rank !== t.loraConfig.rank) {
167
+ i.detachLoRA();
168
+ const s = t.loraName || d();
169
+ i.createLoRA(s, t.loraConfig), i.attachLoRA(s);
162
170
  }
163
171
  } else {
164
- const r = t.loraName || c();
165
- e.createLoRA(r, t.loraConfig), e.attachLoRA(r);
172
+ const e = t.loraName || d();
173
+ i.createLoRA(e, t.loraConfig), i.attachLoRA(e);
166
174
  }
167
- else if (!e.hasLoRA()) throw new Error("LoRA configuration must be provided for lora SFT mode");
175
+ else if (!i.hasLoRA()) throw new Error("LoRA configuration must be provided for lora SFT mode");
168
176
  } else
169
177
  this.trainer.model.hasLoRA() && this.trainer.model.detachLoRA();
170
- i === "last-layer" ? this.trainer.model.weightStore.setTrainable([
178
+ a === "last-layer" ? this.trainer.model.weightStore.setTrainable([
171
179
  `block_${this.trainer.model.config.nLayer - 1}_*`,
172
180
  "token_embedding"
173
- ]) : i === "full" && this.trainer.model.weightStore.setTrainable(["*"]);
181
+ ]) : a === "full" && this.trainer.model.weightStore.setTrainable(["*"]);
174
182
  }
175
183
  t?.trainableWeights && this.trainer.model.weightStore.setTrainable(t.trainableWeights);
176
184
  }
@@ -178,37 +186,44 @@ class d extends f {
178
186
  const t = this.options;
179
187
  if (!this.trainDataset)
180
188
  throw new Error("Dataset not prepared");
181
- this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), this.trainer.setMixedPrecision(t?.mixedPrecision || !1), this.trainer.setLabelSmoothing(t?.labelSmoothing || 0), this.trainer.setDropout(t?.dropout || 0), this.trainer.setLayerDrop(t?.layerDrop || 0), this.configureModel(t), await this.trainer.trainOnDataset(
189
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.model.metaData.pretrainingSettings = t;
190
+ const a = Date.now();
191
+ this.log.length > 0 && this.trainer.resumeFromLog(this.log[this.log.length - 1]), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), this.trainer.setMixedPrecision(t?.mixedPrecision || !1), this.trainer.setLabelSmoothing(t?.labelSmoothing || 0), this.trainer.setDropout(t?.dropout || 0), this.trainer.setLayerDrop(t?.layerDrop || 0), this.configureModel(t), await this.trainer.trainOnDataset(
182
192
  this.trainDataset,
183
193
  {
184
194
  ...t,
185
- onStep: async (i) => {
186
- this.log.push(i), this.progress = {
187
- lastLog: i,
188
- progress: i.totalSamples / this.totalSamples,
189
- remaining: Math.max(
190
- 0,
191
- (this.totalSamples - i.totalSamples) / i.totalSamples * i.duration
192
- )
193
- };
194
- const e = this.listeners("log");
195
- for (const r of e)
196
- await r(i, this.progress);
195
+ onStep: async (e) => {
196
+ this.log.push(e), this.progress = {
197
+ lastLog: e,
198
+ progress: e.totalTokens / this.totalTokens,
199
+ remaining: Math.max(0, (this.totalTokens - e.totalTokens) / e.totalTokens * e.duration)
200
+ }, this.tokensProcessed = e.totalTokens;
201
+ const s = this.listeners("log");
202
+ for (const o of s)
203
+ await o(e, this.progress);
197
204
  }
198
205
  },
199
206
  this.validationDataset
200
- ), this.emit("stop");
207
+ ), this.model.metaData.actionLog = this.model.metaData.actionLog || [];
208
+ const i = Date.now();
209
+ this.model.metaData.actionLog.push({
210
+ action: "pretrain",
211
+ timestamp: i,
212
+ duration: i - a,
213
+ tokensProcessed: this.tokensProcessed,
214
+ options: t
215
+ }), this.emit("stop");
201
216
  }
202
217
  async step(t) {
203
218
  if (!this.trainDataset)
204
219
  throw new Error("Dataset not prepared");
205
220
  this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
206
- const { log: i } = await this.trainer.stepDataset(this.trainDataset, t || {}, this.validationDataset), e = this.listeners("log");
207
- for (const r of e)
208
- await r(i, {
209
- lastLog: i,
210
- progress: i.totalSamples / this.totalSamples,
211
- remaining: Math.max(0, (this.totalSamples - i.totalSamples) / i.totalSamples * i.duration)
221
+ const { log: a } = await this.trainer.stepDataset(this.trainDataset, t || {}, this.validationDataset), i = this.listeners("log");
222
+ for (const e of i)
223
+ await e(a, {
224
+ lastLog: a,
225
+ progress: a.totalTokens / this.totalTokens,
226
+ remaining: Math.max(0, (this.totalTokens - a.totalTokens) / a.totalTokens * a.duration)
212
227
  });
213
228
  this.emit("stop");
214
229
  }
@@ -223,5 +238,5 @@ class d extends f {
223
238
  }
224
239
  }
225
240
  export {
226
- d as default
241
+ f as default
227
242
  };
@@ -1,4 +1,4 @@
1
- import { x as c } from "./index-CUXkjxiT.js";
1
+ import { v as c } from "./index-DSGwv2Yx.js";
2
2
  function i(e, n) {
3
3
  for (let t = 0; t < e.length; ++t)
4
4
  if (e[e.length - t - 1] !== n - 1 - t)
package/dist/backend.js CHANGED
@@ -1,9 +1,9 @@
1
- import { g as o, s as e, r as s } from "./index-CUXkjxiT.js";
1
+ import { g as o, s as e, r as s } from "./index-DSGwv2Yx.js";
2
2
  async function c(t, a) {
3
3
  if (o() !== t) {
4
4
  if (t === "webgpu") {
5
5
  const { registerWebGPUBackend: i } = await import("./patches/webgpu_base.js");
6
- i(a), await import("./index-CjOWnMXP.js"), await import("./ops/webgpu/index.js");
6
+ i(a), await import("./index-BQvB7LCC.js"), await import("./ops/webgpu/index.js");
7
7
  }
8
8
  await e(t), await s(), console.log(`Backend set to ${t}`);
9
9
  }
@@ -1,9 +1,9 @@
1
- import { U as d, a9 as A, a8 as O, x as g, av as _, az as w, ad as D, _ as x, $ as b, am as y, aY as M } from "./index-CUXkjxiT.js";
2
- import { d as T, f as L, h as W, c as v, e as F, a as N, b as C, g as P } from "./axis_util-GTVlo58H.js";
3
- import { a as z, c as U } from "./concat_util-D18dJ4fD.js";
4
- import { c as B, b as H, d as V, f as G, g as Z, h as j, i as q, j as J, k as K, m as X, t as Y } from "./step-BTxPtq1r.js";
5
- import { S as k, a as Q, b as ee, g as te, c as se, s as ne } from "./selu_util-BLhIqRkw.js";
6
- import { c as re, v as oe, a as ie } from "./scatter_nd_util-ByNJaL6I.js";
1
+ import { N as d, a9 as A, a8 as O, v as g, av as _, az as w, ad as D, _ as x, $ as b, am as y, aY as M } from "./index-DSGwv2Yx.js";
2
+ import { d as T, f as L, h as v, c as W, e as F, a as N, b as C, g as P } from "./axis_util-QWWgLjut.js";
3
+ import { a as z, c as B } from "./concat_util-C1Mxe27t.js";
4
+ import { c as U, b as H, d as V, f as G, g as Z, h as j, i as q, j as J, k as K, m as X, t as Y } from "./step-DQY6_ABw.js";
5
+ import { S as k, a as Q, b as ee, g as te, c as se, s as ne } from "./selu_util-BXdhy_W6.js";
6
+ import { c as re, v as oe, a as ie } from "./scatter_nd_util-C-x73Cj6.js";
7
7
  import { a as ae, c as ue, b as ce, e as pe, d as le, g as fe, m as he, s as ge } from "./complex_util-Yc1A_gV1.js";
8
8
  function de(e, t) {
9
9
  const r = e.shape.length, n = t.shape.length;
@@ -146,10 +146,10 @@ function De(e, t, r) {
146
146
  return n;
147
147
  }
148
148
  const xe = 0.3275911, be = 0.254829592, ye = -0.284496736, Me = 1.421413741, Te = -1.453152027, Le = 1.061405429;
149
- const I = "->", We = /->/g, E = ",", $ = "...";
150
- function ve(e, t) {
149
+ const I = "->", ve = /->/g, E = ",", $ = "...";
150
+ function We(e, t) {
151
151
  e = e.replace(/\s/g, "");
152
- const r = (e.length - e.replace(We, "").length) / I.length;
152
+ const r = (e.length - e.replace(ve, "").length) / I.length;
153
153
  if (r < 1)
154
154
  throw new Error("Equations without an arrow are not supported.");
155
155
  if (r > 1)
@@ -226,7 +226,7 @@ function ze(e, t) {
226
226
  (e[n].length === 0 || e[n].indexOf(t) !== -1 || t === -1) && r.push(n);
227
227
  return r;
228
228
  }
229
- function Ue(e, t, r = 0) {
229
+ function Be(e, t, r = 0) {
230
230
  let n = [];
231
231
  if (typeof t == "number")
232
232
  g(e.shape[r] % t === 0, () => "Number of splits must evenly divide the axis."), n = new Array(t).fill(e.shape[r] / t);
@@ -242,7 +242,7 @@ function Ue(e, t, r = 0) {
242
242
  }
243
243
  return n;
244
244
  }
245
- function Be(e) {
245
+ function Ue(e) {
246
246
  return `Received SparseTensor with denseShape[0] = 0 but
247
247
  indices.shape[0] = ${e}`;
248
248
  }
@@ -314,8 +314,8 @@ const ut = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
314
314
  axesAreInnerMostDims: L,
315
315
  calculateShapes: re,
316
316
  checkEinsumDimSizes: Ne,
317
- checkPadOnDimRoundingMode: B,
318
- combineLocations: W,
317
+ checkPadOnDimRoundingMode: U,
318
+ combineLocations: v,
319
319
  combineRaggedTensorToTensorShapes: me,
320
320
  complexWithEvenIndex: ue,
321
321
  complexWithOddIndex: ce,
@@ -324,12 +324,12 @@ const ut = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
324
324
  computeDefaultPad: G,
325
325
  computeDilation2DInfo: Z,
326
326
  computeOptimalWindowSize: Se,
327
- computeOutAndReduceShapes: v,
328
- computeOutShape: U,
327
+ computeOutAndReduceShapes: W,
328
+ computeOutShape: B,
329
329
  computePool2DInfo: j,
330
330
  computePool3DInfo: q,
331
331
  convertConv2DDataFormat: J,
332
- decodeEinsumEquation: ve,
332
+ decodeEinsumEquation: We,
333
333
  eitherStridesOrDilationsAreOne: K,
334
334
  expandShapeToKeepDim: F,
335
335
  exponent: pe,
@@ -353,7 +353,7 @@ const ut = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
353
353
  getRowPartitionTypesHelper: Ie,
354
354
  getSliceBeginCoords: we,
355
355
  getSliceSize: De,
356
- getSparseFillEmptyRowsIndicesDenseShapeMismatch: Be,
356
+ getSparseFillEmptyRowsIndicesDenseShapeMismatch: Ue,
357
357
  getSparseFillEmptyRowsNegativeIndexErrorMessage: He,
358
358
  getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: Ve,
359
359
  getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: je,
@@ -369,7 +369,7 @@ const ut = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
369
369
  isIdentityPermutation: Pe,
370
370
  mergeRealAndImagArrays: he,
371
371
  prepareAndValidate: de,
372
- prepareSplitSize: Ue,
372
+ prepareSplitSize: Be,
373
373
  shouldFuse: ne,
374
374
  splitRealAndImagArrays: ge,
375
375
  stridesOrDilationsArePositive: X,
@@ -386,14 +386,14 @@ export {
386
386
  we as C,
387
387
  De as D,
388
388
  xe as E,
389
- ve as F,
389
+ We as F,
390
390
  Ne as G,
391
391
  Ce as H,
392
392
  Fe as I,
393
393
  Pe as J,
394
394
  de as K,
395
395
  Re as L,
396
- Ue as M,
396
+ Be as M,
397
397
  S as P,
398
398
  f as R,
399
399
  Ee as a,
@@ -403,7 +403,7 @@ export {
403
403
  et as e,
404
404
  Qe as f,
405
405
  Ie as g,
406
- Be as h,
406
+ Ue as h,
407
407
  He as i,
408
408
  Ve as j,
409
409
  Ge as k,
@@ -1,7 +1,7 @@
1
- import { ab as g, as as $, at as K, h as D, x as _, au as O, U as x, av as Z, a5 as W, aw as F, ax as j, ay as X, az as J, af as ee, a9 as k } from "./index-CUXkjxiT.js";
2
- import { m as te, f as se, P as re } from "./webgpu_program-B4HmApL1.js";
3
- import { i as ne, G as q } from "./webgpu_util-DYlGSwOJ.js";
4
- import { m as N } from "./complex_util-Yc1A_gV1.js";
1
+ import { ab as g, as as $, at as K, e as D, v as _, au as O, N as x, av as Z, a5 as W, aw as F, ax as j, ay as X, az as J, af as ee, a9 as k } from "./index-DSGwv2Yx.js";
2
+ import { m as te, f as se, P as re } from "./webgpu_program-CbjdYLYk.js";
3
+ import { i as ne, G as N } from "./webgpu_util-DuofJBMo.js";
4
+ import { m as q } from "./complex_util-Yc1A_gV1.js";
5
5
  const d = g();
6
6
  d.registerFlag("WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE", () => 15);
7
7
  d.registerFlag("WEBGPU_CPU_FORWARD", () => !0);
@@ -248,7 +248,7 @@ class R extends $ {
248
248
  if (s != null || t.dtype === "string")
249
249
  return s;
250
250
  if (t.dtype === "complex64") {
251
- const E = this.readSync(n.real.dataId), B = this.readSync(n.imag.dataId), y = O(N(E, B).buffer, "float32");
251
+ const E = this.readSync(n.real.dataId), B = this.readSync(n.imag.dataId), y = O(q(E, B).buffer, "float32");
252
252
  return this.convertAndCacheOnCPU(e, y), y;
253
253
  }
254
254
  this.hasReadSyncWarned || (this.hasReadSyncWarned = !0, console.warn("The performance of synchronously reading data from GPU to CPU is poor on the webgpu backend, please use asynchronous APIs instead."));
@@ -309,7 +309,7 @@ class R extends $ {
309
309
  this.read(t.complexTensorInfos.real.dataId),
310
310
  this.read(t.complexTensorInfos.imag.dataId)
311
311
  ]), a = r[0], i = r[1];
312
- n = N(a, i);
312
+ n = q(a, i);
313
313
  } else {
314
314
  const r = await this.getBufferData(t.resource);
315
315
  n = O(r, t.dtype);
@@ -337,7 +337,7 @@ class R extends $ {
337
337
  refCount: 1,
338
338
  external: e.zeroCopy
339
339
  });
340
- const a = this.tensorMap.get(r), i = q(a.dtype) * x(a.shape);
340
+ const a = this.tensorMap.get(r), i = N(a.dtype) * x(a.shape);
341
341
  if (e.buffer.size < i)
342
342
  throw new Error(`GPUBuffer size(${e.buffer.size}) is smaller than tensor size(${i})!`);
343
343
  if ((e.buffer.usage & (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC)) !== (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC))
@@ -398,7 +398,7 @@ class R extends $ {
398
398
  const t = this.tensorMap.get(e);
399
399
  if (t.resource != null)
400
400
  return;
401
- const s = q(t.dtype) * x(t.shape);
401
+ const s = N(t.dtype) * x(t.shape);
402
402
  let n;
403
403
  const r = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;
404
404
  if (t.values) {
@@ -1,5 +1,5 @@
1
- import { o as h, q as f, u as p, w as g, E as u, T } from "./index-CUXkjxiT.js";
2
- import { r as b } from "./reshape-BDOuCSNW.js";
1
+ import { o as h, n as f, q as p, u as g, E as u, T } from "./index-DSGwv2Yx.js";
2
+ import { r as b } from "./reshape-BIN71H3p.js";
3
3
  function m(e, r) {
4
4
  let n = f(e, "broadcastTo", "x");
5
5
  const a = n.shape;
@@ -1,5 +1,5 @@
1
- import { s, h as a } from "../index-CUXkjxiT.js";
2
- import { t } from "../tensor4d-C6UCG_u8.js";
1
+ import { s, e as a } from "../index-DSGwv2Yx.js";
2
+ import { t } from "../tensor4d-Dj4rDssL.js";
3
3
  async function u(e) {
4
4
  await s(e);
5
5
  const n = t(
@@ -1,6 +1,6 @@
1
- import { s as i, h as a } from "../index-CUXkjxiT.js";
2
- import { t } from "../tensor4d-C6UCG_u8.js";
3
- import { t as e } from "../tensor2d-BnXMKScO.js";
1
+ import { s as i, e } from "../index-DSGwv2Yx.js";
2
+ import { t } from "../tensor4d-Dj4rDssL.js";
3
+ import { t as a } from "../tensor2d-DTtQ1QcT.js";
4
4
  async function k(n) {
5
5
  await i(n);
6
6
  const s = t(
@@ -23,14 +23,14 @@ async function k(n) {
23
23
  ]
24
24
  ],
25
25
  [1, 1, 2, 4]
26
- ), r = e(
26
+ ), r = a(
27
27
  [
28
28
  [0, -1 / 0, -1 / 0, -1 / 0],
29
29
  [0, 0, 0, -1 / 0]
30
30
  ],
31
31
  [2, 4]
32
32
  );
33
- return await a().runKernel("AttentionMask", { q: s, k: o, mask: r }, { divisor: 0.5, pastLen: 0 }).array();
33
+ return await e().runKernel("AttentionMask", { q: s, k: o, mask: r }, { divisor: 0.5, pastLen: 0 }).array();
34
34
  }
35
35
  export {
36
36
  k as execute
@@ -1,5 +1,5 @@
1
- import { s as e, h as o } from "../index-CUXkjxiT.js";
2
- import { t as s } from "../tensor2d-BnXMKScO.js";
1
+ import { s as e, e as o } from "../index-DSGwv2Yx.js";
2
+ import { t as s } from "../tensor2d-DTtQ1QcT.js";
3
3
  async function m(t) {
4
4
  await e(t);
5
5
  const r = s(
@@ -1,5 +1,5 @@
1
- import { s as o, h as s } from "../index-CUXkjxiT.js";
2
- import { t as e } from "../tensor2d-BnXMKScO.js";
1
+ import { s as o, e as s } from "../index-DSGwv2Yx.js";
2
+ import { t as e } from "../tensor2d-DTtQ1QcT.js";
3
3
  async function i(t) {
4
4
  await o(t);
5
5
  const r = e(
@@ -1,13 +1,13 @@
1
- import { s as u, a0 as A, h } from "../index-CUXkjxiT.js";
2
- import { a as y } from "../ops-CsXeTq1P.js";
3
- import { t as p } from "../tensor1d-LMGMIUlr.js";
4
- import { t as r } from "../tensor-BWFldCso.js";
1
+ import { s as u, a0 as A, e as y } from "../index-DSGwv2Yx.js";
2
+ import { a as h } from "../ops-CURIZSVt.js";
3
+ import { t as p } from "../tensor1d-BMl0eZYV.js";
4
+ import { t as r } from "../tensor-D8e0Gd7c.js";
5
5
  const w = Array.from({ length: 2048 * 192 }, () => Math.random()), x = Array.from({ length: 192 }, () => Math.random()), M = Array.from({ length: 2048 * 192 }, () => Math.random());
6
6
  async function k(t) {
7
7
  await u(t);
8
8
  const o = p(x, "float32"), n = r(w, [16, 128, 192], "float32"), s = r(M, [16, 128, 192], "float32"), e = (d, g) => {
9
- const i = h().runKernel("RMSNorm", { x: d, gamma: g });
10
- return y.meanSquaredError(i, s);
9
+ const i = y().runKernel("RMSNorm", { x: d, gamma: g });
10
+ return h.meanSquaredError(i, s);
11
11
  }, { value: m, grads: a } = A(e)([n, o]), c = await m.array(), f = await a[0].array(), l = await a[1].array();
12
12
  return [c, f, l];
13
13
  }
@@ -1,6 +1,6 @@
1
- import { s as c, h as d } from "../index-CUXkjxiT.js";
2
- import { t as f } from "../tensor1d-LMGMIUlr.js";
3
- import { t as r } from "../tensor-BWFldCso.js";
1
+ import { s as c, e as d } from "../index-DSGwv2Yx.js";
2
+ import { t as f } from "../tensor1d-BMl0eZYV.js";
3
+ import { t as r } from "../tensor-D8e0Gd7c.js";
4
4
  const y = Array.from({ length: 2048 * 192 }, () => Math.random()), i = Array.from({ length: 192 }, () => Math.random()), l = Array.from({ length: 2048 * 192 }, () => Math.random());
5
5
  async function x(t) {
6
6
  await c(t);
@@ -1,7 +1,7 @@
1
- import { s as a, h as n } from "../index-CUXkjxiT.js";
2
- import { t as c } from "../tensor2d-BnXMKScO.js";
3
- async function i(e) {
4
- await a(e);
1
+ import { s as a, e } from "../index-DSGwv2Yx.js";
2
+ import { t as c } from "../tensor2d-DTtQ1QcT.js";
3
+ async function i(n) {
4
+ await a(n);
5
5
  const r = c(
6
6
  [
7
7
  [0.1, 0.2, 0, 0, 1230, 1232331234, -12234234],
@@ -10,8 +10,8 @@ async function i(e) {
10
10
  [0, 0, 0, 0, -0.1, 1e-3, 0]
11
11
  ],
12
12
  [4, 7]
13
- ), t = n().runKernel("Pack16", { x: r });
14
- return await n().runKernel("Unpack16", { x: t }).array();
13
+ ), t = e().runKernel("Pack16", { x: r });
14
+ return await e().runKernel("Unpack16", { x: t }).array();
15
15
  }
16
16
  export {
17
17
  i as execute
@@ -1,5 +1,5 @@
1
- import { W as i, X as u, Y as c, s as l, h } from "../index-CUXkjxiT.js";
2
- import { t as f } from "../tensor2d-BnXMKScO.js";
1
+ import { U as i, V as u, W as c, s as l, e as h } from "../index-DSGwv2Yx.js";
2
+ import { t as f } from "../tensor2d-DTtQ1QcT.js";
3
3
  function m(t, e, n) {
4
4
  if (i(t), e != null && e.length !== 3)
5
5
  throw new Error("tensor3d() requires shape to have three numbers");
@@ -1,6 +1,6 @@
1
1
  import s from "../layers/RoPECache.js";
2
- import { s as c, h as i } from "../index-CUXkjxiT.js";
3
- import { t as p } from "../tensor4d-C6UCG_u8.js";
2
+ import { s as c, e as i } from "../index-DSGwv2Yx.js";
3
+ import { t as p } from "../tensor4d-Dj4rDssL.js";
4
4
  async function f(r) {
5
5
  await c(r);
6
6
  const n = p(