@genai-fi/nanogpt 0.15.0 → 0.15.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +33 -31
- package/dist/{RealDiv-B2Tyc34U.js → RealDiv-CJpH9Bif.js} +13 -13
- package/dist/{Reshape-Bqk-z_7-.js → Reshape-C4ZzbS5c.js} +3 -3
- package/dist/{Reshape-D973Ba8R.js → Reshape-CKzb2DIN.js} +4 -4
- package/dist/TeachableLLM.d.ts +5 -0
- package/dist/TeachableLLM.js +30 -18
- package/dist/Trainer.d.ts +1 -0
- package/dist/Trainer.js +65 -62
- package/dist/{axis_util-RrJzDQJc.js → axis_util-BBaWKQoo.js} +1 -1
- package/dist/backend.js +2 -2
- package/dist/{backend_util-9wV3yg0r.js → backend_util-DLIicY0X.js} +50 -50
- package/dist/{backend_webgpu-CnFoGvzK.js → backend_webgpu-BwfUOSiJ.js} +21 -21
- package/dist/{broadcast_to-hAMmZJpr.js → broadcast_to-CxKUM6zp.js} +2 -2
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/matMulGelu.js +2 -2
- package/dist/checks/normRMS.js +6 -6
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.js +2 -2
- package/dist/checks/qkv.js +2 -2
- package/dist/checks/rope.js +2 -2
- package/dist/clip_by_value-lDwNWeyI.js +12 -0
- package/dist/{complex-BDvCF_r9.js → complex-NXAORdbW.js} +1 -1
- package/dist/{concat-B9WckkXa.js → concat-DCm6KW65.js} +1 -1
- package/dist/{concat_util-DVNU-Nn3.js → concat_util-DT0Mofs3.js} +1 -1
- package/dist/{dataset-ZUdlBUXV.js → dataset-Bwcib9pp.js} +3 -3
- package/dist/dropout_util-Crmm4aOV.js +27 -0
- package/dist/{expand_dims-DoiHvcDw.js → expand_dims-DgU0Vlpg.js} +1 -1
- package/dist/{exports_initializers-8SQOHjAF.js → exports_initializers-VKuLTIiX.js} +1 -1
- package/dist/floor-Bhmfrtly.js +9 -0
- package/dist/{gather-BYhIiO5e.js → gather-FIoUa4Zd.js} +1 -1
- package/dist/{gelu-9_DFp2Q5.js → gelu-CmkPheOK.js} +1 -1
- package/dist/{gpgpu_math-Dzx_EUJa.js → gpgpu_math-D83bWKYw.js} +25 -25
- package/dist/{index-3FfEY3tm.js → index-D0b5F1JD.js} +58 -58
- package/dist/{index-B8eBIyjS.js → index-nwvWLdRt.js} +89 -89
- package/dist/{kernel_funcs_utils-BLvDeLPe.js → kernel_funcs_utils-Bu6bS4D_.js} +11 -11
- package/dist/layers/BaseLayer.d.ts +4 -0
- package/dist/layers/BaseLayer.js +11 -7
- package/dist/layers/CausalSelfAttention.js +55 -51
- package/dist/layers/LoRA.js +4 -4
- package/dist/layers/MLP.d.ts +1 -1
- package/dist/layers/MLP.js +20 -19
- package/dist/layers/PositionEmbedding.js +10 -10
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +6 -6
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/layers/WeightStore.js +3 -3
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +20 -18
- package/dist/loader/save.js +6 -5
- package/dist/loader/types.d.ts +2 -0
- package/dist/main.js +9 -9
- package/dist/{matMul16-Bp17gt56.js → matMul16-bI7XM831.js} +3 -3
- package/dist/{matMulGelu-Bdxn3VPX.js → matMulGelu-Cbtq3pxJ.js} +21 -21
- package/dist/{mat_mul-BUuYg3qo.js → mat_mul-BQY_GSqm.js} +1 -1
- package/dist/{mod-4q-X1J5l.js → mod-ChddM4vN.js} +1 -1
- package/dist/models/NanoGPTV1.js +9 -9
- package/dist/models/NanoGPTV2.js +12 -10
- package/dist/models/model.d.ts +1 -1
- package/dist/models/model.js +14 -12
- package/dist/not_equal-duCIyEXv.js +64 -0
- package/dist/{ones-aGZXepq3.js → ones-Piv0gZxv.js} +3 -3
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/add16.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/concat16.js +2 -2
- package/dist/ops/cpu/adamAdjust.js +1 -1
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +6 -6
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +6 -6
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMul16.js +2 -2
- package/dist/ops/cpu/matMulGelu.js +3 -3
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +9 -9
- package/dist/ops/dot16.js +2 -2
- package/dist/ops/dropout.d.ts +2 -0
- package/dist/ops/dropout.js +14 -0
- package/dist/ops/dropout16.d.ts +2 -0
- package/dist/ops/dropout16.js +25 -0
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/globalNorm.js +2 -2
- package/dist/ops/grads/add16.js +1 -1
- package/dist/ops/grads/attentionMask.js +2 -2
- package/dist/ops/grads/dropout16.d.ts +1 -0
- package/dist/ops/grads/dropout16.js +2 -0
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMul16.js +3 -3
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/mul16.d.ts +1 -0
- package/dist/ops/grads/mul16.js +4 -0
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/pack16.js +3 -3
- package/dist/ops/grads/qkv.js +3 -3
- package/dist/ops/grads/rope.js +2 -2
- package/dist/ops/grads/softmax16.js +1 -1
- package/dist/ops/grads/unpack16.js +2 -2
- package/dist/ops/matMul16.js +3 -3
- package/dist/ops/matMulGelu.js +2 -2
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mul16.js +36 -5
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +13 -4
- package/dist/ops/pack16.js +2 -2
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/reshape16.js +2 -2
- package/dist/ops/rope.js +2 -2
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.js +2 -2
- package/dist/ops/softmax16.js +1 -1
- package/dist/ops/sub16.js +1 -1
- package/dist/ops/sum16.js +2 -2
- package/dist/ops/transpose16.js +3 -3
- package/dist/ops/unpack16.js +2 -2
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/dropout16.d.ts +1 -0
- package/dist/ops/webgl/dropout16.js +11 -0
- package/dist/ops/webgl/fusedSoftmax.js +6 -6
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMul16.js +5 -5
- package/dist/ops/webgl/matMulGelu.js +4 -4
- package/dist/ops/webgl/matMulMul.js +2 -2
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/add16.js +1 -1
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +2 -2
- package/dist/ops/webgpu/attentionMask32_program.js +2 -2
- package/dist/ops/webgpu/clipScale.js +1 -1
- package/dist/ops/webgpu/concat16.js +12 -12
- package/dist/ops/webgpu/dropout16.d.ts +1 -0
- package/dist/ops/webgpu/dropout16.js +51 -0
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/index.js +1 -0
- package/dist/ops/webgpu/matMul16.js +14 -14
- package/dist/ops/webgpu/matMul16_program.js +2 -2
- package/dist/ops/webgpu/mul16.js +9 -9
- package/dist/ops/webgpu/norm2.js +1 -1
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +4 -4
- package/dist/ops/webgpu/pack16.js +1 -1
- package/dist/ops/webgpu/pack16_program.js +2 -2
- package/dist/ops/webgpu/qkv.js +2 -2
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/slice16.js +4 -4
- package/dist/ops/webgpu/softmax16.js +2 -2
- package/dist/ops/webgpu/softmax16_program.js +2 -2
- package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
- package/dist/ops/webgpu/softmax16grad.js +1 -1
- package/dist/ops/webgpu/sub16.js +1 -1
- package/dist/ops/webgpu/sum16.js +5 -5
- package/dist/ops/webgpu/transpose16.js +2 -2
- package/dist/ops/webgpu/transpose16_program.js +2 -2
- package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
- package/dist/ops/webgpu/unpack16.js +3 -3
- package/dist/ops/webgpu/utils/binary_op.d.ts +16 -0
- package/dist/ops/webgpu/utils/binary_op.js +74 -13
- package/dist/ops/webgpu/utils/reductions.js +5 -5
- package/dist/{ops-BLDakU_V.js → ops-BXr-37bF.js} +30 -30
- package/dist/{pack16-F9gxcBrq.js → pack16-DO9GrRdk.js} +2 -2
- package/dist/patches/webgpu_backend.js +9 -9
- package/dist/patches/webgpu_base.js +1 -1
- package/dist/patches/webgpu_program.js +2 -2
- package/dist/rand_util-CZ7yLoUm.js +50 -0
- package/dist/random_normal-CO9xf9dz.js +14 -0
- package/dist/{random_width-DSeITIFc.js → random_width-CliSj-et.js} +164 -162
- package/dist/{range-BvA7g6TS.js → range-Dx4PwA2-.js} +1 -1
- package/dist/{readers-lNVRVUDO.js → readers-DwZhCW0C.js} +2 -2
- package/dist/{relu-DyGjd4UV.js → relu-BnpM8PVa.js} +1 -1
- package/dist/{reshape-3ugLpT-p.js → reshape-DVh8yLpI.js} +1 -1
- package/dist/{resize_nearest_neighbor-DBPfHMkZ.js → resize_nearest_neighbor-Dl7ehaQl.js} +39 -39
- package/dist/{rope-D5BJXlc7.js → rope-DjON_IMj.js} +1 -1
- package/dist/{scatter_nd_util-6lhBuxGa.js → scatter_nd_util-SSoGmfpx.js} +1 -1
- package/dist/{selu_util-emNhirms.js → selu_util-C0DN3KhX.js} +5 -5
- package/dist/{shared-Wn4Lkf40.js → shared-CefTy5O1.js} +1 -1
- package/dist/{shared-DeC0UJkK.js → shared-DgNUoqSc.js} +35 -35
- package/dist/{slice-C1VU5kjs.js → slice-BluUPHKL.js} +1 -1
- package/dist/{slice_util-5UIO9Akz.js → slice_util-DK4kHJjN.js} +1 -1
- package/dist/{softmax-BSXRSMAA.js → softmax-HULrSwJC.js} +1 -1
- package/dist/{split-Z_OF59mV.js → split-QwVeUPZt.js} +1 -1
- package/dist/{squeeze-DuB_IYFY.js → squeeze-Brkwo5OI.js} +2 -2
- package/dist/{stack-CdjLGyjr.js → stack-C_8ubcjt.js} +1 -1
- package/dist/{step-CA-PdcE1.js → step-wz0MZ7BP.js} +1 -1
- package/dist/{sum-CX6lFpfv.js → sum-iKJXG43N.js} +1 -1
- package/dist/{tensor-BLWBtdey.js → tensor-Dfy8cN1y.js} +1 -1
- package/dist/{tensor1d-Dp80hTtj.js → tensor1d-CoOFcAZs.js} +1 -1
- package/dist/{tensor2d-DryAvP1o.js → tensor2d-C8gFDiIC.js} +1 -1
- package/dist/{tensor4d-BR5YioKH.js → tensor4d-Bvqzr_Wu.js} +1 -1
- package/dist/{tfjs_backend-BuO7pU2h.js → tfjs_backend-9QO-TAAZ.js} +275 -295
- package/dist/{tile-CB7Cg2Cm.js → tile-CcpklBqG.js} +1 -1
- package/dist/training/AdamW.js +2 -2
- package/dist/training/BasicTrainer.d.ts +6 -0
- package/dist/training/BasicTrainer.js +74 -60
- package/dist/training/DatasetBuilder.js +3 -3
- package/dist/training/Evaluator.js +2 -2
- package/dist/training/SFTDatasetBuilder.js +3 -3
- package/dist/training/SFTTrainer.js +6 -6
- package/dist/training/loss.d.ts +1 -1
- package/dist/training/loss.js +12 -8
- package/dist/training/orthoGrad.js +1 -1
- package/dist/training/sparseCrossEntropy.d.ts +2 -2
- package/dist/training/sparseCrossEntropy.js +54 -31
- package/dist/training/types.d.ts +4 -0
- package/dist/training/validation.js +19 -17
- package/dist/{transpose-COw0-lqd.js → transpose-CwEYsCv1.js} +2 -2
- package/dist/{unsorted_segment_sum-C23hrdi0.js → unsorted_segment_sum-DRVX2bX2.js} +22 -22
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.js +1 -1
- package/dist/utilities/parameters.d.ts +1 -0
- package/dist/utilities/parameters.js +20 -15
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.js +5 -5
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-lnPOlwsK.js → variable-CqrRzzxM.js} +1 -1
- package/dist/{webgpu_program-CuMK2hhh.js → webgpu_program-BlAY4Q29.js} +1 -1
- package/dist/{webgpu_util-DWXgz54K.js → webgpu_util-D1Ynuktt.js} +1 -1
- package/dist/{zeros-BJogAj4Z.js → zeros-B8VPk-mx.js} +2 -2
- package/dist/{zeros_like-WQK7VrX-.js → zeros_like-DfWM-ezN.js} +90 -89
- package/package.json +1 -1
- package/dist/floor-B6EO3Z6x.js +0 -18
- package/dist/not_equal-BO_DB61m.js +0 -64
- package/dist/random_normal-dxcPUb9x.js +0 -61
package/dist/Generator.js
CHANGED
|
@@ -1,40 +1,40 @@
|
|
|
1
1
|
import { E as Ui } from "./index-DvYrXKkX.js";
|
|
2
|
-
import { o as Hi, q as Xi, E as Ki, dn as Ss,
|
|
3
|
-
import { n as Mc } from "./random_width-
|
|
4
|
-
import { t as Bc } from "./zeros_like-
|
|
2
|
+
import { o as Hi, q as Xi, E as Ki, dn as Ss, a5 as pe, ab as _, as as oo, at as ao, e as Oe, a_ as Dt, ay as ro, az as io, au as ji, av as Ft, aD as Ge, ae as co, aG as ws, aH as qi, U as G, af as _e, aI as Yi, H as Ns, aJ as Rs, R as Qi, aj as Zi, x as te, D as lo, _ as Ne, a9 as ee, bd as uo, c9 as Ts, ca as Es, cQ as po, bo as ho, ah as ue, Q as Ye, bp as fo, bq as mo, cb as go, cc as Ds, cd as Fs, ce as Ps, cf as Os, cg as As, br as xo, ac as nt, cA as Co, cR as bo, cS as Io, bu as yo, bt as ko, bf as $o, dg as vo, C as _s, cU as So, ao as wo, z as No, bv as Ro, c4 as $e, cF as To, bw as Eo, cB as Do, cV as Fo, cC as Po, bx as Ls, by as Vs, bh as Oo, bz as Ao, am as Qe, V as Ji, bA as _o, cD as Lo, ch as Vo, bB as Wo, cH as Mo, cI as Bo, dh as Go, ci as zo, bZ as ns, bT as St, bW as Ws, cW as yn, dv as ut, dw as Uo, cX as kn, di as ec, N as tc, bg as Ho, cY as Xo, bD as Ms, A as Ko, aX as jo, aV as mt, cr as qo, de as Yo, df as Qo, bi as Zo, cG as Jo, dk as ea, al as ta, G as sa, cs as na, cj as Bs, ck as Gs, cl as zs, dl as oa, b5 as Us, b6 as Hs, bF as Xs, cn as Ks, cm as aa, c_ as ra, aM as sc, bG as ia, cE as ca, c$ as la, d0 as ua, dm as da, aP as pa, a$ as ha, co as fa, bS as ma, M as js, I as ga, bk as xa, bm as Ca, bl as ba, bH as Ia, d5 as ya, bI as ka, P as $a, a7 as va, bJ as Sa, d1 as qs, dx as wa, dy as Na, dz as Ra, Z as Ta, cp as Ys, bb as Ea, d2 as Da, bc as Fa, d3 as Pa, bL as Oa, bj as Aa, b8 as Qs, ak as _a, dp as La, aL as Va, bN as Zs, cq as Js, bO as en, bP as tn, bE as sn, bK as Wa, dA as Ma, dB as Ba, dq as Ga, dr as za, ds as Ua, J as Ha, d4 as Xa, aK as nn, ct as Ka, dt as ja, dC as qa, dD as Ya, cu as on, bs as an, du as Qa, T as Za, cv as Ja, bn as er, a6 as rn, cw as tr, ba as sr, bQ as nr, c as or, dE as ar, dF as $n, aw as vn, ax as nc, t as rr, a as oc, dG as ac, dH as rc, c2 as ic, ar as cc, bR as lc, bX as uc, S as dc, bY as pc, aQ as hc, aq as fc, bU as mc, bV as gc, b_ as xc, aS as ir, bC as Cc, aN as bc, b$ as Ic, F as yc, c0 as kc, dj as $c, b1 as vc, b2 as Sc, b3 as wc, b4 as Nc, aO as Rc, c1 as Tc, b7 as Ec, c7 as Dc, ap as Fc, c3 as Pc, aW as cr, bM as Oc, aF as Ac, c5 as _c, b9 as Lc, c6 as Vc, k as Wc } from "./index-D0b5F1JD.js";
|
|
3
|
+
import { n as Mc } from "./random_width-CliSj-et.js";
|
|
4
|
+
import { t as Bc } from "./zeros_like-DfWM-ezN.js";
|
|
5
5
|
import "./index-Cp39cXWe.js";
|
|
6
|
-
import "./dataset-
|
|
7
|
-
import { a as j, u as ae, c as ot, i as at, b as Gc, d as wt, t as Re, e as gt, f as dt, g as lr, r as Nt, h as Ae, j as zc, k as Uc, l as cn, z as Hc, m as ln, n as ur, o as Xc, p as Kc, q as jc, v as qc, w as Yc, x as Qc, y as Zc, A as Jc, B as el, C as tl, D as lt, E as sl, F as nl, G as dr, H as ol, I as al, J as rl, K as il, L as cl, M as ll, N as ul, O as dl, P as pl, Q as hl, R as fl, S as ml, T as gl, U as xl, V as Cl, W as bl, X as Il, Y as yl, Z as kl, _ as $l, $ as vl, a0 as Sl, a1 as wl, a2 as Nl, a3 as Rl, a4 as Tl, a5 as El, a6 as Dl, a7 as Fl, a8 as Pl, a9 as Ol, aa as Al, ab as _l, ac as Ll, ad as Vl, ae as Wl, af as Ml, ag as Bl, ah as Gl, ai as zl } from "./shared-
|
|
6
|
+
import "./dataset-Bwcib9pp.js";
|
|
7
|
+
import { a as j, u as ae, c as ot, i as at, b as Gc, d as wt, t as Re, e as gt, f as dt, g as lr, r as Nt, h as Ae, j as zc, k as Uc, l as cn, z as Hc, m as ln, n as ur, o as Xc, p as Kc, q as jc, v as qc, w as Yc, x as Qc, y as Zc, A as Jc, B as el, C as tl, D as lt, E as sl, F as nl, G as dr, H as ol, I as al, J as rl, K as il, L as cl, M as ll, N as ul, O as dl, P as pl, Q as hl, R as fl, S as ml, T as gl, U as xl, V as Cl, W as bl, X as Il, Y as yl, Z as kl, _ as $l, $ as vl, a0 as Sl, a1 as wl, a2 as Nl, a3 as Rl, a4 as Tl, a5 as El, a6 as Dl, a7 as Fl, a8 as Pl, a9 as Ol, aa as Al, ab as _l, ac as Ll, ad as Vl, ae as Wl, af as Ml, ag as Bl, ah as Gl, ai as zl } from "./shared-DgNUoqSc.js";
|
|
8
8
|
import { m as pt, g as pr, s as Ul, c as Hl, b as Xl, d as Kl, a as jl, e as ql } from "./complex_util-Yc1A_gV1.js";
|
|
9
|
-
import { a as ge, b as xe, d as ke, c as ve, e as Te, g as os } from "./axis_util-
|
|
10
|
-
import { k as Ze, h as Le, i as Je, j as rt, b as Se, d as xt, g as as } from "./step-
|
|
11
|
-
import { z as rs, A as is, B as cs, C as hr, D as fr, F as mr, G as gr, H as xr, I as Cr, J as br, y as Ir, x as yr, w as kr, u as $r, t as vr, E as Sr, K as wr, L as Nr, M as Rr, N as Tr, c as Er, f as Yl, O as Ql, P as Zl } from "./backend_util-
|
|
12
|
-
import { a as Dr, c as Ue } from "./concat_util-
|
|
9
|
+
import { a as ge, b as xe, d as ke, c as ve, e as Te, g as os } from "./axis_util-BBaWKQoo.js";
|
|
10
|
+
import { k as Ze, h as Le, i as Je, j as rt, b as Se, d as xt, g as as } from "./step-wz0MZ7BP.js";
|
|
11
|
+
import { z as rs, A as is, B as cs, C as hr, D as fr, F as mr, G as gr, H as xr, I as Cr, J as br, y as Ir, x as yr, w as kr, u as $r, t as vr, E as Sr, K as wr, L as Nr, M as Rr, N as Tr, c as Er, f as Yl, O as Ql, P as Zl } from "./backend_util-DLIicY0X.js";
|
|
12
|
+
import { a as Dr, c as Ue } from "./concat_util-DT0Mofs3.js";
|
|
13
13
|
import { s as Jl } from "./index-CieiGp4Y.js";
|
|
14
14
|
import { n as Fr, b as Pr, a as Or } from "./non_max_suppression_impl-B2W7YjZB.js";
|
|
15
|
-
import { c as Ct } from "./scatter_nd_util-
|
|
16
|
-
import { S as Ar, a as _r } from "./selu_util-
|
|
17
|
-
import { b as Lr, d as Vr, p as eu, a as tu, i as su, c as nu } from "./slice_util-
|
|
18
|
-
import { h as Sn, j as ou, k as au, l as ru, m as iu, n as cu, o as lu, P as un, p as Ve, u as Pe, q as Wr, c as Mr, T as De, E as Br, g as Gr, a as zr, r as uu, s as du, t as Y, v as Pt, w as pu, x as wn, y as hu, z as fu, A as Ot, B as mu, C as gu, D as bs, F as Gt, G as zt, H as xu, I as Cu, J as Nn, K as bu, L as Iu, M as fs, N as yu, O as ku, Q as $u, R as Ut, S as ms, U as vu, f as he, V as be, W as Ht, X as Xt, Y as Su, d as Rn, e as Tn, i as Ur, Z as wu, _ as Nu, $ as Ru, a0 as Tu, a1 as Eu, a2 as Du, a3 as At } from "./gpgpu_math-
|
|
19
|
-
import { s as Hr, a as Fu, t as Xr, b as Pu, c as Ou, d as Kr, e as Au, n as _u, f as Lu, g as Vu, h as Wu, i as Mu, j as Bu, k as Gu, l as zu, o as Uu, p as Hu, q as Xu, r as Ku, u as ju, v as qu, w as Yu, x as Qu, y as Zu, z as Ju, A as ed, B as td, C as sd, D as nd, E as od, F as ad, G as rd, H as id, I as cd, J as ld, K as ud, L as dd, M as jr, N as pd, O as hd, P as fd, Q as md, R as gd, S as xd, T as Cd, U as bd, V as Id, W as yd } from "./shared-
|
|
20
|
-
import { a as ye, c as kd, U as st, d as qe, e as ze, A as En, f as bt, B as dn, h as pn, m as Rt, u as se, C as We, b as Ce, i as Fe, j as hn, k as it, l as It, n as $d, o as vd, p as Sd, q as wd } from "./kernel_funcs_utils-
|
|
21
|
-
import { R as Nd, r as U, a as Rd } from "./Reshape-
|
|
22
|
-
import { M as qr } from "./matMulGelu-
|
|
23
|
-
import { t as Yr, s as fn, a as _t, m as Td, r as Ed, b as Dd, c as Fd, d as Pd } from "./RealDiv-
|
|
24
|
-
import { z as Od } from "./zeros-
|
|
15
|
+
import { c as Ct } from "./scatter_nd_util-SSoGmfpx.js";
|
|
16
|
+
import { S as Ar, a as _r } from "./selu_util-C0DN3KhX.js";
|
|
17
|
+
import { b as Lr, d as Vr, p as eu, a as tu, i as su, c as nu } from "./slice_util-DK4kHJjN.js";
|
|
18
|
+
import { h as Sn, j as ou, k as au, l as ru, m as iu, n as cu, o as lu, P as un, p as Ve, u as Pe, q as Wr, c as Mr, T as De, E as Br, g as Gr, a as zr, r as uu, s as du, t as Y, v as Pt, w as pu, x as wn, y as hu, z as fu, A as Ot, B as mu, C as gu, D as bs, F as Gt, G as zt, H as xu, I as Cu, J as Nn, K as bu, L as Iu, M as fs, N as yu, O as ku, Q as $u, R as Ut, S as ms, U as vu, f as he, V as be, W as Ht, X as Xt, Y as Su, d as Rn, e as Tn, i as Ur, Z as wu, _ as Nu, $ as Ru, a0 as Tu, a1 as Eu, a2 as Du, a3 as At } from "./gpgpu_math-D83bWKYw.js";
|
|
19
|
+
import { s as Hr, a as Fu, t as Xr, b as Pu, c as Ou, d as Kr, e as Au, n as _u, f as Lu, g as Vu, h as Wu, i as Mu, j as Bu, k as Gu, l as zu, o as Uu, p as Hu, q as Xu, r as Ku, u as ju, v as qu, w as Yu, x as Qu, y as Zu, z as Ju, A as ed, B as td, C as sd, D as nd, E as od, F as ad, G as rd, H as id, I as cd, J as ld, K as ud, L as dd, M as jr, N as pd, O as hd, P as fd, Q as md, R as gd, S as xd, T as Cd, U as bd, V as Id, W as yd } from "./shared-CefTy5O1.js";
|
|
20
|
+
import { a as ye, c as kd, U as st, d as qe, e as ze, A as En, f as bt, B as dn, h as pn, m as Rt, u as se, C as We, b as Ce, i as Fe, j as hn, k as it, l as It, n as $d, o as vd, p as Sd, q as wd } from "./kernel_funcs_utils-Bu6bS4D_.js";
|
|
21
|
+
import { R as Nd, r as U, a as Rd } from "./Reshape-CKzb2DIN.js";
|
|
22
|
+
import { M as qr } from "./matMulGelu-Cbtq3pxJ.js";
|
|
23
|
+
import { t as Yr, s as fn, a as _t, m as Td, r as Ed, b as Dd, c as Fd, d as Pd } from "./RealDiv-CJpH9Bif.js";
|
|
24
|
+
import { z as Od } from "./zeros-B8VPk-mx.js";
|
|
25
25
|
import "./ops/cpu/attentionMask.js";
|
|
26
26
|
import "./ops/webgl/attentionMask.js";
|
|
27
27
|
import "./ops/grads/attentionMask.js";
|
|
28
28
|
import "./ops/cpu/rope.js";
|
|
29
29
|
import "./ops/webgl/rope.js";
|
|
30
|
-
import "./rope-
|
|
30
|
+
import "./rope-DjON_IMj.js";
|
|
31
31
|
import "./ops/cpu/appendCache.js";
|
|
32
32
|
import "./ops/webgl/appendCache.js";
|
|
33
33
|
import "./ops/grads/softmax16.js";
|
|
34
|
-
import "./matMul16-
|
|
34
|
+
import "./matMul16-bI7XM831.js";
|
|
35
35
|
import "./ops/webgl/matMul16.js";
|
|
36
36
|
import "./ops/cpu/matMul16.js";
|
|
37
|
-
import "./pack16-
|
|
37
|
+
import "./pack16-DO9GrRdk.js";
|
|
38
38
|
import "./ops/transpose16.js";
|
|
39
39
|
import "./ops/reshape16.js";
|
|
40
40
|
import "./ops/cpu/qkv.js";
|
|
@@ -43,6 +43,8 @@ import "./ops/grads/qkv.js";
|
|
|
43
43
|
import "./ops/cpu/normRMS.js";
|
|
44
44
|
import "./ops/webgl/normRMS.js";
|
|
45
45
|
import "./ops/grads/normRMS.js";
|
|
46
|
+
import "./ops/dropout16.js";
|
|
47
|
+
import "./ops/webgl/dropout16.js";
|
|
46
48
|
import "./ops/grads/add16.js";
|
|
47
49
|
import "./jszip.min-Bz5-11Bk.js";
|
|
48
50
|
import Ad from "./tokeniser/CharTokeniser.js";
|
|
@@ -62,17 +64,17 @@ import "./ops/cpu/matMulGelu.js";
|
|
|
62
64
|
import "./ops/grads/matMulGelu.js";
|
|
63
65
|
import "./ops/cpu/gelu.js";
|
|
64
66
|
import "./ops/webgl/gelu.js";
|
|
65
|
-
import "./gelu-
|
|
67
|
+
import "./gelu-CmkPheOK.js";
|
|
66
68
|
import "./ops/webgl/log.js";
|
|
67
69
|
import "./checks/normRMS.js";
|
|
68
70
|
import "./checks/normRMSGrad.js";
|
|
69
71
|
import Wd from "./utilities/multinomialCPU.js";
|
|
70
|
-
import { r as Dn } from "./reshape-
|
|
71
|
-
import { t as Kt } from "./tensor2d-
|
|
72
|
-
import { z as Md } from "./unsorted_segment_sum-
|
|
73
|
-
import { s as gs } from "./softmax-
|
|
74
|
-
import { g as Bd } from "./gather-
|
|
75
|
-
import { c as Gd } from "./concat-
|
|
72
|
+
import { r as Dn } from "./reshape-DVh8yLpI.js";
|
|
73
|
+
import { t as Kt } from "./tensor2d-C8gFDiIC.js";
|
|
74
|
+
import { z as Md } from "./unsorted_segment_sum-DRVX2bX2.js";
|
|
75
|
+
import { s as gs } from "./softmax-HULrSwJC.js";
|
|
76
|
+
import { g as Bd } from "./gather-FIoUa4Zd.js";
|
|
77
|
+
import { c as Gd } from "./concat-DCm6KW65.js";
|
|
76
78
|
function zd(a, t, e, n = !1) {
|
|
77
79
|
const s = Xi(a, "logits", "multinomial"), o = s.size, r = s.rank;
|
|
78
80
|
if (o < 2)
|
|
@@ -11676,7 +11678,7 @@ const lv = [
|
|
|
11676
11678
|
function uv(a, t) {
|
|
11677
11679
|
return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
|
|
11678
11680
|
}
|
|
11679
|
-
class
|
|
11681
|
+
class OS extends Ui {
|
|
11680
11682
|
constructor(t, e) {
|
|
11681
11683
|
super(), this.model = t, this.tokeniser = e, this.actualTokeniser = e;
|
|
11682
11684
|
}
|
|
@@ -11872,6 +11874,6 @@ class FS extends Ui {
|
|
|
11872
11874
|
}
|
|
11873
11875
|
}
|
|
11874
11876
|
export {
|
|
11875
|
-
|
|
11877
|
+
OS as default,
|
|
11876
11878
|
cv as isConversation
|
|
11877
11879
|
};
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import { aE as E,
|
|
2
|
-
import { r as $ } from "./Reshape-
|
|
3
|
-
import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-
|
|
4
|
-
import { t as
|
|
5
|
-
import { c as j } from "./backend_util-
|
|
6
|
-
import { f as y } from "./gpgpu_math-
|
|
7
|
-
import { g as G, b as L } from "./kernel_funcs_utils-
|
|
1
|
+
import { aE as E, ab as T, ah as O, U as V, aW as B, N as F, aM as U, aX as W } from "./index-D0b5F1JD.js";
|
|
2
|
+
import { r as $ } from "./Reshape-CKzb2DIN.js";
|
|
3
|
+
import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-BBaWKQoo.js";
|
|
4
|
+
import { t as K, m as _ } from "./shared-CefTy5O1.js";
|
|
5
|
+
import { c as j } from "./backend_util-DLIicY0X.js";
|
|
6
|
+
import { f as y } from "./gpgpu_math-D83bWKYw.js";
|
|
7
|
+
import { g as G, b as L } from "./kernel_funcs_utils-Bu6bS4D_.js";
|
|
8
8
|
class w {
|
|
9
9
|
constructor(s, e) {
|
|
10
10
|
this.variableNames = ["x"];
|
|
@@ -273,7 +273,7 @@ function Q(a, s, e, t) {
|
|
|
273
273
|
const [p, h] = N(u.shape, i);
|
|
274
274
|
let d = p;
|
|
275
275
|
e && (d = R(p, r));
|
|
276
|
-
const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }),
|
|
276
|
+
const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), S = B(a.dtype), I = M(x, S, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
|
|
277
277
|
return t.disposeIntermediateTensorInfo(x), t.disposeIntermediateTensorInfo(I), o && t.disposeIntermediateTensorInfo(u), m;
|
|
278
278
|
}
|
|
279
279
|
function Z(a) {
|
|
@@ -299,7 +299,7 @@ function te(a) {
|
|
|
299
299
|
const I = e.texData.get(d.dataId).values, m = new Array(i);
|
|
300
300
|
for (let v = 0; v < m.length; v++)
|
|
301
301
|
m[v] = n.shape[u[v]];
|
|
302
|
-
const z =
|
|
302
|
+
const z = K(I, n.shape, n.dtype, u, m);
|
|
303
303
|
d = e.makeTensorInfo(m, n.dtype);
|
|
304
304
|
const D = e.texData.get(d.dataId);
|
|
305
305
|
D.values = z;
|
|
@@ -308,21 +308,21 @@ function te(a) {
|
|
|
308
308
|
o = k(o.length, i);
|
|
309
309
|
}
|
|
310
310
|
C("max", o, i);
|
|
311
|
-
const [f,
|
|
311
|
+
const [f, b] = N(d.shape, o);
|
|
312
312
|
let g = f;
|
|
313
313
|
r && (g = R(f, c));
|
|
314
314
|
let x;
|
|
315
315
|
if (h) {
|
|
316
|
-
const I = e.texData.get(d.dataId).values, m = _(I, V(
|
|
316
|
+
const I = e.texData.get(d.dataId).values, m = _(I, V(b), g, n.dtype);
|
|
317
317
|
x = e.makeTensorInfo(g, n.dtype);
|
|
318
318
|
const z = e.texData.get(x.dataId);
|
|
319
319
|
z.values = m;
|
|
320
320
|
} else
|
|
321
|
-
x = ee(d,
|
|
321
|
+
x = ee(d, b, g, e);
|
|
322
322
|
return p && e.disposeIntermediateTensorInfo(d), x;
|
|
323
323
|
}
|
|
324
324
|
const he = {
|
|
325
|
-
kernelName:
|
|
325
|
+
kernelName: U,
|
|
326
326
|
backendName: "webgl",
|
|
327
327
|
kernelFunc: te
|
|
328
328
|
};
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { U as h, aj as d, x as c, R as m } from "./index-D0b5F1JD.js";
|
|
2
2
|
function i(n) {
|
|
3
3
|
const { inputs: p, attrs: o } = n, { x: e } = p, { shape: r } = o, a = h(e.shape), s = d(r, a), t = h(s);
|
|
4
4
|
return c(a === t, () => `The new shape (${s}) has ${t} elements and the old shape (${e.shape}) has ${a} elements. The new shape and old shape must have the same number of elements.`), n.backend.incRef(e.dataId), { dataId: e.dataId, shape: s, dtype: e.dtype };
|
|
5
5
|
}
|
|
6
|
-
const
|
|
6
|
+
const u = {
|
|
7
7
|
kernelName: m,
|
|
8
8
|
backendName: "webgpu",
|
|
9
9
|
kernelFunc: i
|
|
10
10
|
};
|
|
11
11
|
export {
|
|
12
|
-
|
|
12
|
+
u as a,
|
|
13
13
|
i as r
|
|
14
14
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { R as C,
|
|
2
|
-
import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-
|
|
1
|
+
import { R as C, U as c, aj as R, x as f } from "./index-D0b5F1JD.js";
|
|
2
|
+
import { u as g, g as I, a as x, b as F, c as $, d as u, e as m, i as l } from "./gpgpu_math-D83bWKYw.js";
|
|
3
3
|
class S {
|
|
4
4
|
constructor(t, i) {
|
|
5
5
|
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = t, this.enableShapeUniforms = g(this.outputShape.length);
|
|
@@ -62,8 +62,8 @@ function b(s, t, i) {
|
|
|
62
62
|
return { dataId: h.dataId, shape: t, dtype: h.dtype };
|
|
63
63
|
}
|
|
64
64
|
function y(s) {
|
|
65
|
-
const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n =
|
|
66
|
-
|
|
65
|
+
const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = c(e.shape), n = R(o, p), h = c(n);
|
|
66
|
+
f(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
|
|
67
67
|
const d = r.texData.get(e.dataId);
|
|
68
68
|
return d.isPacked && !l(e.shape, n) && !(d.texture !== null && l(d.shape, n)) ? b(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
|
|
69
69
|
}
|
package/dist/TeachableLLM.d.ts
CHANGED
|
@@ -8,6 +8,7 @@ import { default as MemoryProfiler } from './utilities/profile';
|
|
|
8
8
|
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
9
9
|
import { Task } from './training/tasks/Task';
|
|
10
10
|
import { TrainingLogEntry, TrainingOptions } from './training/types';
|
|
11
|
+
import { ModelPhase } from './loader/types';
|
|
11
12
|
type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
|
|
12
13
|
interface TeachableLLMMeta {
|
|
13
14
|
name?: string;
|
|
@@ -26,6 +27,8 @@ export default class TeachableLLM {
|
|
|
26
27
|
private _trainer;
|
|
27
28
|
constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes, GPTConfig>);
|
|
28
29
|
get vocab(): string[];
|
|
30
|
+
get phase(): ModelPhase;
|
|
31
|
+
set phase(phase: ModelPhase);
|
|
29
32
|
/** Model is fully loaded */
|
|
30
33
|
get loaded(): boolean;
|
|
31
34
|
get config(): GPTConfig;
|
|
@@ -52,10 +55,12 @@ export default class TeachableLLM {
|
|
|
52
55
|
generateText(options?: IGenerateOptions): Promise<Conversation[]>;
|
|
53
56
|
dispose(): void;
|
|
54
57
|
on(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
|
|
58
|
+
on(event: 'phase', listener: (phase: ModelPhase) => void): void;
|
|
55
59
|
on(event: 'error', listener: (error: Error) => void): void;
|
|
56
60
|
on(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
|
|
57
61
|
on(event: 'loaded', listener: () => void): void;
|
|
58
62
|
off(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
|
|
63
|
+
off(event: 'phase', listener: (phase: ModelPhase) => void): void;
|
|
59
64
|
off(event: 'error', listener: (error: Error) => void): void;
|
|
60
65
|
off(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
|
|
61
66
|
off(event: 'loaded', listener: () => void): void;
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import { validateConfig as m } from "./models/config.js";
|
|
2
2
|
import { saveModel as d } from "./loader/save.js";
|
|
3
|
-
import { loadModel as
|
|
4
|
-
import
|
|
3
|
+
import { loadModel as p } from "./loader/load.js";
|
|
4
|
+
import u from "./Generator.js";
|
|
5
5
|
import h from "./Trainer.js";
|
|
6
6
|
import { E as f } from "./index-DvYrXKkX.js";
|
|
7
7
|
import { dummyPassTrainAsync as l } from "./utilities/dummy.js";
|
|
8
|
-
import "./index-
|
|
9
|
-
import "./random_width-
|
|
10
|
-
import "./zeros_like-
|
|
8
|
+
import "./index-D0b5F1JD.js";
|
|
9
|
+
import "./random_width-CliSj-et.js";
|
|
10
|
+
import "./zeros_like-DfWM-ezN.js";
|
|
11
11
|
import "./index-Cp39cXWe.js";
|
|
12
|
-
import "./dataset-
|
|
12
|
+
import "./dataset-Bwcib9pp.js";
|
|
13
13
|
import "./ops/cpu/attentionMask.js";
|
|
14
14
|
import "./ops/webgl/attentionMask.js";
|
|
15
15
|
import "./ops/grads/attentionMask.js";
|
|
16
16
|
import "./ops/cpu/rope.js";
|
|
17
17
|
import "./ops/webgl/rope.js";
|
|
18
|
-
import "./rope-
|
|
18
|
+
import "./rope-DjON_IMj.js";
|
|
19
19
|
import "./ops/cpu/appendCache.js";
|
|
20
20
|
import "./ops/webgl/appendCache.js";
|
|
21
21
|
import "./ops/grads/softmax16.js";
|
|
22
|
-
import "./matMul16-
|
|
22
|
+
import "./matMul16-bI7XM831.js";
|
|
23
23
|
import "./ops/webgl/matMul16.js";
|
|
24
24
|
import "./ops/cpu/matMul16.js";
|
|
25
|
-
import "./pack16-
|
|
25
|
+
import "./pack16-DO9GrRdk.js";
|
|
26
26
|
import "./ops/transpose16.js";
|
|
27
27
|
import "./ops/reshape16.js";
|
|
28
28
|
import "./ops/cpu/qkv.js";
|
|
@@ -31,6 +31,8 @@ import "./ops/grads/qkv.js";
|
|
|
31
31
|
import "./ops/cpu/normRMS.js";
|
|
32
32
|
import "./ops/webgl/normRMS.js";
|
|
33
33
|
import "./ops/grads/normRMS.js";
|
|
34
|
+
import "./ops/dropout16.js";
|
|
35
|
+
import "./ops/webgl/dropout16.js";
|
|
34
36
|
import "./ops/grads/add16.js";
|
|
35
37
|
import c from "./tokeniser/CharTokeniser.js";
|
|
36
38
|
import g from "./tokeniser/bpe.js";
|
|
@@ -41,11 +43,11 @@ import "./ops/webgl/gatherSub.js";
|
|
|
41
43
|
import "./ops/cpu/scatterSub.js";
|
|
42
44
|
import "./ops/webgl/scatterSub.js";
|
|
43
45
|
import "./ops/cpu/matMulGelu.js";
|
|
44
|
-
import "./matMulGelu-
|
|
46
|
+
import "./matMulGelu-Cbtq3pxJ.js";
|
|
45
47
|
import "./ops/grads/matMulGelu.js";
|
|
46
48
|
import "./ops/cpu/gelu.js";
|
|
47
49
|
import "./ops/webgl/gelu.js";
|
|
48
|
-
import "./gelu-
|
|
50
|
+
import "./gelu-CmkPheOK.js";
|
|
49
51
|
import "./ops/webgl/log.js";
|
|
50
52
|
import "./ops/cpu/adamMoments.js";
|
|
51
53
|
import "./ops/webgl/adamMoments.js";
|
|
@@ -70,6 +72,14 @@ class a {
|
|
|
70
72
|
get vocab() {
|
|
71
73
|
return this._tokeniser?.getVocab() || [];
|
|
72
74
|
}
|
|
75
|
+
get phase() {
|
|
76
|
+
return this._model?.metaData?.phase ?? "untrained";
|
|
77
|
+
}
|
|
78
|
+
set phase(t) {
|
|
79
|
+
if (!this._model)
|
|
80
|
+
throw new Error("model_not_initialized.");
|
|
81
|
+
this._model.metaData.phase = t, this.ee.emit("phase", t);
|
|
82
|
+
}
|
|
73
83
|
/** Model is fully loaded */
|
|
74
84
|
get loaded() {
|
|
75
85
|
return !!this._model && !!this._tokeniser && !!this._config;
|
|
@@ -116,9 +126,9 @@ class a {
|
|
|
116
126
|
}
|
|
117
127
|
static loadModel(t, r) {
|
|
118
128
|
const e = new a();
|
|
119
|
-
return
|
|
129
|
+
return p(t, r).then(({ model: o, tokeniser: n, metaData: i }) => {
|
|
120
130
|
m(o.config), e._model = o, e._tokeniser = n, e._config = o.config, i?.name && (e.meta.name = i.name), e.setStatus("warmup"), l(o).then((s) => {
|
|
121
|
-
e._memoryRequirements = s, e.setStatus("ready"), e.ee.emit("loaded");
|
|
131
|
+
e._memoryRequirements = s, e.setStatus("ready"), e.ee.emit("loaded"), e.ee.emit("phase", e.phase);
|
|
122
132
|
}).catch((s) => {
|
|
123
133
|
e.setStatus("error"), e.ee.emit("error", s), console.error("Error during warmup:", s);
|
|
124
134
|
});
|
|
@@ -130,7 +140,7 @@ class a {
|
|
|
130
140
|
m(r);
|
|
131
141
|
const e = r, o = t === "char" ? new c(e.vocabSize) : new g(e.vocabSize), n = k(e), i = new a(o, n);
|
|
132
142
|
return i.setStatus("warmup"), l(n).then((s) => {
|
|
133
|
-
i._memoryRequirements = s, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded")) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.tokeniser.once("trainStatus", (_) => {
|
|
143
|
+
i._memoryRequirements = s, i.tokeniser.trained ? (i.setStatus("ready"), i.ee.emit("loaded"), i.ee.emit("phase", i.phase)) : (i.setStatus("awaitingTokens"), i.ee.emit("loaded"), i.ee.emit("phase", i.phase), i.tokeniser.once("trainStatus", (_) => {
|
|
134
144
|
_ === "trained" && i.setStatus("ready");
|
|
135
145
|
}));
|
|
136
146
|
}).catch((s) => {
|
|
@@ -159,11 +169,13 @@ class a {
|
|
|
159
169
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
160
170
|
this._trainer && t && this._trainer.trainingType !== t && (this._trainer.dispose(), this._trainer = null);
|
|
161
171
|
const e = this._trainer === null ? new h(this._model, this._tokeniser, t, r) : new h(this._trainer, r);
|
|
162
|
-
return e.on("start", () =>
|
|
172
|
+
return e.on("start", () => {
|
|
173
|
+
this.setStatus("training"), this.phase = t === "sft" ? "finetuned" : "pretrained";
|
|
174
|
+
}), e.on("stop", () => this.setStatus("ready")), e.on("log", async (o) => {
|
|
163
175
|
const n = this.ee.listeners("trainStep");
|
|
164
176
|
for (const i of n)
|
|
165
177
|
await i(o);
|
|
166
|
-
}), this._trainer = e, e;
|
|
178
|
+
}), this._trainer && this._trainer !== e && this._trainer.dispose(), this._trainer = e, e;
|
|
167
179
|
}
|
|
168
180
|
async train(t, r, e) {
|
|
169
181
|
const o = this.trainer(e, r);
|
|
@@ -178,7 +190,7 @@ class a {
|
|
|
178
190
|
generator() {
|
|
179
191
|
if (!this._model || !this._tokeniser)
|
|
180
192
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
181
|
-
const t = new
|
|
193
|
+
const t = new u(this._model, this._tokeniser);
|
|
182
194
|
return t.on("start", () => {
|
|
183
195
|
this.status === "ready" && this.setStatus("busy");
|
|
184
196
|
}), t.on("stop", () => {
|
|
@@ -189,7 +201,7 @@ class a {
|
|
|
189
201
|
return Array.isArray(t) ? this.generator().generate(t, r) : this.generator().generate([], r);
|
|
190
202
|
}
|
|
191
203
|
dispose() {
|
|
192
|
-
this._model?.dispose(), this.ee.removeAllListeners();
|
|
204
|
+
this._trainer && (this._trainer.dispose(), this._trainer = null), this._model?.dispose(), this.ee.removeAllListeners();
|
|
193
205
|
}
|
|
194
206
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
195
207
|
on(t, r) {
|
package/dist/Trainer.d.ts
CHANGED
|
@@ -20,6 +20,7 @@ export default class Trainer extends EE<'start' | 'stop' | 'log'> {
|
|
|
20
20
|
log: TrainingLogEntry[];
|
|
21
21
|
private progress;
|
|
22
22
|
options: TrainingOptions;
|
|
23
|
+
protected tokenizer: ITokeniser;
|
|
23
24
|
constructor(model: Model<ModelForwardAttributes>, tokeniser: ITokeniser, trainingType?: TrainingType, options?: TrainingOptions);
|
|
24
25
|
constructor(trainer: Trainer, options?: TrainingOptions);
|
|
25
26
|
get model(): Model<ModelForwardAttributes>;
|
package/dist/Trainer.js
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import { E as g } from "./index-DvYrXKkX.js";
|
|
2
|
-
import
|
|
2
|
+
import o from "./training/PreTrainer.js";
|
|
3
3
|
import { createTrainValidationSplit as p } from "./training/validation.js";
|
|
4
|
-
import
|
|
5
|
-
class
|
|
4
|
+
import h from "./training/SFTTrainer.js";
|
|
5
|
+
class l extends g {
|
|
6
6
|
trainer;
|
|
7
7
|
trainingType = "pretraining";
|
|
8
8
|
hasTrained = !1;
|
|
@@ -16,19 +16,22 @@ class o extends g {
|
|
|
16
16
|
sftMode: "full",
|
|
17
17
|
logInterval: 10
|
|
18
18
|
};
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
tokenizer;
|
|
20
|
+
constructor(t, i, e = "pretraining", a) {
|
|
21
|
+
if (super(), t instanceof l) {
|
|
22
|
+
const r = i || t.options, n = t.options;
|
|
23
|
+
let s = !1;
|
|
24
|
+
t.trainingType === "sft" && r.sftMode !== n.sftMode && (s = !0), e !== t.trainingType && (s = !0), s ? (t.trainingType === "sft" ? this.trainer = new h(t.model, t.tokenizer, r) : this.trainer = new o(t.model, t.tokenizer, r), this.trainingType = e, this.options = r, this.tokenizer = t.tokenizer) : (this.trainer = t.trainer, this.trainingType = e, this.options = r, this.trainer.updateOptimizer(this.options), this.log = t.log, this.progress = t.progress, this.totalSamples = t.totalSamples, this.tokenizer = t.tokenizer, r.batchSize === n.batchSize && (this.trainDataset = t.trainDataset, this.validationDataset = t.validationDataset));
|
|
22
25
|
return;
|
|
23
26
|
}
|
|
24
|
-
if (!t)
|
|
25
|
-
throw new Error("Tokeniser must be provided when initializing Trainer with a model");
|
|
26
27
|
if (!i)
|
|
28
|
+
throw new Error("Tokeniser must be provided when initializing Trainer with a model");
|
|
29
|
+
if (!t)
|
|
27
30
|
throw new Error("Model must be provided when initializing Trainer");
|
|
28
31
|
this.options = a || {
|
|
29
32
|
batchSize: 32,
|
|
30
33
|
sftMode: "full"
|
|
31
|
-
}, e === "sft" ? this.trainer = new
|
|
34
|
+
}, e === "sft" ? this.trainer = new h(t, i, a) : this.trainer = new o(t, i, a), this.trainingType = e, this.tokenizer = i;
|
|
32
35
|
}
|
|
33
36
|
get model() {
|
|
34
37
|
return this.trainer.model;
|
|
@@ -48,110 +51,110 @@ class o extends g {
|
|
|
48
51
|
getTotalSamples() {
|
|
49
52
|
return this.totalSamples;
|
|
50
53
|
}
|
|
51
|
-
setOptions(
|
|
52
|
-
const
|
|
53
|
-
Object.keys(
|
|
54
|
-
(e) =>
|
|
54
|
+
setOptions(t) {
|
|
55
|
+
const i = new Set(
|
|
56
|
+
Object.keys(t).filter(
|
|
57
|
+
(e) => t[e] !== this.options[e]
|
|
55
58
|
)
|
|
56
59
|
);
|
|
57
60
|
if (this.trainer.isRunning) {
|
|
58
|
-
if (
|
|
61
|
+
if (i.has("batchSize"))
|
|
59
62
|
throw new Error("Cannot change batch size during training");
|
|
60
|
-
if (
|
|
63
|
+
if (i.has("sftMode"))
|
|
61
64
|
throw new Error("Cannot change SFT mode during training");
|
|
62
|
-
if (
|
|
65
|
+
if (i.has("loraConfig"))
|
|
63
66
|
throw new Error("Cannot change LoRA configuration during training");
|
|
64
|
-
if (
|
|
67
|
+
if (i.has("validationSplit"))
|
|
65
68
|
throw new Error("Cannot change validation split during training");
|
|
66
|
-
if (
|
|
69
|
+
if (i.has("trainableWeights"))
|
|
67
70
|
throw new Error("Cannot change trainable weights during training");
|
|
68
|
-
if (
|
|
71
|
+
if (i.has("mixedPrecision"))
|
|
69
72
|
throw new Error("Cannot change mixed precision setting during training");
|
|
70
|
-
if (
|
|
73
|
+
if (i.has("gradientCheckpointing"))
|
|
71
74
|
throw new Error("Cannot change gradient checkpointing setting during training");
|
|
72
75
|
}
|
|
73
76
|
this.options = {
|
|
74
77
|
...this.options,
|
|
75
|
-
...
|
|
76
|
-
}, this.trainer.updateOptimizer(this.options),
|
|
78
|
+
...t
|
|
79
|
+
}, this.trainer.updateOptimizer(this.options), i.has("metrics") && this.trainer.setMetrics(t.metrics || []);
|
|
77
80
|
}
|
|
78
|
-
async prepare(
|
|
79
|
-
const
|
|
80
|
-
if (this.trainingType === "pretraining" && this.trainer instanceof
|
|
81
|
-
const { trainDataset: e, validationDataset: a, size: r, trainState:
|
|
82
|
-
|
|
81
|
+
async prepare(t = []) {
|
|
82
|
+
const i = this.options;
|
|
83
|
+
if (this.trainingType === "pretraining" && this.trainer instanceof o) {
|
|
84
|
+
const { trainDataset: e, validationDataset: a, size: r, trainState: n } = await p(
|
|
85
|
+
t,
|
|
83
86
|
this.trainer.tokenizer,
|
|
84
87
|
this.trainer.datasetBuilder,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
),
|
|
88
|
-
this.trainDataset = e, this.validationDataset = a, this.totalSamples =
|
|
89
|
-
} else if (this.trainingType === "sft" && this.trainer instanceof
|
|
90
|
-
if (
|
|
88
|
+
i?.batchSize || 32,
|
|
89
|
+
i?.validationSplit || 0.1
|
|
90
|
+
), s = r * (1 - (i?.validationSplit || 0));
|
|
91
|
+
this.trainDataset = e, this.validationDataset = a, this.totalSamples = s, this.options.epochSteps = Math.ceil(n.shuffledIndexes.length / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
|
|
92
|
+
} else if (this.trainingType === "sft" && this.trainer instanceof h) {
|
|
93
|
+
if (t instanceof Uint16Array)
|
|
91
94
|
throw new Error("SFT training requires Task[] input");
|
|
92
95
|
const e = await this.trainer.datasetBuilder.createSFTDataset(
|
|
93
|
-
|
|
94
|
-
|
|
96
|
+
t,
|
|
97
|
+
i?.batchSize || 32,
|
|
95
98
|
-100
|
|
96
99
|
);
|
|
97
|
-
this.trainDataset = e, this.totalSamples =
|
|
100
|
+
this.trainDataset = e, this.totalSamples = t.reduce((a, r) => a + r.length, 0), this.options.epochSteps = Math.ceil(this.totalSamples / (i?.batchSize || 32)), this.trainer.updateOptimizer(this.options);
|
|
98
101
|
}
|
|
99
102
|
}
|
|
100
|
-
configureModel(
|
|
101
|
-
const
|
|
103
|
+
configureModel(t) {
|
|
104
|
+
const i = t?.sftMode || "full";
|
|
102
105
|
if (this.trainingType === "pretraining" && (this.trainer.model.hasLoRA() && this.trainer.model.detachLoRA(), this.trainer.model.weightStore.setTrainable(["*"])), this.trainingType === "sft") {
|
|
103
|
-
if (
|
|
104
|
-
if (!
|
|
106
|
+
if (i === "lora") {
|
|
107
|
+
if (!t?.loraConfig)
|
|
105
108
|
throw new Error("LoRA configuration must be provided for lora mode");
|
|
106
109
|
if (this.trainer.model.hasLoRA()) {
|
|
107
110
|
const e = this.trainer.model.lora;
|
|
108
|
-
(e.alpha !==
|
|
111
|
+
(e.alpha !== t.loraConfig.alpha || e.rank !== t.loraConfig.rank) && (this.trainer.model.detachLoRA(), this.trainer.model.attachLoRA(t.loraConfig));
|
|
109
112
|
} else
|
|
110
|
-
this.trainer.model.attachLoRA(
|
|
113
|
+
this.trainer.model.attachLoRA(t.loraConfig);
|
|
111
114
|
} else
|
|
112
115
|
this.trainer.model.hasLoRA() && this.trainer.model.detachLoRA();
|
|
113
|
-
|
|
116
|
+
i === "last-layer" ? this.trainer.model.weightStore.setTrainable([
|
|
114
117
|
`block_${this.trainer.model.config.nLayer - 1}_*`,
|
|
115
118
|
"token_embedding"
|
|
116
|
-
]) :
|
|
119
|
+
]) : i === "full" && this.trainer.model.weightStore.setTrainable(["*"]);
|
|
117
120
|
}
|
|
118
|
-
|
|
121
|
+
t?.trainableWeights && this.trainer.model.weightStore.setTrainable(t.trainableWeights);
|
|
119
122
|
}
|
|
120
123
|
async train() {
|
|
121
|
-
const
|
|
124
|
+
const t = this.options;
|
|
122
125
|
if (!this.trainDataset)
|
|
123
126
|
throw new Error("Dataset not prepared");
|
|
124
|
-
this.hasTrained || this.trainer.setLearningRate(
|
|
127
|
+
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), this.trainer.setMixedPrecision(t?.mixedPrecision || !1), this.trainer.setLabelSmoothing(t?.labelSmoothing || 0), this.trainer.setDropout(t?.dropout || 0), this.trainer.setLayerDrop(t?.layerDrop || 0), this.configureModel(t), await this.trainer.trainOnDataset(
|
|
125
128
|
this.trainDataset,
|
|
126
129
|
{
|
|
127
|
-
...
|
|
128
|
-
onStep: async (
|
|
129
|
-
this.log.push(
|
|
130
|
-
lastLog:
|
|
131
|
-
progress:
|
|
130
|
+
...t,
|
|
131
|
+
onStep: async (i) => {
|
|
132
|
+
this.log.push(i), this.progress = {
|
|
133
|
+
lastLog: i,
|
|
134
|
+
progress: i.totalSamples / this.totalSamples,
|
|
132
135
|
remaining: Math.max(
|
|
133
136
|
0,
|
|
134
|
-
(this.totalSamples -
|
|
137
|
+
(this.totalSamples - i.totalSamples) / i.totalSamples * i.duration
|
|
135
138
|
)
|
|
136
139
|
};
|
|
137
140
|
const e = this.listeners("log");
|
|
138
141
|
for (const a of e)
|
|
139
|
-
await a(
|
|
142
|
+
await a(i, this.progress);
|
|
140
143
|
}
|
|
141
144
|
},
|
|
142
145
|
this.validationDataset
|
|
143
146
|
), this.emit("stop");
|
|
144
147
|
}
|
|
145
|
-
async step(
|
|
148
|
+
async step(t) {
|
|
146
149
|
if (!this.trainDataset)
|
|
147
150
|
throw new Error("Dataset not prepared");
|
|
148
|
-
this.hasTrained || this.trainer.setLearningRate(
|
|
149
|
-
const { log:
|
|
151
|
+
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
|
|
152
|
+
const { log: i } = await this.trainer.stepDataset(this.trainDataset, t || {}, this.validationDataset), e = this.listeners("log");
|
|
150
153
|
for (const a of e)
|
|
151
|
-
await a(
|
|
152
|
-
lastLog:
|
|
153
|
-
progress:
|
|
154
|
-
remaining: Math.max(0, (this.totalSamples -
|
|
154
|
+
await a(i, {
|
|
155
|
+
lastLog: i,
|
|
156
|
+
progress: i.totalSamples / this.totalSamples,
|
|
157
|
+
remaining: Math.max(0, (this.totalSamples - i.totalSamples) / i.totalSamples * i.duration)
|
|
155
158
|
});
|
|
156
159
|
this.emit("stop");
|
|
157
160
|
}
|
|
@@ -166,5 +169,5 @@ class o extends g {
|
|
|
166
169
|
}
|
|
167
170
|
}
|
|
168
171
|
export {
|
|
169
|
-
|
|
172
|
+
l as default
|
|
170
173
|
};
|
package/dist/backend.js
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import { g as o, s as e, r as s } from "./index-
|
|
1
|
+
import { g as o, s as e, r as s } from "./index-D0b5F1JD.js";
|
|
2
2
|
async function c(t, a) {
|
|
3
3
|
if (o() !== t) {
|
|
4
4
|
if (t === "webgpu") {
|
|
5
5
|
const { registerWebGPUBackend: i } = await import("./patches/webgpu_base.js");
|
|
6
|
-
i(a), await import("./index-
|
|
6
|
+
i(a), await import("./index-nwvWLdRt.js"), await import("./ops/webgpu/index.js");
|
|
7
7
|
}
|
|
8
8
|
await e(t), await s(), console.log(`Backend set to ${t}`);
|
|
9
9
|
}
|