@genai-fi/nanogpt 0.17.4 → 0.18.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +2 -15
- package/dist/Generator.js +45 -34
- package/dist/{RealDiv-CGwv0liw.js → RealDiv-ioj6Z-ox.js} +9 -9
- package/dist/{Reshape-BW__R4mZ.js → Reshape-BZC-ebeR.js} +7 -7
- package/dist/{Reshape-CPBkTIH2.js → Reshape-pwprEaej.js} +1 -1
- package/dist/TeachableLLM.d.ts +3 -8
- package/dist/TeachableLLM.js +61 -44
- package/dist/Trainer.d.ts +6 -4
- package/dist/Trainer.js +107 -92
- package/dist/{axis_util-GTVlo58H.js → axis_util-QWWgLjut.js} +1 -1
- package/dist/backend.js +2 -2
- package/dist/{backend_util-GaFarB78.js → backend_util-qwSFfxYx.js} +21 -21
- package/dist/{backend_webgpu-BqASlsbV.js → backend_webgpu-DI2wXEC2.js} +8 -8
- package/dist/{broadcast_to-eS93CCN_.js → broadcast_to-C_EJTVTZ.js} +2 -2
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +5 -5
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/matMulGelu.js +2 -2
- package/dist/checks/normRMS.js +6 -6
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.js +6 -6
- package/dist/checks/qkv.js +2 -2
- package/dist/checks/rope.js +2 -2
- package/dist/{clip_by_value-DDA7rrcT.js → clip_by_value-CLAD4h_I.js} +1 -1
- package/dist/complex-3DpPEG9B.js +11 -0
- package/dist/{concat-CAQpCret.js → concat-Dqk7Xk7h.js} +5 -5
- package/dist/{concat_util-D18dJ4fD.js → concat_util-C1Mxe27t.js} +1 -1
- package/dist/{dataset-CGGp1z9P.js → dataset-DlqAN81i.js} +3 -3
- package/dist/{dropout_util--NxWuYg2.js → dropout_util-N0z8Os-K.js} +1 -1
- package/dist/{expand_dims-Bkd1YD5x.js → expand_dims-D0rBtgT1.js} +4 -4
- package/dist/{exports_initializers-CYzKLjN7.js → exports_initializers-DIOZQt_L.js} +1 -1
- package/dist/{floor-BQtb-Azg.js → floor-CymuCmTO.js} +1 -1
- package/dist/{gather-qIqEqaGn.js → gather-DEyjXNb1.js} +1 -1
- package/dist/{gelu-B220X1Go.js → gelu-DpTCC3eB.js} +1 -1
- package/dist/{gpgpu_math-BwvV12df.js → gpgpu_math-3bCb5ooU.js} +25 -25
- package/dist/{index-CjOWnMXP.js → index-BQvB7LCC.js} +15 -15
- package/dist/{index-CUXkjxiT.js → index-DSGwv2Yx.js} +33 -33
- package/dist/inference/types.d.ts +16 -0
- package/dist/inference/types.js +1 -0
- package/dist/{kernel_funcs_utils-pq0CK9co.js → kernel_funcs_utils-DGqzNlHT.js} +6 -6
- package/dist/layers/BaseLayer.js +4 -4
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/LoRA.js +4 -4
- package/dist/layers/MLP.js +4 -4
- package/dist/layers/PositionEmbedding.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +6 -6
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/layers/WeightStore.js +2 -2
- package/dist/loader/load.d.ts +2 -8
- package/dist/loader/loadTransformers.d.ts +2 -8
- package/dist/loader/loadTransformers.js +13 -11
- package/dist/loader/newZipLoad.d.ts +2 -8
- package/dist/loader/newZipLoad.js +25 -10
- package/dist/loader/oldZipLoad.js +13 -13
- package/dist/loader/save.d.ts +9 -2
- package/dist/loader/save.js +64 -55
- package/dist/loader/types.d.ts +29 -1
- package/dist/main.d.ts +2 -0
- package/dist/main.js +45 -43
- package/dist/{matMul16-BcVC_E62.js → matMul16-BIT70Vya.js} +3 -3
- package/dist/{matMulGelu-JNLZqKQp.js → matMulGelu-CsZnh18H.js} +18 -18
- package/dist/mat_mul-DP86qZtZ.js +11 -0
- package/dist/mod-BXjLYwvM.js +11 -0
- package/dist/models/NanoGPTV1.js +2 -2
- package/dist/models/NanoGPTV2.js +2 -2
- package/dist/models/model.d.ts +3 -2
- package/dist/models/model.js +13 -13
- package/dist/{not_equal-hurPF26l.js → not_equal-CkQKkKZy.js} +15 -15
- package/dist/{ones-BytntneX.js → ones-DbVB5N58.js} +3 -3
- package/dist/ops/adamAdjust.js +3 -3
- package/dist/ops/adamMoments.js +3 -3
- package/dist/ops/add16.js +1 -1
- package/dist/ops/appendCache.js +6 -6
- package/dist/ops/attentionMask.js +3 -3
- package/dist/ops/concat16.js +3 -3
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +5 -5
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +6 -6
- package/dist/ops/cpu/fusedSoftmax.js +4 -4
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +4 -4
- package/dist/ops/cpu/matMul16.js +2 -2
- package/dist/ops/cpu/matMulGelu.js +7 -7
- package/dist/ops/cpu/matMulMul.js +2 -2
- package/dist/ops/cpu/mulDropout.js +5 -5
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +5 -5
- package/dist/ops/dot16.js +2 -2
- package/dist/ops/dropout.js +6 -6
- package/dist/ops/dropout16.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/globalNorm.js +7 -7
- package/dist/ops/grads/add16.js +1 -1
- package/dist/ops/grads/attentionMask.js +2 -2
- package/dist/ops/grads/dropout16.js +1 -1
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMul16.js +3 -3
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/mul16.js +1 -1
- package/dist/ops/grads/normRMS.js +7 -7
- package/dist/ops/grads/pack16.js +3 -3
- package/dist/ops/grads/qkv.js +11 -11
- package/dist/ops/grads/rope.js +2 -2
- package/dist/ops/grads/softmax16.js +1 -1
- package/dist/ops/grads/unpack16.js +2 -2
- package/dist/ops/matMul16.js +3 -3
- package/dist/ops/matMulGelu.js +6 -6
- package/dist/ops/matMulMul.js +3 -3
- package/dist/ops/mul16.js +1 -1
- package/dist/ops/mulDrop.js +3 -3
- package/dist/ops/normRMS.js +4 -4
- package/dist/ops/pack16.js +2 -2
- package/dist/ops/qkv.js +3 -3
- package/dist/ops/reshape16.js +6 -6
- package/dist/ops/rope.js +2 -2
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.js +2 -2
- package/dist/ops/softmax16.js +1 -1
- package/dist/ops/sub16.js +1 -1
- package/dist/ops/sum16.js +6 -6
- package/dist/ops/transpose16.js +3 -3
- package/dist/ops/unpack16.js +2 -2
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/dropout16.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +7 -7
- package/dist/ops/webgl/gatherSub.js +3 -3
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMul16.js +13 -13
- package/dist/ops/webgl/matMulGelu.js +4 -4
- package/dist/ops/webgl/matMulMul.js +2 -2
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +2 -2
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/add16.js +6 -6
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +2 -2
- package/dist/ops/webgpu/attentionMask32_program.js +2 -2
- package/dist/ops/webgpu/clipScale.js +7 -7
- package/dist/ops/webgpu/concat16.js +5 -5
- package/dist/ops/webgpu/dropout16.js +6 -6
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +8 -8
- package/dist/ops/webgpu/matMul16.js +16 -16
- package/dist/ops/webgpu/matMul16_program.js +2 -2
- package/dist/ops/webgpu/mul16.js +5 -5
- package/dist/ops/webgpu/norm2.js +1 -1
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +4 -4
- package/dist/ops/webgpu/pack16.js +4 -4
- package/dist/ops/webgpu/pack16_program.js +2 -2
- package/dist/ops/webgpu/qkv.js +2 -2
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/slice16.js +4 -4
- package/dist/ops/webgpu/softmax16.js +4 -4
- package/dist/ops/webgpu/softmax16_program.js +2 -2
- package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
- package/dist/ops/webgpu/softmax16grad.js +4 -4
- package/dist/ops/webgpu/sub16.js +6 -6
- package/dist/ops/webgpu/sum16.js +3 -3
- package/dist/ops/webgpu/transpose16.js +8 -8
- package/dist/ops/webgpu/transpose16_program.js +2 -2
- package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
- package/dist/ops/webgpu/unpack16.js +3 -3
- package/dist/ops/webgpu/utils/binary_op.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +5 -5
- package/dist/{ops-CsXeTq1P.js → ops-CURIZSVt.js} +100 -100
- package/dist/{pack16-bqltoUlR.js → pack16-WlOSOuZA.js} +2 -2
- package/dist/patches/webgpu_backend.js +6 -6
- package/dist/patches/webgpu_base.js +1 -1
- package/dist/patches/webgpu_program.js +2 -2
- package/dist/{random_normal-IBRrha8a.js → random_normal-CIm8lk2-.js} +1 -1
- package/dist/{random_width-DN5ZtQkM.js → random_width-B_fVXhGx.js} +131 -131
- package/dist/{range-C-CjF-LI.js → range-BDxO73mk.js} +1 -1
- package/dist/{readers-iz5u3HBo.js → readers-17HLdxVM.js} +2 -2
- package/dist/relu-DTvZKBsZ.js +9 -0
- package/dist/{reshape-BDOuCSNW.js → reshape-BIN71H3p.js} +1 -1
- package/dist/{resize_nearest_neighbor-BojqlfRe.js → resize_nearest_neighbor-C6_0dAnK.js} +41 -41
- package/dist/{rope-0j_f1TPm.js → rope-CC5RjmKU.js} +4 -4
- package/dist/{scatter_nd_util-ByNJaL6I.js → scatter_nd_util-C-x73Cj6.js} +1 -1
- package/dist/{segment_util-Dasb2Zaf.js → segment_util-4zuHV5IG.js} +2 -2
- package/dist/{selu_util-BLhIqRkw.js → selu_util-BXdhy_W6.js} +5 -5
- package/dist/{shared-CagdqkLh.js → shared-DRWDyk9w.js} +6 -6
- package/dist/{shared-3agzAqQ_.js → shared-zTaJ5siv.js} +1 -1
- package/dist/slice-BvItlgXu.js +12 -0
- package/dist/{slice_util-CC35pLmT.js → slice_util-DPY56GzQ.js} +5 -5
- package/dist/{softmax-D4q1LJN7.js → softmax-BLGJqdwx.js} +1 -1
- package/dist/split-BN9LkEgS.js +9 -0
- package/dist/{squeeze-ho4wLUek.js → squeeze-O_YWJpw_.js} +2 -2
- package/dist/{stack-DudVrtmG.js → stack-z6QE7kmP.js} +1 -1
- package/dist/{step-BTxPtq1r.js → step-DQY6_ABw.js} +4 -4
- package/dist/{sum-BpiwSWvg.js → sum-D39FeU5h.js} +3 -3
- package/dist/{tensor-BWFldCso.js → tensor-D8e0Gd7c.js} +1 -1
- package/dist/{tensor1d-LMGMIUlr.js → tensor1d-BMl0eZYV.js} +1 -1
- package/dist/{tensor2d-BnXMKScO.js → tensor2d-DTtQ1QcT.js} +1 -1
- package/dist/{tensor4d-C6UCG_u8.js → tensor4d-Dj4rDssL.js} +1 -1
- package/dist/{tfjs_backend-BGnG-ppu.js → tfjs_backend-Bk3PmK91.js} +65 -65
- package/dist/{tile-CFy-xTO6.js → tile-CsWlVKKz.js} +1 -1
- package/dist/tokeniser/BaseTokeniser.d.ts +4 -1
- package/dist/tokeniser/BaseTokeniser.js +21 -5
- package/dist/tokeniser/CharTokeniser.d.ts +1 -1
- package/dist/tokeniser/CharTokeniser.js +62 -50
- package/dist/tokeniser/bpe.d.ts +1 -1
- package/dist/tokeniser/bpe.js +41 -35
- package/dist/tokeniser/type.d.ts +3 -1
- package/dist/training/AdamW.d.ts +3 -0
- package/dist/training/AdamW.js +59 -30
- package/dist/training/BasicTrainer.d.ts +1 -0
- package/dist/training/BasicTrainer.js +112 -92
- package/dist/training/DatasetBuilder.js +3 -3
- package/dist/training/Evaluator.js +2 -2
- package/dist/training/LRScheduler.d.ts +1 -0
- package/dist/training/LRScheduler.js +18 -12
- package/dist/training/PreTrainer.js +3 -3
- package/dist/training/SFTDatasetBuilder.js +3 -3
- package/dist/training/SFTTrainer.js +1 -1
- package/dist/training/orthoGrad.js +1 -1
- package/dist/training/sparseCrossEntropy.js +30 -30
- package/dist/training/types.d.ts +5 -3
- package/dist/training/validation.js +13 -13
- package/dist/{transpose-9kRxIXWR.js → transpose-Qxz-4os3.js} +7 -7
- package/dist/{unsorted_segment_sum-DJvk5xnh.js → unsorted_segment_sum-BfFVV9Zm.js} +20 -20
- package/dist/utilities/datasetID.d.ts +2 -0
- package/dist/utilities/datasetID.js +21 -0
- package/dist/utilities/dummy.js +6 -6
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.js +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.js +5 -5
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-Ck482e3n.js → variable-SSATClyt.js} +1 -1
- package/dist/{webgpu_program-B4HmApL1.js → webgpu_program-CbjdYLYk.js} +1 -1
- package/dist/{webgpu_util-DYlGSwOJ.js → webgpu_util-DuofJBMo.js} +7 -7
- package/dist/{zeros-DvZpK8s6.js → zeros-Bw0puq_w.js} +2 -2
- package/dist/{zeros_like-CWjDdwr-.js → zeros_like-rOHr54NY.js} +69 -69
- package/package.json +3 -3
- package/dist/complex-DI35Q-gW.js +0 -11
- package/dist/mat_mul-DhG0Newp.js +0 -11
- package/dist/mod-CSdCpRjf.js +0 -11
- package/dist/relu-J_X6MUzx.js +0 -9
- package/dist/slice-BzS11Qh0.js +0 -12
- package/dist/split-C2Sj255c.js +0 -9
package/dist/Trainer.js
CHANGED
|
@@ -1,35 +1,36 @@
|
|
|
1
|
-
import { E as
|
|
2
|
-
import
|
|
3
|
-
import { createTrainValidationSplit as
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
const
|
|
7
|
-
for (let
|
|
8
|
-
|
|
9
|
-
function
|
|
10
|
-
return (n[
|
|
1
|
+
import { E as m } from "./index-DvYrXKkX.js";
|
|
2
|
+
import g from "./training/PreTrainer.js";
|
|
3
|
+
import { createTrainValidationSplit as u } from "./training/validation.js";
|
|
4
|
+
import c from "./training/SFTTrainer.js";
|
|
5
|
+
import p from "./training/tasks/splitter.js";
|
|
6
|
+
const r = [];
|
|
7
|
+
for (let n = 0; n < 256; ++n)
|
|
8
|
+
r.push((n + 256).toString(16).slice(1));
|
|
9
|
+
function w(n, t = 0) {
|
|
10
|
+
return (r[n[t + 0]] + r[n[t + 1]] + r[n[t + 2]] + r[n[t + 3]] + "-" + r[n[t + 4]] + r[n[t + 5]] + "-" + r[n[t + 6]] + r[n[t + 7]] + "-" + r[n[t + 8]] + r[n[t + 9]] + "-" + r[n[t + 10]] + r[n[t + 11]] + r[n[t + 12]] + r[n[t + 13]] + r[n[t + 14]] + r[n[t + 15]]).toLowerCase();
|
|
11
11
|
}
|
|
12
|
-
const
|
|
13
|
-
function
|
|
14
|
-
return crypto.getRandomValues(
|
|
12
|
+
const T = new Uint8Array(16);
|
|
13
|
+
function D() {
|
|
14
|
+
return crypto.getRandomValues(T);
|
|
15
15
|
}
|
|
16
|
-
function
|
|
17
|
-
return crypto.randomUUID ? crypto.randomUUID() :
|
|
16
|
+
function d(n, t, a) {
|
|
17
|
+
return crypto.randomUUID ? crypto.randomUUID() : k(n);
|
|
18
18
|
}
|
|
19
|
-
function
|
|
20
|
-
|
|
21
|
-
const
|
|
22
|
-
if (
|
|
19
|
+
function k(n, t, a) {
|
|
20
|
+
n = n || {};
|
|
21
|
+
const i = n.random ?? n.rng?.() ?? D();
|
|
22
|
+
if (i.length < 16)
|
|
23
23
|
throw new Error("Random bytes length must be >= 16");
|
|
24
|
-
return
|
|
24
|
+
return i[6] = i[6] & 15 | 64, i[8] = i[8] & 63 | 128, w(i);
|
|
25
25
|
}
|
|
26
|
-
class
|
|
26
|
+
class f extends m {
|
|
27
27
|
trainer;
|
|
28
28
|
trainingType = "pretraining";
|
|
29
29
|
hasTrained = !1;
|
|
30
30
|
trainDataset;
|
|
31
31
|
validationDataset;
|
|
32
|
-
|
|
32
|
+
totalTokens = 0;
|
|
33
|
+
tokensProcessed = 0;
|
|
33
34
|
log = [];
|
|
34
35
|
progress = null;
|
|
35
36
|
options = {
|
|
@@ -38,21 +39,21 @@ class d extends f {
|
|
|
38
39
|
logInterval: 10
|
|
39
40
|
};
|
|
40
41
|
tokenizer;
|
|
41
|
-
constructor(t,
|
|
42
|
-
if (super(), t instanceof
|
|
43
|
-
const
|
|
44
|
-
let
|
|
45
|
-
t.trainingType === "sft" &&
|
|
42
|
+
constructor(t, a, i = "pretraining", e, s) {
|
|
43
|
+
if (super(), t instanceof f) {
|
|
44
|
+
const o = a || t.options, h = t.options;
|
|
45
|
+
let l = !1;
|
|
46
|
+
t.trainingType === "sft" && o.sftMode !== h.sftMode && (l = !0), i !== t.trainingType && (l = !0), l ? (t.trainingType === "sft" ? this.trainer = new c(t.model, t.tokenizer, o) : this.trainer = new g(t.model, t.tokenizer, o), this.trainingType = i, this.options = o, this.tokenizer = t.tokenizer) : (this.trainer = t.trainer, this.trainingType = i, this.options = o, this.trainer.updateOptimizer(this.options), this.log = t.log, this.progress = t.progress, this.totalTokens = t.totalTokens, this.tokenizer = t.tokenizer, o.batchSize === h.batchSize && (this.trainDataset = t.trainDataset, this.validationDataset = t.validationDataset));
|
|
46
47
|
return;
|
|
47
48
|
}
|
|
48
|
-
if (!
|
|
49
|
+
if (!a)
|
|
49
50
|
throw new Error("Tokeniser must be provided when initializing Trainer with a model");
|
|
50
51
|
if (!t)
|
|
51
52
|
throw new Error("Model must be provided when initializing Trainer");
|
|
52
|
-
this.options =
|
|
53
|
+
this.options = e || {
|
|
53
54
|
batchSize: 32,
|
|
54
55
|
sftMode: "full"
|
|
55
|
-
},
|
|
56
|
+
}, i === "sft" ? this.trainer = new c(t, a, e, s) : this.trainer = new g(t, a, e, s), this.trainingType = i, this.tokenizer = a;
|
|
56
57
|
}
|
|
57
58
|
get model() {
|
|
58
59
|
return this.trainer.model;
|
|
@@ -69,61 +70,66 @@ class d extends f {
|
|
|
69
70
|
dispose() {
|
|
70
71
|
this.trainer.dispose(), this.removeAllListeners();
|
|
71
72
|
}
|
|
72
|
-
|
|
73
|
-
return this.
|
|
73
|
+
getTotalTokens() {
|
|
74
|
+
return this.totalTokens;
|
|
74
75
|
}
|
|
75
76
|
setOptions(t) {
|
|
76
|
-
const
|
|
77
|
+
const a = new Set(
|
|
77
78
|
Object.keys(t).filter(
|
|
78
|
-
(
|
|
79
|
+
(i) => t[i] !== this.options[i]
|
|
79
80
|
)
|
|
80
81
|
);
|
|
81
82
|
if (this.trainer.isRunning) {
|
|
82
|
-
if (
|
|
83
|
+
if (a.has("batchSize"))
|
|
83
84
|
throw new Error("Cannot change batch size during training");
|
|
84
|
-
if (
|
|
85
|
+
if (a.has("sftMode"))
|
|
85
86
|
throw new Error("Cannot change SFT mode during training");
|
|
86
|
-
if (
|
|
87
|
+
if (a.has("loraConfig"))
|
|
87
88
|
throw new Error("Cannot change LoRA configuration during training");
|
|
88
|
-
if (
|
|
89
|
+
if (a.has("validationSplit"))
|
|
89
90
|
throw new Error("Cannot change validation split during training");
|
|
90
|
-
if (
|
|
91
|
+
if (a.has("trainableWeights"))
|
|
91
92
|
throw new Error("Cannot change trainable weights during training");
|
|
92
|
-
if (
|
|
93
|
+
if (a.has("mixedPrecision"))
|
|
93
94
|
throw new Error("Cannot change mixed precision setting during training");
|
|
94
|
-
if (
|
|
95
|
+
if (a.has("gradientCheckpointing"))
|
|
95
96
|
throw new Error("Cannot change gradient checkpointing setting during training");
|
|
96
97
|
}
|
|
97
98
|
this.options = {
|
|
98
99
|
...this.options,
|
|
99
100
|
...t
|
|
100
|
-
}, this.trainer.updateOptimizer(this.options),
|
|
101
|
+
}, this.trainer.updateOptimizer(this.options), a.has("metrics") && this.trainer.setMetrics(t.metrics || []);
|
|
101
102
|
}
|
|
102
|
-
async prepare(t = []) {
|
|
103
|
+
async prepare(t = [], a) {
|
|
103
104
|
const i = this.options;
|
|
104
|
-
if (
|
|
105
|
-
|
|
105
|
+
if (a && (this.model.metaData.pretrainingData = a.map((e) => ({
|
|
106
|
+
id: e.id,
|
|
107
|
+
name: e.name
|
|
108
|
+
}))), this.trainingType === "pretraining" && this.trainer instanceof g) {
|
|
109
|
+
const { trainDataset: e, validationDataset: s, size: o } = await u(
|
|
106
110
|
t,
|
|
107
111
|
this.trainer.tokenizer,
|
|
108
112
|
this.trainer.datasetBuilder,
|
|
109
113
|
i?.batchSize || 32,
|
|
110
114
|
i?.validationSplit || 0.1
|
|
111
|
-
),
|
|
112
|
-
this.trainDataset = e, this.validationDataset =
|
|
113
|
-
|
|
115
|
+
), h = o * (1 - (i?.validationSplit || 0));
|
|
116
|
+
this.trainDataset = e, this.validationDataset = s, this.totalTokens = h, this.options.epochSteps = Math.ceil(
|
|
117
|
+
this.totalTokens / ((i?.batchSize || 32) * this.model.config.blockSize)
|
|
118
|
+
), this.trainer.updateOptimizer(this.options);
|
|
119
|
+
} else if (this.trainingType === "sft" && this.trainer instanceof c) {
|
|
114
120
|
if (t instanceof Uint16Array)
|
|
115
121
|
throw new Error("SFT training requires Task[] input");
|
|
116
122
|
if (i?.validationSplit && i.validationSplit > 0) {
|
|
117
|
-
const e =
|
|
123
|
+
const e = p(t, i?.validationSplit), s = await this.trainer.datasetBuilder.createSFTDataset(
|
|
118
124
|
[e.training],
|
|
119
125
|
i?.batchSize || 32,
|
|
120
126
|
-100
|
|
121
|
-
),
|
|
127
|
+
), o = await this.trainer.datasetBuilder.createSFTDataset(
|
|
122
128
|
[e.validation],
|
|
123
129
|
i?.batchSize || 32,
|
|
124
130
|
-100
|
|
125
131
|
);
|
|
126
|
-
this.validationDataset =
|
|
132
|
+
this.validationDataset = o, this.trainDataset = s;
|
|
127
133
|
} else {
|
|
128
134
|
const e = await this.trainer.datasetBuilder.createSFTDataset(
|
|
129
135
|
t,
|
|
@@ -132,45 +138,47 @@ class d extends f {
|
|
|
132
138
|
);
|
|
133
139
|
this.trainDataset = e;
|
|
134
140
|
}
|
|
135
|
-
this.
|
|
141
|
+
this.totalTokens = t.reduce((e, s) => e + s.length, 0), this.options.epochSteps = Math.ceil(
|
|
142
|
+
this.totalTokens / ((i?.batchSize || 32) * this.model.config.blockSize)
|
|
143
|
+
), this.trainer.updateOptimizer(this.options);
|
|
136
144
|
}
|
|
137
145
|
}
|
|
138
146
|
configureModel(t) {
|
|
139
|
-
const
|
|
147
|
+
const a = t?.sftMode || "full";
|
|
140
148
|
if (this.trainingType === "pretraining" && (this.trainer.model.hasLoRA() && this.trainer.model.detachLoRA(), this.trainer.model.weightStore.setTrainable(["*"])), this.trainingType === "sft") {
|
|
141
|
-
if (
|
|
142
|
-
const
|
|
149
|
+
if (a === "lora") {
|
|
150
|
+
const i = this.trainer.model;
|
|
143
151
|
if (t?.loraName)
|
|
144
|
-
if (
|
|
145
|
-
if (
|
|
146
|
-
const
|
|
147
|
-
(
|
|
152
|
+
if (i.hasLoRA(t.loraName)) {
|
|
153
|
+
if (i.attachLoRA(t.loraName), t.loraConfig) {
|
|
154
|
+
const e = i.lora;
|
|
155
|
+
(e.alpha !== t.loraConfig.alpha || e.rank !== t.loraConfig.rank) && (i.detachLoRA(), i.deleteLoRA(t.loraName), i.createLoRA(t.loraName, t.loraConfig), i.attachLoRA(t.loraName), console.warn("Resetting LoRA with new configuration."));
|
|
148
156
|
}
|
|
149
157
|
} else if (t.loraConfig)
|
|
150
|
-
|
|
158
|
+
i.createLoRA(t.loraName, t.loraConfig), i.attachLoRA(t.loraName);
|
|
151
159
|
else
|
|
152
160
|
throw new Error(
|
|
153
161
|
`LoRA configuration must be provided to create LoRA with name ${t.loraName}`
|
|
154
162
|
);
|
|
155
163
|
else if (t?.loraConfig)
|
|
156
|
-
if (
|
|
157
|
-
const
|
|
158
|
-
if (
|
|
159
|
-
|
|
160
|
-
const s = t.loraName ||
|
|
161
|
-
|
|
164
|
+
if (i.hasLoRA()) {
|
|
165
|
+
const e = i.lora;
|
|
166
|
+
if (e.alpha !== t.loraConfig.alpha || e.rank !== t.loraConfig.rank) {
|
|
167
|
+
i.detachLoRA();
|
|
168
|
+
const s = t.loraName || d();
|
|
169
|
+
i.createLoRA(s, t.loraConfig), i.attachLoRA(s);
|
|
162
170
|
}
|
|
163
171
|
} else {
|
|
164
|
-
const
|
|
165
|
-
|
|
172
|
+
const e = t.loraName || d();
|
|
173
|
+
i.createLoRA(e, t.loraConfig), i.attachLoRA(e);
|
|
166
174
|
}
|
|
167
|
-
else if (!
|
|
175
|
+
else if (!i.hasLoRA()) throw new Error("LoRA configuration must be provided for lora SFT mode");
|
|
168
176
|
} else
|
|
169
177
|
this.trainer.model.hasLoRA() && this.trainer.model.detachLoRA();
|
|
170
|
-
|
|
178
|
+
a === "last-layer" ? this.trainer.model.weightStore.setTrainable([
|
|
171
179
|
`block_${this.trainer.model.config.nLayer - 1}_*`,
|
|
172
180
|
"token_embedding"
|
|
173
|
-
]) :
|
|
181
|
+
]) : a === "full" && this.trainer.model.weightStore.setTrainable(["*"]);
|
|
174
182
|
}
|
|
175
183
|
t?.trainableWeights && this.trainer.model.weightStore.setTrainable(t.trainableWeights);
|
|
176
184
|
}
|
|
@@ -178,37 +186,44 @@ class d extends f {
|
|
|
178
186
|
const t = this.options;
|
|
179
187
|
if (!this.trainDataset)
|
|
180
188
|
throw new Error("Dataset not prepared");
|
|
181
|
-
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.
|
|
189
|
+
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.model.metaData.pretrainingSettings = t;
|
|
190
|
+
const a = Date.now();
|
|
191
|
+
this.log.length > 0 && this.trainer.resumeFromLog(this.log[this.log.length - 1]), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), this.trainer.setMixedPrecision(t?.mixedPrecision || !1), this.trainer.setLabelSmoothing(t?.labelSmoothing || 0), this.trainer.setDropout(t?.dropout || 0), this.trainer.setLayerDrop(t?.layerDrop || 0), this.configureModel(t), await this.trainer.trainOnDataset(
|
|
182
192
|
this.trainDataset,
|
|
183
193
|
{
|
|
184
194
|
...t,
|
|
185
|
-
onStep: async (
|
|
186
|
-
this.log.push(
|
|
187
|
-
lastLog:
|
|
188
|
-
progress:
|
|
189
|
-
remaining: Math.max(
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
const e = this.listeners("log");
|
|
195
|
-
for (const r of e)
|
|
196
|
-
await r(i, this.progress);
|
|
195
|
+
onStep: async (e) => {
|
|
196
|
+
this.log.push(e), this.progress = {
|
|
197
|
+
lastLog: e,
|
|
198
|
+
progress: e.totalTokens / this.totalTokens,
|
|
199
|
+
remaining: Math.max(0, (this.totalTokens - e.totalTokens) / e.totalTokens * e.duration)
|
|
200
|
+
}, this.tokensProcessed = e.totalTokens;
|
|
201
|
+
const s = this.listeners("log");
|
|
202
|
+
for (const o of s)
|
|
203
|
+
await o(e, this.progress);
|
|
197
204
|
}
|
|
198
205
|
},
|
|
199
206
|
this.validationDataset
|
|
200
|
-
), this.
|
|
207
|
+
), this.model.metaData.actionLog = this.model.metaData.actionLog || [];
|
|
208
|
+
const i = Date.now();
|
|
209
|
+
this.model.metaData.actionLog.push({
|
|
210
|
+
action: "pretrain",
|
|
211
|
+
timestamp: i,
|
|
212
|
+
duration: i - a,
|
|
213
|
+
tokensProcessed: this.tokensProcessed,
|
|
214
|
+
options: t
|
|
215
|
+
}), this.emit("stop");
|
|
201
216
|
}
|
|
202
217
|
async step(t) {
|
|
203
218
|
if (!this.trainDataset)
|
|
204
219
|
throw new Error("Dataset not prepared");
|
|
205
220
|
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start");
|
|
206
|
-
const { log:
|
|
207
|
-
for (const
|
|
208
|
-
await
|
|
209
|
-
lastLog:
|
|
210
|
-
progress:
|
|
211
|
-
remaining: Math.max(0, (this.
|
|
221
|
+
const { log: a } = await this.trainer.stepDataset(this.trainDataset, t || {}, this.validationDataset), i = this.listeners("log");
|
|
222
|
+
for (const e of i)
|
|
223
|
+
await e(a, {
|
|
224
|
+
lastLog: a,
|
|
225
|
+
progress: a.totalTokens / this.totalTokens,
|
|
226
|
+
remaining: Math.max(0, (this.totalTokens - a.totalTokens) / a.totalTokens * a.duration)
|
|
212
227
|
});
|
|
213
228
|
this.emit("stop");
|
|
214
229
|
}
|
|
@@ -223,5 +238,5 @@ class d extends f {
|
|
|
223
238
|
}
|
|
224
239
|
}
|
|
225
240
|
export {
|
|
226
|
-
|
|
241
|
+
f as default
|
|
227
242
|
};
|
package/dist/backend.js
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import { g as o, s as e, r as s } from "./index-
|
|
1
|
+
import { g as o, s as e, r as s } from "./index-DSGwv2Yx.js";
|
|
2
2
|
async function c(t, a) {
|
|
3
3
|
if (o() !== t) {
|
|
4
4
|
if (t === "webgpu") {
|
|
5
5
|
const { registerWebGPUBackend: i } = await import("./patches/webgpu_base.js");
|
|
6
|
-
i(a), await import("./index-
|
|
6
|
+
i(a), await import("./index-BQvB7LCC.js"), await import("./ops/webgpu/index.js");
|
|
7
7
|
}
|
|
8
8
|
await e(t), await s(), console.log(`Backend set to ${t}`);
|
|
9
9
|
}
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { d as T, f as L, h as
|
|
3
|
-
import { a as z, c as
|
|
4
|
-
import { c as
|
|
5
|
-
import { S as k, a as Q, b as ee, g as te, c as se, s as ne } from "./selu_util-
|
|
6
|
-
import { c as re, v as oe, a as ie } from "./scatter_nd_util-
|
|
1
|
+
import { N as d, a9 as A, a8 as O, v as g, av as _, az as w, ad as D, _ as x, $ as b, am as y, aY as M } from "./index-DSGwv2Yx.js";
|
|
2
|
+
import { d as T, f as L, h as v, c as W, e as F, a as N, b as C, g as P } from "./axis_util-QWWgLjut.js";
|
|
3
|
+
import { a as z, c as B } from "./concat_util-C1Mxe27t.js";
|
|
4
|
+
import { c as U, b as H, d as V, f as G, g as Z, h as j, i as q, j as J, k as K, m as X, t as Y } from "./step-DQY6_ABw.js";
|
|
5
|
+
import { S as k, a as Q, b as ee, g as te, c as se, s as ne } from "./selu_util-BXdhy_W6.js";
|
|
6
|
+
import { c as re, v as oe, a as ie } from "./scatter_nd_util-C-x73Cj6.js";
|
|
7
7
|
import { a as ae, c as ue, b as ce, e as pe, d as le, g as fe, m as he, s as ge } from "./complex_util-Yc1A_gV1.js";
|
|
8
8
|
function de(e, t) {
|
|
9
9
|
const r = e.shape.length, n = t.shape.length;
|
|
@@ -146,10 +146,10 @@ function De(e, t, r) {
|
|
|
146
146
|
return n;
|
|
147
147
|
}
|
|
148
148
|
const xe = 0.3275911, be = 0.254829592, ye = -0.284496736, Me = 1.421413741, Te = -1.453152027, Le = 1.061405429;
|
|
149
|
-
const I = "->",
|
|
150
|
-
function
|
|
149
|
+
const I = "->", ve = /->/g, E = ",", $ = "...";
|
|
150
|
+
function We(e, t) {
|
|
151
151
|
e = e.replace(/\s/g, "");
|
|
152
|
-
const r = (e.length - e.replace(
|
|
152
|
+
const r = (e.length - e.replace(ve, "").length) / I.length;
|
|
153
153
|
if (r < 1)
|
|
154
154
|
throw new Error("Equations without an arrow are not supported.");
|
|
155
155
|
if (r > 1)
|
|
@@ -226,7 +226,7 @@ function ze(e, t) {
|
|
|
226
226
|
(e[n].length === 0 || e[n].indexOf(t) !== -1 || t === -1) && r.push(n);
|
|
227
227
|
return r;
|
|
228
228
|
}
|
|
229
|
-
function
|
|
229
|
+
function Be(e, t, r = 0) {
|
|
230
230
|
let n = [];
|
|
231
231
|
if (typeof t == "number")
|
|
232
232
|
g(e.shape[r] % t === 0, () => "Number of splits must evenly divide the axis."), n = new Array(t).fill(e.shape[r] / t);
|
|
@@ -242,7 +242,7 @@ function Ue(e, t, r = 0) {
|
|
|
242
242
|
}
|
|
243
243
|
return n;
|
|
244
244
|
}
|
|
245
|
-
function
|
|
245
|
+
function Ue(e) {
|
|
246
246
|
return `Received SparseTensor with denseShape[0] = 0 but
|
|
247
247
|
indices.shape[0] = ${e}`;
|
|
248
248
|
}
|
|
@@ -314,8 +314,8 @@ const ut = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
314
314
|
axesAreInnerMostDims: L,
|
|
315
315
|
calculateShapes: re,
|
|
316
316
|
checkEinsumDimSizes: Ne,
|
|
317
|
-
checkPadOnDimRoundingMode:
|
|
318
|
-
combineLocations:
|
|
317
|
+
checkPadOnDimRoundingMode: U,
|
|
318
|
+
combineLocations: v,
|
|
319
319
|
combineRaggedTensorToTensorShapes: me,
|
|
320
320
|
complexWithEvenIndex: ue,
|
|
321
321
|
complexWithOddIndex: ce,
|
|
@@ -324,12 +324,12 @@ const ut = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
324
324
|
computeDefaultPad: G,
|
|
325
325
|
computeDilation2DInfo: Z,
|
|
326
326
|
computeOptimalWindowSize: Se,
|
|
327
|
-
computeOutAndReduceShapes:
|
|
328
|
-
computeOutShape:
|
|
327
|
+
computeOutAndReduceShapes: W,
|
|
328
|
+
computeOutShape: B,
|
|
329
329
|
computePool2DInfo: j,
|
|
330
330
|
computePool3DInfo: q,
|
|
331
331
|
convertConv2DDataFormat: J,
|
|
332
|
-
decodeEinsumEquation:
|
|
332
|
+
decodeEinsumEquation: We,
|
|
333
333
|
eitherStridesOrDilationsAreOne: K,
|
|
334
334
|
expandShapeToKeepDim: F,
|
|
335
335
|
exponent: pe,
|
|
@@ -353,7 +353,7 @@ const ut = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
353
353
|
getRowPartitionTypesHelper: Ie,
|
|
354
354
|
getSliceBeginCoords: we,
|
|
355
355
|
getSliceSize: De,
|
|
356
|
-
getSparseFillEmptyRowsIndicesDenseShapeMismatch:
|
|
356
|
+
getSparseFillEmptyRowsIndicesDenseShapeMismatch: Ue,
|
|
357
357
|
getSparseFillEmptyRowsNegativeIndexErrorMessage: He,
|
|
358
358
|
getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: Ve,
|
|
359
359
|
getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: je,
|
|
@@ -369,7 +369,7 @@ const ut = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
|
|
369
369
|
isIdentityPermutation: Pe,
|
|
370
370
|
mergeRealAndImagArrays: he,
|
|
371
371
|
prepareAndValidate: de,
|
|
372
|
-
prepareSplitSize:
|
|
372
|
+
prepareSplitSize: Be,
|
|
373
373
|
shouldFuse: ne,
|
|
374
374
|
splitRealAndImagArrays: ge,
|
|
375
375
|
stridesOrDilationsArePositive: X,
|
|
@@ -386,14 +386,14 @@ export {
|
|
|
386
386
|
we as C,
|
|
387
387
|
De as D,
|
|
388
388
|
xe as E,
|
|
389
|
-
|
|
389
|
+
We as F,
|
|
390
390
|
Ne as G,
|
|
391
391
|
Ce as H,
|
|
392
392
|
Fe as I,
|
|
393
393
|
Pe as J,
|
|
394
394
|
de as K,
|
|
395
395
|
Re as L,
|
|
396
|
-
|
|
396
|
+
Be as M,
|
|
397
397
|
S as P,
|
|
398
398
|
f as R,
|
|
399
399
|
Ee as a,
|
|
@@ -403,7 +403,7 @@ export {
|
|
|
403
403
|
et as e,
|
|
404
404
|
Qe as f,
|
|
405
405
|
Ie as g,
|
|
406
|
-
|
|
406
|
+
Ue as h,
|
|
407
407
|
He as i,
|
|
408
408
|
Ve as j,
|
|
409
409
|
Ge as k,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { ab as g, as as $, at as K,
|
|
2
|
-
import { m as te, f as se, P as re } from "./webgpu_program-
|
|
3
|
-
import { i as ne, G as
|
|
4
|
-
import { m as
|
|
1
|
+
import { ab as g, as as $, at as K, e as D, v as _, au as O, N as x, av as Z, a5 as W, aw as F, ax as j, ay as X, az as J, af as ee, a9 as k } from "./index-DSGwv2Yx.js";
|
|
2
|
+
import { m as te, f as se, P as re } from "./webgpu_program-CbjdYLYk.js";
|
|
3
|
+
import { i as ne, G as N } from "./webgpu_util-DuofJBMo.js";
|
|
4
|
+
import { m as q } from "./complex_util-Yc1A_gV1.js";
|
|
5
5
|
const d = g();
|
|
6
6
|
d.registerFlag("WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE", () => 15);
|
|
7
7
|
d.registerFlag("WEBGPU_CPU_FORWARD", () => !0);
|
|
@@ -248,7 +248,7 @@ class R extends $ {
|
|
|
248
248
|
if (s != null || t.dtype === "string")
|
|
249
249
|
return s;
|
|
250
250
|
if (t.dtype === "complex64") {
|
|
251
|
-
const E = this.readSync(n.real.dataId), B = this.readSync(n.imag.dataId), y = O(
|
|
251
|
+
const E = this.readSync(n.real.dataId), B = this.readSync(n.imag.dataId), y = O(q(E, B).buffer, "float32");
|
|
252
252
|
return this.convertAndCacheOnCPU(e, y), y;
|
|
253
253
|
}
|
|
254
254
|
this.hasReadSyncWarned || (this.hasReadSyncWarned = !0, console.warn("The performance of synchronously reading data from GPU to CPU is poor on the webgpu backend, please use asynchronous APIs instead."));
|
|
@@ -309,7 +309,7 @@ class R extends $ {
|
|
|
309
309
|
this.read(t.complexTensorInfos.real.dataId),
|
|
310
310
|
this.read(t.complexTensorInfos.imag.dataId)
|
|
311
311
|
]), a = r[0], i = r[1];
|
|
312
|
-
n =
|
|
312
|
+
n = q(a, i);
|
|
313
313
|
} else {
|
|
314
314
|
const r = await this.getBufferData(t.resource);
|
|
315
315
|
n = O(r, t.dtype);
|
|
@@ -337,7 +337,7 @@ class R extends $ {
|
|
|
337
337
|
refCount: 1,
|
|
338
338
|
external: e.zeroCopy
|
|
339
339
|
});
|
|
340
|
-
const a = this.tensorMap.get(r), i =
|
|
340
|
+
const a = this.tensorMap.get(r), i = N(a.dtype) * x(a.shape);
|
|
341
341
|
if (e.buffer.size < i)
|
|
342
342
|
throw new Error(`GPUBuffer size(${e.buffer.size}) is smaller than tensor size(${i})!`);
|
|
343
343
|
if ((e.buffer.usage & (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC)) !== (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC))
|
|
@@ -398,7 +398,7 @@ class R extends $ {
|
|
|
398
398
|
const t = this.tensorMap.get(e);
|
|
399
399
|
if (t.resource != null)
|
|
400
400
|
return;
|
|
401
|
-
const s =
|
|
401
|
+
const s = N(t.dtype) * x(t.shape);
|
|
402
402
|
let n;
|
|
403
403
|
const r = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;
|
|
404
404
|
if (t.values) {
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { o as h,
|
|
2
|
-
import { r as b } from "./reshape-
|
|
1
|
+
import { o as h, n as f, q as p, u as g, E as u, T } from "./index-DSGwv2Yx.js";
|
|
2
|
+
import { r as b } from "./reshape-BIN71H3p.js";
|
|
3
3
|
function m(e, r) {
|
|
4
4
|
let n = f(e, "broadcastTo", "x");
|
|
5
5
|
const a = n.shape;
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { s as i,
|
|
2
|
-
import { t } from "../tensor4d-
|
|
3
|
-
import { t as
|
|
1
|
+
import { s as i, e } from "../index-DSGwv2Yx.js";
|
|
2
|
+
import { t } from "../tensor4d-Dj4rDssL.js";
|
|
3
|
+
import { t as a } from "../tensor2d-DTtQ1QcT.js";
|
|
4
4
|
async function k(n) {
|
|
5
5
|
await i(n);
|
|
6
6
|
const s = t(
|
|
@@ -23,14 +23,14 @@ async function k(n) {
|
|
|
23
23
|
]
|
|
24
24
|
],
|
|
25
25
|
[1, 1, 2, 4]
|
|
26
|
-
), r =
|
|
26
|
+
), r = a(
|
|
27
27
|
[
|
|
28
28
|
[0, -1 / 0, -1 / 0, -1 / 0],
|
|
29
29
|
[0, 0, 0, -1 / 0]
|
|
30
30
|
],
|
|
31
31
|
[2, 4]
|
|
32
32
|
);
|
|
33
|
-
return await
|
|
33
|
+
return await e().runKernel("AttentionMask", { q: s, k: o, mask: r }, { divisor: 0.5, pastLen: 0 }).array();
|
|
34
34
|
}
|
|
35
35
|
export {
|
|
36
36
|
k as execute
|
package/dist/checks/gelu.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { s as e,
|
|
2
|
-
import { t as s } from "../tensor2d-
|
|
1
|
+
import { s as e, e as o } from "../index-DSGwv2Yx.js";
|
|
2
|
+
import { t as s } from "../tensor2d-DTtQ1QcT.js";
|
|
3
3
|
async function m(t) {
|
|
4
4
|
await e(t);
|
|
5
5
|
const r = s(
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { s as o,
|
|
2
|
-
import { t as e } from "../tensor2d-
|
|
1
|
+
import { s as o, e as s } from "../index-DSGwv2Yx.js";
|
|
2
|
+
import { t as e } from "../tensor2d-DTtQ1QcT.js";
|
|
3
3
|
async function i(t) {
|
|
4
4
|
await o(t);
|
|
5
5
|
const r = e(
|
package/dist/checks/normRMS.js
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
import { s as u, a0 as A,
|
|
2
|
-
import { a as
|
|
3
|
-
import { t as p } from "../tensor1d-
|
|
4
|
-
import { t as r } from "../tensor-
|
|
1
|
+
import { s as u, a0 as A, e as y } from "../index-DSGwv2Yx.js";
|
|
2
|
+
import { a as h } from "../ops-CURIZSVt.js";
|
|
3
|
+
import { t as p } from "../tensor1d-BMl0eZYV.js";
|
|
4
|
+
import { t as r } from "../tensor-D8e0Gd7c.js";
|
|
5
5
|
const w = Array.from({ length: 2048 * 192 }, () => Math.random()), x = Array.from({ length: 192 }, () => Math.random()), M = Array.from({ length: 2048 * 192 }, () => Math.random());
|
|
6
6
|
async function k(t) {
|
|
7
7
|
await u(t);
|
|
8
8
|
const o = p(x, "float32"), n = r(w, [16, 128, 192], "float32"), s = r(M, [16, 128, 192], "float32"), e = (d, g) => {
|
|
9
|
-
const i =
|
|
10
|
-
return
|
|
9
|
+
const i = y().runKernel("RMSNorm", { x: d, gamma: g });
|
|
10
|
+
return h.meanSquaredError(i, s);
|
|
11
11
|
}, { value: m, grads: a } = A(e)([n, o]), c = await m.array(), f = await a[0].array(), l = await a[1].array();
|
|
12
12
|
return [c, f, l];
|
|
13
13
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { s as c,
|
|
2
|
-
import { t as f } from "../tensor1d-
|
|
3
|
-
import { t as r } from "../tensor-
|
|
1
|
+
import { s as c, e as d } from "../index-DSGwv2Yx.js";
|
|
2
|
+
import { t as f } from "../tensor1d-BMl0eZYV.js";
|
|
3
|
+
import { t as r } from "../tensor-D8e0Gd7c.js";
|
|
4
4
|
const y = Array.from({ length: 2048 * 192 }, () => Math.random()), i = Array.from({ length: 192 }, () => Math.random()), l = Array.from({ length: 2048 * 192 }, () => Math.random());
|
|
5
5
|
async function x(t) {
|
|
6
6
|
await c(t);
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { s as a,
|
|
2
|
-
import { t as c } from "../tensor2d-
|
|
3
|
-
async function i(
|
|
4
|
-
await a(
|
|
1
|
+
import { s as a, e } from "../index-DSGwv2Yx.js";
|
|
2
|
+
import { t as c } from "../tensor2d-DTtQ1QcT.js";
|
|
3
|
+
async function i(n) {
|
|
4
|
+
await a(n);
|
|
5
5
|
const r = c(
|
|
6
6
|
[
|
|
7
7
|
[0.1, 0.2, 0, 0, 1230, 1232331234, -12234234],
|
|
@@ -10,8 +10,8 @@ async function i(e) {
|
|
|
10
10
|
[0, 0, 0, 0, -0.1, 1e-3, 0]
|
|
11
11
|
],
|
|
12
12
|
[4, 7]
|
|
13
|
-
), t =
|
|
14
|
-
return await
|
|
13
|
+
), t = e().runKernel("Pack16", { x: r });
|
|
14
|
+
return await e().runKernel("Unpack16", { x: t }).array();
|
|
15
15
|
}
|
|
16
16
|
export {
|
|
17
17
|
i as execute
|
package/dist/checks/qkv.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { t as f } from "../tensor2d-
|
|
1
|
+
import { U as i, V as u, W as c, s as l, e as h } from "../index-DSGwv2Yx.js";
|
|
2
|
+
import { t as f } from "../tensor2d-DTtQ1QcT.js";
|
|
3
3
|
function m(t, e, n) {
|
|
4
4
|
if (i(t), e != null && e.length !== 3)
|
|
5
5
|
throw new Error("tensor3d() requires shape to have three numbers");
|
package/dist/checks/rope.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import s from "../layers/RoPECache.js";
|
|
2
|
-
import { s as c,
|
|
3
|
-
import { t as p } from "../tensor4d-
|
|
2
|
+
import { s as c, e as i } from "../index-DSGwv2Yx.js";
|
|
3
|
+
import { t as p } from "../tensor4d-Dj4rDssL.js";
|
|
4
4
|
async function f(r) {
|
|
5
5
|
await c(r);
|
|
6
6
|
const n = p(
|