@genai-fi/nanogpt 0.15.0 → 0.15.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (248) hide show
  1. package/dist/Generator.js +33 -31
  2. package/dist/{RealDiv-B2Tyc34U.js → RealDiv-CJpH9Bif.js} +13 -13
  3. package/dist/{Reshape-Bqk-z_7-.js → Reshape-C4ZzbS5c.js} +3 -3
  4. package/dist/{Reshape-D973Ba8R.js → Reshape-CKzb2DIN.js} +4 -4
  5. package/dist/TeachableLLM.d.ts +5 -0
  6. package/dist/TeachableLLM.js +30 -18
  7. package/dist/Trainer.d.ts +1 -0
  8. package/dist/Trainer.js +65 -62
  9. package/dist/{axis_util-RrJzDQJc.js → axis_util-BBaWKQoo.js} +1 -1
  10. package/dist/backend.js +2 -2
  11. package/dist/{backend_util-9wV3yg0r.js → backend_util-DLIicY0X.js} +50 -50
  12. package/dist/{backend_webgpu-CnFoGvzK.js → backend_webgpu-BwfUOSiJ.js} +21 -21
  13. package/dist/{broadcast_to-hAMmZJpr.js → broadcast_to-CxKUM6zp.js} +2 -2
  14. package/dist/checks/appendCache.js +2 -2
  15. package/dist/checks/attentionMask.js +3 -3
  16. package/dist/checks/gelu.js +2 -2
  17. package/dist/checks/matMulGelu.js +2 -2
  18. package/dist/checks/normRMS.js +6 -6
  19. package/dist/checks/normRMSGrad.js +3 -3
  20. package/dist/checks/packUnpack.js +2 -2
  21. package/dist/checks/qkv.js +2 -2
  22. package/dist/checks/rope.js +2 -2
  23. package/dist/clip_by_value-lDwNWeyI.js +12 -0
  24. package/dist/{complex-BDvCF_r9.js → complex-NXAORdbW.js} +1 -1
  25. package/dist/{concat-B9WckkXa.js → concat-DCm6KW65.js} +1 -1
  26. package/dist/{concat_util-DVNU-Nn3.js → concat_util-DT0Mofs3.js} +1 -1
  27. package/dist/{dataset-ZUdlBUXV.js → dataset-Bwcib9pp.js} +3 -3
  28. package/dist/dropout_util-Crmm4aOV.js +27 -0
  29. package/dist/{expand_dims-DoiHvcDw.js → expand_dims-DgU0Vlpg.js} +1 -1
  30. package/dist/{exports_initializers-8SQOHjAF.js → exports_initializers-VKuLTIiX.js} +1 -1
  31. package/dist/floor-Bhmfrtly.js +9 -0
  32. package/dist/{gather-BYhIiO5e.js → gather-FIoUa4Zd.js} +1 -1
  33. package/dist/{gelu-9_DFp2Q5.js → gelu-CmkPheOK.js} +1 -1
  34. package/dist/{gpgpu_math-Dzx_EUJa.js → gpgpu_math-D83bWKYw.js} +25 -25
  35. package/dist/{index-3FfEY3tm.js → index-D0b5F1JD.js} +58 -58
  36. package/dist/{index-B8eBIyjS.js → index-nwvWLdRt.js} +89 -89
  37. package/dist/{kernel_funcs_utils-BLvDeLPe.js → kernel_funcs_utils-Bu6bS4D_.js} +11 -11
  38. package/dist/layers/BaseLayer.d.ts +4 -0
  39. package/dist/layers/BaseLayer.js +11 -7
  40. package/dist/layers/CausalSelfAttention.js +55 -51
  41. package/dist/layers/LoRA.js +4 -4
  42. package/dist/layers/MLP.d.ts +1 -1
  43. package/dist/layers/MLP.js +20 -19
  44. package/dist/layers/PositionEmbedding.js +10 -10
  45. package/dist/layers/RMSNorm.js +3 -3
  46. package/dist/layers/RoPECache.js +4 -4
  47. package/dist/layers/TiedEmbedding.js +6 -6
  48. package/dist/layers/TransformerBlock.js +1 -1
  49. package/dist/layers/WeightStore.js +3 -3
  50. package/dist/loader/loadTransformers.js +1 -1
  51. package/dist/loader/oldZipLoad.js +20 -18
  52. package/dist/loader/save.js +6 -5
  53. package/dist/loader/types.d.ts +2 -0
  54. package/dist/main.js +9 -9
  55. package/dist/{matMul16-Bp17gt56.js → matMul16-bI7XM831.js} +3 -3
  56. package/dist/{matMulGelu-Bdxn3VPX.js → matMulGelu-Cbtq3pxJ.js} +21 -21
  57. package/dist/{mat_mul-BUuYg3qo.js → mat_mul-BQY_GSqm.js} +1 -1
  58. package/dist/{mod-4q-X1J5l.js → mod-ChddM4vN.js} +1 -1
  59. package/dist/models/NanoGPTV1.js +9 -9
  60. package/dist/models/NanoGPTV2.js +12 -10
  61. package/dist/models/model.d.ts +1 -1
  62. package/dist/models/model.js +14 -12
  63. package/dist/not_equal-duCIyEXv.js +64 -0
  64. package/dist/{ones-aGZXepq3.js → ones-Piv0gZxv.js} +3 -3
  65. package/dist/ops/adamAdjust.js +1 -1
  66. package/dist/ops/adamMoments.js +1 -1
  67. package/dist/ops/add16.js +1 -1
  68. package/dist/ops/appendCache.js +3 -3
  69. package/dist/ops/attentionMask.js +1 -1
  70. package/dist/ops/concat16.js +2 -2
  71. package/dist/ops/cpu/adamAdjust.js +1 -1
  72. package/dist/ops/cpu/adamMoments.js +2 -2
  73. package/dist/ops/cpu/appendCache.js +2 -2
  74. package/dist/ops/cpu/attentionMask.js +6 -6
  75. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  76. package/dist/ops/cpu/gatherSub.js +6 -6
  77. package/dist/ops/cpu/gelu.js +1 -1
  78. package/dist/ops/cpu/matMul16.js +2 -2
  79. package/dist/ops/cpu/matMulGelu.js +3 -3
  80. package/dist/ops/cpu/matMulMul.js +1 -1
  81. package/dist/ops/cpu/mulDropout.js +1 -1
  82. package/dist/ops/cpu/normRMS.js +1 -1
  83. package/dist/ops/cpu/qkv.js +3 -3
  84. package/dist/ops/cpu/rope.js +5 -5
  85. package/dist/ops/cpu/scatterSub.js +9 -9
  86. package/dist/ops/dot16.js +2 -2
  87. package/dist/ops/dropout.d.ts +2 -0
  88. package/dist/ops/dropout.js +14 -0
  89. package/dist/ops/dropout16.d.ts +2 -0
  90. package/dist/ops/dropout16.js +25 -0
  91. package/dist/ops/gatherSub.js +1 -1
  92. package/dist/ops/gelu.js +2 -2
  93. package/dist/ops/globalNorm.js +2 -2
  94. package/dist/ops/grads/add16.js +1 -1
  95. package/dist/ops/grads/attentionMask.js +2 -2
  96. package/dist/ops/grads/dropout16.d.ts +1 -0
  97. package/dist/ops/grads/dropout16.js +2 -0
  98. package/dist/ops/grads/gelu.js +2 -2
  99. package/dist/ops/grads/matMul16.js +3 -3
  100. package/dist/ops/grads/matMulGelu.js +1 -1
  101. package/dist/ops/grads/mul16.d.ts +1 -0
  102. package/dist/ops/grads/mul16.js +4 -0
  103. package/dist/ops/grads/normRMS.js +1 -1
  104. package/dist/ops/grads/pack16.js +3 -3
  105. package/dist/ops/grads/qkv.js +3 -3
  106. package/dist/ops/grads/rope.js +2 -2
  107. package/dist/ops/grads/softmax16.js +1 -1
  108. package/dist/ops/grads/unpack16.js +2 -2
  109. package/dist/ops/matMul16.js +3 -3
  110. package/dist/ops/matMulGelu.js +2 -2
  111. package/dist/ops/matMulMul.js +1 -1
  112. package/dist/ops/mul16.js +36 -5
  113. package/dist/ops/mulDrop.js +1 -1
  114. package/dist/ops/normRMS.js +13 -4
  115. package/dist/ops/pack16.js +2 -2
  116. package/dist/ops/qkv.js +1 -1
  117. package/dist/ops/reshape16.js +2 -2
  118. package/dist/ops/rope.js +2 -2
  119. package/dist/ops/scatterSub.js +1 -1
  120. package/dist/ops/slice16.js +2 -2
  121. package/dist/ops/softmax16.js +1 -1
  122. package/dist/ops/sub16.js +1 -1
  123. package/dist/ops/sum16.js +2 -2
  124. package/dist/ops/transpose16.js +3 -3
  125. package/dist/ops/unpack16.js +2 -2
  126. package/dist/ops/webgl/adamAdjust.js +2 -2
  127. package/dist/ops/webgl/adamMoments.js +1 -1
  128. package/dist/ops/webgl/appendCache.js +1 -1
  129. package/dist/ops/webgl/attentionMask.js +1 -1
  130. package/dist/ops/webgl/dropout16.d.ts +1 -0
  131. package/dist/ops/webgl/dropout16.js +11 -0
  132. package/dist/ops/webgl/fusedSoftmax.js +6 -6
  133. package/dist/ops/webgl/gatherSub.js +1 -1
  134. package/dist/ops/webgl/gelu.js +2 -2
  135. package/dist/ops/webgl/log.js +3 -3
  136. package/dist/ops/webgl/matMul16.js +5 -5
  137. package/dist/ops/webgl/matMulGelu.js +4 -4
  138. package/dist/ops/webgl/matMulMul.js +2 -2
  139. package/dist/ops/webgl/mulDropout.js +1 -1
  140. package/dist/ops/webgl/normRMS.js +2 -2
  141. package/dist/ops/webgl/qkv.js +1 -1
  142. package/dist/ops/webgl/rope.js +1 -1
  143. package/dist/ops/webgl/scatterSub.js +1 -1
  144. package/dist/ops/webgpu/adamAdjust.js +3 -3
  145. package/dist/ops/webgpu/adamMoments.js +3 -3
  146. package/dist/ops/webgpu/add16.js +1 -1
  147. package/dist/ops/webgpu/appendCache.js +3 -3
  148. package/dist/ops/webgpu/attentionMask.js +2 -2
  149. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  150. package/dist/ops/webgpu/clipScale.js +1 -1
  151. package/dist/ops/webgpu/concat16.js +12 -12
  152. package/dist/ops/webgpu/dropout16.d.ts +1 -0
  153. package/dist/ops/webgpu/dropout16.js +51 -0
  154. package/dist/ops/webgpu/gatherSub.js +3 -3
  155. package/dist/ops/webgpu/gelu.js +3 -3
  156. package/dist/ops/webgpu/index.js +1 -0
  157. package/dist/ops/webgpu/matMul16.js +14 -14
  158. package/dist/ops/webgpu/matMul16_program.js +2 -2
  159. package/dist/ops/webgpu/mul16.js +9 -9
  160. package/dist/ops/webgpu/norm2.js +1 -1
  161. package/dist/ops/webgpu/normRMS.js +2 -2
  162. package/dist/ops/webgpu/normRMSGrad.js +4 -4
  163. package/dist/ops/webgpu/pack16.js +1 -1
  164. package/dist/ops/webgpu/pack16_program.js +2 -2
  165. package/dist/ops/webgpu/qkv.js +2 -2
  166. package/dist/ops/webgpu/rope.js +3 -3
  167. package/dist/ops/webgpu/scatterSub.js +3 -3
  168. package/dist/ops/webgpu/slice16.js +4 -4
  169. package/dist/ops/webgpu/softmax16.js +2 -2
  170. package/dist/ops/webgpu/softmax16_program.js +2 -2
  171. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  172. package/dist/ops/webgpu/softmax16grad.js +1 -1
  173. package/dist/ops/webgpu/sub16.js +1 -1
  174. package/dist/ops/webgpu/sum16.js +5 -5
  175. package/dist/ops/webgpu/transpose16.js +2 -2
  176. package/dist/ops/webgpu/transpose16_program.js +2 -2
  177. package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
  178. package/dist/ops/webgpu/unpack16.js +3 -3
  179. package/dist/ops/webgpu/utils/binary_op.d.ts +16 -0
  180. package/dist/ops/webgpu/utils/binary_op.js +74 -13
  181. package/dist/ops/webgpu/utils/reductions.js +5 -5
  182. package/dist/{ops-BLDakU_V.js → ops-BXr-37bF.js} +30 -30
  183. package/dist/{pack16-F9gxcBrq.js → pack16-DO9GrRdk.js} +2 -2
  184. package/dist/patches/webgpu_backend.js +9 -9
  185. package/dist/patches/webgpu_base.js +1 -1
  186. package/dist/patches/webgpu_program.js +2 -2
  187. package/dist/rand_util-CZ7yLoUm.js +50 -0
  188. package/dist/random_normal-CO9xf9dz.js +14 -0
  189. package/dist/{random_width-DSeITIFc.js → random_width-CliSj-et.js} +164 -162
  190. package/dist/{range-BvA7g6TS.js → range-Dx4PwA2-.js} +1 -1
  191. package/dist/{readers-lNVRVUDO.js → readers-DwZhCW0C.js} +2 -2
  192. package/dist/{relu-DyGjd4UV.js → relu-BnpM8PVa.js} +1 -1
  193. package/dist/{reshape-3ugLpT-p.js → reshape-DVh8yLpI.js} +1 -1
  194. package/dist/{resize_nearest_neighbor-DBPfHMkZ.js → resize_nearest_neighbor-Dl7ehaQl.js} +39 -39
  195. package/dist/{rope-D5BJXlc7.js → rope-DjON_IMj.js} +1 -1
  196. package/dist/{scatter_nd_util-6lhBuxGa.js → scatter_nd_util-SSoGmfpx.js} +1 -1
  197. package/dist/{selu_util-emNhirms.js → selu_util-C0DN3KhX.js} +5 -5
  198. package/dist/{shared-Wn4Lkf40.js → shared-CefTy5O1.js} +1 -1
  199. package/dist/{shared-DeC0UJkK.js → shared-DgNUoqSc.js} +35 -35
  200. package/dist/{slice-C1VU5kjs.js → slice-BluUPHKL.js} +1 -1
  201. package/dist/{slice_util-5UIO9Akz.js → slice_util-DK4kHJjN.js} +1 -1
  202. package/dist/{softmax-BSXRSMAA.js → softmax-HULrSwJC.js} +1 -1
  203. package/dist/{split-Z_OF59mV.js → split-QwVeUPZt.js} +1 -1
  204. package/dist/{squeeze-DuB_IYFY.js → squeeze-Brkwo5OI.js} +2 -2
  205. package/dist/{stack-CdjLGyjr.js → stack-C_8ubcjt.js} +1 -1
  206. package/dist/{step-CA-PdcE1.js → step-wz0MZ7BP.js} +1 -1
  207. package/dist/{sum-CX6lFpfv.js → sum-iKJXG43N.js} +1 -1
  208. package/dist/{tensor-BLWBtdey.js → tensor-Dfy8cN1y.js} +1 -1
  209. package/dist/{tensor1d-Dp80hTtj.js → tensor1d-CoOFcAZs.js} +1 -1
  210. package/dist/{tensor2d-DryAvP1o.js → tensor2d-C8gFDiIC.js} +1 -1
  211. package/dist/{tensor4d-BR5YioKH.js → tensor4d-Bvqzr_Wu.js} +1 -1
  212. package/dist/{tfjs_backend-BuO7pU2h.js → tfjs_backend-9QO-TAAZ.js} +275 -295
  213. package/dist/{tile-CB7Cg2Cm.js → tile-CcpklBqG.js} +1 -1
  214. package/dist/training/AdamW.js +2 -2
  215. package/dist/training/BasicTrainer.d.ts +6 -0
  216. package/dist/training/BasicTrainer.js +74 -60
  217. package/dist/training/DatasetBuilder.js +3 -3
  218. package/dist/training/Evaluator.js +2 -2
  219. package/dist/training/SFTDatasetBuilder.js +3 -3
  220. package/dist/training/SFTTrainer.js +6 -6
  221. package/dist/training/loss.d.ts +1 -1
  222. package/dist/training/loss.js +12 -8
  223. package/dist/training/orthoGrad.js +1 -1
  224. package/dist/training/sparseCrossEntropy.d.ts +2 -2
  225. package/dist/training/sparseCrossEntropy.js +54 -31
  226. package/dist/training/types.d.ts +4 -0
  227. package/dist/training/validation.js +19 -17
  228. package/dist/{transpose-COw0-lqd.js → transpose-CwEYsCv1.js} +2 -2
  229. package/dist/{unsorted_segment_sum-C23hrdi0.js → unsorted_segment_sum-DRVX2bX2.js} +22 -22
  230. package/dist/utilities/dummy.js +2 -2
  231. package/dist/utilities/multinomialCPU.js +2 -2
  232. package/dist/utilities/packed.js +1 -1
  233. package/dist/utilities/parameters.d.ts +1 -0
  234. package/dist/utilities/parameters.js +20 -15
  235. package/dist/utilities/performance.js +1 -1
  236. package/dist/utilities/profile.js +1 -1
  237. package/dist/utilities/safetensors.js +2 -2
  238. package/dist/utilities/sentences.js +5 -5
  239. package/dist/utilities/weights.js +2 -2
  240. package/dist/{variable-lnPOlwsK.js → variable-CqrRzzxM.js} +1 -1
  241. package/dist/{webgpu_program-CuMK2hhh.js → webgpu_program-BlAY4Q29.js} +1 -1
  242. package/dist/{webgpu_util-DWXgz54K.js → webgpu_util-D1Ynuktt.js} +1 -1
  243. package/dist/{zeros-BJogAj4Z.js → zeros-B8VPk-mx.js} +2 -2
  244. package/dist/{zeros_like-WQK7VrX-.js → zeros_like-DfWM-ezN.js} +90 -89
  245. package/package.json +1 -1
  246. package/dist/floor-B6EO3Z6x.js +0 -18
  247. package/dist/not_equal-BO_DB61m.js +0 -64
  248. package/dist/random_normal-dxcPUb9x.js +0 -61
package/dist/Generator.js CHANGED
@@ -1,40 +1,40 @@
1
1
  import { E as Ui } from "./index-DvYrXKkX.js";
2
- import { o as Hi, q as Xi, E as Ki, dn as Ss, am as pe, a8 as _, ar as oo, as as ao, e as Oe, a_ as Dt, ax as ro, ay as io, at as ji, au as Ft, aD as Ge, ab as co, aG as ws, aH as qi, N as G, ac as _e, aI as Yi, D as Ns, aJ as Rs, R as Qi, af as Zi, x as te, B as lo, Y as Ne, a6 as ee, bd as uo, c9 as Ts, ca as Es, cQ as po, bo as ho, ad as ue, L as Ye, bp as fo, bq as mo, cb as go, cc as Ds, cd as Fs, ce as Ps, cf as Os, cg as As, br as xo, a9 as nt, cA as Co, cR as bo, cS as Io, bu as yo, bt as ko, bf as $o, dg as vo, aj as _s, cU as So, an as wo, C as No, bv as Ro, c4 as $e, cF as To, bw as Eo, cB as Do, cV as Fo, cC as Po, bx as Ls, by as Vs, bh as Oo, bz as Ao, ai as Qe, Q as Ji, bA as _o, cD as Lo, ch as Vo, bB as Wo, cH as Mo, cI as Bo, dh as Go, ci as zo, bZ as ns, bT as St, bW as Ws, cW as yn, dv as ut, dw as Uo, cX as kn, di as ec, K as tc, bg as Ho, cY as Xo, bD as Ms, z as Ko, aX as jo, aV as mt, cr as qo, de as Yo, df as Qo, bi as Zo, cG as Jo, dk as ea, ah as ta, G as sa, cs as na, cj as Bs, ck as Gs, cl as zs, dl as oa, b5 as Us, b6 as Hs, bF as Xs, cn as Ks, cm as aa, c_ as ra, aM as sc, bG as ia, cE as ca, c$ as la, d0 as ua, dm as da, aP as pa, a$ as ha, co as fa, bS as ma, M as js, F as ga, bk as xa, bm as Ca, bl as ba, bH as Ia, d5 as ya, bI as ka, P as $a, a4 as va, bJ as Sa, d1 as qs, dx as wa, dy as Na, dz as Ra, X as Ta, cp as Ys, bb as Ea, d2 as Da, bc as Fa, d3 as Pa, bL as Oa, bj as Aa, b8 as Qs, ag as _a, dp as La, aL as Va, bN as Zs, cq as Js, bO as en, bP as tn, bE as sn, bK as Wa, dA as Ma, dB as Ba, dq as Ga, dr as za, ds as Ua, H as Ha, d4 as Xa, aK as nn, ct as Ka, dt as ja, dC as qa, dD as Ya, cu as on, bs as an, du as Qa, T as Za, cv as Ja, bn as er, a3 as rn, cw as tr, ba as sr, bQ as nr, c as or, dE as ar, dF as $n, av as vn, aw as nc, t as rr, a as oc, dG as ac, dH as rc, c2 as ic, aq as cc, bR as lc, bX as uc, S as dc, bY as pc, aQ as hc, ap as fc, bU as mc, bV as gc, b_ as xc, aS as ir, bC as Cc, aN as bc, b$ as Ic, ak as yc, c0 as kc, dj as $c, b1 as vc, b2 as Sc, b3 as wc, b4 as Nc, aO as Rc, c1 as Tc, b7 as Ec, c7 as Dc, ao as Fc, c3 as Pc, aW as cr, bM as Oc, aF as Ac, c5 as _c, b9 as Lc, c6 as Vc, k as Wc } from "./index-3FfEY3tm.js";
3
- import { n as Mc } from "./random_width-DSeITIFc.js";
4
- import { t as Bc } from "./zeros_like-WQK7VrX-.js";
2
+ import { o as Hi, q as Xi, E as Ki, dn as Ss, a5 as pe, ab as _, as as oo, at as ao, e as Oe, a_ as Dt, ay as ro, az as io, au as ji, av as Ft, aD as Ge, ae as co, aG as ws, aH as qi, U as G, af as _e, aI as Yi, H as Ns, aJ as Rs, R as Qi, aj as Zi, x as te, D as lo, _ as Ne, a9 as ee, bd as uo, c9 as Ts, ca as Es, cQ as po, bo as ho, ah as ue, Q as Ye, bp as fo, bq as mo, cb as go, cc as Ds, cd as Fs, ce as Ps, cf as Os, cg as As, br as xo, ac as nt, cA as Co, cR as bo, cS as Io, bu as yo, bt as ko, bf as $o, dg as vo, C as _s, cU as So, ao as wo, z as No, bv as Ro, c4 as $e, cF as To, bw as Eo, cB as Do, cV as Fo, cC as Po, bx as Ls, by as Vs, bh as Oo, bz as Ao, am as Qe, V as Ji, bA as _o, cD as Lo, ch as Vo, bB as Wo, cH as Mo, cI as Bo, dh as Go, ci as zo, bZ as ns, bT as St, bW as Ws, cW as yn, dv as ut, dw as Uo, cX as kn, di as ec, N as tc, bg as Ho, cY as Xo, bD as Ms, A as Ko, aX as jo, aV as mt, cr as qo, de as Yo, df as Qo, bi as Zo, cG as Jo, dk as ea, al as ta, G as sa, cs as na, cj as Bs, ck as Gs, cl as zs, dl as oa, b5 as Us, b6 as Hs, bF as Xs, cn as Ks, cm as aa, c_ as ra, aM as sc, bG as ia, cE as ca, c$ as la, d0 as ua, dm as da, aP as pa, a$ as ha, co as fa, bS as ma, M as js, I as ga, bk as xa, bm as Ca, bl as ba, bH as Ia, d5 as ya, bI as ka, P as $a, a7 as va, bJ as Sa, d1 as qs, dx as wa, dy as Na, dz as Ra, Z as Ta, cp as Ys, bb as Ea, d2 as Da, bc as Fa, d3 as Pa, bL as Oa, bj as Aa, b8 as Qs, ak as _a, dp as La, aL as Va, bN as Zs, cq as Js, bO as en, bP as tn, bE as sn, bK as Wa, dA as Ma, dB as Ba, dq as Ga, dr as za, ds as Ua, J as Ha, d4 as Xa, aK as nn, ct as Ka, dt as ja, dC as qa, dD as Ya, cu as on, bs as an, du as Qa, T as Za, cv as Ja, bn as er, a6 as rn, cw as tr, ba as sr, bQ as nr, c as or, dE as ar, dF as $n, aw as vn, ax as nc, t as rr, a as oc, dG as ac, dH as rc, c2 as ic, ar as cc, bR as lc, bX as uc, S as dc, bY as pc, aQ as hc, aq as fc, bU as mc, bV as gc, b_ as xc, aS as ir, bC as Cc, aN as bc, b$ as Ic, F as yc, c0 as kc, dj as $c, b1 as vc, b2 as Sc, b3 as wc, b4 as Nc, aO as Rc, c1 as Tc, b7 as Ec, c7 as Dc, ap as Fc, c3 as Pc, aW as cr, bM as Oc, aF as Ac, c5 as _c, b9 as Lc, c6 as Vc, k as Wc } from "./index-D0b5F1JD.js";
3
+ import { n as Mc } from "./random_width-CliSj-et.js";
4
+ import { t as Bc } from "./zeros_like-DfWM-ezN.js";
5
5
  import "./index-Cp39cXWe.js";
6
- import "./dataset-ZUdlBUXV.js";
7
- import { a as j, u as ae, c as ot, i as at, b as Gc, d as wt, t as Re, e as gt, f as dt, g as lr, r as Nt, h as Ae, j as zc, k as Uc, l as cn, z as Hc, m as ln, n as ur, o as Xc, p as Kc, q as jc, v as qc, w as Yc, x as Qc, y as Zc, A as Jc, B as el, C as tl, D as lt, E as sl, F as nl, G as dr, H as ol, I as al, J as rl, K as il, L as cl, M as ll, N as ul, O as dl, P as pl, Q as hl, R as fl, S as ml, T as gl, U as xl, V as Cl, W as bl, X as Il, Y as yl, Z as kl, _ as $l, $ as vl, a0 as Sl, a1 as wl, a2 as Nl, a3 as Rl, a4 as Tl, a5 as El, a6 as Dl, a7 as Fl, a8 as Pl, a9 as Ol, aa as Al, ab as _l, ac as Ll, ad as Vl, ae as Wl, af as Ml, ag as Bl, ah as Gl, ai as zl } from "./shared-DeC0UJkK.js";
6
+ import "./dataset-Bwcib9pp.js";
7
+ import { a as j, u as ae, c as ot, i as at, b as Gc, d as wt, t as Re, e as gt, f as dt, g as lr, r as Nt, h as Ae, j as zc, k as Uc, l as cn, z as Hc, m as ln, n as ur, o as Xc, p as Kc, q as jc, v as qc, w as Yc, x as Qc, y as Zc, A as Jc, B as el, C as tl, D as lt, E as sl, F as nl, G as dr, H as ol, I as al, J as rl, K as il, L as cl, M as ll, N as ul, O as dl, P as pl, Q as hl, R as fl, S as ml, T as gl, U as xl, V as Cl, W as bl, X as Il, Y as yl, Z as kl, _ as $l, $ as vl, a0 as Sl, a1 as wl, a2 as Nl, a3 as Rl, a4 as Tl, a5 as El, a6 as Dl, a7 as Fl, a8 as Pl, a9 as Ol, aa as Al, ab as _l, ac as Ll, ad as Vl, ae as Wl, af as Ml, ag as Bl, ah as Gl, ai as zl } from "./shared-DgNUoqSc.js";
8
8
  import { m as pt, g as pr, s as Ul, c as Hl, b as Xl, d as Kl, a as jl, e as ql } from "./complex_util-Yc1A_gV1.js";
9
- import { a as ge, b as xe, d as ke, c as ve, e as Te, g as os } from "./axis_util-RrJzDQJc.js";
10
- import { k as Ze, h as Le, i as Je, j as rt, b as Se, d as xt, g as as } from "./step-CA-PdcE1.js";
11
- import { z as rs, A as is, B as cs, C as hr, D as fr, F as mr, G as gr, H as xr, I as Cr, J as br, y as Ir, x as yr, w as kr, u as $r, t as vr, E as Sr, K as wr, L as Nr, M as Rr, N as Tr, c as Er, f as Yl, O as Ql, P as Zl } from "./backend_util-9wV3yg0r.js";
12
- import { a as Dr, c as Ue } from "./concat_util-DVNU-Nn3.js";
9
+ import { a as ge, b as xe, d as ke, c as ve, e as Te, g as os } from "./axis_util-BBaWKQoo.js";
10
+ import { k as Ze, h as Le, i as Je, j as rt, b as Se, d as xt, g as as } from "./step-wz0MZ7BP.js";
11
+ import { z as rs, A as is, B as cs, C as hr, D as fr, F as mr, G as gr, H as xr, I as Cr, J as br, y as Ir, x as yr, w as kr, u as $r, t as vr, E as Sr, K as wr, L as Nr, M as Rr, N as Tr, c as Er, f as Yl, O as Ql, P as Zl } from "./backend_util-DLIicY0X.js";
12
+ import { a as Dr, c as Ue } from "./concat_util-DT0Mofs3.js";
13
13
  import { s as Jl } from "./index-CieiGp4Y.js";
14
14
  import { n as Fr, b as Pr, a as Or } from "./non_max_suppression_impl-B2W7YjZB.js";
15
- import { c as Ct } from "./scatter_nd_util-6lhBuxGa.js";
16
- import { S as Ar, a as _r } from "./selu_util-emNhirms.js";
17
- import { b as Lr, d as Vr, p as eu, a as tu, i as su, c as nu } from "./slice_util-5UIO9Akz.js";
18
- import { h as Sn, j as ou, k as au, l as ru, m as iu, n as cu, o as lu, P as un, p as Ve, u as Pe, q as Wr, c as Mr, T as De, E as Br, g as Gr, a as zr, r as uu, s as du, t as Y, v as Pt, w as pu, x as wn, y as hu, z as fu, A as Ot, B as mu, C as gu, D as bs, F as Gt, G as zt, H as xu, I as Cu, J as Nn, K as bu, L as Iu, M as fs, N as yu, O as ku, Q as $u, R as Ut, S as ms, U as vu, f as he, V as be, W as Ht, X as Xt, Y as Su, d as Rn, e as Tn, i as Ur, Z as wu, _ as Nu, $ as Ru, a0 as Tu, a1 as Eu, a2 as Du, a3 as At } from "./gpgpu_math-Dzx_EUJa.js";
19
- import { s as Hr, a as Fu, t as Xr, b as Pu, c as Ou, d as Kr, e as Au, n as _u, f as Lu, g as Vu, h as Wu, i as Mu, j as Bu, k as Gu, l as zu, o as Uu, p as Hu, q as Xu, r as Ku, u as ju, v as qu, w as Yu, x as Qu, y as Zu, z as Ju, A as ed, B as td, C as sd, D as nd, E as od, F as ad, G as rd, H as id, I as cd, J as ld, K as ud, L as dd, M as jr, N as pd, O as hd, P as fd, Q as md, R as gd, S as xd, T as Cd, U as bd, V as Id, W as yd } from "./shared-Wn4Lkf40.js";
20
- import { a as ye, c as kd, U as st, d as qe, e as ze, A as En, f as bt, B as dn, h as pn, m as Rt, u as se, C as We, b as Ce, i as Fe, j as hn, k as it, l as It, n as $d, o as vd, p as Sd, q as wd } from "./kernel_funcs_utils-BLvDeLPe.js";
21
- import { R as Nd, r as U, a as Rd } from "./Reshape-D973Ba8R.js";
22
- import { M as qr } from "./matMulGelu-Bdxn3VPX.js";
23
- import { t as Yr, s as fn, a as _t, m as Td, r as Ed, b as Dd, c as Fd, d as Pd } from "./RealDiv-B2Tyc34U.js";
24
- import { z as Od } from "./zeros-BJogAj4Z.js";
15
+ import { c as Ct } from "./scatter_nd_util-SSoGmfpx.js";
16
+ import { S as Ar, a as _r } from "./selu_util-C0DN3KhX.js";
17
+ import { b as Lr, d as Vr, p as eu, a as tu, i as su, c as nu } from "./slice_util-DK4kHJjN.js";
18
+ import { h as Sn, j as ou, k as au, l as ru, m as iu, n as cu, o as lu, P as un, p as Ve, u as Pe, q as Wr, c as Mr, T as De, E as Br, g as Gr, a as zr, r as uu, s as du, t as Y, v as Pt, w as pu, x as wn, y as hu, z as fu, A as Ot, B as mu, C as gu, D as bs, F as Gt, G as zt, H as xu, I as Cu, J as Nn, K as bu, L as Iu, M as fs, N as yu, O as ku, Q as $u, R as Ut, S as ms, U as vu, f as he, V as be, W as Ht, X as Xt, Y as Su, d as Rn, e as Tn, i as Ur, Z as wu, _ as Nu, $ as Ru, a0 as Tu, a1 as Eu, a2 as Du, a3 as At } from "./gpgpu_math-D83bWKYw.js";
19
+ import { s as Hr, a as Fu, t as Xr, b as Pu, c as Ou, d as Kr, e as Au, n as _u, f as Lu, g as Vu, h as Wu, i as Mu, j as Bu, k as Gu, l as zu, o as Uu, p as Hu, q as Xu, r as Ku, u as ju, v as qu, w as Yu, x as Qu, y as Zu, z as Ju, A as ed, B as td, C as sd, D as nd, E as od, F as ad, G as rd, H as id, I as cd, J as ld, K as ud, L as dd, M as jr, N as pd, O as hd, P as fd, Q as md, R as gd, S as xd, T as Cd, U as bd, V as Id, W as yd } from "./shared-CefTy5O1.js";
20
+ import { a as ye, c as kd, U as st, d as qe, e as ze, A as En, f as bt, B as dn, h as pn, m as Rt, u as se, C as We, b as Ce, i as Fe, j as hn, k as it, l as It, n as $d, o as vd, p as Sd, q as wd } from "./kernel_funcs_utils-Bu6bS4D_.js";
21
+ import { R as Nd, r as U, a as Rd } from "./Reshape-CKzb2DIN.js";
22
+ import { M as qr } from "./matMulGelu-Cbtq3pxJ.js";
23
+ import { t as Yr, s as fn, a as _t, m as Td, r as Ed, b as Dd, c as Fd, d as Pd } from "./RealDiv-CJpH9Bif.js";
24
+ import { z as Od } from "./zeros-B8VPk-mx.js";
25
25
  import "./ops/cpu/attentionMask.js";
26
26
  import "./ops/webgl/attentionMask.js";
27
27
  import "./ops/grads/attentionMask.js";
28
28
  import "./ops/cpu/rope.js";
29
29
  import "./ops/webgl/rope.js";
30
- import "./rope-D5BJXlc7.js";
30
+ import "./rope-DjON_IMj.js";
31
31
  import "./ops/cpu/appendCache.js";
32
32
  import "./ops/webgl/appendCache.js";
33
33
  import "./ops/grads/softmax16.js";
34
- import "./matMul16-Bp17gt56.js";
34
+ import "./matMul16-bI7XM831.js";
35
35
  import "./ops/webgl/matMul16.js";
36
36
  import "./ops/cpu/matMul16.js";
37
- import "./pack16-F9gxcBrq.js";
37
+ import "./pack16-DO9GrRdk.js";
38
38
  import "./ops/transpose16.js";
39
39
  import "./ops/reshape16.js";
40
40
  import "./ops/cpu/qkv.js";
@@ -43,6 +43,8 @@ import "./ops/grads/qkv.js";
43
43
  import "./ops/cpu/normRMS.js";
44
44
  import "./ops/webgl/normRMS.js";
45
45
  import "./ops/grads/normRMS.js";
46
+ import "./ops/dropout16.js";
47
+ import "./ops/webgl/dropout16.js";
46
48
  import "./ops/grads/add16.js";
47
49
  import "./jszip.min-Bz5-11Bk.js";
48
50
  import Ad from "./tokeniser/CharTokeniser.js";
@@ -62,17 +64,17 @@ import "./ops/cpu/matMulGelu.js";
62
64
  import "./ops/grads/matMulGelu.js";
63
65
  import "./ops/cpu/gelu.js";
64
66
  import "./ops/webgl/gelu.js";
65
- import "./gelu-9_DFp2Q5.js";
67
+ import "./gelu-CmkPheOK.js";
66
68
  import "./ops/webgl/log.js";
67
69
  import "./checks/normRMS.js";
68
70
  import "./checks/normRMSGrad.js";
69
71
  import Wd from "./utilities/multinomialCPU.js";
70
- import { r as Dn } from "./reshape-3ugLpT-p.js";
71
- import { t as Kt } from "./tensor2d-DryAvP1o.js";
72
- import { z as Md } from "./unsorted_segment_sum-C23hrdi0.js";
73
- import { s as gs } from "./softmax-BSXRSMAA.js";
74
- import { g as Bd } from "./gather-BYhIiO5e.js";
75
- import { c as Gd } from "./concat-B9WckkXa.js";
72
+ import { r as Dn } from "./reshape-DVh8yLpI.js";
73
+ import { t as Kt } from "./tensor2d-C8gFDiIC.js";
74
+ import { z as Md } from "./unsorted_segment_sum-DRVX2bX2.js";
75
+ import { s as gs } from "./softmax-HULrSwJC.js";
76
+ import { g as Bd } from "./gather-FIoUa4Zd.js";
77
+ import { c as Gd } from "./concat-DCm6KW65.js";
76
78
  function zd(a, t, e, n = !1) {
77
79
  const s = Xi(a, "logits", "multinomial"), o = s.size, r = s.rank;
78
80
  if (o < 2)
@@ -11676,7 +11678,7 @@ const lv = [
11676
11678
  function uv(a, t) {
11677
11679
  return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
11678
11680
  }
11679
- class FS extends Ui {
11681
+ class OS extends Ui {
11680
11682
  constructor(t, e) {
11681
11683
  super(), this.model = t, this.tokeniser = e, this.actualTokeniser = e;
11682
11684
  }
@@ -11872,6 +11874,6 @@ class FS extends Ui {
11872
11874
  }
11873
11875
  }
11874
11876
  export {
11875
- FS as default,
11877
+ OS as default,
11876
11878
  cv as isConversation
11877
11879
  };
@@ -1,10 +1,10 @@
1
- import { aE as E, a8 as T, ad as O, N as V, aW as B, K as F, aM as K, aX as W } from "./index-3FfEY3tm.js";
2
- import { r as $ } from "./Reshape-D973Ba8R.js";
3
- import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-RrJzDQJc.js";
4
- import { t as U, m as _ } from "./shared-Wn4Lkf40.js";
5
- import { c as j } from "./backend_util-9wV3yg0r.js";
6
- import { f as y } from "./gpgpu_math-Dzx_EUJa.js";
7
- import { g as G, b as L } from "./kernel_funcs_utils-BLvDeLPe.js";
1
+ import { aE as E, ab as T, ah as O, U as V, aW as B, N as F, aM as U, aX as W } from "./index-D0b5F1JD.js";
2
+ import { r as $ } from "./Reshape-CKzb2DIN.js";
3
+ import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-BBaWKQoo.js";
4
+ import { t as K, m as _ } from "./shared-CefTy5O1.js";
5
+ import { c as j } from "./backend_util-DLIicY0X.js";
6
+ import { f as y } from "./gpgpu_math-D83bWKYw.js";
7
+ import { g as G, b as L } from "./kernel_funcs_utils-Bu6bS4D_.js";
8
8
  class w {
9
9
  constructor(s, e) {
10
10
  this.variableNames = ["x"];
@@ -273,7 +273,7 @@ function Q(a, s, e, t) {
273
273
  const [p, h] = N(u.shape, i);
274
274
  let d = p;
275
275
  e && (d = R(p, r));
276
- const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b = B(a.dtype), I = M(x, b, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
276
+ const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), S = B(a.dtype), I = M(x, S, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
277
277
  return t.disposeIntermediateTensorInfo(x), t.disposeIntermediateTensorInfo(I), o && t.disposeIntermediateTensorInfo(u), m;
278
278
  }
279
279
  function Z(a) {
@@ -299,7 +299,7 @@ function te(a) {
299
299
  const I = e.texData.get(d.dataId).values, m = new Array(i);
300
300
  for (let v = 0; v < m.length; v++)
301
301
  m[v] = n.shape[u[v]];
302
- const z = U(I, n.shape, n.dtype, u, m);
302
+ const z = K(I, n.shape, n.dtype, u, m);
303
303
  d = e.makeTensorInfo(m, n.dtype);
304
304
  const D = e.texData.get(d.dataId);
305
305
  D.values = z;
@@ -308,21 +308,21 @@ function te(a) {
308
308
  o = k(o.length, i);
309
309
  }
310
310
  C("max", o, i);
311
- const [f, S] = N(d.shape, o);
311
+ const [f, b] = N(d.shape, o);
312
312
  let g = f;
313
313
  r && (g = R(f, c));
314
314
  let x;
315
315
  if (h) {
316
- const I = e.texData.get(d.dataId).values, m = _(I, V(S), g, n.dtype);
316
+ const I = e.texData.get(d.dataId).values, m = _(I, V(b), g, n.dtype);
317
317
  x = e.makeTensorInfo(g, n.dtype);
318
318
  const z = e.texData.get(x.dataId);
319
319
  z.values = m;
320
320
  } else
321
- x = ee(d, S, g, e);
321
+ x = ee(d, b, g, e);
322
322
  return p && e.disposeIntermediateTensorInfo(d), x;
323
323
  }
324
324
  const he = {
325
- kernelName: K,
325
+ kernelName: U,
326
326
  backendName: "webgl",
327
327
  kernelFunc: te
328
328
  };
@@ -1,14 +1,14 @@
1
- import { N as h, af as d, x as c, R as m } from "./index-3FfEY3tm.js";
1
+ import { U as h, aj as d, x as c, R as m } from "./index-D0b5F1JD.js";
2
2
  function i(n) {
3
3
  const { inputs: p, attrs: o } = n, { x: e } = p, { shape: r } = o, a = h(e.shape), s = d(r, a), t = h(s);
4
4
  return c(a === t, () => `The new shape (${s}) has ${t} elements and the old shape (${e.shape}) has ${a} elements. The new shape and old shape must have the same number of elements.`), n.backend.incRef(e.dataId), { dataId: e.dataId, shape: s, dtype: e.dtype };
5
5
  }
6
- const f = {
6
+ const u = {
7
7
  kernelName: m,
8
8
  backendName: "webgpu",
9
9
  kernelFunc: i
10
10
  };
11
11
  export {
12
- f as a,
12
+ u as a,
13
13
  i as r
14
14
  };
@@ -1,5 +1,5 @@
1
- import { R as C, N as c, af as f, x as R } from "./index-3FfEY3tm.js";
2
- import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-Dzx_EUJa.js";
1
+ import { R as C, U as c, aj as R, x as f } from "./index-D0b5F1JD.js";
2
+ import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-D83bWKYw.js";
3
3
  class S {
4
4
  constructor(t, i) {
5
5
  this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = t, this.enableShapeUniforms = g(this.outputShape.length);
@@ -62,8 +62,8 @@ function b(s, t, i) {
62
62
  return { dataId: h.dataId, shape: t, dtype: h.dtype };
63
63
  }
64
64
  function y(s) {
65
- const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = f(o, p), h = c(n);
66
- R(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
65
+ const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = R(o, p), h = c(n);
66
+ f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
67
67
  const d = r.texData.get(e.dataId);
68
68
  return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
69
69
  }
@@ -8,6 +8,7 @@ import { default as MemoryProfiler } from './utilities/profile';
8
8
  import { default as Model, ModelForwardAttributes } from './models/model';
9
9
  import { Task } from './training/tasks/Task';
10
10
  import { TrainingLogEntry, TrainingOptions } from './training/types';
11
+ import { ModelPhase } from './loader/types';
11
12
  type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
12
13
  interface TeachableLLMMeta {
13
14
  name?: string;
@@ -26,6 +27,8 @@ export default class TeachableLLM {
26
27
  private _trainer;
27
28
  constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes, GPTConfig>);
28
29
  get vocab(): string[];
30
+ get phase(): ModelPhase;
31
+ set phase(phase: ModelPhase);
29
32
  /** Model is fully loaded */
30
33
  get loaded(): boolean;
31
34
  get config(): GPTConfig;
@@ -52,10 +55,12 @@ export default class TeachableLLM {
52
55
  generateText(options?: IGenerateOptions): Promise<Conversation[]>;
53
56
  dispose(): void;
54
57
  on(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
58
+ on(event: 'phase', listener: (phase: ModelPhase) => void): void;
55
59
  on(event: 'error', listener: (error: Error) => void): void;
56
60
  on(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
57
61
  on(event: 'loaded', listener: () => void): void;
58
62
  off(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
63
+ off(event: 'phase', listener: (phase: ModelPhase) => void): void;
59
64
  off(event: 'error', listener: (error: Error) => void): void;
60
65
  off(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
61
66
  off(event: 'loaded', listener: () => void): void;
@@ -1,28 +1,28 @@
1
1
  import { validateConfig as m } from "./models/config.js";
2
2
  import { saveModel as d } from "./loader/save.js";
3
- import { loadModel as u } from "./loader/load.js";
4
- import p from "./Generator.js";
3
+ import { loadModel as p } from "./loader/load.js";
4
+ import u from "./Generator.js";
5
5
  import h from "./Trainer.js";
6
6
  import { E as f } from "./index-DvYrXKkX.js";
7
7
  import { dummyPassTrainAsync as l } from "./utilities/dummy.js";
8
- import "./index-3FfEY3tm.js";
9
- import "./random_width-DSeITIFc.js";
10
- import "./zeros_like-WQK7VrX-.js";
8
+ import "./index-D0b5F1JD.js";
9
+ import "./random_width-CliSj-et.js";
10
+ import "./zeros_like-DfWM-ezN.js";
11
11
  import "./index-Cp39cXWe.js";
12
- import "./dataset-ZUdlBUXV.js";
12
+ import "./dataset-Bwcib9pp.js";
13
13
  import "./ops/cpu/attentionMask.js";
14
14
  import "./ops/webgl/attentionMask.js";
15
15
  import "./ops/grads/attentionMask.js";
16
16
  import "./ops/cpu/rope.js";
17
17
  import "./ops/webgl/rope.js";
18
- import "./rope-D5BJXlc7.js";
18
+ import "./rope-DjON_IMj.js";
19
19
  import "./ops/cpu/appendCache.js";
20
20
  import "./ops/webgl/appendCache.js";
21
21
  import "./ops/grads/softmax16.js";
22
- import "./matMul16-Bp17gt56.js";
22
+ import "./matMul16-bI7XM831.js";
23
23
  import "./ops/webgl/matMul16.js";
24
24
  import "./ops/cpu/matMul16.js";
25
- import "./pack16-F9gxcBrq.js";
25
+ import "./pack16-DO9GrRdk.js";
26
26
  import "./ops/transpose16.js";
27
27
  import "./ops/reshape16.js";
28
28
  import "./ops/cpu/qkv.js";
@@ -31,6 +31,8 @@ import "./ops/grads/qkv.js";
31
31
  import "./ops/cpu/normRMS.js";
32
32
  import "./ops/webgl/normRMS.js";
33
33
  import "./ops/grads/normRMS.js";
34
+ import "./ops/dropout16.js";
35
+ import "./ops/webgl/dropout16.js";
34
36
  import "./ops/grads/add16.js";
35
37
  import c from "./tokeniser/CharTokeniser.js";
36
38
  import g from "./tokeniser/bpe.js";
@@ -41,11 +43,11 @@ import "./ops/webgl/gatherSub.js";
41
43
  import "./ops/cpu/scatterSub.js";
42
44
  import "./ops/webgl/scatterSub.js";
43
45
  import "./ops/cpu/matMulGelu.js";
44
- import "./matMulGelu-Bdxn3VPX.js";
46
+ import "./matMulGelu-Cbtq3pxJ.js";
45
47
  import "./ops/grads/matMulGelu.js";
46
48
  import "./ops/cpu/gelu.js";
47
49
  import "./ops/webgl/gelu.js";
48
- import "./gelu-9_DFp2Q5.js";
50
+ import "./gelu-CmkPheOK.js";
49
51
  import "./ops/webgl/log.js";
50
52
  import "./ops/cpu/adamMoments.js";
51
53
  import "./ops/webgl/adamMoments.js";
@@ -70,6 +72,14 @@ class a {
70
72
  get vocab() {
71
73
  return this._tokeniser?.getVocab() || [];
72
74
  }
75
+ get phase() {
76
+ return this._model?.metaData?.phase ?? "untrained";
77
+ }
78
+ set phase(t) {
79
+ if (!this._model)
80
+ throw new Error("model_not_initialized.");
81
+ this._model.metaData.phase = t, this.ee.emit("phase", t);
82
+ }
73
83
  /** Model is fully loaded */
74
84
  get loaded() {
75
85
  return !!this._model && !!this._tokeniser && !!this._config;
@@ -116,9 +126,9 @@ class a {
116
126
  }
117
127
  static loadModel(t, r) {
118
128
  const e = new a();
119
- return u(t, r).then(({ model: o, tokeniser: n, metaData: i }) => {
129
+ return p(t, r).then(({ model: o, tokeniser: n, metaData: i }) => {
120
130
  m(o.config), e._model = o, e._tokeniser = n, e._config = o.config, i?.name && (e.meta.name = i.name), e.setStatus("warmup"), l(o).then((s) => {
121
- e._memoryRequirements = s, e.setStatus("ready"), e.ee.emit("loaded");
131
+ e._memoryRequirements = s, e.setStatus("ready"), e.ee.emit("loaded"), e.ee.emit("phase", e.phase);
122
132
  }).catch((s) => {
123
133
  e.setStatus("error"), e.ee.emit("error", s), console.error("Error during warmup:", s);
124
134
  });
@@ -130,7 +140,7 @@ class a {
130
140
  m(r);
131
141
  const e = r, o = t === "char" ? new c(e.vocabSize) : new g(e.vocabSize), n = k(e), i = new a(o, n);
132
142
  return i.setStatus("warmup"), l(n).then((s) => {
133
- i._memoryRequirements = s, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (_) => {
143
+ i._memoryRequirements = s, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded"), i.ee.emit("phase", i.phase)) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.ee.emit("phase", i.phase), i.tokeniser.once("trainStatus", (_) => {
134
144
  _ === "trained" && i.setStatus("ready");
135
145
  }));
136
146
  }).catch((s) => {
@@ -159,11 +169,13 @@ class a {
159
169
  throw new Error("model_or_tokeniser_not_initialized.");
160
170
  this._trainer && t && this._trainer.trainingType !== t && (this._trainer.dispose(), this._trainer = null);
161
171
  const e = this._trainer === null ? new h(this._model, this._tokeniser, t, r) : new h(this._trainer, r);
162
- return e.on("start", () => this.setStatus("training")), e.on("stop", () => this.setStatus("ready")), e.on("log", async (o) => {
172
+ return e.on("start", () => {
173
+ this.setStatus("training"), this.phase = t === "sft" ? "finetuned" : "pretrained";
174
+ }), e.on("stop", () => this.setStatus("ready")), e.on("log", async (o) => {
163
175
  const n = this.ee.listeners("trainStep");
164
176
  for (const i of n)
165
177
  await i(o);
166
- }), this._trainer = e, e;
178
+ }), this._trainer && this._trainer !== e && this._trainer.dispose(), this._trainer = e, e;
167
179
  }
168
180
  async train(t, r, e) {
169
181
  const o = this.trainer(e, r);
@@ -178,7 +190,7 @@ class a {
178
190
  generator() {
179
191
  if (!this._model || !this._tokeniser)
180
192
  throw new Error("model_or_tokeniser_not_initialized.");
181
- const t = new p(this._model, this._tokeniser);
193
+ const t = new u(this._model, this._tokeniser);
182
194
  return t.on("start", () => {
183
195
  this.status === "ready" && this.setStatus("busy");
184
196
  }), t.on("stop", () => {
@@ -189,7 +201,7 @@ class a {
189
201
  return Array.isArray(t) ? this.generator().generate(t, r) : this.generator().generate([], r);
190
202
  }
191
203
  dispose() {
192
- this._model?.dispose(), this.ee.removeAllListeners();
204
+ this._trainer && (this._trainer.dispose(), this._trainer = null), this._model?.dispose(), this.ee.removeAllListeners();
193
205
  }
194
206
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
195
207
  on(t, r) {
package/dist/Trainer.d.ts CHANGED
@@ -20,6 +20,7 @@ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
20
20
  log: TrainingLogEntry[];
21
21
  private progress;
22
22
  options: TrainingOptions;
23
+ protected tokenizer: ITokeniser;
23
24
  constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser, trainingType?: TrainingType, options?: TrainingOptions);
24
25
  constructor(trainer: Trainer, options?: TrainingOptions);
25
26
  get model(): Model<ModelForwardAttributes>;
package/dist/Trainer.js CHANGED
@@ -1,8 +1,8 @@
1
1
  import { E as g } from "./index-DvYrXKkX.js";
2
- import s from "./training/PreTrainer.js";
2
+ import o from "./training/PreTrainer.js";
3
3
  import { createTrainValidationSplit as p } from "./training/validation.js";
4
- import n from "./training/SFTTrainer.js";
5
- class o extends g {
4
+ import h from "./training/SFTTrainer.js";
5
+ class l extends g {
6
6
  trainer;
7
7
  trainingType = "pretraining";
8
8
  hasTrained = !1;
@@ -16,19 +16,22 @@ class o extends g {
16
16
  sftMode: "full",
17
17
  logInterval: 10
18
18
  };
19
- constructor(i, t, e = "pretraining", a) {
20
- if (super(), i instanceof o) {
21
- this.trainer = i.trainer, this.trainingType = i.trainingType, this.options = t ?? i.options, this.trainer.updateOptimizer(this.options), this.log = i.log, this.progress = i.progress, this.totalSamples = i.totalSamples;
19
+ tokenizer;
20
+ constructor(t, i, e = "pretraining", a) {
21
+ if (super(), t instanceof l) {
22
+ const r = i || t.options, n = t.options;
23
+ let s = !1;
24
+ t.trainingType === "sft" && r.sftMode !== n.sftMode && (s = !0), e !== t.trainingType && (s = !0), s ? (t.trainingType === "sft" ? this.trainer = new h(t.model, t.tokenizer, r) : this.trainer = new o(t.model, t.tokenizer, r), this.trainingType = e, this.options = r, this.tokenizer = t.tokenizer) : (this.trainer = t.trainer, this.trainingType = e, this.options = r, this.trainer.updateOptimizer(this.options), this.log = t.log, this.progress = t.progress, this.totalSamples = t.totalSamples, this.tokenizer = t.tokenizer, r.batchSize === n.batchSize && (this.trainDataset = t.trainDataset, this.validationDataset = t.validationDataset));
22
25
  return;
23
26
  }
24
- if (!t)
25
- throw new Error("Tokeniser must be provided when initializing Trainer with a model");
26
27
  if (!i)
28
+ throw new Error("Tokeniser must be provided when initializing Trainer with a model");
29
+ if (!t)
27
30
  throw new Error("Model must be provided when initializing Trainer");
28
31
  this.options = a || {
29
32
  batchSize: 32,
30
33
  sftMode: "full"
31
- }, e === "sft" ? this.trainer = new n(i, t, a) : this.trainer = new s(i, t, a), this.trainingType = e;
34
+ }, e === "sft" ? this.trainer = new h(t, i, a) : this.trainer = new o(t, i, a), this.trainingType = e, this.tokenizer = i;
32
35
  }
33
36
  get model() {
34
37
  return this.trainer.model;
@@ -48,110 +51,110 @@ class o extends g {
48
51
  getTotalSamples() {
49
52
  return this.totalSamples;
50
53
  }
51
- setOptions(i) {
52
- const t = new Set(
53
- Object.keys(i).filter(
54
- (e) => i[e] !== this.options[e]
54
+ setOptions(t) {
55
+ const i = new Set(
56
+ Object.keys(t).filter(
57
+ (e) => t[e] !== this.options[e]
55
58
  )
56
59
  );
57
60
  if (this.trainer.isRunning) {
58
- if (t.has("batchSize"))
61
+ if (i.has("batchSize"))
59
62
  throw new Error("Cannot change batch size during training");
60
- if (t.has("sftMode"))
63
+ if (i.has("sftMode"))
61
64
  throw new Error("Cannot change SFT mode during training");
62
- if (t.has("loraConfig"))
65
+ if (i.has("loraConfig"))
63
66
  throw new Error("Cannot change LoRA configuration during training");
64
- if (t.has("validationSplit"))
67
+ if (i.has("validationSplit"))
65
68
  throw new Error("Cannot change validation split during training");
66
- if (t.has("trainableWeights"))
69
+ if (i.has("trainableWeights"))
67
70
  throw new Error("Cannot change trainable weights during training");
68
- if (t.has("mixedPrecision"))
71
+ if (i.has("mixedPrecision"))
69
72
  throw new Error("Cannot change mixed precision setting during training");
70
- if (t.has("gradientCheckpointing"))
73
+ if (i.has("gradientCheckpointing"))
71
74
  throw new Error("Cannot change gradient checkpointing setting during training");
72
75
  }
73
76
  this.options = {
74
77
  ...this.options,
75
- ...i
76
- }, this.trainer.updateOptimizer(this.options), t.has("metrics") && this.trainer.setMetrics(i.metrics || []);
78
+ ...t
79
+ }, this.trainer.updateOptimizer(this.options), i.has("metrics") && this.trainer.setMetrics(t.metrics || []);
77
80
  }
78
- async prepare(i = []) {
79
- const t = this.options;
80
- if (this.trainingType === "pretraining" && this.trainer instanceof s) {
81
- const { trainDataset: e, validationDataset: a, size: r, trainState: h } = await p(
82
- i,
81
+ async prepare(t = []) {
82
+ const i = this.options;
83
+ if (this.trainingType === "pretraining" && this.trainer instanceof o) {
84
+ const { trainDataset: e, validationDataset: a, size: r, trainState: n } = await p(
85
+ t,
83
86
  this.trainer.tokenizer,
84
87
  this.trainer.datasetBuilder,
85
- t?.batchSize || 32,
86
- t?.validationSplit || 0.1
87
- ), l = r * (1 - (t?.validationSplit || 0));
88
- this.trainDataset = e, this.validationDataset = a, this.totalSamples = l, this.options.epochSteps = Math.ceil(h.shuffledIndexes.length / (t?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
89
- } else if (this.trainingType === "sft" && this.trainer instanceof n) {
90
- if (i instanceof Uint16Array)
88
+ i?.batchSize || 32,
89
+ i?.validationSplit || 0.1
90
+ ), s = r * (1 - (i?.validationSplit || 0));
91
+ this.trainDataset = e, this.validationDataset = a, this.totalSamples = s, this.options.epochSteps = Math.ceil(n.shuffledIndexes.length / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
92
+ } else if (this.trainingType === "sft" && this.trainer instanceof h) {
93
+ if (t instanceof Uint16Array)
91
94
  throw new Error("SFT training requires Task[] input");
92
95
  const e = await this.trainer.datasetBuilder.createSFTDataset(
93
- i,
94
- t?.batchSize || 32,
96
+ t,
97
+ i?.batchSize || 32,
95
98
  -100
96
99
  );
97
- this.trainDataset = e, this.totalSamples = i.reduce((a, r) => a + r.length, 0), this.options.epochSteps = Math.ceil(this.totalSamples / (t?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
100
+ this.trainDataset = e, this.totalSamples = t.reduce((a, r) => a + r.length, 0), this.options.epochSteps = Math.ceil(this.totalSamples / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
98
101
  }
99
102
  }
100
- configureModel(i) {
101
- const t = i?.sftMode || "full";
103
+ configureModel(t) {
104
+ const i = t?.sftMode || "full";
102
105
  if (this.trainingType === "pretraining" && (this.trainer.model.hasLoRA() && this.trainer.model.detachLoRA(), this.trainer.model.weightStore.setTrainable(["*"])), this.trainingType === "sft") {
103
- if (t === "lora") {
104
- if (!i?.loraConfig)
106
+ if (i === "lora") {
107
+ if (!t?.loraConfig)
105
108
  throw new Error("LoRA configuration must be provided for lora mode");
106
109
  if (this.trainer.model.hasLoRA()) {
107
110
  const e = this.trainer.model.lora;
108
- (e.alpha !== i.loraConfig.alpha || e.rank !== i.loraConfig.rank) && (this.trainer.model.detachLoRA(), this.trainer.model.attachLoRA(i.loraConfig));
111
+ (e.alpha !== t.loraConfig.alpha || e.rank !== t.loraConfig.rank) && (this.trainer.model.detachLoRA(), this.trainer.model.attachLoRA(t.loraConfig));
109
112
  } else
110
- this.trainer.model.attachLoRA(i.loraConfig);
113
+ this.trainer.model.attachLoRA(t.loraConfig);
111
114
  } else
112
115
  this.trainer.model.hasLoRA() && this.trainer.model.detachLoRA();
113
- t === "last-layer" ? this.trainer.model.weightStore.setTrainable([
116
+ i === "last-layer" ? this.trainer.model.weightStore.setTrainable([
114
117
  `block_${this.trainer.model.config.nLayer - 1}_*`,
115
118
  "token_embedding"
116
- ]) : t === "full" && this.trainer.model.weightStore.setTrainable(["*"]);
119
+ ]) : i === "full" && this.trainer.model.weightStore.setTrainable(["*"]);
117
120
  }
118
- i?.trainableWeights && this.trainer.model.weightStore.setTrainable(i.trainableWeights);
121
+ t?.trainableWeights && this.trainer.model.weightStore.setTrainable(t.trainableWeights);
119
122
  }
120
123
  async train() {
121
- const i = this.options;
124
+ const t = this.options;
122
125
  if (!this.trainDataset)
123
126
  throw new Error("Dataset not prepared");
124
- this.hasTrained || this.trainer.setLearningRate(i?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(i?.gradientCheckpointing || !1), this.trainer.setMixedPrecision(i?.mixedPrecision || !1), this.configureModel(i), await this.trainer.trainOnDataset(
127
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), this.trainer.setMixedPrecision(t?.mixedPrecision || !1), this.trainer.setLabelSmoothing(t?.labelSmoothing || 0), this.trainer.setDropout(t?.dropout || 0), this.trainer.setLayerDrop(t?.layerDrop || 0), this.configureModel(t), await this.trainer.trainOnDataset(
125
128
  this.trainDataset,
126
129
  {
127
- ...i,
128
- onStep: async (t) => {
129
- this.log.push(t), this.progress = {
130
- lastLog: t,
131
- progress: t.totalSamples / this.totalSamples,
130
+ ...t,
131
+ onStep: async (i) => {
132
+ this.log.push(i), this.progress = {
133
+ lastLog: i,
134
+ progress: i.totalSamples / this.totalSamples,
132
135
  remaining: Math.max(
133
136
  0,
134
- (this.totalSamples - t.totalSamples) / t.totalSamples * t.duration
137
+ (this.totalSamples - i.totalSamples) / i.totalSamples * i.duration
135
138
  )
136
139
  };
137
140
  const e = this.listeners("log");
138
141
  for (const a of e)
139
- await a(t, this.progress);
142
+ await a(i, this.progress);
140
143
  }
141
144
  },
142
145
  this.validationDataset
143
146
  ), this.emit("stop");
144
147
  }
145
- async step(i) {
148
+ async step(t) {
146
149
  if (!this.trainDataset)
147
150
  throw new Error("Dataset not prepared");
148
- this.hasTrained || this.trainer.setLearningRate(i?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
149
- const { log: t } = await this.trainer.stepDataset(this.trainDataset, i || {}, this.validationDataset), e = this.listeners("log");
151
+ this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
152
+ const { log: i } = await this.trainer.stepDataset(this.trainDataset, t || {}, this.validationDataset), e = this.listeners("log");
150
153
  for (const a of e)
151
- await a(t, {
152
- lastLog: t,
153
- progress: t.totalSamples / this.totalSamples,
154
- remaining: Math.max(0, (this.totalSamples - t.totalSamples) / t.totalSamples * t.duration)
154
+ await a(i, {
155
+ lastLog: i,
156
+ progress: i.totalSamples / this.totalSamples,
157
+ remaining: Math.max(0, (this.totalSamples - i.totalSamples) / i.totalSamples * i.duration)
155
158
  });
156
159
  this.emit("stop");
157
160
  }
@@ -166,5 +169,5 @@ class o extends g {
166
169
  }
167
170
  }
168
171
  export {
169
- o as default
172
+ l as default
170
173
  };
@@ -1,4 +1,4 @@
1
- import { x as c } from "./index-3FfEY3tm.js";
1
+ import { x as c } from "./index-D0b5F1JD.js";
2
2
  function i(e, n) {
3
3
  for (let t = 0; t < e.length; ++t)
4
4
  if (e[e.length - t - 1] !== n - 1 - t)
package/dist/backend.js CHANGED
@@ -1,9 +1,9 @@
1
- import { g as o, s as e, r as s } from "./index-3FfEY3tm.js";
1
+ import { g as o, s as e, r as s } from "./index-D0b5F1JD.js";
2
2
  async function c(t, a) {
3
3
  if (o() !== t) {
4
4
  if (t === "webgpu") {
5
5
  const { registerWebGPUBackend: i } = await import("./patches/webgpu_base.js");
6
- i(a), await import("./index-B8eBIyjS.js"), await import("./ops/webgpu/index.js");
6
+ i(a), await import("./index-nwvWLdRt.js"), await import("./ops/webgpu/index.js");
7
7
  }
8
8
  await e(t), await s(), console.log(`Backend set to ${t}`);
9
9
  }