@genai-fi/nanogpt 0.9.0 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +352 -14
- package/dist/Generator.js +69 -78
- package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
- package/dist/Reshape-CF6odzV4.js +16 -0
- package/dist/Reshape-_kILl6tK.js +81 -0
- package/dist/TeachableLLM.js +28 -22
- package/dist/Trainer.d.ts +2 -0
- package/dist/Trainer.js +3 -2
- package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
- package/dist/backend.d.ts +2 -1
- package/dist/backend.js +10 -4
- package/dist/backend_util-D-rUb2ty.js +474 -0
- package/dist/backend_webgpu-B0u2ndUn.js +547 -0
- package/dist/binary_op_util-pKXltfxI.js +192 -0
- package/dist/broadcast_to-CwF7XIeu.js +30 -0
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/check.d.ts +1 -1
- package/dist/checks/check.js +8 -8
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/index.d.ts +2 -0
- package/dist/checks/index.js +7 -5
- package/dist/checks/matMulGelu.js +6 -6
- package/dist/checks/normRMS.js +7 -7
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.d.ts +1 -0
- package/dist/checks/packUnpack.js +18 -0
- package/dist/checks/qkv.js +12 -27
- package/dist/checks/rope.js +2 -2
- package/dist/checks/weights.js +18 -16
- package/dist/complex-CSlYz-2T.js +13 -0
- package/dist/complex_util-Yc1A_gV1.js +55 -0
- package/dist/concat-BHlIJeyT.js +19 -0
- package/dist/concat_util-DcJk7YHS.js +22 -0
- package/dist/data/docx.js +1 -1
- package/dist/data/parquet.js +2 -2
- package/dist/data/pdf.js +1 -1
- package/dist/data/textLoader.js +1 -1
- package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
- package/dist/dropout-C1pM3f11.js +99 -0
- package/dist/expand_dims-BPG4fwBP.js +13 -0
- package/dist/exports_initializers-xuidcwI4.js +7 -0
- package/dist/gather-DykLGqmW.js +10 -0
- package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
- package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
- package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
- package/dist/index-CjOj7j-u.js +7308 -0
- package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
- package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
- package/dist/index-ZyQhjEPo.js +2157 -0
- package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
- package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
- package/dist/layers/BaseLayer.d.ts +1 -0
- package/dist/layers/BaseLayer.js +7 -6
- package/dist/layers/CausalSelfAttention.d.ts +0 -1
- package/dist/layers/CausalSelfAttention.js +56 -55
- package/dist/layers/MLP.js +15 -16
- package/dist/layers/PositionEmbedding.js +5 -14
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.d.ts +2 -0
- package/dist/layers/RoPECache.js +22 -17
- package/dist/layers/TiedEmbedding.js +22 -17
- package/dist/layers/TransformerBlock.js +21 -20
- package/dist/loader/load.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +39 -33
- package/dist/loader/save.js +1 -1
- package/dist/log_sum_exp-DWI-76TI.js +41 -0
- package/dist/main.d.ts +8 -0
- package/dist/main.js +63 -52
- package/dist/matMul16--R5hOwDG.js +77 -0
- package/dist/mat_mul-DeAh4uTH.js +12 -0
- package/dist/mod-Gt1rMB4n.js +12 -0
- package/dist/models/NanoGPTV1.js +40 -31
- package/dist/models/model.d.ts +2 -0
- package/dist/models/model.js +37 -29
- package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
- package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
- package/dist/ones-CAMiP4I2.js +15 -0
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.d.ts +1 -1
- package/dist/ops/adamMoments.js +4 -4
- package/dist/ops/add16.d.ts +2 -0
- package/dist/ops/add16.js +9 -0
- package/dist/ops/appendCache.js +16 -9
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/concat16.d.ts +2 -0
- package/dist/ops/concat16.js +9 -0
- package/dist/ops/cpu/adamAdjust.js +14 -13
- package/dist/ops/cpu/adamMoments.js +10 -9
- package/dist/ops/cpu/appendCache.js +9 -8
- package/dist/ops/cpu/attentionMask.js +15 -14
- package/dist/ops/cpu/fusedSoftmax.js +13 -12
- package/dist/ops/cpu/gatherSub.js +9 -24
- package/dist/ops/cpu/gelu.js +13 -12
- package/dist/ops/cpu/matMul16.d.ts +1 -0
- package/dist/ops/cpu/matMul16.js +16 -0
- package/dist/ops/cpu/matMulGelu.js +18 -16
- package/dist/ops/cpu/matMulMul.js +8 -7
- package/dist/ops/cpu/mulDropout.js +4 -3
- package/dist/ops/cpu/normRMS.js +11 -10
- package/dist/ops/cpu/qkv.js +17 -13
- package/dist/ops/cpu/rope.js +23 -22
- package/dist/ops/cpu/scatterSub.js +16 -30
- package/dist/ops/dot16.d.ts +2 -0
- package/dist/ops/dot16.js +42 -0
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/add16.d.ts +1 -0
- package/dist/ops/grads/add16.js +27 -0
- package/dist/ops/grads/attentionMask.js +12 -19
- package/dist/ops/grads/gelu.js +4 -3
- package/dist/ops/grads/matMul16.d.ts +2 -0
- package/dist/ops/grads/matMul16.js +9 -0
- package/dist/ops/grads/matMulGelu.js +8 -7
- package/dist/ops/grads/normRMS.js +8 -7
- package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
- package/dist/ops/grads/pack16.js +7 -0
- package/dist/ops/grads/qkv.d.ts +3 -1
- package/dist/ops/grads/qkv.js +28 -22
- package/dist/ops/grads/rope.d.ts +2 -1
- package/dist/ops/grads/rope.js +6 -13
- package/dist/ops/grads/softmax16.d.ts +2 -0
- package/dist/ops/grads/softmax16.js +26 -0
- package/dist/ops/grads/unpack16.d.ts +2 -0
- package/dist/ops/grads/unpack16.js +6 -0
- package/dist/ops/grads/utils.d.ts +3 -0
- package/dist/ops/grads/utils.js +10 -0
- package/dist/ops/matMul16.d.ts +15 -0
- package/dist/ops/matMul16.js +13 -0
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mul16.d.ts +2 -0
- package/dist/ops/mul16.js +8 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/pack16.d.ts +2 -0
- package/dist/ops/pack16.js +6 -0
- package/dist/ops/qkv.d.ts +1 -1
- package/dist/ops/qkv.js +8 -4
- package/dist/ops/reshape16.d.ts +2 -0
- package/dist/ops/reshape16.js +43 -0
- package/dist/ops/rope.d.ts +1 -1
- package/dist/ops/rope.js +8 -10
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.d.ts +2 -0
- package/dist/ops/slice16.js +9 -0
- package/dist/ops/softmax16.d.ts +2 -0
- package/dist/ops/softmax16.js +12 -0
- package/dist/ops/sub16.d.ts +2 -0
- package/dist/ops/sub16.js +8 -0
- package/dist/ops/sum16.d.ts +2 -0
- package/dist/ops/sum16.js +13 -0
- package/dist/ops/transpose16.d.ts +3 -0
- package/dist/ops/transpose16.js +41 -0
- package/dist/ops/unpack16.d.ts +2 -0
- package/dist/ops/unpack16.js +6 -0
- package/dist/ops/webgl/adamAdjust.js +3 -2
- package/dist/ops/webgl/adamMoments.js +2 -1
- package/dist/ops/webgl/appendCache.js +2 -1
- package/dist/ops/webgl/attentionMask.js +5 -4
- package/dist/ops/webgl/fusedSoftmax.js +6 -4
- package/dist/ops/webgl/gatherSub.js +7 -6
- package/dist/ops/webgl/gelu.js +3 -2
- package/dist/ops/webgl/log.js +12 -27
- package/dist/ops/webgl/matMul16.d.ts +1 -0
- package/dist/ops/webgl/matMul16.js +37 -0
- package/dist/ops/webgl/matMulGelu.js +17 -15
- package/dist/ops/webgl/matMulMul.js +13 -12
- package/dist/ops/webgl/mulDropout.js +9 -8
- package/dist/ops/webgl/normRMS.js +8 -7
- package/dist/ops/webgl/qkv.js +6 -5
- package/dist/ops/webgl/rope.js +11 -10
- package/dist/ops/webgl/scatterSub.js +6 -5
- package/dist/ops/webgpu/adamAdjust.js +12 -10
- package/dist/ops/webgpu/adamMoments.js +27 -22
- package/dist/ops/webgpu/add16.d.ts +1 -0
- package/dist/ops/webgpu/add16.js +14 -0
- package/dist/ops/webgpu/appendCache.js +64 -17
- package/dist/ops/webgpu/attentionMask.js +19 -62
- package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
- package/dist/ops/webgpu/attentionMask32_program.js +54 -0
- package/dist/ops/webgpu/concat16.d.ts +19 -0
- package/dist/ops/webgpu/concat16.js +128 -0
- package/dist/ops/webgpu/gatherSub.js +9 -7
- package/dist/ops/webgpu/gelu.js +78 -31
- package/dist/ops/webgpu/index.js +12 -0
- package/dist/ops/webgpu/matMul16.d.ts +1 -0
- package/dist/ops/webgpu/matMul16.js +58 -0
- package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
- package/dist/ops/webgpu/matMul16_program.js +336 -0
- package/dist/ops/webgpu/mul16.d.ts +1 -0
- package/dist/ops/webgpu/mul16.js +14 -0
- package/dist/ops/webgpu/normRMS.js +21 -40
- package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
- package/dist/ops/webgpu/normRMS16_program.js +24 -0
- package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
- package/dist/ops/webgpu/normRMS32_program.js +24 -0
- package/dist/ops/webgpu/normRMSGrad.js +113 -64
- package/dist/ops/webgpu/pack16.d.ts +1 -0
- package/dist/ops/webgpu/pack16.js +19 -0
- package/dist/ops/webgpu/pack16_program.d.ts +19 -0
- package/dist/ops/webgpu/pack16_program.js +92 -0
- package/dist/ops/webgpu/qkv.js +20 -55
- package/dist/ops/webgpu/rope.js +77 -22
- package/dist/ops/webgpu/scatterSub.js +9 -7
- package/dist/ops/webgpu/slice16.d.ts +7 -0
- package/dist/ops/webgpu/slice16.js +71 -0
- package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
- package/dist/ops/webgpu/softmax16.js +23 -0
- package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
- package/dist/ops/webgpu/softmax16_program.js +73 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
- package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
- package/dist/ops/webgpu/softmax16grad.js +38 -0
- package/dist/ops/webgpu/sub16.d.ts +1 -0
- package/dist/ops/webgpu/sub16.js +14 -0
- package/dist/ops/webgpu/sum16.d.ts +1 -0
- package/dist/ops/webgpu/sum16.js +40 -0
- package/dist/ops/webgpu/transpose16.d.ts +1 -0
- package/dist/ops/webgpu/transpose16.js +35 -0
- package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
- package/dist/ops/webgpu/transpose16_program.js +50 -0
- package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
- package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
- package/dist/ops/webgpu/unpack16.d.ts +1 -0
- package/dist/ops/webgpu/unpack16.js +49 -0
- package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
- package/dist/ops/webgpu/utils/binary_op.js +79 -0
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
- package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
- package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
- package/dist/ops/webgpu/utils/reductions.js +236 -45
- package/dist/ops-CNI3TwqM.js +645 -0
- package/dist/pack16-CFUqumar.js +41 -0
- package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
- package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
- package/dist/patches/PackedTensor.d.ts +12 -0
- package/dist/patches/PackedTensor.js +11 -0
- package/dist/patches/engine.d.ts +261 -0
- package/dist/patches/engine.js +10 -0
- package/dist/patches/tape.d.ts +12 -0
- package/dist/patches/tape.js +5 -0
- package/dist/patches/webgpu_backend.d.ts +18 -0
- package/dist/patches/webgpu_backend.js +57 -0
- package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
- package/dist/patches/webgpu_base.js +34 -0
- package/dist/patches/webgpu_program.d.ts +36 -0
- package/dist/patches/webgpu_program.js +401 -0
- package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
- package/dist/random_width-DY6Kk2Dl.js +10051 -0
- package/dist/range-BMS52eQi.js +11 -0
- package/dist/reciprocal-CTmshQ9J.js +10 -0
- package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
- package/dist/relu-yZ2-7WxU.js +10 -0
- package/dist/reshape-DevtBWtf.js +10 -0
- package/dist/rope-B5UUMsPi.js +32 -0
- package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
- package/dist/selu_util-D1w6yyTO.js +303 -0
- package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
- package/dist/shared-BuAXb4CI.js +2145 -0
- package/dist/sin-BGfy2HZo.js +16 -0
- package/dist/slice-D_gkkqZK.js +13 -0
- package/dist/slice_util-DtEldBfK.js +261 -0
- package/dist/softmax-ZHVebtR1.js +13 -0
- package/dist/split-DrfihRpZ.js +10 -0
- package/dist/squeeze-DZEpeblb.js +11 -0
- package/dist/stack-yOIAalTq.js +13 -0
- package/dist/sum-_fzj5ZTB.js +12 -0
- package/dist/tensor-DdQUJZlz.js +909 -0
- package/dist/tensor-f35l8Odg.js +8 -0
- package/dist/tensor1d-CeZuc-Rv.js +12 -0
- package/dist/tensor2d-G4Ys2GxX.js +15 -0
- package/dist/tensor4d-B8roDgtc.js +15 -0
- package/dist/tensor_util-DV-FP5Q3.js +523 -0
- package/dist/tfjs_backend-kNyO5L2d.js +653 -0
- package/dist/tile-BzyEiF-F.js +13 -0
- package/dist/tokeniser/CharTokeniser.js +1 -1
- package/dist/tokeniser/bpe.js +1 -1
- package/dist/training/Adam.d.ts +2 -1
- package/dist/training/Adam.js +12 -28
- package/dist/training/AdamExt.d.ts +1 -0
- package/dist/training/AdamExt.js +2 -2
- package/dist/training/DatasetBuilder.js +3 -20
- package/dist/training/FullTrainer.js +82 -64
- package/dist/training/Trainer.d.ts +11 -6
- package/dist/training/Trainer.js +51 -39
- package/dist/training/sparseCrossEntropy.js +3 -3
- package/dist/transpose-DKELTqhe.js +38 -0
- package/dist/utilities/arrayClose.js +7 -7
- package/dist/utilities/dummy.js +35 -27
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.d.ts +7 -0
- package/dist/utilities/packed.js +716 -0
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.d.ts +5 -0
- package/dist/utilities/sentences.js +41 -0
- package/dist/utilities/weights.js +2 -2
- package/dist/variable-Bhn5bHYv.js +7 -0
- package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
- package/dist/webgpu_util-BBCnKm2X.js +65 -0
- package/dist/zeros-2gldETuK.js +14 -0
- package/package.json +4 -3
- package/dist/Reshape-Bowtk9BP.js +0 -127
- package/dist/Reshape-DUqYftGC.js +0 -30
- package/dist/backend_util-CJIiDoV1.js +0 -749
- package/dist/broadcast_to-DzlNweb8.js +0 -44
- package/dist/concat-B912vBbo.js +0 -33
- package/dist/dropout-C-csYCLj.js +0 -193
- package/dist/exports_initializers-B8iZMgQ0.js +0 -16
- package/dist/gather-Dnpgw-YQ.js +0 -25
- package/dist/index-BzFyqcy-.js +0 -4457
- package/dist/index-C1rx_Ajs.js +0 -12076
- package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
- package/dist/log_sum_exp-DO6z8tSE.js +0 -103
- package/dist/mat_mul-DzjTFx-u.js +0 -27
- package/dist/mod-Dobti4j4.js +0 -27
- package/dist/ones-tIJeHlq-.js +0 -29
- package/dist/ops/fusedSoftmax.d.ts +0 -2
- package/dist/ops/fusedSoftmax.js +0 -10
- package/dist/ops/grads/fusedSoftmax.js +0 -22
- package/dist/ops-LuCMAnmM.js +0 -1525
- package/dist/random_width-CXVRloNK.js +0 -13670
- package/dist/range-CWcz7xFA.js +0 -26
- package/dist/reciprocal-C4rNcM-S.js +0 -25
- package/dist/relu-BjCh_SYb.js +0 -25
- package/dist/reshape-CnIwVG1c.js +0 -25
- package/dist/selu_util-OtRzVwW5.js +0 -719
- package/dist/shared-DmRsFyaJ.js +0 -3134
- package/dist/sin-gpDNRxE0.js +0 -47
- package/dist/slice-d0Vo9XTN.js +0 -28
- package/dist/softmax-D7Jj3p_P.js +0 -28
- package/dist/split-DK2k5eHf.js +0 -25
- package/dist/stack-DFatutCx.js +0 -27
- package/dist/sum-CJ0ULhmt.js +0 -27
- package/dist/tensor1d-vML0r3q6.js +0 -27
- package/dist/tensor2d-D76QGjF3.js +0 -30
- package/dist/tensor4d-Df1WlVDY.js +0 -30
- package/dist/webgpu_util-pLEV9tks.js +0 -80
- package/dist/zeros-Bj5rMYA7.js +0 -52
|
@@ -1,26 +1,12 @@
|
|
|
1
|
-
import
|
|
2
|
-
import { r as $ } from "./Reshape-
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import { c as
|
|
6
|
-
import {
|
|
7
|
-
import {
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
11
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
12
|
-
* you may not use this file except in compliance with the License.
|
|
13
|
-
* You may obtain a copy of the License at
|
|
14
|
-
*
|
|
15
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
16
|
-
*
|
|
17
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
18
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
19
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
20
|
-
* See the License for the specific language governing permissions and
|
|
21
|
-
* limitations under the License.
|
|
22
|
-
* =============================================================================
|
|
23
|
-
*/
|
|
1
|
+
import "./index-ZyQhjEPo.js";
|
|
2
|
+
import { r as $ } from "./Reshape-_kILl6tK.js";
|
|
3
|
+
import { _ as T, g as E, y as B, $ as F } from "./tensor_util-DV-FP5Q3.js";
|
|
4
|
+
import { H as _, e as K, p as O, s as V } from "./tensor-DdQUJZlz.js";
|
|
5
|
+
import { a as A, b as k, d as C, c as N, e as R } from "./axis_util-BvHEw88j.js";
|
|
6
|
+
import { t as U, m as W } from "./shared-BRksrJb3.js";
|
|
7
|
+
import { c as j } from "./backend_util-D-rUb2ty.js";
|
|
8
|
+
import { f as y } from "./gpgpu_math-DDVJCn6-.js";
|
|
9
|
+
import { g as G, b as L } from "./kernel_funcs_utils-Dg_-E44D.js";
|
|
24
10
|
class w {
|
|
25
11
|
constructor(s, e) {
|
|
26
12
|
this.variableNames = ["x"];
|
|
@@ -30,7 +16,7 @@ class w {
|
|
|
30
16
|
let o = "sumValue += dot(values, ones);";
|
|
31
17
|
if (e != null) {
|
|
32
18
|
const p = 1 / e;
|
|
33
|
-
o = `sumValue += dot(values * ${
|
|
19
|
+
o = `sumValue += dot(values * ${_(p) ? p.toPrecision(2) : p}, ones);`;
|
|
34
20
|
}
|
|
35
21
|
let u = "";
|
|
36
22
|
l % t > 0 && (u = `
|
|
@@ -89,23 +75,7 @@ class w {
|
|
|
89
75
|
`;
|
|
90
76
|
}
|
|
91
77
|
}
|
|
92
|
-
|
|
93
|
-
* @license
|
|
94
|
-
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
95
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
96
|
-
* you may not use this file except in compliance with the License.
|
|
97
|
-
* You may obtain a copy of the License at
|
|
98
|
-
*
|
|
99
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
100
|
-
*
|
|
101
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
102
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
103
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
104
|
-
* See the License for the specific language governing permissions and
|
|
105
|
-
* limitations under the License.
|
|
106
|
-
* =============================================================================
|
|
107
|
-
*/
|
|
108
|
-
class q {
|
|
78
|
+
class X {
|
|
109
79
|
constructor(s, e) {
|
|
110
80
|
this.variableNames = ["x"];
|
|
111
81
|
const { windowSize: t, batchSize: n, inSize: l, outSize: r } = s;
|
|
@@ -213,26 +183,10 @@ class q {
|
|
|
213
183
|
`;
|
|
214
184
|
}
|
|
215
185
|
}
|
|
216
|
-
|
|
217
|
-
* @license
|
|
218
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
219
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
220
|
-
* you may not use this file except in compliance with the License.
|
|
221
|
-
* You may obtain a copy of the License at
|
|
222
|
-
*
|
|
223
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
224
|
-
*
|
|
225
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
226
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
227
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
228
|
-
* See the License for the specific language governing permissions and
|
|
229
|
-
* limitations under the License.
|
|
230
|
-
* =============================================================================
|
|
231
|
-
*/
|
|
232
|
-
function X(a) {
|
|
186
|
+
function q(a) {
|
|
233
187
|
const s = [];
|
|
234
188
|
for (; s.length === 0 || s[s.length - 1].outSize !== 1; ) {
|
|
235
|
-
const e = s.length ? s[s.length - 1].outSize : a[1], t =
|
|
189
|
+
const e = s.length ? s[s.length - 1].outSize : a[1], t = j(e);
|
|
236
190
|
s.push({
|
|
237
191
|
inSize: e,
|
|
238
192
|
windowSize: t,
|
|
@@ -242,39 +196,23 @@ function X(a) {
|
|
|
242
196
|
return s;
|
|
243
197
|
}
|
|
244
198
|
function P(a, s, e, t) {
|
|
245
|
-
const n =
|
|
199
|
+
const n = q(a.shape);
|
|
246
200
|
let l = a;
|
|
247
201
|
for (let r = 0; r < n.length; r++) {
|
|
248
202
|
const { inSize: i, windowSize: c, outSize: o } = n[r];
|
|
249
203
|
let u, p;
|
|
250
|
-
e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new
|
|
204
|
+
e === "mean" ? u = r === 0 ? new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, i) : new w({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }) : u = new X({ windowSize: c, inSize: i, batchSize: a.shape[0], outSize: o }, e), p = l, l = t.runWebGLProgram(u, [l], s), p.dataId !== a.dataId && t.disposeIntermediateTensorInfo(p);
|
|
251
205
|
}
|
|
252
206
|
return l;
|
|
253
207
|
}
|
|
254
|
-
|
|
255
|
-
* @license
|
|
256
|
-
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
257
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
258
|
-
* you may not use this file except in compliance with the License.
|
|
259
|
-
* You may obtain a copy of the License at
|
|
260
|
-
*
|
|
261
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
262
|
-
*
|
|
263
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
264
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
265
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
266
|
-
* See the License for the specific language governing permissions and
|
|
267
|
-
* limitations under the License.
|
|
268
|
-
* =============================================================================
|
|
269
|
-
*/
|
|
270
|
-
class Y {
|
|
208
|
+
class H {
|
|
271
209
|
constructor(s, e) {
|
|
272
210
|
this.variableNames = ["A"];
|
|
273
211
|
const t = new Array(s.length);
|
|
274
212
|
for (let r = 0; r < t.length; r++)
|
|
275
213
|
t[r] = s[e[r]];
|
|
276
214
|
this.outputShape = t, this.rank = t.length;
|
|
277
|
-
const n = y(this.rank), l =
|
|
215
|
+
const n = y(this.rank), l = Y(e);
|
|
278
216
|
this.userCode = `
|
|
279
217
|
void main() {
|
|
280
218
|
${n} resRC = getOutputCoords();
|
|
@@ -283,7 +221,7 @@ class Y {
|
|
|
283
221
|
`;
|
|
284
222
|
}
|
|
285
223
|
}
|
|
286
|
-
function
|
|
224
|
+
function Y(a) {
|
|
287
225
|
const s = a.length;
|
|
288
226
|
if (s > 6)
|
|
289
227
|
throw Error(`Transpose for rank ${s} is not yet supported`);
|
|
@@ -292,22 +230,6 @@ function H(a) {
|
|
|
292
230
|
t[a[n]] = e[n];
|
|
293
231
|
return t.join();
|
|
294
232
|
}
|
|
295
|
-
/**
|
|
296
|
-
* @license
|
|
297
|
-
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
298
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
299
|
-
* you may not use this file except in compliance with the License.
|
|
300
|
-
* You may obtain a copy of the License at
|
|
301
|
-
*
|
|
302
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
303
|
-
*
|
|
304
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
305
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
306
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
307
|
-
* See the License for the specific language governing permissions and
|
|
308
|
-
* limitations under the License.
|
|
309
|
-
* =============================================================================
|
|
310
|
-
*/
|
|
311
233
|
class J {
|
|
312
234
|
constructor(s, e) {
|
|
313
235
|
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
|
|
@@ -340,115 +262,35 @@ class J {
|
|
|
340
262
|
`;
|
|
341
263
|
}
|
|
342
264
|
}
|
|
343
|
-
/**
|
|
344
|
-
* @license
|
|
345
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
346
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
347
|
-
* you may not use this file except in compliance with the License.
|
|
348
|
-
* You may obtain a copy of the License at
|
|
349
|
-
*
|
|
350
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
351
|
-
*
|
|
352
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
353
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
354
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
355
|
-
* See the License for the specific language governing permissions and
|
|
356
|
-
* limitations under the License.
|
|
357
|
-
* =============================================================================
|
|
358
|
-
*/
|
|
359
265
|
function D(a, s, e) {
|
|
360
|
-
const t =
|
|
266
|
+
const t = K().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new J(a.shape, s) : new H(a.shape, s);
|
|
361
267
|
return e.runWebGLProgram(t, [a], a.dtype);
|
|
362
268
|
}
|
|
363
|
-
/**
|
|
364
|
-
* @license
|
|
365
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
366
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
367
|
-
* you may not use this file except in compliance with the License.
|
|
368
|
-
* You may obtain a copy of the License at
|
|
369
|
-
*
|
|
370
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
371
|
-
*
|
|
372
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
373
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
374
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
375
|
-
* See the License for the specific language governing permissions and
|
|
376
|
-
* limitations under the License.
|
|
377
|
-
* =============================================================================
|
|
378
|
-
*/
|
|
379
269
|
function Q(a, s, e, t) {
|
|
380
270
|
const n = s, l = a.shape.length, r = O(n, a.shape);
|
|
381
271
|
let i = r;
|
|
382
272
|
const c = A(i, l), o = c != null;
|
|
383
273
|
let u = a;
|
|
384
|
-
o && (u = D(a, c, t), i =
|
|
274
|
+
o && (u = D(a, c, t), i = k(i.length, l)), C("sum", i, l);
|
|
385
275
|
const [p, h] = N(u.shape, i);
|
|
386
276
|
let d = p;
|
|
387
277
|
e && (d = R(p, r));
|
|
388
|
-
const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b =
|
|
278
|
+
const f = V(h), g = V(a.shape) / f, x = $({ inputs: { x: u }, attrs: { shape: [g, f] }, backend: t }), b = T(a.dtype), I = P(x, b, "sum", t), m = $({ inputs: { x: I }, attrs: { shape: d }, backend: t });
|
|
389
279
|
return t.disposeIntermediateTensorInfo(x), t.disposeIntermediateTensorInfo(I), o && t.disposeIntermediateTensorInfo(u), m;
|
|
390
280
|
}
|
|
391
|
-
/**
|
|
392
|
-
* @license
|
|
393
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
394
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
395
|
-
* you may not use this file except in compliance with the License.
|
|
396
|
-
* You may obtain a copy of the License at
|
|
397
|
-
*
|
|
398
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
399
|
-
*
|
|
400
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
401
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
402
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
403
|
-
* See the License for the specific language governing permissions and
|
|
404
|
-
* limitations under the License.
|
|
405
|
-
* =============================================================================
|
|
406
|
-
*/
|
|
407
281
|
function Z(a) {
|
|
408
282
|
const { inputs: s, backend: e, attrs: t } = a, { x: n } = s, { axis: l, keepDims: r } = t;
|
|
409
283
|
return Q(n, l, r, e);
|
|
410
284
|
}
|
|
411
|
-
const
|
|
412
|
-
kernelName:
|
|
285
|
+
const fe = {
|
|
286
|
+
kernelName: E,
|
|
413
287
|
backendName: "webgl",
|
|
414
288
|
kernelFunc: Z
|
|
415
289
|
};
|
|
416
|
-
/**
|
|
417
|
-
* @license
|
|
418
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
419
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
420
|
-
* you may not use this file except in compliance with the License.
|
|
421
|
-
* You may obtain a copy of the License at
|
|
422
|
-
*
|
|
423
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
424
|
-
*
|
|
425
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
426
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
427
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
428
|
-
* See the License for the specific language governing permissions and
|
|
429
|
-
* limitations under the License.
|
|
430
|
-
* =============================================================================
|
|
431
|
-
*/
|
|
432
290
|
function ee(a, s, e, t) {
|
|
433
291
|
const n = V(s), r = V(a.shape) / n, i = $({ inputs: { x: a }, attrs: { shape: [r, n] }, backend: t }), c = P(i, a.dtype, "max", t), o = $({ inputs: { x: c }, attrs: { shape: e }, backend: t });
|
|
434
292
|
return t.disposeIntermediateTensorInfo(i), t.disposeIntermediateTensorInfo(c), o;
|
|
435
293
|
}
|
|
436
|
-
/**
|
|
437
|
-
* @license
|
|
438
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
439
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
440
|
-
* you may not use this file except in compliance with the License.
|
|
441
|
-
* You may obtain a copy of the License at
|
|
442
|
-
*
|
|
443
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
444
|
-
*
|
|
445
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
446
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
447
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
448
|
-
* See the License for the specific language governing permissions and
|
|
449
|
-
* limitations under the License.
|
|
450
|
-
* =============================================================================
|
|
451
|
-
*/
|
|
452
294
|
function te(a) {
|
|
453
295
|
const { inputs: s, backend: e, attrs: t } = a, { x: n } = s, { reductionIndices: l, keepDims: r } = t, i = n.shape.length, c = O(l, n.shape);
|
|
454
296
|
let o = c;
|
|
@@ -465,9 +307,9 @@ function te(a) {
|
|
|
465
307
|
M.values = z;
|
|
466
308
|
} else
|
|
467
309
|
d = D(n, u, e);
|
|
468
|
-
o =
|
|
310
|
+
o = k(o.length, i);
|
|
469
311
|
}
|
|
470
|
-
|
|
312
|
+
C("max", o, i);
|
|
471
313
|
const [f, S] = N(d.shape, o);
|
|
472
314
|
let g = f;
|
|
473
315
|
r && (g = R(f, c));
|
|
@@ -481,27 +323,11 @@ function te(a) {
|
|
|
481
323
|
x = ee(d, S, g, e);
|
|
482
324
|
return p && e.disposeIntermediateTensorInfo(d), x;
|
|
483
325
|
}
|
|
484
|
-
const
|
|
485
|
-
kernelName:
|
|
326
|
+
const me = {
|
|
327
|
+
kernelName: B,
|
|
486
328
|
backendName: "webgl",
|
|
487
329
|
kernelFunc: te
|
|
488
330
|
};
|
|
489
|
-
/**
|
|
490
|
-
* @license
|
|
491
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
492
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
493
|
-
* you may not use this file except in compliance with the License.
|
|
494
|
-
* You may obtain a copy of the License at
|
|
495
|
-
*
|
|
496
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
497
|
-
*
|
|
498
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
499
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
500
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
501
|
-
* See the License for the specific language governing permissions and
|
|
502
|
-
* limitations under the License.
|
|
503
|
-
* =============================================================================
|
|
504
|
-
*/
|
|
505
331
|
const ae = `
|
|
506
332
|
if (a == b) {
|
|
507
333
|
return 1.0;
|
|
@@ -524,16 +350,16 @@ return a / b;`, se = `
|
|
|
524
350
|
}
|
|
525
351
|
|
|
526
352
|
return result;
|
|
527
|
-
`, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }),
|
|
528
|
-
kernelName:
|
|
353
|
+
`, ne = L({ opSnippet: ae, packedOpSnippet: se, checkOutOfBounds: !0 }), xe = {
|
|
354
|
+
kernelName: F,
|
|
529
355
|
backendName: "webgl",
|
|
530
356
|
kernelFunc: ne
|
|
531
357
|
};
|
|
532
358
|
export {
|
|
533
359
|
P as a,
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
360
|
+
me as b,
|
|
361
|
+
xe as c,
|
|
362
|
+
fe as d,
|
|
537
363
|
te as m,
|
|
538
364
|
ne as r,
|
|
539
365
|
Z as s,
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import "./index-ZyQhjEPo.js";
|
|
2
|
+
import { s as p, n as m, a as d } from "./tensor-DdQUJZlz.js";
|
|
3
|
+
import { b as c } from "./tensor_util-DV-FP5Q3.js";
|
|
4
|
+
function i(t) {
|
|
5
|
+
const { inputs: h, attrs: o } = t, { x: e } = h, { shape: r } = o, a = p(e.shape), s = m(r, a), n = p(s);
|
|
6
|
+
return d(a === n, () => `The new shape (${s}) has ${n} elements and the old shape (${e.shape}) has ${a} elements. The new shape and old shape must have the same number of elements.`), t.backend.incRef(e.dataId), { dataId: e.dataId, shape: s, dtype: e.dtype };
|
|
7
|
+
}
|
|
8
|
+
const $ = {
|
|
9
|
+
kernelName: c,
|
|
10
|
+
backendName: "webgpu",
|
|
11
|
+
kernelFunc: i
|
|
12
|
+
};
|
|
13
|
+
export {
|
|
14
|
+
$ as a,
|
|
15
|
+
i as r
|
|
16
|
+
};
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import "./index-ZyQhjEPo.js";
|
|
2
|
+
import { u as C, g as f, a as R, b as g, c as I, d as c, e as u, i as m } from "./gpgpu_math-DDVJCn6-.js";
|
|
3
|
+
import { b as x } from "./tensor_util-DV-FP5Q3.js";
|
|
4
|
+
import { s as l, n as F, a as $ } from "./tensor-DdQUJZlz.js";
|
|
5
|
+
class S {
|
|
6
|
+
constructor(t, i) {
|
|
7
|
+
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = t, this.enableShapeUniforms = C(this.outputShape.length);
|
|
8
|
+
let a = "";
|
|
9
|
+
for (let e = 0; e < 4; e++) {
|
|
10
|
+
let o = "thisRC = rc;";
|
|
11
|
+
e % 2 === 1 && (o += "thisRC.z += 1;"), e > 1 && (o += "thisRC.y += 1;"), a += `
|
|
12
|
+
${o}
|
|
13
|
+
${e > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : ""}
|
|
14
|
+
int flatIndex = getFlatIndex(thisRC);
|
|
15
|
+
|
|
16
|
+
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
|
|
17
|
+
vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
|
|
18
|
+
|
|
19
|
+
result[${e}] =
|
|
20
|
+
getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
|
|
21
|
+
${e > 0 ? "}" : ""}
|
|
22
|
+
`;
|
|
23
|
+
}
|
|
24
|
+
this.userCode = `
|
|
25
|
+
${b(i, this.enableShapeUniforms)}
|
|
26
|
+
${this.enableShapeUniforms ? f() : R(t)}
|
|
27
|
+
|
|
28
|
+
void main() {
|
|
29
|
+
ivec3 rc = getOutputCoords();
|
|
30
|
+
|
|
31
|
+
vec4 result = vec4(0.);
|
|
32
|
+
|
|
33
|
+
ivec3 thisRC;
|
|
34
|
+
int rows = ${this.enableShapeUniforms ? "outShape[1]" : t[1]};
|
|
35
|
+
int cols = ${this.enableShapeUniforms ? "outShape[2]" : t[2]};
|
|
36
|
+
|
|
37
|
+
${a}
|
|
38
|
+
|
|
39
|
+
setOutput(result);
|
|
40
|
+
}
|
|
41
|
+
`;
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
function b(s, t) {
|
|
45
|
+
return `
|
|
46
|
+
ivec3 inputCoordsFromReshapedOutCoords(int index) {
|
|
47
|
+
${t ? g(["r", "c", "d"], "inputShape") : I(["r", "c", "d"], s)}
|
|
48
|
+
return ivec3(r, c, d);
|
|
49
|
+
}
|
|
50
|
+
`;
|
|
51
|
+
}
|
|
52
|
+
function v(s, t, i) {
|
|
53
|
+
const a = [
|
|
54
|
+
c(s.shape),
|
|
55
|
+
...u(s.shape)
|
|
56
|
+
], e = {
|
|
57
|
+
dtype: s.dtype,
|
|
58
|
+
shape: a,
|
|
59
|
+
dataId: s.dataId
|
|
60
|
+
}, o = [
|
|
61
|
+
c(t),
|
|
62
|
+
...u(t)
|
|
63
|
+
], r = new S(o, a), p = !0, n = [a], h = i.runWebGLProgram(r, [e], s.dtype, n, p);
|
|
64
|
+
return { dataId: h.dataId, shape: t, dtype: h.dtype };
|
|
65
|
+
}
|
|
66
|
+
function y(s) {
|
|
67
|
+
const { inputs: t, backend: i, attrs: a } = s, { x: e } = t, { shape: o } = a, r = i, p = l(e.shape), n = F(o, p), h = l(n);
|
|
68
|
+
$(p === h, () => `The new shape (${n}) has ${h} elements and the old shape (${e.shape}) has ${p} elements. The new shape and old shape must have the same number of elements.`);
|
|
69
|
+
const d = r.texData.get(e.dataId);
|
|
70
|
+
return d.isPacked && !m(e.shape, n) && !(d.texture !== null && m(d.shape, n)) ? v(e, n, r) : (r.incRef(e.dataId), { dataId: e.dataId, shape: n, dtype: e.dtype });
|
|
71
|
+
}
|
|
72
|
+
const O = {
|
|
73
|
+
kernelName: x,
|
|
74
|
+
backendName: "webgl",
|
|
75
|
+
kernelFunc: y
|
|
76
|
+
};
|
|
77
|
+
export {
|
|
78
|
+
S as R,
|
|
79
|
+
O as a,
|
|
80
|
+
y as r
|
|
81
|
+
};
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -2,45 +2,51 @@ import { defaultConfig as d } from "./models/config.js";
|
|
|
2
2
|
import { saveModel as l } from "./loader/save.js";
|
|
3
3
|
import { loadModel as _ } from "./loader/load.js";
|
|
4
4
|
import u from "./Generator.js";
|
|
5
|
-
import
|
|
6
|
-
import { E as
|
|
5
|
+
import p from "./Trainer.js";
|
|
6
|
+
import { E as f } from "./index-DvYrXKkX.js";
|
|
7
7
|
import { dummyPassTrainAsync as m } from "./utilities/dummy.js";
|
|
8
|
-
import "./
|
|
8
|
+
import "./utilities/packed.js";
|
|
9
|
+
import "./index-ZyQhjEPo.js";
|
|
9
10
|
import "./ops/cpu/attentionMask.js";
|
|
10
11
|
import "./ops/webgl/attentionMask.js";
|
|
11
12
|
import "./ops/grads/attentionMask.js";
|
|
12
|
-
import "./
|
|
13
|
-
import "./
|
|
14
|
-
import "./
|
|
15
|
-
import "./
|
|
16
|
-
import "./register_all_kernels-DIGpEwcf.js";
|
|
17
|
-
import "./index-Tf7vU29b.js";
|
|
18
|
-
import "./dataset-DlZtKmBq.js";
|
|
13
|
+
import "./random_width-DY6Kk2Dl.js";
|
|
14
|
+
import "./register_all_kernels-Bwu1PTuU.js";
|
|
15
|
+
import "./index-Cp39cXWe.js";
|
|
16
|
+
import "./dataset-0xP8GjwI.js";
|
|
19
17
|
import "./ops/cpu/rope.js";
|
|
20
18
|
import "./ops/webgl/rope.js";
|
|
21
|
-
import "./
|
|
19
|
+
import "./rope-B5UUMsPi.js";
|
|
22
20
|
import "./ops/cpu/appendCache.js";
|
|
23
21
|
import "./ops/webgl/appendCache.js";
|
|
24
|
-
import "./ops/
|
|
25
|
-
import "./
|
|
26
|
-
import "./ops/
|
|
27
|
-
import "./ops/cpu/
|
|
28
|
-
import "./
|
|
29
|
-
import "./ops/
|
|
22
|
+
import "./ops/grads/softmax16.js";
|
|
23
|
+
import "./matMul16--R5hOwDG.js";
|
|
24
|
+
import "./ops/webgl/matMul16.js";
|
|
25
|
+
import "./ops/cpu/matMul16.js";
|
|
26
|
+
import "./pack16-CFUqumar.js";
|
|
27
|
+
import "./ops/transpose16.js";
|
|
28
|
+
import "./ops/reshape16.js";
|
|
29
|
+
import "./ops/cpu/qkv.js";
|
|
30
|
+
import "./ops/webgl/qkv.js";
|
|
31
|
+
import "./ops/grads/qkv.js";
|
|
30
32
|
import "./ops/cpu/normRMS.js";
|
|
31
33
|
import "./ops/webgl/normRMS.js";
|
|
32
34
|
import "./ops/grads/normRMS.js";
|
|
35
|
+
import "./ops/grads/add16.js";
|
|
33
36
|
import "./ops/cpu/gatherSub.js";
|
|
34
37
|
import "./ops/webgl/gatherSub.js";
|
|
35
38
|
import "./ops/cpu/scatterSub.js";
|
|
36
39
|
import "./ops/webgl/scatterSub.js";
|
|
37
40
|
import c from "./tokeniser/CharTokeniser.js";
|
|
38
41
|
import g from "./tokeniser/bpe.js";
|
|
39
|
-
import "./papaparse.min-
|
|
40
|
-
import "./jszip.min-
|
|
42
|
+
import "./papaparse.min-C0cScC2i.js";
|
|
43
|
+
import "./jszip.min-Bz5-11Bk.js";
|
|
44
|
+
import "./ops/cpu/matMulGelu.js";
|
|
45
|
+
import "./ops/webgl/matMulGelu.js";
|
|
46
|
+
import "./ops/grads/matMulGelu.js";
|
|
41
47
|
import "./ops/cpu/gelu.js";
|
|
42
48
|
import "./ops/webgl/gelu.js";
|
|
43
|
-
import "./gelu-
|
|
49
|
+
import "./gelu-CNLFZWea.js";
|
|
44
50
|
import "./ops/webgl/log.js";
|
|
45
51
|
import "./ops/cpu/adamMoments.js";
|
|
46
52
|
import "./ops/webgl/adamMoments.js";
|
|
@@ -51,7 +57,7 @@ import "./checks/normRMSGrad.js";
|
|
|
51
57
|
import k from "./utilities/profile.js";
|
|
52
58
|
import w from "./models/factory.js";
|
|
53
59
|
class a {
|
|
54
|
-
ee = new
|
|
60
|
+
ee = new f();
|
|
55
61
|
_config;
|
|
56
62
|
_model;
|
|
57
63
|
_tokeniser;
|
|
@@ -150,7 +156,7 @@ class a {
|
|
|
150
156
|
trainer() {
|
|
151
157
|
if (!this._model || !this._tokeniser)
|
|
152
158
|
throw new Error("model_or_tokeniser_not_initialized.");
|
|
153
|
-
const t = new
|
|
159
|
+
const t = new p(this._model, this._tokeniser);
|
|
154
160
|
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e, r) => {
|
|
155
161
|
const o = this.ee.listeners("trainStep");
|
|
156
162
|
for (const s of o)
|
package/dist/Trainer.d.ts
CHANGED
|
@@ -12,6 +12,8 @@ export interface ITrainerOptions {
|
|
|
12
12
|
validationSplit?: number;
|
|
13
13
|
advancedMetrics?: boolean;
|
|
14
14
|
gradientCheckpointing?: boolean;
|
|
15
|
+
gradientMetrics?: boolean;
|
|
16
|
+
mixedPrecision?: boolean;
|
|
15
17
|
}
|
|
16
18
|
interface ExtendedTrainingProgress extends TrainingProgress {
|
|
17
19
|
progress: number;
|
package/dist/Trainer.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { E as l } from "./index-
|
|
1
|
+
import { E as l } from "./index-DvYrXKkX.js";
|
|
2
2
|
import h from "./training/FullTrainer.js";
|
|
3
3
|
class m extends l {
|
|
4
4
|
trainer;
|
|
@@ -28,7 +28,7 @@ class m extends l {
|
|
|
28
28
|
async train(t) {
|
|
29
29
|
if (!this.trainDataset || !this.validationDataset)
|
|
30
30
|
throw new Error("Datasets not prepared");
|
|
31
|
-
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), await this.trainer.trainOnDataset(
|
|
31
|
+
this.hasTrained || this.trainer.setLearningRate(t?.learningRate || 1e-3), this.hasTrained = !0, this.emit("start"), this.trainer.setGradientCheckpointing(t?.gradientCheckpointing || !1), this.trainer.setMixedPrecision(t?.mixedPrecision || !1), await this.trainer.trainOnDataset(
|
|
32
32
|
this.trainDataset,
|
|
33
33
|
{
|
|
34
34
|
prompt: t?.prompt,
|
|
@@ -36,6 +36,7 @@ class m extends l {
|
|
|
36
36
|
desiredLoss: t?.desiredLoss || 0.01,
|
|
37
37
|
maxSteps: t?.maxSteps || 1e3,
|
|
38
38
|
advancedMetrics: t?.advancedMetrics || !1,
|
|
39
|
+
gradientMetrics: t?.gradientMetrics || !1,
|
|
39
40
|
onStep: async (e, a) => {
|
|
40
41
|
this.log.push(e), this.progress = {
|
|
41
42
|
...a,
|
|
@@ -1,20 +1,4 @@
|
|
|
1
|
-
import {
|
|
2
|
-
/**
|
|
3
|
-
* @license
|
|
4
|
-
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
5
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
* you may not use this file except in compliance with the License.
|
|
7
|
-
* You may obtain a copy of the License at
|
|
8
|
-
*
|
|
9
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
*
|
|
11
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
* See the License for the specific language governing permissions and
|
|
15
|
-
* limitations under the License.
|
|
16
|
-
* =============================================================================
|
|
17
|
-
*/
|
|
1
|
+
import { a as c } from "./tensor-DdQUJZlz.js";
|
|
18
2
|
function i(e, n) {
|
|
19
3
|
for (let t = 0; t < e.length; ++t)
|
|
20
4
|
if (e[e.length - t - 1] !== n - 1 - t)
|
|
@@ -60,12 +44,12 @@ function x(e, n) {
|
|
|
60
44
|
return t;
|
|
61
45
|
}
|
|
62
46
|
export {
|
|
63
|
-
|
|
64
|
-
|
|
47
|
+
d as a,
|
|
48
|
+
x as b,
|
|
65
49
|
l as c,
|
|
66
|
-
|
|
50
|
+
m as d,
|
|
67
51
|
h as e,
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
52
|
+
i as f,
|
|
53
|
+
g,
|
|
54
|
+
a as h
|
|
71
55
|
};
|
package/dist/backend.d.ts
CHANGED
|
@@ -1 +1,2 @@
|
|
|
1
|
-
|
|
1
|
+
import { GPUOptions } from './patches/webgpu_base';
|
|
2
|
+
export declare function selectBackend(backendName: 'cpu' | 'webgl' | 'webgpu', options?: GPUOptions): Promise<void>;
|
package/dist/backend.js
CHANGED
|
@@ -1,7 +1,13 @@
|
|
|
1
|
-
import { g as
|
|
2
|
-
async function
|
|
3
|
-
|
|
1
|
+
import { g as o, s as e, r as s } from "./index-ZyQhjEPo.js";
|
|
2
|
+
async function c(t, a) {
|
|
3
|
+
if (o() !== t) {
|
|
4
|
+
if (t === "webgpu") {
|
|
5
|
+
const { registerWebGPUBackend: i } = await import("./patches/webgpu_base.js");
|
|
6
|
+
i(a), await import("./index-CjOj7j-u.js"), await import("./ops/webgpu/index.js");
|
|
7
|
+
}
|
|
8
|
+
await e(t), await s(), console.log(`Backend set to ${t}`);
|
|
9
|
+
}
|
|
4
10
|
}
|
|
5
11
|
export {
|
|
6
|
-
|
|
12
|
+
c as selectBackend
|
|
7
13
|
};
|