@genai-fi/nanogpt 0.11.0 → 0.12.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. package/dist/Generator.js +29 -29
  2. package/dist/{RealDiv-Ds-jvL09.js → RealDiv-C8neBwFi.js} +17 -17
  3. package/dist/{Reshape-Cd6e-Otn.js → Reshape-Bd4V_4X7.js} +1 -1
  4. package/dist/{Reshape-Ct266DEk.js → Reshape-Ck29jQSY.js} +7 -7
  5. package/dist/TeachableLLM.d.ts +2 -1
  6. package/dist/TeachableLLM.js +9 -9
  7. package/dist/Trainer.d.ts +4 -2
  8. package/dist/Trainer.js +12 -9
  9. package/dist/{axis_util-DofAuy0p.js → axis_util-DGqbT-FX.js} +1 -1
  10. package/dist/backend.js +2 -2
  11. package/dist/{backend_util-C7NWHpv7.js → backend_util-DC3rBo_H.js} +18 -18
  12. package/dist/{backend_webgpu-B0Vls736.js → backend_webgpu-mbhNnlx9.js} +10 -10
  13. package/dist/{broadcast_to-DDaNMbX7.js → broadcast_to-D1Dmg2Oz.js} +2 -2
  14. package/dist/checks/appendCache.js +2 -2
  15. package/dist/checks/attentionMask.js +3 -3
  16. package/dist/checks/gelu.js +2 -2
  17. package/dist/checks/matMulGelu.js +2 -2
  18. package/dist/checks/normRMS.js +4 -4
  19. package/dist/checks/normRMSGrad.js +3 -3
  20. package/dist/checks/packUnpack.js +2 -2
  21. package/dist/checks/qkv.js +2 -2
  22. package/dist/checks/rope.js +2 -2
  23. package/dist/clip_by_value-fg2aKzUy.js +12 -0
  24. package/dist/{complex-DClmWqJt.js → complex-Cyg-eQeZ.js} +1 -1
  25. package/dist/concat-CSm2rMwe.js +17 -0
  26. package/dist/{concat_util-CHsJFZJJ.js → concat_util-D0je5Ppu.js} +1 -1
  27. package/dist/{dataset-DcjWqUVQ.js → dataset-CVIJu7Xa.js} +3 -3
  28. package/dist/{dropout-OxuaJz6z.js → dropout-DLhSMNTZ.js} +14 -14
  29. package/dist/expand_dims-ChkuOp6I.js +11 -0
  30. package/dist/{exports_initializers-eS9QJ6ut.js → exports_initializers-1KWPiStI.js} +1 -1
  31. package/dist/{floor-DIb-lN_u.js → floor-BRMPgeIs.js} +1 -1
  32. package/dist/gather-BSULDalH.js +9 -0
  33. package/dist/{gelu-DqTbCx5x.js → gelu-BK1k-n1i.js} +1 -1
  34. package/dist/{gpgpu_math-CJcbnKPC.js → gpgpu_math-BJSTk_mW.js} +25 -25
  35. package/dist/{index-Dj5TkmPY.js → index-BBVLAXZD.js} +14 -14
  36. package/dist/{index-D0RBWjq8.js → index-Duu1Lvvv.js} +45 -45
  37. package/dist/{kernel_funcs_utils-CSaumNDs.js → kernel_funcs_utils-BtYrPoJu.js} +8 -8
  38. package/dist/layers/BaseLayer.js +2 -2
  39. package/dist/layers/CausalSelfAttention.js +6 -6
  40. package/dist/layers/MLP.js +4 -4
  41. package/dist/layers/PositionEmbedding.js +5 -5
  42. package/dist/layers/RMSNorm.js +3 -3
  43. package/dist/layers/RoPECache.js +4 -4
  44. package/dist/layers/TiedEmbedding.js +6 -6
  45. package/dist/layers/TransformerBlock.js +1 -1
  46. package/dist/loader/loadTransformers.js +1 -1
  47. package/dist/loader/oldZipLoad.js +17 -17
  48. package/dist/{log_sum_exp-VLZgbFAH.js → log_sum_exp-CVqLsVLl.js} +4 -4
  49. package/dist/main.d.ts +9 -0
  50. package/dist/main.js +69 -58
  51. package/dist/{matMul16-cDxwemKj.js → matMul16-xswmhSuF.js} +7 -7
  52. package/dist/{matMulGelu-B2s_80-H.js → matMulGelu-BpvgnYG8.js} +26 -26
  53. package/dist/mat_mul-Bn2BDpT4.js +11 -0
  54. package/dist/{mod-PrOKlFxH.js → mod-B4AUd1Np.js} +1 -1
  55. package/dist/models/NanoGPTV1.js +2 -2
  56. package/dist/models/model.js +9 -9
  57. package/dist/{ones-BX_wEgzB.js → ones-CBI1AQjb.js} +3 -3
  58. package/dist/ops/adamAdjust.js +1 -1
  59. package/dist/ops/adamMoments.js +1 -1
  60. package/dist/ops/add16.js +1 -1
  61. package/dist/ops/appendCache.js +3 -3
  62. package/dist/ops/attentionMask.js +1 -1
  63. package/dist/ops/concat16.js +2 -2
  64. package/dist/ops/cpu/adamAdjust.js +7 -7
  65. package/dist/ops/cpu/adamMoments.js +5 -5
  66. package/dist/ops/cpu/appendCache.js +6 -6
  67. package/dist/ops/cpu/attentionMask.js +6 -6
  68. package/dist/ops/cpu/fusedSoftmax.js +5 -5
  69. package/dist/ops/cpu/gatherSub.js +7 -7
  70. package/dist/ops/cpu/gelu.js +5 -5
  71. package/dist/ops/cpu/matMul16.js +2 -2
  72. package/dist/ops/cpu/matMulGelu.js +3 -3
  73. package/dist/ops/cpu/matMulMul.js +5 -5
  74. package/dist/ops/cpu/mulDropout.js +1 -1
  75. package/dist/ops/cpu/normRMS.js +5 -5
  76. package/dist/ops/cpu/qkv.js +3 -3
  77. package/dist/ops/cpu/rope.js +9 -9
  78. package/dist/ops/cpu/scatterSub.js +5 -5
  79. package/dist/ops/dot16.js +2 -2
  80. package/dist/ops/gatherSub.js +1 -1
  81. package/dist/ops/gelu.js +2 -2
  82. package/dist/ops/grads/add16.js +1 -1
  83. package/dist/ops/grads/attentionMask.js +2 -2
  84. package/dist/ops/grads/gelu.js +2 -2
  85. package/dist/ops/grads/matMul16.js +3 -3
  86. package/dist/ops/grads/matMulGelu.js +5 -5
  87. package/dist/ops/grads/normRMS.js +6 -6
  88. package/dist/ops/grads/pack16.js +3 -3
  89. package/dist/ops/grads/qkv.js +9 -9
  90. package/dist/ops/grads/rope.js +2 -2
  91. package/dist/ops/grads/softmax16.js +1 -1
  92. package/dist/ops/grads/unpack16.js +2 -2
  93. package/dist/ops/matMul16.js +3 -3
  94. package/dist/ops/matMulGelu.js +2 -2
  95. package/dist/ops/matMulMul.js +1 -1
  96. package/dist/ops/mul16.js +1 -1
  97. package/dist/ops/mulDrop.js +1 -1
  98. package/dist/ops/normRMS.js +1 -1
  99. package/dist/ops/pack16.js +2 -2
  100. package/dist/ops/qkv.js +1 -1
  101. package/dist/ops/reshape16.js +6 -6
  102. package/dist/ops/rope.js +2 -2
  103. package/dist/ops/scatterSub.js +1 -1
  104. package/dist/ops/slice16.js +2 -2
  105. package/dist/ops/softmax16.js +1 -1
  106. package/dist/ops/sub16.js +1 -1
  107. package/dist/ops/sum16.js +2 -2
  108. package/dist/ops/transpose16.js +6 -6
  109. package/dist/ops/unpack16.js +2 -2
  110. package/dist/ops/webgl/adamAdjust.js +2 -2
  111. package/dist/ops/webgl/adamMoments.js +1 -1
  112. package/dist/ops/webgl/appendCache.js +1 -1
  113. package/dist/ops/webgl/attentionMask.js +4 -4
  114. package/dist/ops/webgl/fusedSoftmax.js +6 -6
  115. package/dist/ops/webgl/gatherSub.js +1 -1
  116. package/dist/ops/webgl/gelu.js +2 -2
  117. package/dist/ops/webgl/log.js +3 -3
  118. package/dist/ops/webgl/matMul16.js +10 -10
  119. package/dist/ops/webgl/matMulGelu.js +4 -4
  120. package/dist/ops/webgl/matMulMul.js +2 -2
  121. package/dist/ops/webgl/mulDropout.js +1 -1
  122. package/dist/ops/webgl/normRMS.js +2 -2
  123. package/dist/ops/webgl/qkv.js +1 -1
  124. package/dist/ops/webgl/rope.js +4 -4
  125. package/dist/ops/webgl/scatterSub.js +1 -1
  126. package/dist/ops/webgpu/adamAdjust.js +3 -3
  127. package/dist/ops/webgpu/adamMoments.js +5 -5
  128. package/dist/ops/webgpu/add16.js +1 -1
  129. package/dist/ops/webgpu/appendCache.js +3 -3
  130. package/dist/ops/webgpu/attentionMask.js +5 -5
  131. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  132. package/dist/ops/webgpu/concat16.js +5 -5
  133. package/dist/ops/webgpu/gatherSub.js +3 -3
  134. package/dist/ops/webgpu/gelu.js +3 -3
  135. package/dist/ops/webgpu/matMul16.js +19 -19
  136. package/dist/ops/webgpu/matMul16_program.js +2 -2
  137. package/dist/ops/webgpu/mul16.js +1 -1
  138. package/dist/ops/webgpu/normRMS.js +2 -2
  139. package/dist/ops/webgpu/normRMSGrad.js +4 -4
  140. package/dist/ops/webgpu/pack16.js +3 -3
  141. package/dist/ops/webgpu/pack16_program.js +2 -2
  142. package/dist/ops/webgpu/qkv.js +4 -4
  143. package/dist/ops/webgpu/rope.js +3 -3
  144. package/dist/ops/webgpu/scatterSub.js +3 -3
  145. package/dist/ops/webgpu/slice16.js +4 -4
  146. package/dist/ops/webgpu/softmax16.js +4 -4
  147. package/dist/ops/webgpu/softmax16_program.js +2 -2
  148. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  149. package/dist/ops/webgpu/softmax16grad.js +1 -1
  150. package/dist/ops/webgpu/sub16.js +1 -1
  151. package/dist/ops/webgpu/sum16.js +5 -5
  152. package/dist/ops/webgpu/transpose16.js +2 -2
  153. package/dist/ops/webgpu/transpose16_program.js +2 -2
  154. package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
  155. package/dist/ops/webgpu/unpack16.js +5 -5
  156. package/dist/ops/webgpu/utils/binary_op.js +3 -3
  157. package/dist/ops/webgpu/utils/reductions.js +4 -4
  158. package/dist/{ops-FJapAPfm.js → ops-C2_OXuZ4.js} +35 -35
  159. package/dist/{pack16-k4jq6aMX.js → pack16-atD0eYRm.js} +6 -6
  160. package/dist/patches/webgpu_backend.js +8 -8
  161. package/dist/patches/webgpu_base.js +1 -1
  162. package/dist/patches/webgpu_program.js +2 -2
  163. package/dist/{random_width-UGQn4OWb.js → random_width-BN4wGJaW.js} +33 -33
  164. package/dist/{range-CuGvVN2c.js → range-DKmP1-OQ.js} +1 -1
  165. package/dist/relu-BsXmGzzu.js +9 -0
  166. package/dist/{reshape-CkjKPPqB.js → reshape-BI0yzp1T.js} +1 -1
  167. package/dist/{resize_nearest_neighbor-DB8k9KN_.js → resize_nearest_neighbor-BA_BX-ub.js} +25 -25
  168. package/dist/{rope-BmZmp9uP.js → rope-DJ7Y7c-u.js} +1 -1
  169. package/dist/{scatter_nd_util-BY22Cc-C.js → scatter_nd_util-k9MUVUkn.js} +1 -1
  170. package/dist/{selu_util-BuLbmbrl.js → selu_util-DyW0X1WG.js} +5 -5
  171. package/dist/{shared-B7USJZgw.js → shared-Q3BS6T03.js} +1 -1
  172. package/dist/{shared-BQboIImQ.js → shared-nnSWpC3u.js} +6 -6
  173. package/dist/{slice-Aqy7KbJh.js → slice-wBNvzVyz.js} +3 -3
  174. package/dist/{slice_util-D8CQRenR.js → slice_util-zN8KFC5I.js} +7 -7
  175. package/dist/{softmax-faLoUZVT.js → softmax-DfuYyjMh.js} +1 -1
  176. package/dist/split-BYrLboMq.js +9 -0
  177. package/dist/squeeze-Bk8Brcct.js +10 -0
  178. package/dist/{stack-WJK22CFn.js → stack-CDWShFHF.js} +1 -1
  179. package/dist/{step-dXR33iOg.js → step-BS5JXRR6.js} +14 -14
  180. package/dist/sum-BPUfDB2X.js +11 -0
  181. package/dist/{tensor-BQqrDvpx.js → tensor-CEt9Nm2s.js} +1 -1
  182. package/dist/{tensor1d-LxP9asMm.js → tensor1d-Cc_KCIDg.js} +1 -1
  183. package/dist/{tensor2d-BN1sSfQO.js → tensor2d-BN97fF71.js} +1 -1
  184. package/dist/{tensor4d-DVwr7pLF.js → tensor4d-vuDDgdUI.js} +1 -1
  185. package/dist/{tfjs_backend-Vi4JfLzT.js → tfjs_backend-806hyYve.js} +36 -36
  186. package/dist/tile-OWUvpIVt.js +11 -0
  187. package/dist/tokeniser/BaseTokeniser.d.ts +6 -8
  188. package/dist/tokeniser/BaseTokeniser.js +6 -6
  189. package/dist/tokeniser/CharTokeniser.d.ts +6 -6
  190. package/dist/tokeniser/CharTokeniser.js +26 -26
  191. package/dist/tokeniser/bpe.d.ts +6 -6
  192. package/dist/tokeniser/bpe.js +9 -9
  193. package/dist/tokeniser/type.d.ts +6 -8
  194. package/dist/training/Adam.js +2 -2
  195. package/dist/training/AdamExt.js +1 -1
  196. package/dist/training/DatasetBuilder.d.ts +1 -1
  197. package/dist/training/DatasetBuilder.js +29 -29
  198. package/dist/training/FullTrainer.js +1 -1
  199. package/dist/training/Trainer.d.ts +5 -4
  200. package/dist/training/Trainer.js +37 -40
  201. package/dist/training/sparseCrossEntropy.js +3 -3
  202. package/dist/training/tasks/ConversationTask.d.ts +11 -0
  203. package/dist/training/tasks/ConversationTask.js +26 -0
  204. package/dist/training/tasks/PretrainingTask.d.ts +11 -0
  205. package/dist/training/tasks/PretrainingTask.js +34 -0
  206. package/dist/training/tasks/StartSentenceTask.d.ts +12 -0
  207. package/dist/training/tasks/StartSentenceTask.js +42 -0
  208. package/dist/training/tasks/Task.d.ts +8 -0
  209. package/dist/training/tasks/Task.js +44 -0
  210. package/dist/{transpose-JawVKyZy.js → transpose-BUkQCJp9.js} +7 -7
  211. package/dist/{unsorted_segment_sum-LAbmE9G4.js → unsorted_segment_sum-BljxHhCY.js} +78 -78
  212. package/dist/utilities/dummy.js +3 -3
  213. package/dist/utilities/multinomialCPU.js +2 -2
  214. package/dist/utilities/packed.js +1 -1
  215. package/dist/utilities/performance.js +1 -1
  216. package/dist/utilities/profile.js +1 -1
  217. package/dist/utilities/safetensors.js +2 -2
  218. package/dist/utilities/sentences.d.ts +1 -1
  219. package/dist/utilities/sentences.js +11 -11
  220. package/dist/utilities/weights.js +2 -2
  221. package/dist/{variable-DQ9yYgEU.js → variable-DPt_Iuog.js} +1 -1
  222. package/dist/{webgpu_program-CAE4RICo.js → webgpu_program-BpWRlghH.js} +1 -1
  223. package/dist/{webgpu_util-BdovYhXr.js → webgpu_util-DMiKzzQM.js} +7 -7
  224. package/dist/{zeros-DeiE2zTa.js → zeros-5YROwwUH.js} +2 -2
  225. package/dist/{zeros_like-BAz3iKru.js → zeros_like-De4n1C3m.js} +57 -57
  226. package/package.json +1 -1
  227. package/dist/clip_by_value-Dn5tzexi.js +0 -12
  228. package/dist/concat-C6X3AAlQ.js +0 -17
  229. package/dist/expand_dims-BzfJK2uc.js +0 -11
  230. package/dist/gather-BcO5UQNJ.js +0 -9
  231. package/dist/mat_mul-DxpNTCRz.js +0 -11
  232. package/dist/relu-Cf80uA2p.js +0 -9
  233. package/dist/split-BNz5jcGc.js +0 -9
  234. package/dist/squeeze--YMgaAAf.js +0 -10
  235. package/dist/sum-BdplSvq_.js +0 -11
  236. package/dist/tile-CvN_LyVr.js +0 -11
package/dist/Generator.js CHANGED
@@ -1,40 +1,40 @@
1
1
  import { E as Ui } from "./index-DvYrXKkX.js";
2
- import { q as Hi, u as Xi, E as Ki, dn as Ss, at as pe, ab as _, au as oo, av as ao, e as Oe, aW as Dt, aA as ro, aB as io, aw as ji, ax as Ft, aF as Ge, f as co, aI as ws, aJ as qi, V as G, ae as _e, aK as Yi, I as Ns, aL as Rs, R as Qi, ah as Zi, y as te, H as lo, $ as Ne, a9 as ee, aX as uo, c9 as Ts, ca as Es, cQ as po, bo as ho, af as ue, U as Ye, bp as fo, bq as mo, cb as go, cc as Ds, cd as Fs, ce as Ps, cf as Os, cg as As, br as xo, ac as nt, cA as Co, cR as bo, cS as Io, bu as ko, bt as yo, bf as $o, dg as vo, C as _s, cU as So, ap as wo, A as No, bv as Ro, c4 as $e, cF as To, bw as Eo, cB as Do, cV as Fo, cC as Po, bx as Ls, by as Vs, bh as Oo, bz as Ao, ak as Qe, W as Ji, bA as _o, cD as Lo, ch as Vo, bB as Wo, cH as Mo, cI as Bo, dh as Go, ci as zo, bZ as ns, bT as St, bW as Ws, cW as kn, dv as ut, dw as Uo, cX as yn, di as ec, Q as tc, bg as Ho, cY as Xo, bD as Ms, B as Ko, aT as jo, aR as mt, cr as qo, de as Yo, df as Qo, bi as Zo, cG as Jo, dk as ea, aj as ta, G as sa, cs as na, cj as Bs, ck as Gs, cl as zs, dl as oa, b5 as Us, b6 as Hs, bF as Xs, cn as Ks, cm as aa, c_ as ra, am as sc, bG as ia, cE as ca, c$ as la, d0 as ua, dm as da, b7 as pa, a$ as ha, co as fa, bS as ma, M as js, J as ga, bk as xa, bm as Ca, bl as ba, bH as Ia, d5 as ka, bI as ya, P as $a, a7 as va, bJ as Sa, d1 as qs, dx as wa, dy as Na, dz as Ra, _ as Ta, cp as Ys, bd as Ea, d2 as Da, be as Fa, d3 as Pa, bL as Oa, bj as Aa, ba as Qs, ai as _a, dp as La, a_ as Va, bN as Zs, cq as Js, bO as en, bP as tn, bE as sn, bK as Wa, dA as Ma, dB as Ba, dq as Ga, dr as za, ds as Ua, K as Ha, d4 as Xa, aM as nn, ct as Ka, dt as ja, dC as qa, dD as Ya, cu as on, bs as an, du as Qa, T as Za, cv as Ja, bn as er, a6 as rn, cw as tr, bc as sr, bQ as nr, h as or, dE as ar, dF as $n, ay as vn, az as nc, t as rr, b as oc, dG as ac, dH as rc, c2 as ic, as as cc, bR as lc, bX as uc, S as dc, bY as pc, b9 as hc, ar as fc, bU as mc, bV as gc, b_ as xc, aO as ir, bC as Cc, an as bc, b$ as Ic, F as kc, c0 as yc, dj as $c, b1 as vc, b2 as Sc, b3 as wc, b4 as Nc, ao as Rc, c1 as Tc, b8 as Ec, c7 as Dc, aq as Fc, c3 as Pc, aS as cr, bM as Oc, aH as Ac, c5 as _c, bb as Lc, c6 as Vc, k as Wc } from "./index-D0RBWjq8.js";
3
- import { n as Mc } from "./random_width-UGQn4OWb.js";
4
- import { t as Bc } from "./zeros_like-BAz3iKru.js";
2
+ import { o as Hi, q as Xi, E as Ki, dn as Ss, at as pe, aa as _, au as oo, av as ao, e as Oe, aW as Dt, aA as ro, aB as io, aw as ji, ax as Ft, aF as Ge, ad as co, aI as ws, aJ as qi, U as G, ae as _e, aK as Yi, H as Ns, aL as Rs, R as Qi, ah as Zi, x as te, D as lo, _ as Ne, a8 as ee, aX as uo, c9 as Ts, ca as Es, cQ as po, bo as ho, af as ue, Q as Ye, bp as fo, bq as mo, cb as go, cc as Ds, cd as Fs, ce as Ps, cf as Os, cg as As, br as xo, ab as nt, cA as Co, cR as bo, cS as Io, bu as ko, bt as yo, bf as $o, dg as vo, C as _s, cU as So, ap as wo, z as No, bv as Ro, c4 as $e, cF as To, bw as Eo, cB as Do, cV as Fo, cC as Po, bx as Ls, by as Vs, bh as Oo, bz as Ao, ak as Qe, V as Ji, bA as _o, cD as Lo, ch as Vo, bB as Wo, cH as Mo, cI as Bo, dh as Go, ci as zo, bZ as ns, bT as St, bW as Ws, cW as kn, dv as ut, dw as Uo, cX as yn, di as ec, N as tc, bg as Ho, cY as Xo, bD as Ms, A as Ko, aT as jo, aR as mt, cr as qo, de as Yo, df as Qo, bi as Zo, cG as Jo, dk as ea, aj as ta, G as sa, cs as na, cj as Bs, ck as Gs, cl as zs, dl as oa, b5 as Us, b6 as Hs, bF as Xs, cn as Ks, cm as aa, c_ as ra, am as sc, bG as ia, cE as ca, c$ as la, d0 as ua, dm as da, b7 as pa, a$ as ha, co as fa, bS as ma, M as js, I as ga, bk as xa, bm as Ca, bl as ba, bH as Ia, d5 as ka, bI as ya, P as $a, a6 as va, bJ as Sa, d1 as qs, dx as wa, dy as Na, dz as Ra, Z as Ta, cp as Ys, bd as Ea, d2 as Da, be as Fa, d3 as Pa, bL as Oa, bj as Aa, ba as Qs, ai as _a, dp as La, a_ as Va, bN as Zs, cq as Js, bO as en, bP as tn, bE as sn, bK as Wa, dA as Ma, dB as Ba, dq as Ga, dr as za, ds as Ua, J as Ha, d4 as Xa, aM as nn, ct as Ka, dt as ja, dC as qa, dD as Ya, cu as on, bs as an, du as Qa, T as Za, cv as Ja, bn as er, a5 as rn, cw as tr, bc as sr, bQ as nr, f as or, dE as ar, dF as $n, ay as vn, az as nc, t as rr, b as oc, dG as ac, dH as rc, c2 as ic, as as cc, bR as lc, bX as uc, S as dc, bY as pc, b9 as hc, ar as fc, bU as mc, bV as gc, b_ as xc, aO as ir, bC as Cc, an as bc, b$ as Ic, F as kc, c0 as yc, dj as $c, b1 as vc, b2 as Sc, b3 as wc, b4 as Nc, ao as Rc, c1 as Tc, b8 as Ec, c7 as Dc, aq as Fc, c3 as Pc, aS as cr, bM as Oc, aH as Ac, c5 as _c, bb as Lc, c6 as Vc, k as Wc } from "./index-Duu1Lvvv.js";
3
+ import { n as Mc } from "./random_width-BN4wGJaW.js";
4
+ import { t as Bc } from "./zeros_like-De4n1C3m.js";
5
5
  import "./index-Cp39cXWe.js";
6
- import "./dataset-DcjWqUVQ.js";
7
- import { a as j, u as ae, c as ot, i as at, b as Gc, d as wt, t as Re, e as gt, f as dt, g as lr, r as Nt, h as Ae, j as zc, k as Uc, l as cn, z as Hc, m as ln, n as ur, o as Xc, p as Kc, q as jc, v as qc, w as Yc, x as Qc, y as Zc, A as Jc, B as el, C as tl, D as lt, E as sl, F as nl, G as dr, H as ol, I as al, J as rl, K as il, L as cl, M as ll, N as ul, O as dl, P as pl, Q as hl, R as fl, S as ml, T as gl, U as xl, V as Cl, W as bl, X as Il, Y as kl, Z as yl, _ as $l, $ as vl, a0 as Sl, a1 as wl, a2 as Nl, a3 as Rl, a4 as Tl, a5 as El, a6 as Dl, a7 as Fl, a8 as Pl, a9 as Ol, aa as Al, ab as _l, ac as Ll, ad as Vl, ae as Wl, af as Ml, ag as Bl, ah as Gl, ai as zl } from "./shared-BQboIImQ.js";
6
+ import "./dataset-CVIJu7Xa.js";
7
+ import { a as j, u as ae, c as ot, i as at, b as Gc, d as wt, t as Re, e as gt, f as dt, g as lr, r as Nt, h as Ae, j as zc, k as Uc, l as cn, z as Hc, m as ln, n as ur, o as Xc, p as Kc, q as jc, v as qc, w as Yc, x as Qc, y as Zc, A as Jc, B as el, C as tl, D as lt, E as sl, F as nl, G as dr, H as ol, I as al, J as rl, K as il, L as cl, M as ll, N as ul, O as dl, P as pl, Q as hl, R as fl, S as ml, T as gl, U as xl, V as Cl, W as bl, X as Il, Y as kl, Z as yl, _ as $l, $ as vl, a0 as Sl, a1 as wl, a2 as Nl, a3 as Rl, a4 as Tl, a5 as El, a6 as Dl, a7 as Fl, a8 as Pl, a9 as Ol, aa as Al, ab as _l, ac as Ll, ad as Vl, ae as Wl, af as Ml, ag as Bl, ah as Gl, ai as zl } from "./shared-nnSWpC3u.js";
8
8
  import { m as pt, g as pr, s as Ul, c as Hl, b as Xl, d as Kl, a as jl, e as ql } from "./complex_util-Yc1A_gV1.js";
9
- import { a as ge, b as xe, d as ye, c as ve, e as Te, g as os } from "./axis_util-DofAuy0p.js";
10
- import { k as Ze, h as Le, i as Je, j as rt, b as Se, d as xt, g as as } from "./step-dXR33iOg.js";
11
- import { z as rs, A as is, B as cs, C as hr, D as fr, F as mr, G as gr, H as xr, I as Cr, J as br, y as Ir, x as kr, w as yr, u as $r, t as vr, E as Sr, K as wr, L as Nr, M as Rr, N as Tr, c as Er, f as Yl, O as Ql, P as Zl } from "./backend_util-C7NWHpv7.js";
12
- import { a as Dr, c as Ue } from "./concat_util-CHsJFZJJ.js";
9
+ import { a as ge, b as xe, d as ye, c as ve, e as Te, g as os } from "./axis_util-DGqbT-FX.js";
10
+ import { k as Ze, h as Le, i as Je, j as rt, b as Se, d as xt, g as as } from "./step-BS5JXRR6.js";
11
+ import { z as rs, A as is, B as cs, C as hr, D as fr, F as mr, G as gr, H as xr, I as Cr, J as br, y as Ir, x as kr, w as yr, u as $r, t as vr, E as Sr, K as wr, L as Nr, M as Rr, N as Tr, c as Er, f as Yl, O as Ql, P as Zl } from "./backend_util-DC3rBo_H.js";
12
+ import { a as Dr, c as Ue } from "./concat_util-D0je5Ppu.js";
13
13
  import { s as Jl } from "./index-CieiGp4Y.js";
14
14
  import { n as Fr, b as Pr, a as Or } from "./non_max_suppression_impl-B2W7YjZB.js";
15
- import { c as Ct } from "./scatter_nd_util-BY22Cc-C.js";
16
- import { S as Ar, a as _r } from "./selu_util-BuLbmbrl.js";
17
- import { b as Lr, d as Vr, p as eu, a as tu, i as su, c as nu } from "./slice_util-D8CQRenR.js";
18
- import { h as Sn, j as ou, k as au, l as ru, m as iu, n as cu, o as lu, P as un, p as Ve, u as Pe, q as Wr, c as Mr, T as De, E as Br, g as Gr, a as zr, r as uu, s as du, t as Y, v as Pt, w as pu, x as wn, y as hu, z as fu, A as Ot, B as mu, C as gu, D as bs, F as Gt, G as zt, H as xu, I as Cu, J as Nn, K as bu, L as Iu, M as fs, N as ku, O as yu, Q as $u, R as Ut, S as ms, U as vu, f as he, V as be, W as Ht, X as Xt, Y as Su, d as Rn, e as Tn, i as Ur, Z as wu, _ as Nu, $ as Ru, a0 as Tu, a1 as Eu, a2 as Du, a3 as At } from "./gpgpu_math-CJcbnKPC.js";
19
- import { s as Hr, a as Fu, t as Xr, b as Pu, c as Ou, d as Kr, e as Au, n as _u, f as Lu, g as Vu, h as Wu, i as Mu, j as Bu, k as Gu, l as zu, o as Uu, p as Hu, q as Xu, r as Ku, u as ju, v as qu, w as Yu, x as Qu, y as Zu, z as Ju, A as ed, B as td, C as sd, D as nd, E as od, F as ad, G as rd, H as id, I as cd, J as ld, K as ud, L as dd, M as jr, N as pd, O as hd, P as fd, Q as md, R as gd, S as xd, T as Cd, U as bd, V as Id, W as kd } from "./shared-B7USJZgw.js";
20
- import { a as ke, c as yd, U as st, d as qe, e as ze, A as En, f as bt, B as dn, h as pn, m as Rt, u as se, C as We, b as Ce, i as Fe, j as hn, k as it, l as It, n as $d, o as vd, p as Sd, q as wd } from "./kernel_funcs_utils-CSaumNDs.js";
21
- import { R as Nd, r as U, a as Rd } from "./Reshape-Ct266DEk.js";
22
- import { M as qr } from "./matMulGelu-B2s_80-H.js";
23
- import { t as Yr, s as fn, a as _t, m as Td, r as Ed, b as Dd, c as Fd, d as Pd } from "./RealDiv-Ds-jvL09.js";
24
- import { z as Od } from "./zeros-DeiE2zTa.js";
15
+ import { c as Ct } from "./scatter_nd_util-k9MUVUkn.js";
16
+ import { S as Ar, a as _r } from "./selu_util-DyW0X1WG.js";
17
+ import { b as Lr, d as Vr, p as eu, a as tu, i as su, c as nu } from "./slice_util-zN8KFC5I.js";
18
+ import { h as Sn, j as ou, k as au, l as ru, m as iu, n as cu, o as lu, P as un, p as Ve, u as Pe, q as Wr, c as Mr, T as De, E as Br, g as Gr, a as zr, r as uu, s as du, t as Y, v as Pt, w as pu, x as wn, y as hu, z as fu, A as Ot, B as mu, C as gu, D as bs, F as Gt, G as zt, H as xu, I as Cu, J as Nn, K as bu, L as Iu, M as fs, N as ku, O as yu, Q as $u, R as Ut, S as ms, U as vu, f as he, V as be, W as Ht, X as Xt, Y as Su, d as Rn, e as Tn, i as Ur, Z as wu, _ as Nu, $ as Ru, a0 as Tu, a1 as Eu, a2 as Du, a3 as At } from "./gpgpu_math-BJSTk_mW.js";
19
+ import { s as Hr, a as Fu, t as Xr, b as Pu, c as Ou, d as Kr, e as Au, n as _u, f as Lu, g as Vu, h as Wu, i as Mu, j as Bu, k as Gu, l as zu, o as Uu, p as Hu, q as Xu, r as Ku, u as ju, v as qu, w as Yu, x as Qu, y as Zu, z as Ju, A as ed, B as td, C as sd, D as nd, E as od, F as ad, G as rd, H as id, I as cd, J as ld, K as ud, L as dd, M as jr, N as pd, O as hd, P as fd, Q as md, R as gd, S as xd, T as Cd, U as bd, V as Id, W as kd } from "./shared-Q3BS6T03.js";
20
+ import { a as ke, c as yd, U as st, d as qe, e as ze, A as En, f as bt, B as dn, h as pn, m as Rt, u as se, C as We, b as Ce, i as Fe, j as hn, k as it, l as It, n as $d, o as vd, p as Sd, q as wd } from "./kernel_funcs_utils-BtYrPoJu.js";
21
+ import { R as Nd, r as U, a as Rd } from "./Reshape-Ck29jQSY.js";
22
+ import { M as qr } from "./matMulGelu-BpvgnYG8.js";
23
+ import { t as Yr, s as fn, a as _t, m as Td, r as Ed, b as Dd, c as Fd, d as Pd } from "./RealDiv-C8neBwFi.js";
24
+ import { z as Od } from "./zeros-5YROwwUH.js";
25
25
  import "./ops/cpu/attentionMask.js";
26
26
  import "./ops/webgl/attentionMask.js";
27
27
  import "./ops/grads/attentionMask.js";
28
28
  import "./ops/cpu/rope.js";
29
29
  import "./ops/webgl/rope.js";
30
- import "./rope-BmZmp9uP.js";
30
+ import "./rope-DJ7Y7c-u.js";
31
31
  import "./ops/cpu/appendCache.js";
32
32
  import "./ops/webgl/appendCache.js";
33
33
  import "./ops/grads/softmax16.js";
34
- import "./matMul16-cDxwemKj.js";
34
+ import "./matMul16-xswmhSuF.js";
35
35
  import "./ops/webgl/matMul16.js";
36
36
  import "./ops/cpu/matMul16.js";
37
- import "./pack16-k4jq6aMX.js";
37
+ import "./pack16-atD0eYRm.js";
38
38
  import "./ops/transpose16.js";
39
39
  import "./ops/reshape16.js";
40
40
  import "./ops/cpu/qkv.js";
@@ -62,17 +62,17 @@ import "./ops/cpu/matMulGelu.js";
62
62
  import "./ops/grads/matMulGelu.js";
63
63
  import "./ops/cpu/gelu.js";
64
64
  import "./ops/webgl/gelu.js";
65
- import "./gelu-DqTbCx5x.js";
65
+ import "./gelu-BK1k-n1i.js";
66
66
  import "./ops/webgl/log.js";
67
67
  import "./checks/normRMS.js";
68
68
  import "./checks/normRMSGrad.js";
69
69
  import Wd from "./utilities/multinomialCPU.js";
70
- import { r as Dn } from "./reshape-CkjKPPqB.js";
71
- import { t as Kt } from "./tensor2d-BN1sSfQO.js";
72
- import { z as Md } from "./unsorted_segment_sum-LAbmE9G4.js";
73
- import { s as gs } from "./softmax-faLoUZVT.js";
74
- import { g as Bd } from "./gather-BcO5UQNJ.js";
75
- import { c as Gd } from "./concat-C6X3AAlQ.js";
70
+ import { r as Dn } from "./reshape-BI0yzp1T.js";
71
+ import { t as Kt } from "./tensor2d-BN97fF71.js";
72
+ import { z as Md } from "./unsorted_segment_sum-BljxHhCY.js";
73
+ import { s as gs } from "./softmax-DfuYyjMh.js";
74
+ import { g as Bd } from "./gather-BSULDalH.js";
75
+ import { c as Gd } from "./concat-CSm2rMwe.js";
76
76
  function zd(a, t, e, n = !1) {
77
77
  const s = Xi(a, "logits", "multinomial"), o = s.size, r = s.rank;
78
78
  if (o < 2)
@@ -1,10 +1,10 @@
1
- import { aG as T, ab as E, af as O, V, aS as B, Q as F, am as G, aT as K } from "./index-D0RBWjq8.js";
2
- import { r as $ } from "./Reshape-Ct266DEk.js";
3
- import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-DofAuy0p.js";
4
- import { t as U, m as W } from "./shared-B7USJZgw.js";
5
- import { c as _ } from "./backend_util-C7NWHpv7.js";
6
- import { f as y } from "./gpgpu_math-CJcbnKPC.js";
7
- import { g as j, b as L } from "./kernel_funcs_utils-CSaumNDs.js";
1
+ import { aG as T, aa as E, af as O, U as V, aS as B, N as F, am as U, aT as G } from "./index-Duu1Lvvv.js";
2
+ import { r as $ } from "./Reshape-Ck29jQSY.js";
3
+ import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-DGqbT-FX.js";
4
+ import { t as K, m as W } from "./shared-Q3BS6T03.js";
5
+ import { c as _ } from "./backend_util-DC3rBo_H.js";
6
+ import { f as y } from "./gpgpu_math-BJSTk_mW.js";
7
+ import { g as j, b as L } from "./kernel_funcs_utils-BtYrPoJu.js";
8
8
  class w {
9
9
  constructor(s, e) {
10
10
  this.variableNames = ["x"];
@@ -203,14 +203,14 @@ function P(a, s, e, t) {
203
203
  }
204
204
  return l;
205
205
  }
206
- class Q {
206
+ class Y {
207
207
  constructor(s, e) {
208
208
  this.variableNames = ["A"];
209
209
  const t = new Array(s.length);
210
210
  for (let r = 0; r < t.length; r++)
211
211
  t[r] = s[e[r]];
212
212
  this.outputShape = t, this.rank = t.length;
213
- const n = y(this.rank), l = Y(e);
213
+ const n = y(this.rank), l = H(e);
214
214
  this.userCode = `
215
215
  void main() {
216
216
  ${n} resRC = getOutputCoords();
@@ -219,7 +219,7 @@ class Q {
219
219
  `;
220
220
  }
221
221
  }
222
- function Y(a) {
222
+ function H(a) {
223
223
  const s = a.length;
224
224
  if (s > 6)
225
225
  throw Error(`Transpose for rank ${s} is not yet supported`);
@@ -228,7 +228,7 @@ function Y(a) {
228
228
  t[a[n]] = e[n];
229
229
  return t.join();
230
230
  }
231
- class H {
231
+ class J {
232
232
  constructor(s, e) {
233
233
  this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
234
234
  const t = new Array(s.length);
@@ -261,10 +261,10 @@ class H {
261
261
  }
262
262
  }
263
263
  function D(a, s, e) {
264
- const t = E().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new H(a.shape, s) : new Q(a.shape, s);
264
+ const t = E().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new J(a.shape, s) : new Y(a.shape, s);
265
265
  return e.runWebGLProgram(t, [a], a.dtype);
266
266
  }
267
- function J(a, s, e, t) {
267
+ function Q(a, s, e, t) {
268
268
  const n = s, l = a.shape.length, r = O(n, a.shape);
269
269
  let i = r;
270
270
  const c = A(i, l), o = c != null;
@@ -278,7 +278,7 @@ function J(a, s, e, t) {
278
278
  }
279
279
  function Z(a) {
280
280
  const { inputs: s, backend: e, attrs: t } = a, { x: n } = s, { axis: l, keepDims: r } = t;
281
- return J(n, l, r, e);
281
+ return Q(n, l, r, e);
282
282
  }
283
283
  const pe = {
284
284
  kernelName: F,
@@ -299,7 +299,7 @@ function te(a) {
299
299
  const I = e.texData.get(d.dataId).values, m = new Array(i);
300
300
  for (let v = 0; v < m.length; v++)
301
301
  m[v] = n.shape[u[v]];
302
- const z = U(I, n.shape, n.dtype, u, m);
302
+ const z = K(I, n.shape, n.dtype, u, m);
303
303
  d = e.makeTensorInfo(m, n.dtype);
304
304
  const M = e.texData.get(d.dataId);
305
305
  M.values = z;
@@ -322,7 +322,7 @@ function te(a) {
322
322
  return p && e.disposeIntermediateTensorInfo(d), x;
323
323
  }
324
324
  const he = {
325
- kernelName: G,
325
+ kernelName: U,
326
326
  backendName: "webgl",
327
327
  kernelFunc: te
328
328
  };
@@ -349,7 +349,7 @@ return a / b;`, se = `
349
349
 
350
350
  return result;
351
351
  `, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), fe = {
352
- kernelName: K,
352
+ kernelName: G,
353
353
  backendName: "webgl",
354
354
  kernelFunc: ne
355
355
  };
@@ -1,4 +1,4 @@
1
- import { V as h, ah as d, y as c, R as m } from "./index-D0RBWjq8.js";
1
+ import { U as h, ah as d, x as c, R as m } from "./index-Duu1Lvvv.js";
2
2
  function i(n) {
3
3
  const { inputs: p, attrs: o } = n, { x: e } = p, { shape: r } = o, a = h(e.shape), s = d(r, a), t = h(s);
4
4
  return c(a === t, () => `The new shape (${s}) has ${t} elements and the old shape (${e.shape}) has ${a} elements. The new shape and old shape must have the same number of elements.`), n.backend.incRef(e.dataId), { dataId: e.dataId, shape: s, dtype: e.dtype };
@@ -1,5 +1,5 @@
1
- import { R as C, V as c, ah as R, y as f } from "./index-D0RBWjq8.js";
2
- import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-CJcbnKPC.js";
1
+ import { R as C, U as c, ah as R, x as f } from "./index-Duu1Lvvv.js";
2
+ import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-BJSTk_mW.js";
3
3
  class S {
4
4
  constructor(t, i) {
5
5
  this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = t, this.enableShapeUniforms = g(this.outputShape.length);
@@ -47,7 +47,7 @@ function v(s, t) {
47
47
  }
48
48
  `;
49
49
  }
50
- function y(s, t, i) {
50
+ function b(s, t, i) {
51
51
  const a = [
52
52
  u(s.shape),
53
53
  ...m(s.shape)
@@ -61,19 +61,19 @@ function y(s, t, i) {
61
61
  ], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
62
62
  return { dataId: h.dataId, shape: t, dtype: h.dtype };
63
63
  }
64
- function b(s) {
64
+ function y(s) {
65
65
  const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = R(o, p), h = c(n);
66
66
  f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
67
67
  const d = r.texData.get(e.dataId);
68
- return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? y(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
68
+ return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
69
69
  }
70
70
  const U = {
71
71
  kernelName: C,
72
72
  backendName: "webgl",
73
- kernelFunc: b
73
+ kernelFunc: y
74
74
  };
75
75
  export {
76
76
  S as R,
77
77
  U as a,
78
- b as r
78
+ y as r
79
79
  };
@@ -6,6 +6,7 @@ import { default as Trainer, ITrainerOptions } from './Trainer';
6
6
  import { default as MemoryProfiler } from './utilities/profile';
7
7
  import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
8
8
  import { default as Model, ModelForwardAttributes } from './models/model';
9
+ import { Task } from './training/tasks/Task';
9
10
  type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
10
11
  interface TeachableLLMMeta {
11
12
  name?: string;
@@ -41,7 +42,7 @@ export default class TeachableLLM {
41
42
  set enableProfiler(value: boolean);
42
43
  getNumParams(): number;
43
44
  trainer(): Trainer;
44
- train(text: Conversation[][], options?: ITrainerOptions): Promise<void>;
45
+ train(text: Task[], options?: ITrainerOptions): Promise<void>;
45
46
  trainTokeniser(text: string[]): Promise<number>;
46
47
  generator(): Generator;
47
48
  generateText(prompt: Conversation[], options?: IGenerateOptions): Promise<Conversation[]>;
@@ -5,24 +5,24 @@ import u from "./Generator.js";
5
5
  import f from "./Trainer.js";
6
6
  import { E as p } from "./index-DvYrXKkX.js";
7
7
  import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
8
- import "./index-D0RBWjq8.js";
9
- import "./random_width-UGQn4OWb.js";
10
- import "./zeros_like-BAz3iKru.js";
8
+ import "./index-Duu1Lvvv.js";
9
+ import "./random_width-BN4wGJaW.js";
10
+ import "./zeros_like-De4n1C3m.js";
11
11
  import "./index-Cp39cXWe.js";
12
- import "./dataset-DcjWqUVQ.js";
12
+ import "./dataset-CVIJu7Xa.js";
13
13
  import "./ops/cpu/attentionMask.js";
14
14
  import "./ops/webgl/attentionMask.js";
15
15
  import "./ops/grads/attentionMask.js";
16
16
  import "./ops/cpu/rope.js";
17
17
  import "./ops/webgl/rope.js";
18
- import "./rope-BmZmp9uP.js";
18
+ import "./rope-DJ7Y7c-u.js";
19
19
  import "./ops/cpu/appendCache.js";
20
20
  import "./ops/webgl/appendCache.js";
21
21
  import "./ops/grads/softmax16.js";
22
- import "./matMul16-cDxwemKj.js";
22
+ import "./matMul16-xswmhSuF.js";
23
23
  import "./ops/webgl/matMul16.js";
24
24
  import "./ops/cpu/matMul16.js";
25
- import "./pack16-k4jq6aMX.js";
25
+ import "./pack16-atD0eYRm.js";
26
26
  import "./ops/transpose16.js";
27
27
  import "./ops/reshape16.js";
28
28
  import "./ops/cpu/qkv.js";
@@ -41,11 +41,11 @@ import g from "./tokeniser/bpe.js";
41
41
  import "./papaparse.min-C0cScC2i.js";
42
42
  import "./jszip.min-Bz5-11Bk.js";
43
43
  import "./ops/cpu/matMulGelu.js";
44
- import "./matMulGelu-B2s_80-H.js";
44
+ import "./matMulGelu-BpvgnYG8.js";
45
45
  import "./ops/grads/matMulGelu.js";
46
46
  import "./ops/cpu/gelu.js";
47
47
  import "./ops/webgl/gelu.js";
48
- import "./gelu-DqTbCx5x.js";
48
+ import "./gelu-BK1k-n1i.js";
49
49
  import "./ops/webgl/log.js";
50
50
  import "./ops/cpu/adamMoments.js";
51
51
  import "./ops/webgl/adamMoments.js";
package/dist/Trainer.d.ts CHANGED
@@ -1,7 +1,8 @@
1
- import { Conversation, ITokeniser } from './tokeniser/type';
1
+ import { ITokeniser } from './tokeniser/type';
2
2
  import { default as EE } from 'eventemitter3';
3
3
  import { TrainingLogEntry, TrainingProgress } from './training/Trainer';
4
4
  import { default as Model, ModelForwardAttributes } from './models/model';
5
+ import { Task } from './training/tasks/Task';
5
6
  export interface ITrainerOptions {
6
7
  batchSize?: number;
7
8
  learningRate?: number;
@@ -30,7 +31,8 @@ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
30
31
  constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser);
31
32
  stop(): void;
32
33
  reset(): void;
33
- prepare(text: Conversation[][], options?: ITrainerOptions): Promise<void>;
34
+ getTotalSamples(): number;
35
+ prepare(tasks?: Task[] | Uint16Array, options?: ITrainerOptions): Promise<void>;
34
36
  train(options?: ITrainerOptions): Promise<void>;
35
37
  step(options?: ITrainerOptions): Promise<void>;
36
38
  getLog(): TrainingLogEntry[];
package/dist/Trainer.js CHANGED
@@ -1,6 +1,6 @@
1
- import { E as o } from "./index-DvYrXKkX.js";
2
- import d from "./training/FullTrainer.js";
3
- class g extends o {
1
+ import { E as n } from "./index-DvYrXKkX.js";
2
+ import l from "./training/FullTrainer.js";
3
+ class p extends n {
4
4
  trainer;
5
5
  hasTrained = !1;
6
6
  trainDataset;
@@ -9,7 +9,7 @@ class g extends o {
9
9
  log = [];
10
10
  progress = null;
11
11
  constructor(t, e) {
12
- super(), this.trainer = new d(t, e, 1e-3);
12
+ super(), this.trainer = new l(t, e, 1e-3);
13
13
  }
14
14
  stop() {
15
15
  this.trainer.stop();
@@ -17,13 +17,16 @@ class g extends o {
17
17
  reset() {
18
18
  this.hasTrained = !1, this.log = [], this.trainer.reset();
19
19
  }
20
- async prepare(t, e) {
21
- const { trainDataset: a, validationDataset: s } = await this.trainer.createTrainValidationSplit(
20
+ getTotalSamples() {
21
+ return this.totalSamples;
22
+ }
23
+ async prepare(t = [], e) {
24
+ const { trainDataset: a, validationDataset: s, size: i } = await this.trainer.createTrainValidationSplit(
22
25
  t,
23
26
  e?.batchSize || 32,
24
27
  e?.validationSplit || 0.1
25
- ), i = t.reduce((r, n) => r + n.reduce((l, h) => l + h.content.length, 0), 0) * (1 - (e?.validationSplit || 0));
26
- this.trainDataset = a, this.validationDataset = s, this.totalSamples = i;
28
+ ), r = i * (1 - (e?.validationSplit || 0));
29
+ this.trainDataset = a, this.validationDataset = s, this.totalSamples = r;
27
30
  }
28
31
  async train(t) {
29
32
  if (!this.trainDataset || !this.validationDataset)
@@ -91,5 +94,5 @@ class g extends o {
91
94
  }
92
95
  }
93
96
  export {
94
- g as default
97
+ p as default
95
98
  };
@@ -1,4 +1,4 @@
1
- import { y as c } from "./index-D0RBWjq8.js";
1
+ import { x as c } from "./index-Duu1Lvvv.js";
2
2
  function i(e, n) {
3
3
  for (let t = 0; t < e.length; ++t)
4
4
  if (e[e.length - t - 1] !== n - 1 - t)
package/dist/backend.js CHANGED
@@ -1,9 +1,9 @@
1
- import { g as o, s as e, r as s } from "./index-D0RBWjq8.js";
1
+ import { g as o, s as e, r as s } from "./index-Duu1Lvvv.js";
2
2
  async function c(t, a) {
3
3
  if (o() !== t) {
4
4
  if (t === "webgpu") {
5
5
  const { registerWebGPUBackend: i } = await import("./patches/webgpu_base.js");
6
- i(a), await import("./index-Dj5TkmPY.js"), await import("./ops/webgpu/index.js");
6
+ i(a), await import("./index-BBVLAXZD.js"), await import("./ops/webgpu/index.js");
7
7
  }
8
8
  await e(t), await s(), console.log(`Backend set to ${t}`);
9
9
  }
@@ -1,10 +1,10 @@
1
- import { V as m, a9 as w, aU as I, y as d, ax as A, aB as _, $ as y, ad as M, a0 as T, aV as b, ak as D, aW as x } from "./index-D0RBWjq8.js";
2
- import { d as L, f as W, h as v, c as F, e as N, a as C, b as P, g as z } from "./axis_util-DofAuy0p.js";
3
- import { a as B, c as U } from "./concat_util-CHsJFZJJ.js";
4
- import { c as V, b as G, d as H, f as j, g as q, h as Z, i as k, j as J, k as K, m as X, t as Y } from "./step-dXR33iOg.js";
5
- import { S as Q, a as ee, b as te, g as se, c as ne, s as re } from "./selu_util-BuLbmbrl.js";
6
- import { s as oe } from "./slice_util-D8CQRenR.js";
7
- import { c as ae, v as ie, a as ue } from "./scatter_nd_util-BY22Cc-C.js";
1
+ import { U as m, a8 as w, aU as I, x as d, ax as A, aB as _, _ as y, ac as M, $ as T, aV as b, ak as D, aW as x } from "./index-Duu1Lvvv.js";
2
+ import { d as L, f as W, h as v, c as F, e as N, a as C, b as P, g as U } from "./axis_util-DGqbT-FX.js";
3
+ import { a as z, c as B } from "./concat_util-D0je5Ppu.js";
4
+ import { c as V, b as G, d as H, f as j, g as q, h as Z, i as k, j as J, k as K, m as X, t as Y } from "./step-BS5JXRR6.js";
5
+ import { S as Q, a as ee, b as te, g as se, c as ne, s as re } from "./selu_util-DyW0X1WG.js";
6
+ import { s as oe } from "./slice_util-zN8KFC5I.js";
7
+ import { c as ae, v as ie, a as ue } from "./scatter_nd_util-k9MUVUkn.js";
8
8
  import { a as le, c as pe, b as ce, e as he, d as fe, g as ge, m as de, s as me } from "./complex_util-Yc1A_gV1.js";
9
9
  function Ee(e, t) {
10
10
  const r = e.shape.length, s = t.shape.length;
@@ -204,7 +204,7 @@ function Pe(e, t, r) {
204
204
  s[t[n][a]] === void 0 ? s[t[n][a]] = o[a] : d(s[t[n][a]] === o[a], () => `Expected dimension ${s[t[n][a]]} at axis ${a} of input shaped ${JSON.stringify(o)}, but got dimension ${o[a]}`);
205
205
  }
206
206
  }
207
- function ze(e, t) {
207
+ function Ue(e, t) {
208
208
  const r = e, s = [];
209
209
  let n = 0;
210
210
  e.length === 0 && r.push(-1), n = e.length + 1;
@@ -212,16 +212,16 @@ function ze(e, t) {
212
212
  s.push([]);
213
213
  const o = [];
214
214
  for (let a = 0; a < r.length; ++a) {
215
- const u = r[a], p = Ue(t, u);
215
+ const u = r[a], p = Be(t, u);
216
216
  for (const c of p)
217
217
  o.indexOf(c) === -1 && (s[a].push(c), o.push(c));
218
218
  }
219
219
  return { path: r, steps: s };
220
220
  }
221
- function Be(e) {
221
+ function ze(e) {
222
222
  return e.every((t, r) => t === r);
223
223
  }
224
- function Ue(e, t) {
224
+ function Be(e, t) {
225
225
  const r = [];
226
226
  for (let s = 0; s < e.length; ++s)
227
227
  (e[s].length === 0 || e[s].indexOf(t) !== -1 || t === -1) && r.push(s);
@@ -352,7 +352,7 @@ const dt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
352
352
  applyActivation: te,
353
353
  assertAndGetBroadcastShape: y,
354
354
  assertAxesAreInnerMostDims: L,
355
- assertParamsConsistent: B,
355
+ assertParamsConsistent: z,
356
356
  assignToTypedArray: le,
357
357
  axesAreInnerMostDims: W,
358
358
  calculateShapes: ae,
@@ -368,7 +368,7 @@ const dt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
368
368
  computeDilation2DInfo: q,
369
369
  computeOptimalWindowSize: Oe,
370
370
  computeOutAndReduceShapes: F,
371
- computeOutShape: U,
371
+ computeOutShape: B,
372
372
  computePool2DInfo: Z,
373
373
  computePool3DInfo: k,
374
374
  convertConv2DDataFormat: J,
@@ -382,7 +382,7 @@ const dt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
382
382
  getAxesPermutation: C,
383
383
  getBroadcastDims: M,
384
384
  getComplexWithIndex: ge,
385
- getEinsumComputePath: ze,
385
+ getEinsumComputePath: Ue,
386
386
  getEinsumPermutation: Ce,
387
387
  getFusedBiasGradient: se,
388
388
  getFusedDyActivation: ne,
@@ -408,8 +408,8 @@ const dt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
408
408
  getSparseSegmentReductionNegativeSegmentIdsErrorMessage: Xe,
409
409
  getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: Ye,
410
410
  getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: Qe,
411
- getUndoAxesPermutation: z,
412
- isIdentityPermutation: Be,
411
+ getUndoAxesPermutation: U,
412
+ isIdentityPermutation: ze,
413
413
  log: b,
414
414
  mergeRealAndImagArrays: de,
415
415
  prepareAndValidate: Ee,
@@ -434,9 +434,9 @@ export {
434
434
  be as E,
435
435
  Ne as F,
436
436
  Pe as G,
437
- ze as H,
437
+ Ue as H,
438
438
  Ce as I,
439
- Be as J,
439
+ ze as J,
440
440
  Ee as K,
441
441
  nt as L,
442
442
  we as M,
@@ -1,6 +1,6 @@
1
- import { ab as g, au as $, av as K, e as D, y as _, aw as O, V as x, ax as Z, at as W, ay as F, az as j, aA as X, aB as J, ae as ee, a9 as k } from "./index-D0RBWjq8.js";
2
- import { m as te, f as se, P as re } from "./webgpu_program-CAE4RICo.js";
3
- import { i as ne, G as q } from "./webgpu_util-BdovYhXr.js";
1
+ import { aa as g, au as $, av as K, e as D, x as _, aw as O, U as x, ax as Z, at as W, ay as F, az as j, aA as X, aB as J, ae as ee, a8 as k } from "./index-Duu1Lvvv.js";
2
+ import { m as te, f as se, P as re } from "./webgpu_program-BpWRlghH.js";
3
+ import { i as ne, G as q } from "./webgpu_util-DMiKzzQM.js";
4
4
  import { m as N } from "./complex_util-Yc1A_gV1.js";
5
5
  const d = g();
6
6
  d.registerFlag("WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE", () => 15);
@@ -264,7 +264,7 @@ class R extends $ {
264
264
  alphaMode: r[B]
265
265
  }), y.getCurrentTexture();
266
266
  }).map((E, B) => {
267
- const y = f * 4, G = (P, S, v) => {
267
+ const y = f * 4, b = (P, S, v) => {
268
268
  this.ensureCommandEncoderReady(), this.commandEncoder.copyBufferToTexture({
269
269
  buffer: a,
270
270
  bytesPerRow: y,
@@ -279,20 +279,20 @@ class R extends $ {
279
279
  willReadFrequently: !0
280
280
  });
281
281
  I.clearRect(0, 0, P, S), I.drawImage(h[B], 0, 0);
282
- const b = I.getImageData(0, 0, P, S).data, H = r[B], M = new Uint8ClampedArray(o, v, P * S * 4);
282
+ const G = I.getImageData(0, 0, P, S).data, H = r[B], M = new Uint8ClampedArray(o, v, P * S * 4);
283
283
  for (let p = 0; p < M.length; p += 4)
284
284
  if (H === "premultiplied")
285
- M[p + 3] = b[p + 3];
285
+ M[p + 3] = G[p + 3];
286
286
  else {
287
- const V = b[p];
288
- M[p] = b[p + 2], M[p + 1] = b[p + 1], M[p + 2] = V;
287
+ const V = G[p];
288
+ M[p] = G[p + 2], M[p + 1] = G[p + 1], M[p + 2] = V;
289
289
  }
290
290
  }, Y = Math.floor(u / (f * c));
291
291
  let T = f, U = c, C = 0;
292
292
  for (let P = 0; P < Y; P++)
293
- G(T, U, C), C += f * c * 4;
293
+ b(T, U, C), C += f * c * 4;
294
294
  const A = u % (f * c);
295
- U = Math.floor(A / f), U > 0 && (G(T, U, C), C += U * (f * 4)), T = A % f, T > 0 && G(T, 1, C);
295
+ U = Math.floor(A / f), U > 0 && (b(T, U, C), C += U * (f * 4)), T = A % f, T > 0 && b(T, 1, C);
296
296
  });
297
297
  const w = O(o, t.dtype);
298
298
  return this.convertAndCacheOnCPU(e, w), w;
@@ -1,5 +1,5 @@
1
- import { q as h, u as f, w as p, x as g, E as u, T } from "./index-D0RBWjq8.js";
2
- import { r as b } from "./reshape-CkjKPPqB.js";
1
+ import { o as h, q as f, u as p, w as g, E as u, T } from "./index-Duu1Lvvv.js";
2
+ import { r as b } from "./reshape-BI0yzp1T.js";
3
3
  function m(e, r) {
4
4
  let n = f(e, "broadcastTo", "x");
5
5
  const a = n.shape;
@@ -1,5 +1,5 @@
1
- import { s, e as a } from "../index-D0RBWjq8.js";
2
- import { t } from "../tensor4d-DVwr7pLF.js";
1
+ import { s, e as a } from "../index-Duu1Lvvv.js";
2
+ import { t } from "../tensor4d-vuDDgdUI.js";
3
3
  async function u(e) {
4
4
  await s(e);
5
5
  const n = t(
@@ -1,6 +1,6 @@
1
- import { s as i, e } from "../index-D0RBWjq8.js";
2
- import { t } from "../tensor4d-DVwr7pLF.js";
3
- import { t as a } from "../tensor2d-BN1sSfQO.js";
1
+ import { s as i, e } from "../index-Duu1Lvvv.js";
2
+ import { t } from "../tensor4d-vuDDgdUI.js";
3
+ import { t as a } from "../tensor2d-BN97fF71.js";
4
4
  async function k(n) {
5
5
  await i(n);
6
6
  const s = t(
@@ -1,5 +1,5 @@
1
- import { s as e, e as o } from "../index-D0RBWjq8.js";
2
- import { t as s } from "../tensor2d-BN1sSfQO.js";
1
+ import { s as e, e as o } from "../index-Duu1Lvvv.js";
2
+ import { t as s } from "../tensor2d-BN97fF71.js";
3
3
  async function m(t) {
4
4
  await e(t);
5
5
  const r = s(
@@ -1,5 +1,5 @@
1
- import { s as o, e as s } from "../index-D0RBWjq8.js";
2
- import { t as e } from "../tensor2d-BN1sSfQO.js";
1
+ import { s as o, e as s } from "../index-Duu1Lvvv.js";
2
+ import { t as e } from "../tensor2d-BN97fF71.js";
3
3
  async function i(t) {
4
4
  await o(t);
5
5
  const r = e(
@@ -1,7 +1,7 @@
1
- import { s as u, a1 as A, e as y } from "../index-D0RBWjq8.js";
2
- import { a as h } from "../ops-FJapAPfm.js";
3
- import { t as p } from "../tensor1d-LxP9asMm.js";
4
- import { t as r } from "../tensor-BQqrDvpx.js";
1
+ import { s as u, a0 as A, e as y } from "../index-Duu1Lvvv.js";
2
+ import { a as h } from "../ops-C2_OXuZ4.js";
3
+ import { t as p } from "../tensor1d-Cc_KCIDg.js";
4
+ import { t as r } from "../tensor-CEt9Nm2s.js";
5
5
  const w = Array.from({ length: 2048 * 192 }, () => Math.random()), x = Array.from({ length: 192 }, () => Math.random()), M = Array.from({ length: 2048 * 192 }, () => Math.random());
6
6
  async function k(t) {
7
7
  await u(t);
@@ -1,6 +1,6 @@
1
- import { s as c, e as d } from "../index-D0RBWjq8.js";
2
- import { t as f } from "../tensor1d-LxP9asMm.js";
3
- import { t as r } from "../tensor-BQqrDvpx.js";
1
+ import { s as c, e as d } from "../index-Duu1Lvvv.js";
2
+ import { t as f } from "../tensor1d-Cc_KCIDg.js";
3
+ import { t as r } from "../tensor-CEt9Nm2s.js";
4
4
  const y = Array.from({ length: 2048 * 192 }, () => Math.random()), i = Array.from({ length: 192 }, () => Math.random()), l = Array.from({ length: 2048 * 192 }, () => Math.random());
5
5
  async function x(t) {
6
6
  await c(t);
@@ -1,5 +1,5 @@
1
- import { s as a, e } from "../index-D0RBWjq8.js";
2
- import { t as c } from "../tensor2d-BN1sSfQO.js";
1
+ import { s as a, e } from "../index-Duu1Lvvv.js";
2
+ import { t as c } from "../tensor2d-BN97fF71.js";
3
3
  async function i(n) {
4
4
  await a(n);
5
5
  const r = c(
@@ -1,5 +1,5 @@
1
- import { X as i, Y as u, Z as c, s as l, e as h } from "../index-D0RBWjq8.js";
2
- import { t as f } from "../tensor2d-BN1sSfQO.js";
1
+ import { W as i, X as u, Y as c, s as l, e as h } from "../index-Duu1Lvvv.js";
2
+ import { t as f } from "../tensor2d-BN97fF71.js";
3
3
  function m(t, e, n) {
4
4
  if (i(t), e != null && e.length !== 3)
5
5
  throw new Error("tensor3d() requires shape to have three numbers");