@genai-fi/nanogpt 0.9.1 → 0.10.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +352 -14
- package/dist/Generator.js +69 -78
- package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
- package/dist/Reshape-CF6odzV4.js +16 -0
- package/dist/Reshape-_kILl6tK.js +81 -0
- package/dist/TeachableLLM.js +28 -22
- package/dist/Trainer.d.ts +2 -0
- package/dist/Trainer.js +3 -2
- package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
- package/dist/backend.d.ts +2 -1
- package/dist/backend.js +10 -4
- package/dist/backend_util-D-rUb2ty.js +474 -0
- package/dist/backend_webgpu-B0u2ndUn.js +547 -0
- package/dist/binary_op_util-pKXltfxI.js +192 -0
- package/dist/broadcast_to-CwF7XIeu.js +30 -0
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/check.d.ts +1 -1
- package/dist/checks/check.js +8 -8
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/index.d.ts +2 -0
- package/dist/checks/index.js +7 -5
- package/dist/checks/matMulGelu.js +6 -6
- package/dist/checks/normRMS.js +7 -7
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.d.ts +1 -0
- package/dist/checks/packUnpack.js +18 -0
- package/dist/checks/qkv.js +12 -27
- package/dist/checks/rope.js +2 -2
- package/dist/checks/weights.js +18 -16
- package/dist/complex-CSlYz-2T.js +13 -0
- package/dist/complex_util-Yc1A_gV1.js +55 -0
- package/dist/concat-BHlIJeyT.js +19 -0
- package/dist/concat_util-DcJk7YHS.js +22 -0
- package/dist/data/docx.js +1 -1
- package/dist/data/parquet.js +2 -2
- package/dist/data/pdf.js +1 -1
- package/dist/data/textLoader.js +1 -1
- package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
- package/dist/dropout-C1pM3f11.js +99 -0
- package/dist/expand_dims-BPG4fwBP.js +13 -0
- package/dist/exports_initializers-xuidcwI4.js +7 -0
- package/dist/gather-DykLGqmW.js +10 -0
- package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
- package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
- package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
- package/dist/index-CjOj7j-u.js +7308 -0
- package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
- package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
- package/dist/index-ZyQhjEPo.js +2157 -0
- package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
- package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
- package/dist/layers/BaseLayer.d.ts +1 -0
- package/dist/layers/BaseLayer.js +7 -6
- package/dist/layers/CausalSelfAttention.d.ts +0 -1
- package/dist/layers/CausalSelfAttention.js +56 -55
- package/dist/layers/MLP.js +15 -16
- package/dist/layers/PositionEmbedding.js +5 -14
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.d.ts +2 -0
- package/dist/layers/RoPECache.js +22 -17
- package/dist/layers/TiedEmbedding.js +22 -17
- package/dist/layers/TransformerBlock.js +21 -20
- package/dist/loader/load.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +39 -33
- package/dist/loader/save.js +1 -1
- package/dist/log_sum_exp-DWI-76TI.js +41 -0
- package/dist/main.d.ts +8 -0
- package/dist/main.js +63 -52
- package/dist/matMul16--R5hOwDG.js +77 -0
- package/dist/mat_mul-DeAh4uTH.js +12 -0
- package/dist/mod-Gt1rMB4n.js +12 -0
- package/dist/models/NanoGPTV1.js +40 -31
- package/dist/models/model.d.ts +2 -0
- package/dist/models/model.js +37 -29
- package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
- package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
- package/dist/ones-CAMiP4I2.js +15 -0
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.d.ts +1 -1
- package/dist/ops/adamMoments.js +4 -4
- package/dist/ops/add16.d.ts +2 -0
- package/dist/ops/add16.js +9 -0
- package/dist/ops/appendCache.js +16 -9
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/concat16.d.ts +2 -0
- package/dist/ops/concat16.js +9 -0
- package/dist/ops/cpu/adamAdjust.js +14 -13
- package/dist/ops/cpu/adamMoments.js +10 -9
- package/dist/ops/cpu/appendCache.js +9 -8
- package/dist/ops/cpu/attentionMask.js +15 -14
- package/dist/ops/cpu/fusedSoftmax.js +13 -12
- package/dist/ops/cpu/gatherSub.js +9 -24
- package/dist/ops/cpu/gelu.js +13 -12
- package/dist/ops/cpu/matMul16.d.ts +1 -0
- package/dist/ops/cpu/matMul16.js +16 -0
- package/dist/ops/cpu/matMulGelu.js +18 -16
- package/dist/ops/cpu/matMulMul.js +8 -7
- package/dist/ops/cpu/mulDropout.js +4 -3
- package/dist/ops/cpu/normRMS.js +11 -10
- package/dist/ops/cpu/qkv.js +17 -13
- package/dist/ops/cpu/rope.js +23 -22
- package/dist/ops/cpu/scatterSub.js +16 -30
- package/dist/ops/dot16.d.ts +2 -0
- package/dist/ops/dot16.js +42 -0
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/add16.d.ts +1 -0
- package/dist/ops/grads/add16.js +27 -0
- package/dist/ops/grads/attentionMask.js +12 -19
- package/dist/ops/grads/gelu.js +4 -3
- package/dist/ops/grads/matMul16.d.ts +2 -0
- package/dist/ops/grads/matMul16.js +9 -0
- package/dist/ops/grads/matMulGelu.js +8 -7
- package/dist/ops/grads/normRMS.js +8 -7
- package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
- package/dist/ops/grads/pack16.js +7 -0
- package/dist/ops/grads/qkv.d.ts +3 -1
- package/dist/ops/grads/qkv.js +28 -22
- package/dist/ops/grads/rope.d.ts +2 -1
- package/dist/ops/grads/rope.js +6 -13
- package/dist/ops/grads/softmax16.d.ts +2 -0
- package/dist/ops/grads/softmax16.js +26 -0
- package/dist/ops/grads/unpack16.d.ts +2 -0
- package/dist/ops/grads/unpack16.js +6 -0
- package/dist/ops/grads/utils.d.ts +3 -0
- package/dist/ops/grads/utils.js +10 -0
- package/dist/ops/matMul16.d.ts +15 -0
- package/dist/ops/matMul16.js +13 -0
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mul16.d.ts +2 -0
- package/dist/ops/mul16.js +8 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/pack16.d.ts +2 -0
- package/dist/ops/pack16.js +6 -0
- package/dist/ops/qkv.d.ts +1 -1
- package/dist/ops/qkv.js +8 -4
- package/dist/ops/reshape16.d.ts +2 -0
- package/dist/ops/reshape16.js +43 -0
- package/dist/ops/rope.d.ts +1 -1
- package/dist/ops/rope.js +8 -10
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.d.ts +2 -0
- package/dist/ops/slice16.js +9 -0
- package/dist/ops/softmax16.d.ts +2 -0
- package/dist/ops/softmax16.js +12 -0
- package/dist/ops/sub16.d.ts +2 -0
- package/dist/ops/sub16.js +8 -0
- package/dist/ops/sum16.d.ts +2 -0
- package/dist/ops/sum16.js +13 -0
- package/dist/ops/transpose16.d.ts +3 -0
- package/dist/ops/transpose16.js +41 -0
- package/dist/ops/unpack16.d.ts +2 -0
- package/dist/ops/unpack16.js +6 -0
- package/dist/ops/webgl/adamAdjust.js +3 -2
- package/dist/ops/webgl/adamMoments.js +2 -1
- package/dist/ops/webgl/appendCache.js +2 -1
- package/dist/ops/webgl/attentionMask.js +5 -4
- package/dist/ops/webgl/fusedSoftmax.js +6 -4
- package/dist/ops/webgl/gatherSub.js +7 -6
- package/dist/ops/webgl/gelu.js +3 -2
- package/dist/ops/webgl/log.js +12 -27
- package/dist/ops/webgl/matMul16.d.ts +1 -0
- package/dist/ops/webgl/matMul16.js +37 -0
- package/dist/ops/webgl/matMulGelu.js +17 -15
- package/dist/ops/webgl/matMulMul.js +13 -12
- package/dist/ops/webgl/mulDropout.js +9 -8
- package/dist/ops/webgl/normRMS.js +8 -7
- package/dist/ops/webgl/qkv.js +6 -5
- package/dist/ops/webgl/rope.js +11 -10
- package/dist/ops/webgl/scatterSub.js +6 -5
- package/dist/ops/webgpu/adamAdjust.js +12 -10
- package/dist/ops/webgpu/adamMoments.js +27 -22
- package/dist/ops/webgpu/add16.d.ts +1 -0
- package/dist/ops/webgpu/add16.js +14 -0
- package/dist/ops/webgpu/appendCache.js +64 -17
- package/dist/ops/webgpu/attentionMask.js +19 -62
- package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
- package/dist/ops/webgpu/attentionMask32_program.js +54 -0
- package/dist/ops/webgpu/concat16.d.ts +19 -0
- package/dist/ops/webgpu/concat16.js +128 -0
- package/dist/ops/webgpu/gatherSub.js +9 -7
- package/dist/ops/webgpu/gelu.js +78 -31
- package/dist/ops/webgpu/index.js +12 -0
- package/dist/ops/webgpu/matMul16.d.ts +1 -0
- package/dist/ops/webgpu/matMul16.js +58 -0
- package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
- package/dist/ops/webgpu/matMul16_program.js +336 -0
- package/dist/ops/webgpu/mul16.d.ts +1 -0
- package/dist/ops/webgpu/mul16.js +14 -0
- package/dist/ops/webgpu/normRMS.js +21 -40
- package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
- package/dist/ops/webgpu/normRMS16_program.js +24 -0
- package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
- package/dist/ops/webgpu/normRMS32_program.js +24 -0
- package/dist/ops/webgpu/normRMSGrad.js +113 -64
- package/dist/ops/webgpu/pack16.d.ts +1 -0
- package/dist/ops/webgpu/pack16.js +19 -0
- package/dist/ops/webgpu/pack16_program.d.ts +19 -0
- package/dist/ops/webgpu/pack16_program.js +92 -0
- package/dist/ops/webgpu/qkv.js +20 -55
- package/dist/ops/webgpu/rope.js +77 -22
- package/dist/ops/webgpu/scatterSub.js +9 -7
- package/dist/ops/webgpu/slice16.d.ts +7 -0
- package/dist/ops/webgpu/slice16.js +71 -0
- package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
- package/dist/ops/webgpu/softmax16.js +23 -0
- package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
- package/dist/ops/webgpu/softmax16_program.js +73 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
- package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
- package/dist/ops/webgpu/softmax16grad.js +38 -0
- package/dist/ops/webgpu/sub16.d.ts +1 -0
- package/dist/ops/webgpu/sub16.js +14 -0
- package/dist/ops/webgpu/sum16.d.ts +1 -0
- package/dist/ops/webgpu/sum16.js +40 -0
- package/dist/ops/webgpu/transpose16.d.ts +1 -0
- package/dist/ops/webgpu/transpose16.js +35 -0
- package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
- package/dist/ops/webgpu/transpose16_program.js +50 -0
- package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
- package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
- package/dist/ops/webgpu/unpack16.d.ts +1 -0
- package/dist/ops/webgpu/unpack16.js +49 -0
- package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
- package/dist/ops/webgpu/utils/binary_op.js +79 -0
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
- package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
- package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
- package/dist/ops/webgpu/utils/reductions.js +236 -45
- package/dist/ops-CNI3TwqM.js +645 -0
- package/dist/pack16-CFUqumar.js +41 -0
- package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
- package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
- package/dist/patches/PackedTensor.d.ts +12 -0
- package/dist/patches/PackedTensor.js +11 -0
- package/dist/patches/engine.d.ts +261 -0
- package/dist/patches/engine.js +10 -0
- package/dist/patches/tape.d.ts +12 -0
- package/dist/patches/tape.js +5 -0
- package/dist/patches/webgpu_backend.d.ts +18 -0
- package/dist/patches/webgpu_backend.js +57 -0
- package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
- package/dist/patches/webgpu_base.js +34 -0
- package/dist/patches/webgpu_program.d.ts +36 -0
- package/dist/patches/webgpu_program.js +401 -0
- package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
- package/dist/random_width-DY6Kk2Dl.js +10051 -0
- package/dist/range-BMS52eQi.js +11 -0
- package/dist/reciprocal-CTmshQ9J.js +10 -0
- package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
- package/dist/relu-yZ2-7WxU.js +10 -0
- package/dist/reshape-DevtBWtf.js +10 -0
- package/dist/rope-B5UUMsPi.js +32 -0
- package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
- package/dist/selu_util-D1w6yyTO.js +303 -0
- package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
- package/dist/shared-BuAXb4CI.js +2145 -0
- package/dist/sin-BGfy2HZo.js +16 -0
- package/dist/slice-D_gkkqZK.js +13 -0
- package/dist/slice_util-DtEldBfK.js +261 -0
- package/dist/softmax-ZHVebtR1.js +13 -0
- package/dist/split-DrfihRpZ.js +10 -0
- package/dist/squeeze-DZEpeblb.js +11 -0
- package/dist/stack-yOIAalTq.js +13 -0
- package/dist/sum-_fzj5ZTB.js +12 -0
- package/dist/tensor-DdQUJZlz.js +909 -0
- package/dist/tensor-f35l8Odg.js +8 -0
- package/dist/tensor1d-CeZuc-Rv.js +12 -0
- package/dist/tensor2d-G4Ys2GxX.js +15 -0
- package/dist/tensor4d-B8roDgtc.js +15 -0
- package/dist/tensor_util-DV-FP5Q3.js +523 -0
- package/dist/tfjs_backend-kNyO5L2d.js +653 -0
- package/dist/tile-BzyEiF-F.js +13 -0
- package/dist/tokeniser/CharTokeniser.js +1 -1
- package/dist/tokeniser/bpe.js +1 -1
- package/dist/training/Adam.d.ts +2 -1
- package/dist/training/Adam.js +12 -28
- package/dist/training/AdamExt.d.ts +1 -0
- package/dist/training/AdamExt.js +2 -2
- package/dist/training/DatasetBuilder.js +3 -20
- package/dist/training/FullTrainer.js +55 -48
- package/dist/training/Trainer.d.ts +11 -6
- package/dist/training/Trainer.js +51 -39
- package/dist/training/sparseCrossEntropy.js +3 -3
- package/dist/transpose-DKELTqhe.js +38 -0
- package/dist/utilities/arrayClose.js +7 -7
- package/dist/utilities/dummy.js +35 -27
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.d.ts +7 -0
- package/dist/utilities/packed.js +716 -0
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.d.ts +5 -0
- package/dist/utilities/sentences.js +41 -0
- package/dist/utilities/weights.js +2 -2
- package/dist/variable-Bhn5bHYv.js +7 -0
- package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
- package/dist/webgpu_util-BBCnKm2X.js +65 -0
- package/dist/zeros-2gldETuK.js +14 -0
- package/package.json +4 -3
- package/dist/Reshape-Bowtk9BP.js +0 -127
- package/dist/Reshape-DUqYftGC.js +0 -30
- package/dist/backend_util-CJIiDoV1.js +0 -749
- package/dist/broadcast_to-DzlNweb8.js +0 -44
- package/dist/concat-B912vBbo.js +0 -33
- package/dist/dropout-C-csYCLj.js +0 -193
- package/dist/exports_initializers-B8iZMgQ0.js +0 -16
- package/dist/gather-Dnpgw-YQ.js +0 -25
- package/dist/index-BzFyqcy-.js +0 -4457
- package/dist/index-C1rx_Ajs.js +0 -12076
- package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
- package/dist/log_sum_exp-DO6z8tSE.js +0 -103
- package/dist/mat_mul-DzjTFx-u.js +0 -27
- package/dist/mod-Dobti4j4.js +0 -27
- package/dist/ones-tIJeHlq-.js +0 -29
- package/dist/ops/fusedSoftmax.d.ts +0 -2
- package/dist/ops/fusedSoftmax.js +0 -10
- package/dist/ops/grads/fusedSoftmax.js +0 -22
- package/dist/ops-LuCMAnmM.js +0 -1525
- package/dist/random_width-CXVRloNK.js +0 -13670
- package/dist/range-CWcz7xFA.js +0 -26
- package/dist/reciprocal-C4rNcM-S.js +0 -25
- package/dist/relu-BjCh_SYb.js +0 -25
- package/dist/reshape-CnIwVG1c.js +0 -25
- package/dist/selu_util-OtRzVwW5.js +0 -719
- package/dist/shared-DmRsFyaJ.js +0 -3134
- package/dist/sin-gpDNRxE0.js +0 -47
- package/dist/slice-d0Vo9XTN.js +0 -28
- package/dist/softmax-D7Jj3p_P.js +0 -28
- package/dist/split-DK2k5eHf.js +0 -25
- package/dist/stack-DFatutCx.js +0 -27
- package/dist/sum-CJ0ULhmt.js +0 -27
- package/dist/tensor1d-vML0r3q6.js +0 -27
- package/dist/tensor2d-D76QGjF3.js +0 -30
- package/dist/tensor4d-Df1WlVDY.js +0 -30
- package/dist/webgpu_util-pLEV9tks.js +0 -80
- package/dist/zeros-Bj5rMYA7.js +0 -52
package/dist/models/NanoGPTV1.js
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
import { defaultConfig as
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
import { t as
|
|
7
|
-
import
|
|
8
|
-
import
|
|
9
|
-
|
|
1
|
+
import { defaultConfig as g } from "./config.js";
|
|
2
|
+
import b from "../layers/TransformerBlock.js";
|
|
3
|
+
import k from "../layers/TiedEmbedding.js";
|
|
4
|
+
import w from "../layers/RoPECache.js";
|
|
5
|
+
import E from "../layers/RMSNorm.js";
|
|
6
|
+
import { t as l, k as u } from "../index-ZyQhjEPo.js";
|
|
7
|
+
import C from "./model.js";
|
|
8
|
+
import P from "../layers/PositionEmbedding.js";
|
|
9
|
+
import { packingSupported as _ } from "../utilities/packed.js";
|
|
10
|
+
import { p as y, u as M } from "../pack16-CFUqumar.js";
|
|
11
|
+
class I extends C {
|
|
10
12
|
wte;
|
|
11
13
|
// Token embeddings
|
|
12
14
|
wpe;
|
|
@@ -17,56 +19,63 @@ class R extends w {
|
|
|
17
19
|
// Final layer norm
|
|
18
20
|
ropeCache;
|
|
19
21
|
constructor(e = {}) {
|
|
20
|
-
super({ ...
|
|
22
|
+
super({ ...g, ...e }), this.wte = new k(this.config, "token_embedding", this), this.config.useRope === !1 ? this.wpe = new P(this.config, "positional_embedding", this) : this.ropeCache = new w(this.config), this.blocks = [];
|
|
21
23
|
for (let i = 0; i < this.config.nLayer; i++)
|
|
22
|
-
this.blocks.push(new
|
|
23
|
-
this.lnF = new
|
|
24
|
+
this.blocks.push(new b(i, this.config, this));
|
|
25
|
+
this.lnF = new E(this.config, "final_rms_norm", this);
|
|
24
26
|
}
|
|
25
27
|
getClassName() {
|
|
26
28
|
return "GenAI_NanoGPT_v1";
|
|
27
29
|
}
|
|
28
30
|
inputPhase(e, i) {
|
|
29
|
-
return
|
|
30
|
-
const
|
|
31
|
+
return l(() => {
|
|
32
|
+
const n = this.wte.embed(e);
|
|
31
33
|
if (this.config.useRope === !1) {
|
|
32
|
-
const o = this.wpe.call(i,
|
|
34
|
+
const o = this.wpe.call(i, n);
|
|
33
35
|
if (Array.isArray(o))
|
|
34
36
|
throw new Error("PositionEmbedding output should not be an array");
|
|
35
37
|
return o;
|
|
36
38
|
}
|
|
37
|
-
return
|
|
39
|
+
return n;
|
|
38
40
|
});
|
|
39
41
|
}
|
|
40
|
-
forward(e, i,
|
|
41
|
-
return this.validateInput(i), e.ropeCache = this.ropeCache, e.outputEmbeddings && (e.embeddings = []),
|
|
42
|
+
forward(e, i, n) {
|
|
43
|
+
return this.validateInput(i), e.ropeCache = this.ropeCache, e.outputEmbeddings && (e.embeddings = []), l(() => {
|
|
42
44
|
this.startMemory();
|
|
43
45
|
let o = this.inputPhase(i, e);
|
|
44
46
|
if (e.cache && e.cache.length !== this.blocks.length)
|
|
45
47
|
throw console.error("Cache", e.cache), new Error(
|
|
46
48
|
`Cache length ${e.cache.length} does not match number of blocks ${this.blocks.length}`
|
|
47
49
|
);
|
|
48
|
-
|
|
49
|
-
|
|
50
|
+
const t = e.mixedPrecision === !0 && _();
|
|
51
|
+
let s = t ? y(o) : o;
|
|
52
|
+
t && o !== s && o.dispose();
|
|
53
|
+
for (let r = 0; r < this.blocks.length; r++) {
|
|
54
|
+
const d = this.blocks[r], a = Math.random() * 1e9, m = {
|
|
50
55
|
...e,
|
|
51
|
-
seed:
|
|
52
|
-
pastKV: e.cache ? e.cache[
|
|
53
|
-
|
|
54
|
-
|
|
56
|
+
seed: a,
|
|
57
|
+
pastKV: e.cache ? e.cache[r] : void 0,
|
|
58
|
+
mixedPrecision: t
|
|
59
|
+
}, f = e.checkpointing && e.training ? d.callCheckpoint(m, s) : d.call(m, s);
|
|
60
|
+
e.outputEmbeddings ? (u(s), e.embeddings.push({ name: `block_output_${r}`, tensor: s })) : s.dispose(), s = f;
|
|
55
61
|
}
|
|
56
|
-
o = this.lnF.call(e,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
62
|
+
if (o = this.lnF.call({ ...e, mixedPrecision: t }, s), s.dispose(), e.skipLogits)
|
|
63
|
+
return this.endMemory("Forward"), [o];
|
|
64
|
+
const c = this.wte.project(o);
|
|
65
|
+
e.outputEmbeddings ? (u(o), e.embeddings.push({ name: "final_norm_output", tensor: o })) : o.dispose();
|
|
66
|
+
const h = t ? M(c) : c;
|
|
67
|
+
t && c !== h && c.dispose();
|
|
68
|
+
let p;
|
|
69
|
+
return n && (p = this.calculateLoss(h, n)), this.endMemory("Forward"), p ? [h, p] : [h];
|
|
61
70
|
});
|
|
62
71
|
}
|
|
63
72
|
project(e) {
|
|
64
|
-
return
|
|
73
|
+
return l(() => this.wte.project(e));
|
|
65
74
|
}
|
|
66
75
|
dispose() {
|
|
67
76
|
this.wte.dispose(), this.wpe && this.wpe.dispose(), this.blocks.forEach((e) => e.dispose()), this.lnF.dispose();
|
|
68
77
|
}
|
|
69
78
|
}
|
|
70
79
|
export {
|
|
71
|
-
|
|
80
|
+
I as default
|
|
72
81
|
};
|
package/dist/models/model.d.ts
CHANGED
|
@@ -5,6 +5,7 @@ export interface ModelForwardAttributes extends ForwardAttributes {
|
|
|
5
5
|
cache?: KVCache[];
|
|
6
6
|
attentionScores?: AttentionScores;
|
|
7
7
|
seed?: number;
|
|
8
|
+
skipLogits?: boolean;
|
|
8
9
|
}
|
|
9
10
|
interface TrainingState {
|
|
10
11
|
steps: number;
|
|
@@ -13,6 +14,7 @@ interface TrainingState {
|
|
|
13
14
|
loss: number;
|
|
14
15
|
}
|
|
15
16
|
export default abstract class Model<T extends ModelForwardAttributes> extends BaseLayer<T> {
|
|
17
|
+
lossScaling: number;
|
|
16
18
|
trainingState: TrainingState | null;
|
|
17
19
|
abstract getClassName(): string;
|
|
18
20
|
abstract forward(attrs: T, idx: Tensor, targets?: Tensor): Tensor[];
|
package/dist/models/model.js
CHANGED
|
@@ -1,53 +1,60 @@
|
|
|
1
|
-
import
|
|
2
|
-
import "../
|
|
1
|
+
import m from "../layers/BaseLayer.js";
|
|
2
|
+
import "../utilities/packed.js";
|
|
3
|
+
import "../index-ZyQhjEPo.js";
|
|
3
4
|
import "../ops/cpu/attentionMask.js";
|
|
4
5
|
import "../ops/webgl/attentionMask.js";
|
|
5
6
|
import "../ops/grads/attentionMask.js";
|
|
6
|
-
import "../
|
|
7
|
-
import "../
|
|
8
|
-
import "../
|
|
9
|
-
import "../
|
|
10
|
-
import "../register_all_kernels-DIGpEwcf.js";
|
|
11
|
-
import "../index-Tf7vU29b.js";
|
|
12
|
-
import "../dataset-DlZtKmBq.js";
|
|
7
|
+
import "../random_width-DY6Kk2Dl.js";
|
|
8
|
+
import "../register_all_kernels-Bwu1PTuU.js";
|
|
9
|
+
import "../index-Cp39cXWe.js";
|
|
10
|
+
import "../dataset-0xP8GjwI.js";
|
|
13
11
|
import "../ops/cpu/rope.js";
|
|
14
12
|
import "../ops/webgl/rope.js";
|
|
15
|
-
import "../
|
|
13
|
+
import "../rope-B5UUMsPi.js";
|
|
16
14
|
import "../ops/cpu/appendCache.js";
|
|
17
15
|
import "../ops/webgl/appendCache.js";
|
|
18
|
-
import "../ops/
|
|
19
|
-
import "../
|
|
20
|
-
import "../ops/
|
|
21
|
-
import "../ops/cpu/
|
|
22
|
-
import "../
|
|
23
|
-
import "../ops/
|
|
16
|
+
import "../ops/grads/softmax16.js";
|
|
17
|
+
import "../matMul16--R5hOwDG.js";
|
|
18
|
+
import "../ops/webgl/matMul16.js";
|
|
19
|
+
import "../ops/cpu/matMul16.js";
|
|
20
|
+
import "../pack16-CFUqumar.js";
|
|
21
|
+
import "../ops/transpose16.js";
|
|
22
|
+
import "../ops/reshape16.js";
|
|
23
|
+
import "../ops/cpu/qkv.js";
|
|
24
|
+
import "../ops/webgl/qkv.js";
|
|
25
|
+
import "../ops/grads/qkv.js";
|
|
24
26
|
import "../ops/cpu/normRMS.js";
|
|
25
27
|
import "../ops/webgl/normRMS.js";
|
|
26
28
|
import "../ops/grads/normRMS.js";
|
|
27
|
-
import "../
|
|
28
|
-
import "../
|
|
29
|
+
import "../ops/grads/add16.js";
|
|
30
|
+
import "../jszip.min-Bz5-11Bk.js";
|
|
31
|
+
import "../index-DvYrXKkX.js";
|
|
29
32
|
import "../Generator.js";
|
|
30
33
|
import "../ops/cpu/adamAdjust.js";
|
|
31
34
|
import "../ops/webgl/adamAdjust.js";
|
|
32
35
|
import "../ops/cpu/adamMoments.js";
|
|
33
36
|
import "../ops/webgl/adamMoments.js";
|
|
34
|
-
import "../papaparse.min-
|
|
35
|
-
import { estimateParameterCount as
|
|
37
|
+
import "../papaparse.min-C0cScC2i.js";
|
|
38
|
+
import { estimateParameterCount as e } from "../utilities/parameters.js";
|
|
36
39
|
import "../ops/cpu/scatterSub.js";
|
|
37
40
|
import "../ops/webgl/scatterSub.js";
|
|
38
41
|
import "../ops/cpu/gatherSub.js";
|
|
39
42
|
import "../ops/webgl/gatherSub.js";
|
|
43
|
+
import "../ops/cpu/matMulGelu.js";
|
|
44
|
+
import "../ops/webgl/matMulGelu.js";
|
|
45
|
+
import "../ops/grads/matMulGelu.js";
|
|
40
46
|
import "../ops/cpu/gelu.js";
|
|
41
47
|
import "../ops/webgl/gelu.js";
|
|
42
|
-
import "../gelu-
|
|
48
|
+
import "../gelu-CNLFZWea.js";
|
|
43
49
|
import "../ops/webgl/log.js";
|
|
44
50
|
import "../checks/normRMS.js";
|
|
45
51
|
import "../checks/normRMSGrad.js";
|
|
46
|
-
import { createSoftmaxCrossEntropyWithGrad as
|
|
47
|
-
class
|
|
52
|
+
import { createSoftmaxCrossEntropyWithGrad as s } from "../training/sparseCrossEntropy.js";
|
|
53
|
+
class st extends m {
|
|
54
|
+
lossScaling = 128;
|
|
48
55
|
trainingState = null;
|
|
49
56
|
getNumParams() {
|
|
50
|
-
return
|
|
57
|
+
return e(this.config);
|
|
51
58
|
}
|
|
52
59
|
validateInput(t) {
|
|
53
60
|
if (t.shape.length !== 2)
|
|
@@ -57,14 +64,15 @@ class x extends i {
|
|
|
57
64
|
if (t.dtype !== "int32")
|
|
58
65
|
throw new Error(`Input tensor must be of type int32, got ${t.dtype}`);
|
|
59
66
|
}
|
|
60
|
-
calculateLoss(t,
|
|
67
|
+
calculateLoss(t, i) {
|
|
61
68
|
try {
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
69
|
+
const r = s()(t, i), p = r.mean();
|
|
70
|
+
return r.dispose(), p;
|
|
71
|
+
} catch (o) {
|
|
72
|
+
throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
|
|
65
73
|
}
|
|
66
74
|
}
|
|
67
75
|
}
|
|
68
76
|
export {
|
|
69
|
-
|
|
77
|
+
st as default
|
|
70
78
|
};
|
|
@@ -1,20 +1,4 @@
|
|
|
1
|
-
import { u as z } from "./gpgpu_math-
|
|
2
|
-
/**
|
|
3
|
-
* @license
|
|
4
|
-
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
5
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
* you may not use this file except in compliance with the License.
|
|
7
|
-
* You may obtain a copy of the License at
|
|
8
|
-
*
|
|
9
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
*
|
|
11
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
* See the License for the specific language governing permissions and
|
|
15
|
-
* limitations under the License.
|
|
16
|
-
* =============================================================================
|
|
17
|
-
*/
|
|
1
|
+
import { u as z } from "./gpgpu_math-DDVJCn6-.js";
|
|
18
2
|
class g {
|
|
19
3
|
constructor(e, s, v, a = !1, r = !1, c = !1, t = null, o = !1, l = !1) {
|
|
20
4
|
this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = v, this.enableShapeUniforms = z(this.outputShape.length);
|
|
@@ -1,19 +1,3 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @license
|
|
3
|
-
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
4
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
* you may not use this file except in compliance with the License.
|
|
6
|
-
* You may obtain a copy of the License at
|
|
7
|
-
*
|
|
8
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
*
|
|
10
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
* See the License for the specific language governing permissions and
|
|
14
|
-
* limitations under the License.
|
|
15
|
-
* =============================================================================
|
|
16
|
-
*/
|
|
17
1
|
function J(e, t, o) {
|
|
18
2
|
const n = k(e, t, o), s = n < 0 ? -(n + 1) : n;
|
|
19
3
|
e.splice(s, 0, t);
|
|
@@ -33,22 +17,6 @@ function A(e, t, o) {
|
|
|
33
17
|
}
|
|
34
18
|
return u ? n : -n - 1;
|
|
35
19
|
}
|
|
36
|
-
/**
|
|
37
|
-
* @license
|
|
38
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
39
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
40
|
-
* you may not use this file except in compliance with the License.
|
|
41
|
-
* You may obtain a copy of the License at
|
|
42
|
-
*
|
|
43
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
44
|
-
*
|
|
45
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
46
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
47
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
48
|
-
* See the License for the specific language governing permissions and
|
|
49
|
-
* limitations under the License.
|
|
50
|
-
* =============================================================================
|
|
51
|
-
*/
|
|
52
20
|
function O(e, t, o, n, s) {
|
|
53
21
|
return y(
|
|
54
22
|
e,
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import { E as a } from "./index-ZyQhjEPo.js";
|
|
2
|
+
import { d as t, m as n, s as i } from "./tensor-DdQUJZlz.js";
|
|
3
|
+
import { c as f } from "./complex-CSlYz-2T.js";
|
|
4
|
+
import { z as c } from "./zeros-2gldETuK.js";
|
|
5
|
+
function l(o, r = "float32") {
|
|
6
|
+
if (t(o), r === "complex64") {
|
|
7
|
+
const e = l(o, "float32"), m = c(o, "float32");
|
|
8
|
+
return f(e, m);
|
|
9
|
+
}
|
|
10
|
+
const s = n(i(o), r);
|
|
11
|
+
return a.makeTensor(s, o, r);
|
|
12
|
+
}
|
|
13
|
+
export {
|
|
14
|
+
l as o
|
|
15
|
+
};
|
package/dist/ops/adamAdjust.js
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
-
export declare function adamMoments(moments: Tensor, gradient: Tensor, beta1: number, beta2: number): Tensor;
|
|
2
|
+
export declare function adamMoments(moments: Tensor, gradient: Tensor, beta1: number, beta2: number, lossScaling: number): Tensor;
|
package/dist/ops/adamMoments.js
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import { e as
|
|
1
|
+
import { e as t } from "../index-ZyQhjEPo.js";
|
|
2
2
|
import "./cpu/adamMoments.js";
|
|
3
3
|
import "./webgl/adamMoments.js";
|
|
4
|
-
function
|
|
5
|
-
return
|
|
4
|
+
function s(e, n, r, m, o) {
|
|
5
|
+
return t().runKernel("AdamMoments", { moments: e, gradient: n }, { beta1: r, beta2: m, lossScaling: o });
|
|
6
6
|
}
|
|
7
7
|
export {
|
|
8
|
-
|
|
8
|
+
s as adamMoments
|
|
9
9
|
};
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
import { n as t, e as o } from "../index-ZyQhjEPo.js";
|
|
2
|
+
import { isPackedTensor as n } from "../utilities/packed.js";
|
|
3
|
+
import "./grads/add16.js";
|
|
4
|
+
function m(r, e) {
|
|
5
|
+
return !n(r) && !n(e) ? t(r, e) : o().runKernel("Add16", { a: r, b: e });
|
|
6
|
+
}
|
|
7
|
+
export {
|
|
8
|
+
m as add16
|
|
9
|
+
};
|
package/dist/ops/appendCache.js
CHANGED
|
@@ -1,15 +1,22 @@
|
|
|
1
|
-
import { e as a } from "../index-
|
|
1
|
+
import { e as a } from "../index-ZyQhjEPo.js";
|
|
2
2
|
import "./cpu/appendCache.js";
|
|
3
3
|
import "./webgl/appendCache.js";
|
|
4
|
-
import {
|
|
5
|
-
import {
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
4
|
+
import { isPackedTensor as c } from "../utilities/packed.js";
|
|
5
|
+
import { c as t } from "../concat-BHlIJeyT.js";
|
|
6
|
+
import { z as f } from "../zeros-2gldETuK.js";
|
|
7
|
+
function C(r, o, n, p) {
|
|
8
|
+
if (!p) {
|
|
9
|
+
const e = r.shape[2], s = c(r);
|
|
10
|
+
return t(
|
|
11
|
+
[
|
|
12
|
+
r,
|
|
13
|
+
f([r.shape[0], r.shape[1], o - e, r.shape[3]], s ? "int32" : r.dtype)
|
|
14
|
+
],
|
|
15
|
+
2
|
|
16
|
+
);
|
|
10
17
|
}
|
|
11
|
-
return a().runKernel("AppendCache", { cache:
|
|
18
|
+
return a().runKernel("AppendCache", { cache: p, item: r }, { maxSize: o, pastLen: n });
|
|
12
19
|
}
|
|
13
20
|
export {
|
|
14
|
-
|
|
21
|
+
C as appendCache
|
|
15
22
|
};
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import { e as
|
|
1
|
+
import { e as r } from "../index-ZyQhjEPo.js";
|
|
2
2
|
import "./cpu/attentionMask.js";
|
|
3
3
|
import "./webgl/attentionMask.js";
|
|
4
4
|
import "./grads/attentionMask.js";
|
|
5
|
-
function
|
|
6
|
-
return
|
|
5
|
+
function u(t, n, e, o) {
|
|
6
|
+
return r().runKernel("AttentionMask", { q: t, k: n }, { divisor: e, pastLen: o || 0 });
|
|
7
7
|
}
|
|
8
8
|
export {
|
|
9
|
-
|
|
9
|
+
u as attentionMask
|
|
10
10
|
};
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
import { isPackedTensor as o } from "../utilities/packed.js";
|
|
2
|
+
import { e } from "../index-ZyQhjEPo.js";
|
|
3
|
+
import { c } from "../concat-BHlIJeyT.js";
|
|
4
|
+
function p(r, n) {
|
|
5
|
+
return o(r[0]) ? e().runKernel("Concat16", r, { axis: n ?? -1 }) : c(r, n);
|
|
6
|
+
}
|
|
7
|
+
export {
|
|
8
|
+
p as concat16
|
|
9
|
+
};
|
|
@@ -1,18 +1,19 @@
|
|
|
1
|
-
import {
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
n
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
r
|
|
1
|
+
import { l as t, n as r, m as k, o as z } from "../../index-ZyQhjEPo.js";
|
|
2
|
+
import { r as A } from "../../tensor_util-DV-FP5Q3.js";
|
|
3
|
+
function C(o) {
|
|
4
|
+
const { moments: n, value: i } = o.inputs, { beta1: l, beta2: m, epsilon: u, learningRate: d } = o.attrs, e = n.shape.length, c = new Array(e).fill(0), s = n.shape.slice();
|
|
5
|
+
s[e - 1] = 1;
|
|
6
|
+
const a = c.slice();
|
|
7
|
+
a[e - 1] = 1;
|
|
8
|
+
const p = s.slice(), b = n.slice(c, s).squeeze([e - 1]), M = n.slice(a, p).squeeze([e - 1]), f = t(b, l), g = t(M, m);
|
|
9
|
+
return r(
|
|
10
|
+
k(t(f, r(z(g), u ?? 1e-8)), -d),
|
|
11
|
+
i
|
|
11
12
|
);
|
|
12
13
|
}
|
|
13
|
-
const
|
|
14
|
+
const h = {
|
|
14
15
|
kernelName: "AdamAdjust",
|
|
15
16
|
backendName: "cpu",
|
|
16
|
-
kernelFunc:
|
|
17
|
+
kernelFunc: C
|
|
17
18
|
};
|
|
18
|
-
|
|
19
|
+
A(h);
|
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
import
|
|
2
|
-
import {
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
import "../../index-ZyQhjEPo.js";
|
|
2
|
+
import { r as p } from "../../tensor_util-DV-FP5Q3.js";
|
|
3
|
+
import { s as b } from "../../stack-yOIAalTq.js";
|
|
4
|
+
function f(t) {
|
|
5
|
+
const { moments: n, gradient: o } = t.inputs, { beta1: c, beta2: m } = t.attrs, e = n.shape.length, a = new Array(e).fill(0), s = n.shape.slice();
|
|
5
6
|
s[e - 1] = 1;
|
|
6
|
-
const
|
|
7
|
-
|
|
8
|
-
const
|
|
9
|
-
return
|
|
7
|
+
const r = a.slice();
|
|
8
|
+
r[e - 1] = 1;
|
|
9
|
+
const i = s.slice(), l = n.slice(a, s).squeeze([e - 1]), u = n.slice(r, i).squeeze([e - 1]), M = l.mul(c).add(o.mul(1 - c)), d = u.mul(m).add(o.square().mul(1 - m));
|
|
10
|
+
return b([M, d], -1);
|
|
10
11
|
}
|
|
11
12
|
const g = {
|
|
12
13
|
kernelName: "AdamMoments",
|
|
13
14
|
backendName: "cpu",
|
|
14
|
-
kernelFunc:
|
|
15
|
+
kernelFunc: f
|
|
15
16
|
};
|
|
16
17
|
p(g);
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
import
|
|
2
|
-
import {
|
|
1
|
+
import "../../index-ZyQhjEPo.js";
|
|
2
|
+
import { r as d } from "../../tensor_util-DV-FP5Q3.js";
|
|
3
|
+
import { c as h } from "../../concat-BHlIJeyT.js";
|
|
3
4
|
function u(p) {
|
|
4
|
-
const { cache: n, item: s } = p.inputs, { maxSize:
|
|
5
|
-
if (c + e <=
|
|
6
|
-
const
|
|
7
|
-
return
|
|
5
|
+
const { cache: n, item: s } = p.inputs, { maxSize: a, pastLen: c } = p.attrs, t = n.shape[0], o = n.shape[1], r = n.shape[3], e = s.shape[2];
|
|
6
|
+
if (c + e <= a) {
|
|
7
|
+
const m = n.slice([0, 0, 0, 0], [t, o, c, r]), f = n.slice([0, 0, c + e, 0], [t, o, a - c - e, r]), i = e < e ? s.slice([0, 0, 0, 0], [t, o, e, r]) : s, k = h([m, i, f], 2);
|
|
8
|
+
return m.dispose(), f.dispose(), i !== s && i.dispose(), k;
|
|
8
9
|
}
|
|
9
|
-
const
|
|
10
|
-
return
|
|
10
|
+
const l = n.slice([0, 0, e, 0], [t, o, a - e, r]), C = h([l, s], 2);
|
|
11
|
+
return l.dispose(), C;
|
|
11
12
|
}
|
|
12
13
|
const w = {
|
|
13
14
|
kernelName: "AppendCache",
|
|
@@ -1,21 +1,22 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import {
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
1
|
+
import { i as d, b as u } from "../../index-ZyQhjEPo.js";
|
|
2
|
+
import { r as o } from "../../tensor_util-DV-FP5Q3.js";
|
|
3
|
+
import { l as N, w as b } from "../../ops-CNI3TwqM.js";
|
|
4
|
+
import { o as A } from "../../ones-CAMiP4I2.js";
|
|
5
|
+
import { z as I } from "../../zeros-2gldETuK.js";
|
|
6
|
+
import { m as g } from "../../mat_mul-DeAh4uTH.js";
|
|
7
|
+
function a(n) {
|
|
8
|
+
const { q: s, k: e } = n.inputs, { divisor: r } = n.attrs, c = s.shape[2], t = e.shape[2], m = N.bandPart(A([t, t]), -1, 0).cast("bool"), i = I([t, t]), l = d([t, t], Number.NEGATIVE_INFINITY), f = b(m, i, l), k = g(s, e, !1, !0).mul(u(r)), p = f.slice([0, 0], [c, t]).expandDims(0).expandDims(0);
|
|
9
|
+
return k.add(p);
|
|
9
10
|
}
|
|
10
|
-
const
|
|
11
|
+
const w = {
|
|
11
12
|
kernelName: "AttentionMask",
|
|
12
13
|
backendName: "cpu",
|
|
13
|
-
kernelFunc:
|
|
14
|
+
kernelFunc: a
|
|
14
15
|
};
|
|
15
|
-
|
|
16
|
-
const
|
|
16
|
+
o(w);
|
|
17
|
+
const M = {
|
|
17
18
|
kernelName: "AttentionMask",
|
|
18
19
|
backendName: "tensorflow",
|
|
19
|
-
kernelFunc:
|
|
20
|
+
kernelFunc: a
|
|
20
21
|
};
|
|
21
|
-
|
|
22
|
+
o(M);
|
|
@@ -1,29 +1,30 @@
|
|
|
1
|
-
import
|
|
2
|
-
import {
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
1
|
+
import "../../index-ZyQhjEPo.js";
|
|
2
|
+
import { r as e } from "../../tensor_util-DV-FP5Q3.js";
|
|
3
|
+
import { s as m } from "../../softmax-ZHVebtR1.js";
|
|
4
|
+
function o(t) {
|
|
5
|
+
const { inputs: s, attrs: a } = t, { logits: n } = s, { dim: i, dropoutRate: r } = a;
|
|
6
|
+
if (!n)
|
|
6
7
|
throw new Error("Error in softmax: input logits is null");
|
|
7
|
-
return r !== void 0 && r > 0 && console.warn("Dropout in fusedSoftmax not implemented for CPU backend, skipping dropout."),
|
|
8
|
+
return r !== void 0 && r > 0 && console.warn("Dropout in fusedSoftmax not implemented for CPU backend, skipping dropout."), m(n, i);
|
|
8
9
|
}
|
|
9
|
-
const
|
|
10
|
+
const f = {
|
|
10
11
|
kernelName: "FusedSoftmax",
|
|
11
12
|
backendName: "cpu",
|
|
12
|
-
kernelFunc:
|
|
13
|
+
kernelFunc: o
|
|
13
14
|
};
|
|
14
|
-
e(
|
|
15
|
+
e(f);
|
|
15
16
|
const u = {
|
|
16
17
|
kernelName: "FusedSoftmax",
|
|
17
18
|
backendName: "tensorflow",
|
|
18
|
-
kernelFunc:
|
|
19
|
+
kernelFunc: o
|
|
19
20
|
};
|
|
20
21
|
e(u);
|
|
21
22
|
const l = {
|
|
22
23
|
kernelName: "FusedSoftmax",
|
|
23
24
|
backendName: "webgpu",
|
|
24
|
-
kernelFunc:
|
|
25
|
+
kernelFunc: o
|
|
25
26
|
};
|
|
26
27
|
e(l);
|
|
27
28
|
export {
|
|
28
|
-
|
|
29
|
+
o as softmaxCPU
|
|
29
30
|
};
|
|
@@ -1,34 +1,19 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { r as
|
|
3
|
-
import {
|
|
4
|
-
|
|
5
|
-
* @license
|
|
6
|
-
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
7
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
-
* you may not use this file except in compliance with the License.
|
|
9
|
-
* You may obtain a copy of the License at
|
|
10
|
-
*
|
|
11
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
-
*
|
|
13
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
14
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
-
* See the License for the specific language governing permissions and
|
|
17
|
-
* limitations under the License.
|
|
18
|
-
* =============================================================================
|
|
19
|
-
*/
|
|
1
|
+
import { A as u, B as c, E as m, c as g } from "../../index-ZyQhjEPo.js";
|
|
2
|
+
import { k as p, r as h } from "../../tensor_util-DV-FP5Q3.js";
|
|
3
|
+
import { r as f } from "../../range-BMS52eQi.js";
|
|
4
|
+
import { s as l } from "../../stack-yOIAalTq.js";
|
|
20
5
|
function N(e, t) {
|
|
21
|
-
const n = c(t, "indices", "gatherND", "int32"),
|
|
22
|
-
return
|
|
6
|
+
const n = c(t, "indices", "gatherND", "int32"), r = { params: c(e, "x", "gatherND", "string_or_numeric"), indices: n };
|
|
7
|
+
return m.runKernel(p, r);
|
|
23
8
|
}
|
|
24
9
|
const b = /* @__PURE__ */ u({ gatherND_: N });
|
|
25
10
|
function d(e) {
|
|
26
|
-
const { values: t, labels: n, logits:
|
|
27
|
-
return
|
|
11
|
+
const { values: t, labels: n, logits: s } = e.inputs, r = n.shape[0], o = f(0, r, 1, "int32"), i = l([o, n], 1), a = b(s, i);
|
|
12
|
+
return g(t, a);
|
|
28
13
|
}
|
|
29
14
|
const k = {
|
|
30
15
|
kernelName: "EfficientGatherSub",
|
|
31
16
|
backendName: "cpu",
|
|
32
17
|
kernelFunc: d
|
|
33
18
|
};
|
|
34
|
-
|
|
19
|
+
h(k);
|