@genai-fi/nanogpt 0.9.0 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +352 -14
- package/dist/Generator.js +69 -78
- package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
- package/dist/Reshape-CF6odzV4.js +16 -0
- package/dist/Reshape-_kILl6tK.js +81 -0
- package/dist/TeachableLLM.js +28 -22
- package/dist/Trainer.d.ts +2 -0
- package/dist/Trainer.js +3 -2
- package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
- package/dist/backend.d.ts +2 -1
- package/dist/backend.js +10 -4
- package/dist/backend_util-D-rUb2ty.js +474 -0
- package/dist/backend_webgpu-B0u2ndUn.js +547 -0
- package/dist/binary_op_util-pKXltfxI.js +192 -0
- package/dist/broadcast_to-CwF7XIeu.js +30 -0
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/check.d.ts +1 -1
- package/dist/checks/check.js +8 -8
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/index.d.ts +2 -0
- package/dist/checks/index.js +7 -5
- package/dist/checks/matMulGelu.js +6 -6
- package/dist/checks/normRMS.js +7 -7
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.d.ts +1 -0
- package/dist/checks/packUnpack.js +18 -0
- package/dist/checks/qkv.js +12 -27
- package/dist/checks/rope.js +2 -2
- package/dist/checks/weights.js +18 -16
- package/dist/complex-CSlYz-2T.js +13 -0
- package/dist/complex_util-Yc1A_gV1.js +55 -0
- package/dist/concat-BHlIJeyT.js +19 -0
- package/dist/concat_util-DcJk7YHS.js +22 -0
- package/dist/data/docx.js +1 -1
- package/dist/data/parquet.js +2 -2
- package/dist/data/pdf.js +1 -1
- package/dist/data/textLoader.js +1 -1
- package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
- package/dist/dropout-C1pM3f11.js +99 -0
- package/dist/expand_dims-BPG4fwBP.js +13 -0
- package/dist/exports_initializers-xuidcwI4.js +7 -0
- package/dist/gather-DykLGqmW.js +10 -0
- package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
- package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
- package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
- package/dist/index-CjOj7j-u.js +7308 -0
- package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
- package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
- package/dist/index-ZyQhjEPo.js +2157 -0
- package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
- package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
- package/dist/layers/BaseLayer.d.ts +1 -0
- package/dist/layers/BaseLayer.js +7 -6
- package/dist/layers/CausalSelfAttention.d.ts +0 -1
- package/dist/layers/CausalSelfAttention.js +56 -55
- package/dist/layers/MLP.js +15 -16
- package/dist/layers/PositionEmbedding.js +5 -14
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.d.ts +2 -0
- package/dist/layers/RoPECache.js +22 -17
- package/dist/layers/TiedEmbedding.js +22 -17
- package/dist/layers/TransformerBlock.js +21 -20
- package/dist/loader/load.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +39 -33
- package/dist/loader/save.js +1 -1
- package/dist/log_sum_exp-DWI-76TI.js +41 -0
- package/dist/main.d.ts +8 -0
- package/dist/main.js +63 -52
- package/dist/matMul16--R5hOwDG.js +77 -0
- package/dist/mat_mul-DeAh4uTH.js +12 -0
- package/dist/mod-Gt1rMB4n.js +12 -0
- package/dist/models/NanoGPTV1.js +40 -31
- package/dist/models/model.d.ts +2 -0
- package/dist/models/model.js +37 -29
- package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
- package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
- package/dist/ones-CAMiP4I2.js +15 -0
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.d.ts +1 -1
- package/dist/ops/adamMoments.js +4 -4
- package/dist/ops/add16.d.ts +2 -0
- package/dist/ops/add16.js +9 -0
- package/dist/ops/appendCache.js +16 -9
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/concat16.d.ts +2 -0
- package/dist/ops/concat16.js +9 -0
- package/dist/ops/cpu/adamAdjust.js +14 -13
- package/dist/ops/cpu/adamMoments.js +10 -9
- package/dist/ops/cpu/appendCache.js +9 -8
- package/dist/ops/cpu/attentionMask.js +15 -14
- package/dist/ops/cpu/fusedSoftmax.js +13 -12
- package/dist/ops/cpu/gatherSub.js +9 -24
- package/dist/ops/cpu/gelu.js +13 -12
- package/dist/ops/cpu/matMul16.d.ts +1 -0
- package/dist/ops/cpu/matMul16.js +16 -0
- package/dist/ops/cpu/matMulGelu.js +18 -16
- package/dist/ops/cpu/matMulMul.js +8 -7
- package/dist/ops/cpu/mulDropout.js +4 -3
- package/dist/ops/cpu/normRMS.js +11 -10
- package/dist/ops/cpu/qkv.js +17 -13
- package/dist/ops/cpu/rope.js +23 -22
- package/dist/ops/cpu/scatterSub.js +16 -30
- package/dist/ops/dot16.d.ts +2 -0
- package/dist/ops/dot16.js +42 -0
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/add16.d.ts +1 -0
- package/dist/ops/grads/add16.js +27 -0
- package/dist/ops/grads/attentionMask.js +12 -19
- package/dist/ops/grads/gelu.js +4 -3
- package/dist/ops/grads/matMul16.d.ts +2 -0
- package/dist/ops/grads/matMul16.js +9 -0
- package/dist/ops/grads/matMulGelu.js +8 -7
- package/dist/ops/grads/normRMS.js +8 -7
- package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
- package/dist/ops/grads/pack16.js +7 -0
- package/dist/ops/grads/qkv.d.ts +3 -1
- package/dist/ops/grads/qkv.js +28 -22
- package/dist/ops/grads/rope.d.ts +2 -1
- package/dist/ops/grads/rope.js +6 -13
- package/dist/ops/grads/softmax16.d.ts +2 -0
- package/dist/ops/grads/softmax16.js +26 -0
- package/dist/ops/grads/unpack16.d.ts +2 -0
- package/dist/ops/grads/unpack16.js +6 -0
- package/dist/ops/grads/utils.d.ts +3 -0
- package/dist/ops/grads/utils.js +10 -0
- package/dist/ops/matMul16.d.ts +15 -0
- package/dist/ops/matMul16.js +13 -0
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mul16.d.ts +2 -0
- package/dist/ops/mul16.js +8 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/pack16.d.ts +2 -0
- package/dist/ops/pack16.js +6 -0
- package/dist/ops/qkv.d.ts +1 -1
- package/dist/ops/qkv.js +8 -4
- package/dist/ops/reshape16.d.ts +2 -0
- package/dist/ops/reshape16.js +43 -0
- package/dist/ops/rope.d.ts +1 -1
- package/dist/ops/rope.js +8 -10
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.d.ts +2 -0
- package/dist/ops/slice16.js +9 -0
- package/dist/ops/softmax16.d.ts +2 -0
- package/dist/ops/softmax16.js +12 -0
- package/dist/ops/sub16.d.ts +2 -0
- package/dist/ops/sub16.js +8 -0
- package/dist/ops/sum16.d.ts +2 -0
- package/dist/ops/sum16.js +13 -0
- package/dist/ops/transpose16.d.ts +3 -0
- package/dist/ops/transpose16.js +41 -0
- package/dist/ops/unpack16.d.ts +2 -0
- package/dist/ops/unpack16.js +6 -0
- package/dist/ops/webgl/adamAdjust.js +3 -2
- package/dist/ops/webgl/adamMoments.js +2 -1
- package/dist/ops/webgl/appendCache.js +2 -1
- package/dist/ops/webgl/attentionMask.js +5 -4
- package/dist/ops/webgl/fusedSoftmax.js +6 -4
- package/dist/ops/webgl/gatherSub.js +7 -6
- package/dist/ops/webgl/gelu.js +3 -2
- package/dist/ops/webgl/log.js +12 -27
- package/dist/ops/webgl/matMul16.d.ts +1 -0
- package/dist/ops/webgl/matMul16.js +37 -0
- package/dist/ops/webgl/matMulGelu.js +17 -15
- package/dist/ops/webgl/matMulMul.js +13 -12
- package/dist/ops/webgl/mulDropout.js +9 -8
- package/dist/ops/webgl/normRMS.js +8 -7
- package/dist/ops/webgl/qkv.js +6 -5
- package/dist/ops/webgl/rope.js +11 -10
- package/dist/ops/webgl/scatterSub.js +6 -5
- package/dist/ops/webgpu/adamAdjust.js +12 -10
- package/dist/ops/webgpu/adamMoments.js +27 -22
- package/dist/ops/webgpu/add16.d.ts +1 -0
- package/dist/ops/webgpu/add16.js +14 -0
- package/dist/ops/webgpu/appendCache.js +64 -17
- package/dist/ops/webgpu/attentionMask.js +19 -62
- package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
- package/dist/ops/webgpu/attentionMask32_program.js +54 -0
- package/dist/ops/webgpu/concat16.d.ts +19 -0
- package/dist/ops/webgpu/concat16.js +128 -0
- package/dist/ops/webgpu/gatherSub.js +9 -7
- package/dist/ops/webgpu/gelu.js +78 -31
- package/dist/ops/webgpu/index.js +12 -0
- package/dist/ops/webgpu/matMul16.d.ts +1 -0
- package/dist/ops/webgpu/matMul16.js +58 -0
- package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
- package/dist/ops/webgpu/matMul16_program.js +336 -0
- package/dist/ops/webgpu/mul16.d.ts +1 -0
- package/dist/ops/webgpu/mul16.js +14 -0
- package/dist/ops/webgpu/normRMS.js +21 -40
- package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
- package/dist/ops/webgpu/normRMS16_program.js +24 -0
- package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
- package/dist/ops/webgpu/normRMS32_program.js +24 -0
- package/dist/ops/webgpu/normRMSGrad.js +113 -64
- package/dist/ops/webgpu/pack16.d.ts +1 -0
- package/dist/ops/webgpu/pack16.js +19 -0
- package/dist/ops/webgpu/pack16_program.d.ts +19 -0
- package/dist/ops/webgpu/pack16_program.js +92 -0
- package/dist/ops/webgpu/qkv.js +20 -55
- package/dist/ops/webgpu/rope.js +77 -22
- package/dist/ops/webgpu/scatterSub.js +9 -7
- package/dist/ops/webgpu/slice16.d.ts +7 -0
- package/dist/ops/webgpu/slice16.js +71 -0
- package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
- package/dist/ops/webgpu/softmax16.js +23 -0
- package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
- package/dist/ops/webgpu/softmax16_program.js +73 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
- package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
- package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
- package/dist/ops/webgpu/softmax16grad.js +38 -0
- package/dist/ops/webgpu/sub16.d.ts +1 -0
- package/dist/ops/webgpu/sub16.js +14 -0
- package/dist/ops/webgpu/sum16.d.ts +1 -0
- package/dist/ops/webgpu/sum16.js +40 -0
- package/dist/ops/webgpu/transpose16.d.ts +1 -0
- package/dist/ops/webgpu/transpose16.js +35 -0
- package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
- package/dist/ops/webgpu/transpose16_program.js +50 -0
- package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
- package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
- package/dist/ops/webgpu/unpack16.d.ts +1 -0
- package/dist/ops/webgpu/unpack16.js +49 -0
- package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
- package/dist/ops/webgpu/utils/binary_op.js +79 -0
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
- package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
- package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
- package/dist/ops/webgpu/utils/reductions.js +236 -45
- package/dist/ops-CNI3TwqM.js +645 -0
- package/dist/pack16-CFUqumar.js +41 -0
- package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
- package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
- package/dist/patches/PackedTensor.d.ts +12 -0
- package/dist/patches/PackedTensor.js +11 -0
- package/dist/patches/engine.d.ts +261 -0
- package/dist/patches/engine.js +10 -0
- package/dist/patches/tape.d.ts +12 -0
- package/dist/patches/tape.js +5 -0
- package/dist/patches/webgpu_backend.d.ts +18 -0
- package/dist/patches/webgpu_backend.js +57 -0
- package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
- package/dist/patches/webgpu_base.js +34 -0
- package/dist/patches/webgpu_program.d.ts +36 -0
- package/dist/patches/webgpu_program.js +401 -0
- package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
- package/dist/random_width-DY6Kk2Dl.js +10051 -0
- package/dist/range-BMS52eQi.js +11 -0
- package/dist/reciprocal-CTmshQ9J.js +10 -0
- package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
- package/dist/relu-yZ2-7WxU.js +10 -0
- package/dist/reshape-DevtBWtf.js +10 -0
- package/dist/rope-B5UUMsPi.js +32 -0
- package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
- package/dist/selu_util-D1w6yyTO.js +303 -0
- package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
- package/dist/shared-BuAXb4CI.js +2145 -0
- package/dist/sin-BGfy2HZo.js +16 -0
- package/dist/slice-D_gkkqZK.js +13 -0
- package/dist/slice_util-DtEldBfK.js +261 -0
- package/dist/softmax-ZHVebtR1.js +13 -0
- package/dist/split-DrfihRpZ.js +10 -0
- package/dist/squeeze-DZEpeblb.js +11 -0
- package/dist/stack-yOIAalTq.js +13 -0
- package/dist/sum-_fzj5ZTB.js +12 -0
- package/dist/tensor-DdQUJZlz.js +909 -0
- package/dist/tensor-f35l8Odg.js +8 -0
- package/dist/tensor1d-CeZuc-Rv.js +12 -0
- package/dist/tensor2d-G4Ys2GxX.js +15 -0
- package/dist/tensor4d-B8roDgtc.js +15 -0
- package/dist/tensor_util-DV-FP5Q3.js +523 -0
- package/dist/tfjs_backend-kNyO5L2d.js +653 -0
- package/dist/tile-BzyEiF-F.js +13 -0
- package/dist/tokeniser/CharTokeniser.js +1 -1
- package/dist/tokeniser/bpe.js +1 -1
- package/dist/training/Adam.d.ts +2 -1
- package/dist/training/Adam.js +12 -28
- package/dist/training/AdamExt.d.ts +1 -0
- package/dist/training/AdamExt.js +2 -2
- package/dist/training/DatasetBuilder.js +3 -20
- package/dist/training/FullTrainer.js +82 -64
- package/dist/training/Trainer.d.ts +11 -6
- package/dist/training/Trainer.js +51 -39
- package/dist/training/sparseCrossEntropy.js +3 -3
- package/dist/transpose-DKELTqhe.js +38 -0
- package/dist/utilities/arrayClose.js +7 -7
- package/dist/utilities/dummy.js +35 -27
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.d.ts +7 -0
- package/dist/utilities/packed.js +716 -0
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.d.ts +5 -0
- package/dist/utilities/sentences.js +41 -0
- package/dist/utilities/weights.js +2 -2
- package/dist/variable-Bhn5bHYv.js +7 -0
- package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
- package/dist/webgpu_util-BBCnKm2X.js +65 -0
- package/dist/zeros-2gldETuK.js +14 -0
- package/package.json +4 -3
- package/dist/Reshape-Bowtk9BP.js +0 -127
- package/dist/Reshape-DUqYftGC.js +0 -30
- package/dist/backend_util-CJIiDoV1.js +0 -749
- package/dist/broadcast_to-DzlNweb8.js +0 -44
- package/dist/concat-B912vBbo.js +0 -33
- package/dist/dropout-C-csYCLj.js +0 -193
- package/dist/exports_initializers-B8iZMgQ0.js +0 -16
- package/dist/gather-Dnpgw-YQ.js +0 -25
- package/dist/index-BzFyqcy-.js +0 -4457
- package/dist/index-C1rx_Ajs.js +0 -12076
- package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
- package/dist/log_sum_exp-DO6z8tSE.js +0 -103
- package/dist/mat_mul-DzjTFx-u.js +0 -27
- package/dist/mod-Dobti4j4.js +0 -27
- package/dist/ones-tIJeHlq-.js +0 -29
- package/dist/ops/fusedSoftmax.d.ts +0 -2
- package/dist/ops/fusedSoftmax.js +0 -10
- package/dist/ops/grads/fusedSoftmax.js +0 -22
- package/dist/ops-LuCMAnmM.js +0 -1525
- package/dist/random_width-CXVRloNK.js +0 -13670
- package/dist/range-CWcz7xFA.js +0 -26
- package/dist/reciprocal-C4rNcM-S.js +0 -25
- package/dist/relu-BjCh_SYb.js +0 -25
- package/dist/reshape-CnIwVG1c.js +0 -25
- package/dist/selu_util-OtRzVwW5.js +0 -719
- package/dist/shared-DmRsFyaJ.js +0 -3134
- package/dist/sin-gpDNRxE0.js +0 -47
- package/dist/slice-d0Vo9XTN.js +0 -28
- package/dist/softmax-D7Jj3p_P.js +0 -28
- package/dist/split-DK2k5eHf.js +0 -25
- package/dist/stack-DFatutCx.js +0 -27
- package/dist/sum-CJ0ULhmt.js +0 -27
- package/dist/tensor1d-vML0r3q6.js +0 -27
- package/dist/tensor2d-D76QGjF3.js +0 -30
- package/dist/tensor4d-Df1WlVDY.js +0 -30
- package/dist/webgpu_util-pLEV9tks.js +0 -80
- package/dist/zeros-Bj5rMYA7.js +0 -52
|
@@ -0,0 +1,2157 @@
|
|
|
1
|
+
import { w as J, o as je, p as ge, K as qe, I as Ee, q as pe, s as ye, t as Te, v as Ve, a3 as He, x as Ae, A as Me, a4 as Ne, a5 as xe, m as G, a6 as Xe, $ as Je, a7 as Ye, a8 as Ze, a9 as Qe, aa as et, ab as tt, ac as nt, ad as st, ae as rt, af as at } from "./tensor_util-DV-FP5Q3.js";
|
|
2
|
+
import { o as it, E as ot, q as ct, y as lt, a as g, r as ut, t as ht, T as R, u as dt, V as re, v as ae, w as Y, x as be, f as ft, m as mt, s as Z, e as B, M as gt, N as ue, D, O as Fe, B as De, P as pt, d as he, Q as yt, R as bt, S as wt, U as St, j as Ce } from "./tensor-DdQUJZlz.js";
|
|
3
|
+
import { p as Q } from "./index-xuotMAFm.js";
|
|
4
|
+
import { B as ee } from "./index-Cp39cXWe.js";
|
|
5
|
+
function ne(s) {
|
|
6
|
+
return s.kernelName != null;
|
|
7
|
+
}
|
|
8
|
+
class we {
|
|
9
|
+
constructor() {
|
|
10
|
+
this.registeredVariables = {}, this.nextTapeNodeId = 0, this.numBytes = 0, this.numTensors = 0, this.numStringTensors = 0, this.numDataBuffers = 0, this.gradientDepth = 0, this.kernelDepth = 0, this.scopeStack = [], this.numDataMovesStack = [], this.nextScopeId = 0, this.tensorInfo = /* @__PURE__ */ new WeakMap(), this.profiling = !1, this.activeProfile = {
|
|
11
|
+
newBytes: 0,
|
|
12
|
+
newTensors: 0,
|
|
13
|
+
peakBytes: 0,
|
|
14
|
+
kernels: [],
|
|
15
|
+
result: null,
|
|
16
|
+
get kernelNames() {
|
|
17
|
+
return Array.from(new Set(this.kernels.map((e) => e.name)));
|
|
18
|
+
}
|
|
19
|
+
};
|
|
20
|
+
}
|
|
21
|
+
dispose() {
|
|
22
|
+
for (const e in this.registeredVariables)
|
|
23
|
+
this.registeredVariables[e].dispose();
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
class q {
|
|
27
|
+
constructor(e) {
|
|
28
|
+
this.ENV = e, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new we();
|
|
29
|
+
}
|
|
30
|
+
async ready() {
|
|
31
|
+
if (this.pendingBackendInit != null)
|
|
32
|
+
return this.pendingBackendInit.then(() => {
|
|
33
|
+
});
|
|
34
|
+
if (this.backendInstance != null)
|
|
35
|
+
return;
|
|
36
|
+
const e = this.getSortedBackends();
|
|
37
|
+
for (let t = 0; t < e.length; t++) {
|
|
38
|
+
const n = e[t];
|
|
39
|
+
if (await this.initializeBackend(n).success) {
|
|
40
|
+
await this.setBackend(n);
|
|
41
|
+
return;
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
throw new Error("Could not initialize any backends, all backend initializations failed.");
|
|
45
|
+
}
|
|
46
|
+
get backend() {
|
|
47
|
+
if (this.pendingBackendInit != null)
|
|
48
|
+
throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`);
|
|
49
|
+
if (this.backendInstance == null) {
|
|
50
|
+
const { name: e, asyncInit: t } = this.initializeBackendsAndReturnBest();
|
|
51
|
+
if (t)
|
|
52
|
+
throw new Error(`The highest priority backend '${e}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`);
|
|
53
|
+
this.setBackend(e);
|
|
54
|
+
}
|
|
55
|
+
return this.backendInstance;
|
|
56
|
+
}
|
|
57
|
+
backendNames() {
|
|
58
|
+
return Object.keys(this.registryFactory);
|
|
59
|
+
}
|
|
60
|
+
findBackend(e) {
|
|
61
|
+
if (!(e in this.registry))
|
|
62
|
+
if (e in this.registryFactory) {
|
|
63
|
+
const { asyncInit: t } = this.initializeBackend(e);
|
|
64
|
+
if (t)
|
|
65
|
+
return null;
|
|
66
|
+
} else
|
|
67
|
+
return null;
|
|
68
|
+
return this.registry[e];
|
|
69
|
+
}
|
|
70
|
+
findBackendFactory(e) {
|
|
71
|
+
return e in this.registryFactory ? this.registryFactory[e].factory : null;
|
|
72
|
+
}
|
|
73
|
+
registerBackend(e, t, n = 1) {
|
|
74
|
+
return e in this.registryFactory ? (J(`${e} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[e] = { factory: t, priority: n }, !0);
|
|
75
|
+
}
|
|
76
|
+
async setBackend(e) {
|
|
77
|
+
if (this.registryFactory[e] == null)
|
|
78
|
+
throw new Error(`Backend name '${e}' not found in registry`);
|
|
79
|
+
if (this.backendName = e, this.registry[e] == null) {
|
|
80
|
+
this.backendInstance = null;
|
|
81
|
+
const { success: t, asyncInit: n } = this.initializeBackend(e);
|
|
82
|
+
if (!(n ? await t : t))
|
|
83
|
+
return !1;
|
|
84
|
+
}
|
|
85
|
+
return this.backendInstance = this.registry[e], this.setupRegisteredKernels(), this.profiler = new je(this.backendInstance), !0;
|
|
86
|
+
}
|
|
87
|
+
setupRegisteredKernels() {
|
|
88
|
+
ge(this.backendName).forEach((t) => {
|
|
89
|
+
t.setupFunc != null && t.setupFunc(this.backendInstance);
|
|
90
|
+
});
|
|
91
|
+
}
|
|
92
|
+
disposeRegisteredKernels(e) {
|
|
93
|
+
ge(e).forEach((n) => {
|
|
94
|
+
n.disposeFunc != null && n.disposeFunc(this.registry[e]);
|
|
95
|
+
});
|
|
96
|
+
}
|
|
97
|
+
/**
|
|
98
|
+
* Initializes a backend by looking up the backend name in the factory
|
|
99
|
+
* registry and calling the factory method. Returns a boolean representing
|
|
100
|
+
* whether the initialization of the backend succeeded. Throws an error if
|
|
101
|
+
* there is no backend in the factory registry.
|
|
102
|
+
*/
|
|
103
|
+
initializeBackend(e) {
|
|
104
|
+
const t = this.registryFactory[e];
|
|
105
|
+
if (t == null)
|
|
106
|
+
throw new Error(`Cannot initialize backend ${e}, no registration found.`);
|
|
107
|
+
try {
|
|
108
|
+
const n = t.factory();
|
|
109
|
+
if (n && !(n instanceof qe) && typeof n.then == "function") {
|
|
110
|
+
const r = ++this.pendingBackendInitId, a = n.then((i) => r < this.pendingBackendInitId ? !1 : (this.registry[e] = i, this.pendingBackendInit = null, !0)).catch((i) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null, J(`Initialization of backend ${e} failed`), J(i.stack || i.message)), !1));
|
|
111
|
+
return this.pendingBackendInit = a, { success: a, asyncInit: !0 };
|
|
112
|
+
} else
|
|
113
|
+
return this.registry[e] = n, { success: !0, asyncInit: !1 };
|
|
114
|
+
} catch (n) {
|
|
115
|
+
return J(`Initialization of backend ${e} failed`), J(n.stack || n.message), { success: !1, asyncInit: !1 };
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
removeBackend(e) {
|
|
119
|
+
if (!(e in this.registryFactory))
|
|
120
|
+
throw new Error(`${e} backend not found in registry`);
|
|
121
|
+
this.backendName === e && this.pendingBackendInit != null && this.pendingBackendInitId++, e in this.registry && (this.disposeRegisteredKernels(e), this.registry[e].dispose(), delete this.registry[e]), delete this.registryFactory[e], this.backendName === e && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null);
|
|
122
|
+
}
|
|
123
|
+
getSortedBackends() {
|
|
124
|
+
if (Object.keys(this.registryFactory).length === 0)
|
|
125
|
+
throw new Error("No backend found in registry.");
|
|
126
|
+
return Object.keys(this.registryFactory).sort((e, t) => this.registryFactory[t].priority - this.registryFactory[e].priority);
|
|
127
|
+
}
|
|
128
|
+
initializeBackendsAndReturnBest() {
|
|
129
|
+
const e = this.getSortedBackends();
|
|
130
|
+
for (let t = 0; t < e.length; t++) {
|
|
131
|
+
const n = e[t], { success: r, asyncInit: a } = this.initializeBackend(n);
|
|
132
|
+
if (a || r)
|
|
133
|
+
return { name: n, asyncInit: a };
|
|
134
|
+
}
|
|
135
|
+
throw new Error("Could not initialize any backends, all backend initializations failed.");
|
|
136
|
+
}
|
|
137
|
+
moveData(e, t) {
|
|
138
|
+
const n = this.state.tensorInfo.get(t), r = n.backend, a = this.readSync(t), i = r.refCount(t);
|
|
139
|
+
r.disposeData(t, !0), n.backend = e, e.move(t, a, n.shape, n.dtype, i), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
|
|
140
|
+
}
|
|
141
|
+
tidy(e, t) {
|
|
142
|
+
let n = null;
|
|
143
|
+
if (t == null) {
|
|
144
|
+
if (typeof e != "function")
|
|
145
|
+
throw new Error("Please provide a function to tidy()");
|
|
146
|
+
t = e;
|
|
147
|
+
} else {
|
|
148
|
+
if (typeof e != "string" && !(e instanceof String))
|
|
149
|
+
throw new Error("When calling with two arguments, the first argument to tidy() must be a string");
|
|
150
|
+
if (typeof t != "function")
|
|
151
|
+
throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function");
|
|
152
|
+
n = e;
|
|
153
|
+
}
|
|
154
|
+
let r;
|
|
155
|
+
return this.scopedRun(() => this.startScope(n), () => this.endScope(r), () => (r = t(), r instanceof Promise && console.error("Cannot return a Promise inside of tidy."), r));
|
|
156
|
+
}
|
|
157
|
+
scopedRun(e, t, n) {
|
|
158
|
+
e();
|
|
159
|
+
try {
|
|
160
|
+
const r = n();
|
|
161
|
+
return t(), r;
|
|
162
|
+
} catch (r) {
|
|
163
|
+
throw t(), r;
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
nextTensorId() {
|
|
167
|
+
return q.nextTensorId++;
|
|
168
|
+
}
|
|
169
|
+
nextVariableId() {
|
|
170
|
+
return q.nextVariableId++;
|
|
171
|
+
}
|
|
172
|
+
/**
|
|
173
|
+
* This method is called instead of the public-facing tensor.clone() when
|
|
174
|
+
* saving a tensor for backwards pass. It makes sure to add the clone
|
|
175
|
+
* operation to the tape regardless of being called inside a kernel
|
|
176
|
+
* execution.
|
|
177
|
+
*/
|
|
178
|
+
clone(e) {
|
|
179
|
+
const t = u.runKernel(Ee, { x: e }), n = { x: e }, r = (i) => ({
|
|
180
|
+
x: () => {
|
|
181
|
+
const o = "float32", c = { x: i }, l = { dtype: o };
|
|
182
|
+
return u.runKernel(
|
|
183
|
+
Ae,
|
|
184
|
+
c,
|
|
185
|
+
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
186
|
+
l
|
|
187
|
+
);
|
|
188
|
+
}
|
|
189
|
+
}), a = [];
|
|
190
|
+
return this.addTapeNode(this.state.activeScope.name, n, [t], r, a, {}), t;
|
|
191
|
+
}
|
|
192
|
+
/**
|
|
193
|
+
* Execute a kernel with the given name and return the output tensor.
|
|
194
|
+
*
|
|
195
|
+
* @param kernelName The name of the kernel to execute.
|
|
196
|
+
* @param inputs A map of input names to tensors.
|
|
197
|
+
* @param attrs A map of attribute names to their values. An attribute is a
|
|
198
|
+
* primitive (non-tensor) input to the kernel.
|
|
199
|
+
* @param inputsToSave A list of tensors, inputs to save for the backprop
|
|
200
|
+
* computation.
|
|
201
|
+
* @param outputsToSave A list of booleans, specifying which output to save
|
|
202
|
+
* for the backprop computation. These are booleans since the output
|
|
203
|
+
* tensors are not visible to the user.
|
|
204
|
+
*/
|
|
205
|
+
runKernel(e, t, n) {
|
|
206
|
+
if (this.backendName == null && this.backend, !(pe(e, this.backendName) != null))
|
|
207
|
+
throw new Error(`Kernel '${e}' not registered for backend '${this.backendName}'`);
|
|
208
|
+
return this.runKernelFunc({ kernelName: e, inputs: t, attrs: n });
|
|
209
|
+
}
|
|
210
|
+
shouldCheckForMemLeaks() {
|
|
211
|
+
return this.ENV.getBool("IS_TEST");
|
|
212
|
+
}
|
|
213
|
+
checkKernelForMemLeak(e, t, n) {
|
|
214
|
+
const r = this.backend.numDataIds();
|
|
215
|
+
let a = 0;
|
|
216
|
+
n.forEach((c) => {
|
|
217
|
+
a += c.dtype === "complex64" ? 3 : 1;
|
|
218
|
+
});
|
|
219
|
+
const i = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], o = r - t - a - i;
|
|
220
|
+
if (o > 0)
|
|
221
|
+
throw new Error(`Backend '${this.backendName}' has an internal memory leak (${o} data ids) after running '${e}'`);
|
|
222
|
+
}
|
|
223
|
+
/**
|
|
224
|
+
* Internal helper method to execute a kernel Func
|
|
225
|
+
*
|
|
226
|
+
* Use `runKernel` to execute kernels from outside of engine.
|
|
227
|
+
*/
|
|
228
|
+
runKernelFunc(e) {
|
|
229
|
+
let t, n = [];
|
|
230
|
+
const r = this.isTapeOn(), a = this.state.numBytes, i = this.state.numTensors;
|
|
231
|
+
this.shouldCheckForMemLeaks() && this.state.numDataMovesStack.push(0);
|
|
232
|
+
let o;
|
|
233
|
+
this.backendName == null && this.backend;
|
|
234
|
+
let c;
|
|
235
|
+
const l = ne(e) ? e.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
|
|
236
|
+
if (ne(e)) {
|
|
237
|
+
const { kernelName: b, inputs: w, attrs: I } = e;
|
|
238
|
+
this.backendName == null && this.backend;
|
|
239
|
+
const T = pe(b, this.backendName);
|
|
240
|
+
g(T != null, () => `Cannot find registered kernel '${b}' for backend '${this.backendName}'`), o = () => {
|
|
241
|
+
const te = this.backend.numDataIds();
|
|
242
|
+
c = T.kernelFunc({ inputs: w, attrs: I, backend: this.backend });
|
|
243
|
+
const fe = Array.isArray(c) ? c : [c];
|
|
244
|
+
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(b, te, fe);
|
|
245
|
+
const me = fe.map((X) => X.rank != null ? X : this.makeTensorFromTensorInfo(X));
|
|
246
|
+
if (r) {
|
|
247
|
+
const X = this.getTensorsForGradient(b, w, me);
|
|
248
|
+
n = this.saveTensorsForBackwardMode(X);
|
|
249
|
+
}
|
|
250
|
+
return me;
|
|
251
|
+
};
|
|
252
|
+
} else {
|
|
253
|
+
const { forwardFunc: b } = e, w = (I) => {
|
|
254
|
+
r && (n = I.map((T) => this.keep(this.clone(T))));
|
|
255
|
+
};
|
|
256
|
+
o = () => {
|
|
257
|
+
const I = this.backend.numDataIds();
|
|
258
|
+
c = this.tidy(() => b(this.backend, w));
|
|
259
|
+
const T = Array.isArray(c) ? c : [c];
|
|
260
|
+
return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(l, I, T), T;
|
|
261
|
+
};
|
|
262
|
+
}
|
|
263
|
+
const { inputs: h, attrs: m } = e, d = ne(e) ? null : e.backwardsFunc;
|
|
264
|
+
let p;
|
|
265
|
+
return this.scopedRun(
|
|
266
|
+
// Stop recording to a tape when running a kernel.
|
|
267
|
+
() => this.state.kernelDepth++,
|
|
268
|
+
() => this.state.kernelDepth--,
|
|
269
|
+
() => {
|
|
270
|
+
!this.ENV.getBool("DEBUG") && !this.state.profiling ? t = o() : (p = this.profiler.profileKernel(l, h, () => o()), this.ENV.getBool("DEBUG") && this.profiler.logKernelProfile(p), t = p.outputs);
|
|
271
|
+
}
|
|
272
|
+
), r && this.addTapeNode(l, h, t, d, n, m), this.state.profiling && this.state.activeProfile.kernels.push({
|
|
273
|
+
name: l,
|
|
274
|
+
bytesAdded: this.state.numBytes - a,
|
|
275
|
+
totalBytesSnapshot: this.state.numBytes,
|
|
276
|
+
tensorsAdded: this.state.numTensors - i,
|
|
277
|
+
totalTensorsSnapshot: this.state.numTensors,
|
|
278
|
+
inputShapes: Object.keys(h).map((b) => h[b] != null ? h[b].shape : null),
|
|
279
|
+
outputShapes: t.map((b) => b.shape),
|
|
280
|
+
kernelTimeMs: p.timeMs,
|
|
281
|
+
extraInfo: p.extraInfo
|
|
282
|
+
}), Array.isArray(c) ? t : t[0];
|
|
283
|
+
}
|
|
284
|
+
/**
|
|
285
|
+
* Saves tensors used in forward mode for use in backward mode.
|
|
286
|
+
*
|
|
287
|
+
* @param tensors the list of tensors to save.
|
|
288
|
+
*/
|
|
289
|
+
saveTensorsForBackwardMode(e) {
|
|
290
|
+
return e.map((n) => this.keep(this.clone(n)));
|
|
291
|
+
}
|
|
292
|
+
/**
|
|
293
|
+
* Returns a list of tensors to save for a given gradient calculation.
|
|
294
|
+
*
|
|
295
|
+
* @param kernelName name of kernel to look up gradient for.
|
|
296
|
+
* @param inputs a map of input tensors.
|
|
297
|
+
* @param outputs an array of output tensors from forward mode of kernel.
|
|
298
|
+
*/
|
|
299
|
+
getTensorsForGradient(e, t, n) {
|
|
300
|
+
const r = ye(e);
|
|
301
|
+
if (r != null) {
|
|
302
|
+
const a = r.inputsToSave || [], i = r.outputsToSave || [];
|
|
303
|
+
let o;
|
|
304
|
+
r.saveAllInputs ? (g(Array.isArray(t), () => "saveAllInputs is true, expected inputs to be an array."), o = Object.keys(t).map((l) => t[l])) : o = a.map((l) => t[l]);
|
|
305
|
+
const c = n.filter((l, h) => i[h]);
|
|
306
|
+
return o.concat(c);
|
|
307
|
+
}
|
|
308
|
+
return [];
|
|
309
|
+
}
|
|
310
|
+
/**
|
|
311
|
+
* Internal method used by public APIs for tensor creation. Makes a new
|
|
312
|
+
* tensor with the provided shape, dtype and values. It always
|
|
313
|
+
* creates a new data id and writes the values to the underlying backend.
|
|
314
|
+
*/
|
|
315
|
+
makeTensor(e, t, n, r) {
|
|
316
|
+
if (e == null)
|
|
317
|
+
throw new Error("Values passed to engine.makeTensor() are null");
|
|
318
|
+
n = n || "float32", r = r || this.backend;
|
|
319
|
+
let a = e;
|
|
320
|
+
n === "string" && ut(e[0]) && (a = e.map((c) => ht(c)));
|
|
321
|
+
const i = r.write(a, t, n), o = new R(t, n, i, this.nextTensorId());
|
|
322
|
+
if (this.trackTensor(o, r), n === "string") {
|
|
323
|
+
const c = this.state.tensorInfo.get(i), l = dt(a);
|
|
324
|
+
this.state.numBytes += l - c.bytes, c.bytes = l;
|
|
325
|
+
}
|
|
326
|
+
return o;
|
|
327
|
+
}
|
|
328
|
+
/**
|
|
329
|
+
* Internal method used by backends. Makes a new tensor
|
|
330
|
+
* that is a wrapper around an existing data id. It doesn't create
|
|
331
|
+
* a new data id, only increments the ref count used in memory tracking.
|
|
332
|
+
* @deprecated
|
|
333
|
+
*/
|
|
334
|
+
makeTensorFromDataId(e, t, n, r) {
|
|
335
|
+
n = n || "float32";
|
|
336
|
+
const a = { dataId: e, shape: t, dtype: n };
|
|
337
|
+
return this.makeTensorFromTensorInfo(a, r);
|
|
338
|
+
}
|
|
339
|
+
/**
|
|
340
|
+
* Internal method used by backends. Makes a new tensor that is a wrapper
|
|
341
|
+
* around an existing data id in TensorInfo. It doesn't create a new data id,
|
|
342
|
+
* only increments the ref count used in memory tracking.
|
|
343
|
+
*/
|
|
344
|
+
makeTensorFromTensorInfo(e, t) {
|
|
345
|
+
const { dataId: n, shape: r, dtype: a } = e, i = new R(r, a, n, this.nextTensorId());
|
|
346
|
+
return this.trackTensor(i, t), i;
|
|
347
|
+
}
|
|
348
|
+
makeVariable(e, t = !0, n, r) {
|
|
349
|
+
n = n || this.nextVariableId().toString(), r != null && r !== e.dtype && (e = e.cast(r));
|
|
350
|
+
const a = new re(e, t, n, this.nextTensorId());
|
|
351
|
+
if (this.state.registeredVariables[a.name] != null)
|
|
352
|
+
throw new Error(`Variable with name ${a.name} was already registered`);
|
|
353
|
+
return this.state.registeredVariables[a.name] = a, this.incRef(a, this.backend), a;
|
|
354
|
+
}
|
|
355
|
+
trackTensor(e, t) {
|
|
356
|
+
this.state.numTensors++, e.dtype === "string" && this.state.numStringTensors++;
|
|
357
|
+
let n = 0;
|
|
358
|
+
e.dtype !== "complex64" && e.dtype !== "string" && (n = e.size * ae(e.dtype)), this.state.numBytes += n, this.state.tensorInfo.has(e.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(e.dataId, {
|
|
359
|
+
backend: t || this.backend,
|
|
360
|
+
dtype: e.dtype,
|
|
361
|
+
shape: e.shape,
|
|
362
|
+
bytes: n
|
|
363
|
+
})), e instanceof re || this.track(e);
|
|
364
|
+
}
|
|
365
|
+
// Track the tensor by dataId and increase the refCount for the dataId in the
|
|
366
|
+
// backend.
|
|
367
|
+
// TODO(pyu10055): This is currently used by makeVariable method, to increase
|
|
368
|
+
// refCount on the backend for the dataId. It can potentially be replaced with
|
|
369
|
+
// Identity op indead of calling backend directly.
|
|
370
|
+
incRef(e, t) {
|
|
371
|
+
this.trackTensor(e, t), this.backend.incRef(e.dataId);
|
|
372
|
+
}
|
|
373
|
+
removeDataId(e, t) {
|
|
374
|
+
this.state.tensorInfo.has(e) && this.state.tensorInfo.get(e).backend === t && (this.state.tensorInfo.delete(e), this.state.numDataBuffers--);
|
|
375
|
+
}
|
|
376
|
+
disposeTensor(e) {
|
|
377
|
+
if (!this.state.tensorInfo.has(e.dataId))
|
|
378
|
+
return;
|
|
379
|
+
const t = this.state.tensorInfo.get(e.dataId);
|
|
380
|
+
if (this.state.numTensors--, e.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= t.bytes), e.dtype !== "complex64" && e.dtype !== "string") {
|
|
381
|
+
const n = e.size * ae(e.dtype);
|
|
382
|
+
this.state.numBytes -= n;
|
|
383
|
+
}
|
|
384
|
+
t.backend.disposeData(e.dataId) && this.removeDataId(e.dataId, t.backend);
|
|
385
|
+
}
|
|
386
|
+
disposeVariables() {
|
|
387
|
+
for (const e in this.state.registeredVariables) {
|
|
388
|
+
const t = this.state.registeredVariables[e];
|
|
389
|
+
this.disposeVariable(t);
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
disposeVariable(e) {
|
|
393
|
+
this.disposeTensor(e), this.state.registeredVariables[e.name] != null && delete this.state.registeredVariables[e.name];
|
|
394
|
+
}
|
|
395
|
+
memory() {
|
|
396
|
+
const e = this.backend.memory();
|
|
397
|
+
return e.numTensors = this.state.numTensors, e.numDataBuffers = this.state.numDataBuffers, e.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (e.unreliable = !0, e.reasons == null && (e.reasons = []), e.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), e;
|
|
398
|
+
}
|
|
399
|
+
async profile(e) {
|
|
400
|
+
this.state.profiling = !0;
|
|
401
|
+
const t = this.state.numBytes, n = this.state.numTensors;
|
|
402
|
+
this.state.activeProfile.kernels = [], this.state.activeProfile.result = await e(), this.state.profiling = !1, this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map((r) => r.totalBytesSnapshot)), this.state.activeProfile.newBytes = this.state.numBytes - t, this.state.activeProfile.newTensors = this.state.numTensors - n;
|
|
403
|
+
for (const r of this.state.activeProfile.kernels)
|
|
404
|
+
r.kernelTimeMs = await r.kernelTimeMs, r.extraInfo = await r.extraInfo;
|
|
405
|
+
return this.state.activeProfile;
|
|
406
|
+
}
|
|
407
|
+
isTapeOn() {
|
|
408
|
+
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
|
409
|
+
}
|
|
410
|
+
addTapeNode(e, t, n, r, a, i) {
|
|
411
|
+
const o = { id: this.state.nextTapeNodeId++, kernelName: e, inputs: t, outputs: n, saved: a }, c = ye(e);
|
|
412
|
+
c != null && (r = c.gradFunc), r != null && (o.gradient = (l) => (l = l.map((h, m) => {
|
|
413
|
+
if (h == null) {
|
|
414
|
+
const d = n[m], p = ft(d.size, d.dtype);
|
|
415
|
+
return this.makeTensor(p, d.shape, d.dtype);
|
|
416
|
+
}
|
|
417
|
+
return h;
|
|
418
|
+
}), r(l.length > 1 ? l : l[0], a, i))), this.state.activeTape.push(o);
|
|
419
|
+
}
|
|
420
|
+
keep(e) {
|
|
421
|
+
return e.kept = !0, e;
|
|
422
|
+
}
|
|
423
|
+
startTape() {
|
|
424
|
+
this.state.gradientDepth === 0 && (this.state.activeTape = []), this.state.gradientDepth++;
|
|
425
|
+
}
|
|
426
|
+
endTape() {
|
|
427
|
+
this.state.gradientDepth--;
|
|
428
|
+
}
|
|
429
|
+
/**
|
|
430
|
+
* Start a scope. Use this with endScope() to achieve the same functionality
|
|
431
|
+
* as scope() without the need for a function closure.
|
|
432
|
+
*/
|
|
433
|
+
startScope(e) {
|
|
434
|
+
const t = {
|
|
435
|
+
track: [],
|
|
436
|
+
name: "unnamed scope",
|
|
437
|
+
id: this.state.nextScopeId++
|
|
438
|
+
};
|
|
439
|
+
e && (t.name = e), this.state.scopeStack.push(t), this.state.activeScope = t;
|
|
440
|
+
}
|
|
441
|
+
/**
|
|
442
|
+
* End a scope. Use this with startScope() to achieve the same functionality
|
|
443
|
+
* as scope() without the need for a function closure.
|
|
444
|
+
*/
|
|
445
|
+
endScope(e) {
|
|
446
|
+
const t = Te(e), n = new Set(t.map((a) => a.id));
|
|
447
|
+
for (let a = 0; a < this.state.activeScope.track.length; a++) {
|
|
448
|
+
const i = this.state.activeScope.track[a];
|
|
449
|
+
!i.kept && !n.has(i.id) && i.dispose();
|
|
450
|
+
}
|
|
451
|
+
const r = this.state.scopeStack.pop();
|
|
452
|
+
this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], t.forEach((a) => {
|
|
453
|
+
!a.kept && a.scopeId === r.id && this.track(a);
|
|
454
|
+
});
|
|
455
|
+
}
|
|
456
|
+
/**
|
|
457
|
+
* Returns gradients of `f` with respect to each of the `xs`. The gradients
|
|
458
|
+
* returned are of the same length as `xs`, but some might be null if `f`
|
|
459
|
+
* was not a function of that `x`. It also takes optional dy to multiply the
|
|
460
|
+
* gradient, which defaults to `1`.
|
|
461
|
+
*/
|
|
462
|
+
gradients(e, t, n, r = !1) {
|
|
463
|
+
if (g(t.length > 0, () => "gradients() received an empty list of xs."), n != null && n.dtype !== "float32")
|
|
464
|
+
throw new Error(`dy must have 'float32' dtype, but has '${n.dtype}'`);
|
|
465
|
+
const a = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy("forward", e));
|
|
466
|
+
g(a instanceof R, () => "The result y returned by f() must be a tensor.");
|
|
467
|
+
const i = Ve(this.state.activeTape, t, a);
|
|
468
|
+
if (!r && i.length === 0 && t.length > 0)
|
|
469
|
+
throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.");
|
|
470
|
+
return this.tidy("backward", () => {
|
|
471
|
+
const o = {};
|
|
472
|
+
o[a.id] = n ?? kt(a.shape), He(
|
|
473
|
+
o,
|
|
474
|
+
i,
|
|
475
|
+
// Pass the tidy function to avoid circular dep with `tape.ts`.
|
|
476
|
+
(l) => this.tidy(l),
|
|
477
|
+
// Pass an add function to avoide a circular dep with `tape.ts`.
|
|
478
|
+
It
|
|
479
|
+
);
|
|
480
|
+
const c = t.map((l) => o[l.id]);
|
|
481
|
+
return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((l) => {
|
|
482
|
+
for (const h of l.saved)
|
|
483
|
+
h.dispose();
|
|
484
|
+
}), this.state.activeTape = null), { value: a, grads: c };
|
|
485
|
+
});
|
|
486
|
+
}
|
|
487
|
+
customGrad(e) {
|
|
488
|
+
return g(Y(e), () => "The f passed in customGrad(f) must be a function."), (...t) => {
|
|
489
|
+
g(t.every((o) => o instanceof R), () => "The args passed in customGrad(f)(x1, x2,...) must all be tensors");
|
|
490
|
+
let n;
|
|
491
|
+
const r = {};
|
|
492
|
+
t.forEach((o, c) => {
|
|
493
|
+
r[c] = o;
|
|
494
|
+
});
|
|
495
|
+
const a = (o, c) => (n = e(...t, c), g(n.value instanceof R, () => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"), g(Y(n.gradFunc), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."), n.value), i = (o, c) => {
|
|
496
|
+
const l = n.gradFunc(o, c), h = Array.isArray(l) ? l : [l];
|
|
497
|
+
g(h.length === t.length, () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."), g(h.every((d) => d instanceof R), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors.");
|
|
498
|
+
const m = {};
|
|
499
|
+
return h.forEach((d, p) => {
|
|
500
|
+
m[p] = () => d;
|
|
501
|
+
}), m;
|
|
502
|
+
};
|
|
503
|
+
return this.runKernelFunc({
|
|
504
|
+
forwardFunc: a,
|
|
505
|
+
backwardsFunc: i,
|
|
506
|
+
inputs: r
|
|
507
|
+
});
|
|
508
|
+
};
|
|
509
|
+
}
|
|
510
|
+
readSync(e) {
|
|
511
|
+
return this.state.tensorInfo.get(e).backend.readSync(e);
|
|
512
|
+
}
|
|
513
|
+
read(e) {
|
|
514
|
+
return this.state.tensorInfo.get(e).backend.read(e);
|
|
515
|
+
}
|
|
516
|
+
readToGPU(e, t) {
|
|
517
|
+
return this.state.tensorInfo.get(e).backend.readToGPU(e, t);
|
|
518
|
+
}
|
|
519
|
+
async time(e) {
|
|
520
|
+
const t = be(), n = await this.backend.time(e);
|
|
521
|
+
return n.wallMs = be() - t, n;
|
|
522
|
+
}
|
|
523
|
+
/**
|
|
524
|
+
* Tracks a Tensor in the current scope to be automatically cleaned up
|
|
525
|
+
* when the current scope ends, and returns the value.
|
|
526
|
+
*
|
|
527
|
+
* @param result The Tensor to track in the current scope.
|
|
528
|
+
*/
|
|
529
|
+
track(e) {
|
|
530
|
+
return this.state.activeScope != null && (e.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(e)), e;
|
|
531
|
+
}
|
|
532
|
+
get registeredVariables() {
|
|
533
|
+
return this.state.registeredVariables;
|
|
534
|
+
}
|
|
535
|
+
/**
|
|
536
|
+
* Resets the engine state. Removes all backends but does not remove
|
|
537
|
+
* registered backend factories.
|
|
538
|
+
*/
|
|
539
|
+
reset() {
|
|
540
|
+
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new we();
|
|
541
|
+
for (const e in this.registry)
|
|
542
|
+
this.disposeRegisteredKernels(e), this.registry[e].dispose(), delete this.registry[e];
|
|
543
|
+
this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
|
|
544
|
+
}
|
|
545
|
+
}
|
|
546
|
+
q.nextTensorId = 0;
|
|
547
|
+
q.nextVariableId = 0;
|
|
548
|
+
function kt(s) {
|
|
549
|
+
const e = mt(Z(s), "float32");
|
|
550
|
+
return u.makeTensor(e, s, "float32");
|
|
551
|
+
}
|
|
552
|
+
function Re() {
|
|
553
|
+
const s = it();
|
|
554
|
+
if (s._tfengine == null) {
|
|
555
|
+
const e = new ot(s);
|
|
556
|
+
s._tfengine = new q(e);
|
|
557
|
+
}
|
|
558
|
+
return ct(s._tfengine.ENV), lt(() => s._tfengine), s._tfengine;
|
|
559
|
+
}
|
|
560
|
+
const u = Re();
|
|
561
|
+
function It(s, e) {
|
|
562
|
+
const t = { a: s, b: e };
|
|
563
|
+
return u.runKernel(Me, t);
|
|
564
|
+
}
|
|
565
|
+
function vt() {
|
|
566
|
+
return typeof navigator < "u" && navigator != null;
|
|
567
|
+
}
|
|
568
|
+
function $n(s) {
|
|
569
|
+
if (s || vt()) {
|
|
570
|
+
if (s || (s = navigator), s.product === "ReactNative")
|
|
571
|
+
return !0;
|
|
572
|
+
const e = s.userAgent || s.vendor || // tslint:disable-next-line:no-any
|
|
573
|
+
(typeof window < "u" ? window.opera : "");
|
|
574
|
+
if (!e) {
|
|
575
|
+
const t = s;
|
|
576
|
+
return t.userAgentData && t.userAgentData.mobile;
|
|
577
|
+
}
|
|
578
|
+
return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(e) || // tslint:disable-next-line:max-line-length
|
|
579
|
+
/1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(e.substr(0, 4));
|
|
580
|
+
}
|
|
581
|
+
return !1;
|
|
582
|
+
}
|
|
583
|
+
function Bt() {
|
|
584
|
+
return typeof window < "u" && window.document != null || //@ts-ignore
|
|
585
|
+
typeof WorkerGlobalScope < "u";
|
|
586
|
+
}
|
|
587
|
+
const E = B();
|
|
588
|
+
E.registerFlag("DEBUG", () => !1, (s) => {
|
|
589
|
+
s && console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.");
|
|
590
|
+
});
|
|
591
|
+
E.registerFlag("IS_BROWSER", () => Bt());
|
|
592
|
+
E.registerFlag("IS_NODE", () => typeof Q < "u" && typeof Q.versions < "u" && typeof Q.versions.node < "u");
|
|
593
|
+
E.registerFlag("IS_CHROME", () => typeof navigator < "u" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor));
|
|
594
|
+
E.registerFlag("IS_SAFARI", () => typeof navigator < "u" && navigator != null && navigator.userAgent != null && /Safari/.test(navigator.userAgent) && /Apple/.test(navigator.vendor));
|
|
595
|
+
E.registerFlag("PROD", () => !1);
|
|
596
|
+
E.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", () => E.getBool("DEBUG"));
|
|
597
|
+
E.registerFlag("DEPRECATION_WARNINGS_ENABLED", () => !0);
|
|
598
|
+
E.registerFlag("IS_TEST", () => !1);
|
|
599
|
+
E.registerFlag("CHECK_COMPUTATION_FOR_ERRORS", () => E.getBool("DEBUG"));
|
|
600
|
+
E.registerFlag("WRAP_TO_IMAGEBITMAP", () => !1);
|
|
601
|
+
E.registerFlag("CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU", () => !1);
|
|
602
|
+
E.registerFlag("USE_SETTIMEOUTCUSTOM", () => !1);
|
|
603
|
+
function Et(s, e) {
|
|
604
|
+
let t = s;
|
|
605
|
+
if (D(s))
|
|
606
|
+
return e === "string" ? [] : [s.length];
|
|
607
|
+
if (Ne(s)) {
|
|
608
|
+
const r = s.channels || "RGBA";
|
|
609
|
+
return [s.height, s.width * r.length];
|
|
610
|
+
} else if (xe(s))
|
|
611
|
+
return [s.buffer.size / (e == null ? 4 : ae(e))];
|
|
612
|
+
if (!Array.isArray(s))
|
|
613
|
+
return [];
|
|
614
|
+
const n = [];
|
|
615
|
+
for (; Array.isArray(t) || D(t) && e !== "string"; )
|
|
616
|
+
n.push(t.length), t = t[0];
|
|
617
|
+
return Array.isArray(s) && B().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") && $e(s, n, []), n;
|
|
618
|
+
}
|
|
619
|
+
function $e(s, e, t) {
|
|
620
|
+
if (t = t || [], !Array.isArray(s) && !D(s)) {
|
|
621
|
+
g(e.length === 0, () => `Element arr[${t.join("][")}] is a primitive, but should be an array/TypedArray of ${e[0]} elements`);
|
|
622
|
+
return;
|
|
623
|
+
}
|
|
624
|
+
g(e.length > 0, () => `Element arr[${t.join("][")}] should be a primitive, but is an array of ${s.length} elements`), g(s.length === e[0], () => `Element arr[${t.join("][")}] should have ${e[0]} elements, but has ${s.length} elements`);
|
|
625
|
+
const n = e.slice(1);
|
|
626
|
+
for (let r = 0; r < s.length; ++r)
|
|
627
|
+
$e(s[r], n, t.concat(r));
|
|
628
|
+
}
|
|
629
|
+
function Se(s, e, t, n) {
|
|
630
|
+
if (s !== "string_or_numeric") {
|
|
631
|
+
if (s == null)
|
|
632
|
+
throw new Error("Expected dtype cannot be null.");
|
|
633
|
+
if (s !== "numeric" && s !== e || s === "numeric" && e === "string")
|
|
634
|
+
throw new Error(`Argument '${t}' passed to '${n}' must be ${s} tensor, but got ${e} tensor`);
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
function S(s, e, t, n = "numeric") {
|
|
638
|
+
if (s instanceof gt())
|
|
639
|
+
return Se(n, s.dtype, e, t), s;
|
|
640
|
+
let r = ue(s);
|
|
641
|
+
if (r !== "string" && ["bool", "int32", "float32"].indexOf(n) >= 0 && (r = n), Se(n, r, e, t), s == null || !D(s) && !Array.isArray(s) && typeof s != "number" && typeof s != "boolean" && typeof s != "string") {
|
|
642
|
+
const c = s == null ? "null" : s.constructor.name;
|
|
643
|
+
throw new Error(`Argument '${e}' passed to '${t}' must be a Tensor or TensorLike, but got '${c}'`);
|
|
644
|
+
}
|
|
645
|
+
const a = Et(s, r);
|
|
646
|
+
!D(s) && !Array.isArray(s) && (s = [s]);
|
|
647
|
+
const o = r !== "string" ? Fe(s, r) : De(s, [], !0);
|
|
648
|
+
return u.makeTensor(o, a, r);
|
|
649
|
+
}
|
|
650
|
+
function Tt(s, e, t, n = "numeric") {
|
|
651
|
+
if (!Array.isArray(s))
|
|
652
|
+
throw new Error(`Argument ${e} passed to ${t} must be a \`Tensor[]\` or \`TensorLike[]\``);
|
|
653
|
+
return s.map((a, i) => S(a, `${e}[${i}]`, t, n));
|
|
654
|
+
}
|
|
655
|
+
const At = "__op";
|
|
656
|
+
function M(s) {
|
|
657
|
+
const e = Object.keys(s);
|
|
658
|
+
if (e.length !== 1)
|
|
659
|
+
throw new Error(`Please provide an object with a single key (operation name) mapping to a function. Got an object with ${e.length} keys.`);
|
|
660
|
+
let t = e[0];
|
|
661
|
+
const n = s[t];
|
|
662
|
+
t.endsWith("_") && (t = t.substring(0, t.length - 1)), t = t + At;
|
|
663
|
+
const r = (...a) => {
|
|
664
|
+
u.startScope(t);
|
|
665
|
+
try {
|
|
666
|
+
const i = n(...a);
|
|
667
|
+
return pt(i) && console.error("Cannot return a Promise inside of tidy."), u.endScope(i), i;
|
|
668
|
+
} catch (i) {
|
|
669
|
+
throw u.endScope(null), i;
|
|
670
|
+
}
|
|
671
|
+
};
|
|
672
|
+
return Object.defineProperty(r, "name", { value: t, configurable: !0 }), r;
|
|
673
|
+
}
|
|
674
|
+
function Mt(s, e, t, n) {
|
|
675
|
+
if (n == null)
|
|
676
|
+
n = ue(s);
|
|
677
|
+
else if (n === "complex64")
|
|
678
|
+
throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");
|
|
679
|
+
if (xe(s) || Ne(s)) {
|
|
680
|
+
if (n !== "float32" && n !== "int32")
|
|
681
|
+
throw new Error(`Creating tensor from GPU data only supports 'float32'|'int32' dtype, while the dtype is ${n}.`);
|
|
682
|
+
return u.backend.createTensorFromGPUData(s, e || t, n);
|
|
683
|
+
}
|
|
684
|
+
if (!D(s) && !Array.isArray(s) && typeof s != "number" && typeof s != "boolean" && typeof s != "string")
|
|
685
|
+
throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");
|
|
686
|
+
if (e != null) {
|
|
687
|
+
he(e);
|
|
688
|
+
const r = Z(e), a = Z(t);
|
|
689
|
+
g(r === a, () => `Based on the provided shape, [${e}], the tensor should have ${r} values but has ${a}`);
|
|
690
|
+
for (let i = 0; i < t.length; ++i) {
|
|
691
|
+
const o = t[i], c = i === t.length - 1 ? o !== Z(e.slice(i)) : !0;
|
|
692
|
+
g(t[i] === e[i] || !c, () => `Error creating a new Tensor. Inferred shape (${t}) does not match the provided shape (${e}). `);
|
|
693
|
+
}
|
|
694
|
+
}
|
|
695
|
+
return !D(s) && !Array.isArray(s) && (s = [s]), e = e || t, s = n !== "string" ? Fe(s, n) : De(s, [], !0), u.makeTensor(s, e, n);
|
|
696
|
+
}
|
|
697
|
+
class H {
|
|
698
|
+
/**
|
|
699
|
+
* Concatenate a number of ArrayBuffers into one.
|
|
700
|
+
*
|
|
701
|
+
* @param buffers An array of ArrayBuffers to concatenate, or a single
|
|
702
|
+
* ArrayBuffer.
|
|
703
|
+
* @returns Result of concatenating `buffers` in order.
|
|
704
|
+
*/
|
|
705
|
+
static join(e) {
|
|
706
|
+
return new H(e).slice();
|
|
707
|
+
}
|
|
708
|
+
constructor(e) {
|
|
709
|
+
if (this.shards = [], this.previousShardIndex = 0, e == null || (e instanceof Array || (e = [e]), e = e.map((n) => D(n) ? n.buffer : n), e.length === 0))
|
|
710
|
+
return;
|
|
711
|
+
this.bufferUniformSize = e[0].byteLength;
|
|
712
|
+
let t = 0;
|
|
713
|
+
for (let n = 0; n < e.length; n++) {
|
|
714
|
+
const r = e[n];
|
|
715
|
+
n !== e.length - 1 && r.byteLength !== this.bufferUniformSize && (this.bufferUniformSize = void 0);
|
|
716
|
+
const a = t + r.byteLength;
|
|
717
|
+
this.shards.push({ buffer: r, start: t, end: a }), t = a;
|
|
718
|
+
}
|
|
719
|
+
this.shards.length === 0 && (this.byteLength = 0), this.byteLength = this.shards[this.shards.length - 1].end;
|
|
720
|
+
}
|
|
721
|
+
slice(e = 0, t = this.byteLength) {
|
|
722
|
+
if (this.shards.length === 0)
|
|
723
|
+
return new ArrayBuffer(0);
|
|
724
|
+
if (e = isNaN(Number(e)) ? 0 : e, t = isNaN(Number(t)) ? 0 : t, e = Math.max(0, e), t = Math.min(this.byteLength, t), t <= e)
|
|
725
|
+
return new ArrayBuffer(0);
|
|
726
|
+
const n = this.findShardForByte(e);
|
|
727
|
+
if (n === -1)
|
|
728
|
+
throw new Error(`Could not find start shard for byte ${e}`);
|
|
729
|
+
const r = t - e, a = new ArrayBuffer(r), i = new Uint8Array(a);
|
|
730
|
+
let o = 0;
|
|
731
|
+
for (let c = n; c < this.shards.length; c++) {
|
|
732
|
+
const l = this.shards[c], m = e + o - l.start, d = o, b = Math.min(t, l.end) - l.start, w = new Uint8Array(l.buffer, m, b - m);
|
|
733
|
+
if (i.set(w, d), o += w.length, t < l.end)
|
|
734
|
+
break;
|
|
735
|
+
}
|
|
736
|
+
return a;
|
|
737
|
+
}
|
|
738
|
+
/**
|
|
739
|
+
* Get the index of the shard that contains the byte at `byteIndex`.
|
|
740
|
+
*/
|
|
741
|
+
findShardForByte(e) {
|
|
742
|
+
if (this.shards.length === 0 || e < 0 || e >= this.byteLength)
|
|
743
|
+
return -1;
|
|
744
|
+
if (this.bufferUniformSize != null)
|
|
745
|
+
return this.previousShardIndex = Math.floor(e / this.bufferUniformSize), this.previousShardIndex;
|
|
746
|
+
function t(r) {
|
|
747
|
+
return e < r.start ? -1 : e >= r.end ? 1 : 0;
|
|
748
|
+
}
|
|
749
|
+
if (t(this.shards[this.previousShardIndex]) === 0)
|
|
750
|
+
return this.previousShardIndex;
|
|
751
|
+
const n = Nt(this.shards, t);
|
|
752
|
+
return n === -1 ? -1 : (this.previousShardIndex = n, this.previousShardIndex);
|
|
753
|
+
}
|
|
754
|
+
}
|
|
755
|
+
function Nt(s, e) {
|
|
756
|
+
let t = 0, n = s.length;
|
|
757
|
+
for (; t <= n; ) {
|
|
758
|
+
const r = Math.floor((n - t) / 2) + t, a = e(s[r]);
|
|
759
|
+
if (a === 0)
|
|
760
|
+
return r;
|
|
761
|
+
a < 0 ? n = r : t = r + 1;
|
|
762
|
+
}
|
|
763
|
+
return -1;
|
|
764
|
+
}
|
|
765
|
+
function _n() {
|
|
766
|
+
u.disposeVariables();
|
|
767
|
+
}
|
|
768
|
+
function On() {
|
|
769
|
+
return u;
|
|
770
|
+
}
|
|
771
|
+
function Ln() {
|
|
772
|
+
return u.memory();
|
|
773
|
+
}
|
|
774
|
+
function k(s, e) {
|
|
775
|
+
return u.tidy(s, e);
|
|
776
|
+
}
|
|
777
|
+
function A(s) {
|
|
778
|
+
Te(s).forEach((t) => t.dispose());
|
|
779
|
+
}
|
|
780
|
+
function xt(s) {
|
|
781
|
+
return u.keep(s);
|
|
782
|
+
}
|
|
783
|
+
function zn(s) {
|
|
784
|
+
return u.setBackend(s);
|
|
785
|
+
}
|
|
786
|
+
function Pn() {
|
|
787
|
+
return u.ready();
|
|
788
|
+
}
|
|
789
|
+
function Un() {
|
|
790
|
+
return u.backendName;
|
|
791
|
+
}
|
|
792
|
+
function Gn(s, e, t = 1) {
|
|
793
|
+
return u.registerBackend(s, e, t);
|
|
794
|
+
}
|
|
795
|
+
function Wn() {
|
|
796
|
+
return u.backend;
|
|
797
|
+
}
|
|
798
|
+
const ke = 4;
|
|
799
|
+
async function Kn(s, e) {
|
|
800
|
+
const t = [], n = [], r = Array.isArray(s) ? s.map((i) => i.name) : Object.keys(s);
|
|
801
|
+
for (let i = 0; i < r.length; ++i) {
|
|
802
|
+
const o = r[i], c = Array.isArray(s) ? s[i].tensor : s[o];
|
|
803
|
+
if (c.dtype !== "float32" && c.dtype !== "int32" && c.dtype !== "bool" && c.dtype !== "string" && c.dtype !== "complex64")
|
|
804
|
+
throw new Error(`Unsupported dtype in weight '${o}': ${c.dtype}`);
|
|
805
|
+
const l = { name: o, shape: c.shape, dtype: c.dtype };
|
|
806
|
+
if (c.dtype === "string") {
|
|
807
|
+
const h = new Promise(async (m) => {
|
|
808
|
+
const d = await c.bytes(), p = d.reduce((I, T) => I + T.length, 0) + ke * d.length, b = new Uint8Array(p);
|
|
809
|
+
let w = 0;
|
|
810
|
+
for (let I = 0; I < d.length; I++) {
|
|
811
|
+
const T = d[I], te = new Uint8Array(new Uint32Array([T.length]).buffer);
|
|
812
|
+
b.set(te, w), w += ke, b.set(T, w), w += T.length;
|
|
813
|
+
}
|
|
814
|
+
m(b);
|
|
815
|
+
});
|
|
816
|
+
n.push(h);
|
|
817
|
+
} else
|
|
818
|
+
n.push(c.data());
|
|
819
|
+
e != null && (l.group = e), t.push(l);
|
|
820
|
+
}
|
|
821
|
+
const a = await Promise.all(n);
|
|
822
|
+
return { data: Ft(a), specs: t };
|
|
823
|
+
}
|
|
824
|
+
function Ft(s) {
|
|
825
|
+
if (s === null)
|
|
826
|
+
throw new Error(`Invalid input value: ${JSON.stringify(s)}`);
|
|
827
|
+
let e = 0;
|
|
828
|
+
const t = [];
|
|
829
|
+
s.forEach((a) => {
|
|
830
|
+
if (e += a.byteLength, t.push(a.byteLength === a.buffer.byteLength ? a : new a.constructor(a)), !(a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array))
|
|
831
|
+
throw new Error(`Unsupported TypedArray subtype: ${a.constructor.name}`);
|
|
832
|
+
});
|
|
833
|
+
const n = new Uint8Array(e);
|
|
834
|
+
let r = 0;
|
|
835
|
+
return t.forEach((a) => {
|
|
836
|
+
n.set(new Uint8Array(a.buffer), r), r += a.byteLength;
|
|
837
|
+
}), n.buffer;
|
|
838
|
+
}
|
|
839
|
+
const de = typeof ee < "u" && (typeof Blob > "u" || typeof atob > "u" || typeof btoa > "u");
|
|
840
|
+
function Ie(s) {
|
|
841
|
+
return de ? ee.byteLength(s, "utf8") : new Blob([s]).size;
|
|
842
|
+
}
|
|
843
|
+
function Dt(s) {
|
|
844
|
+
if (de)
|
|
845
|
+
return ee.from(s).toString("base64");
|
|
846
|
+
const e = new Uint8Array(s);
|
|
847
|
+
let t = "";
|
|
848
|
+
for (let n = 0, r = e.length; n < r; n++)
|
|
849
|
+
t += String.fromCharCode(e[n]);
|
|
850
|
+
return btoa(t);
|
|
851
|
+
}
|
|
852
|
+
function Ct(s) {
|
|
853
|
+
if (de) {
|
|
854
|
+
const n = ee.from(s, "base64");
|
|
855
|
+
return n.buffer.slice(n.byteOffset, n.byteOffset + n.byteLength);
|
|
856
|
+
}
|
|
857
|
+
const e = atob(s), t = new Uint8Array(e.length);
|
|
858
|
+
for (let n = 0; n < e.length; ++n)
|
|
859
|
+
t.set([e.charCodeAt(n)], n);
|
|
860
|
+
return t.buffer;
|
|
861
|
+
}
|
|
862
|
+
function jn(s) {
|
|
863
|
+
return H.join(s);
|
|
864
|
+
}
|
|
865
|
+
function qn(s, e) {
|
|
866
|
+
const t = {
|
|
867
|
+
modelTopology: s.modelTopology,
|
|
868
|
+
format: s.format,
|
|
869
|
+
generatedBy: s.generatedBy,
|
|
870
|
+
convertedBy: s.convertedBy,
|
|
871
|
+
weightsManifest: e
|
|
872
|
+
};
|
|
873
|
+
return s.signature != null && (t.signature = s.signature), s.userDefinedMetadata != null && (t.userDefinedMetadata = s.userDefinedMetadata), s.modelInitializer != null && (t.modelInitializer = s.modelInitializer), s.initializerSignature != null && (t.initializerSignature = s.initializerSignature), s.trainingConfig != null && (t.trainingConfig = s.trainingConfig), t;
|
|
874
|
+
}
|
|
875
|
+
function Rt(s, e, t) {
|
|
876
|
+
const n = {
|
|
877
|
+
modelTopology: s.modelTopology,
|
|
878
|
+
format: s.format,
|
|
879
|
+
generatedBy: s.generatedBy,
|
|
880
|
+
convertedBy: s.convertedBy
|
|
881
|
+
};
|
|
882
|
+
if (s.trainingConfig != null && (n.trainingConfig = s.trainingConfig), s.weightsManifest != null) {
|
|
883
|
+
if (!e)
|
|
884
|
+
throw new Error("modelJSON has weightsManifest but weightSpecs is null");
|
|
885
|
+
if (!t)
|
|
886
|
+
throw new Error("modelJSON has weightsManifest but weightData is null");
|
|
887
|
+
n.weightSpecs = e, n.weightData = t;
|
|
888
|
+
}
|
|
889
|
+
return s.signature != null && (n.signature = s.signature), s.userDefinedMetadata != null && (n.userDefinedMetadata = s.userDefinedMetadata), s.modelInitializer != null && (n.modelInitializer = s.modelInitializer), s.initializerSignature != null && (n.initializerSignature = s.initializerSignature), n;
|
|
890
|
+
}
|
|
891
|
+
async function Vn(s, e) {
|
|
892
|
+
let t, n;
|
|
893
|
+
return s.weightsManifest != null && ([t, n] = await e(s.weightsManifest)), Rt(s, t, n);
|
|
894
|
+
}
|
|
895
|
+
function _e(s) {
|
|
896
|
+
if (s.modelTopology instanceof ArrayBuffer)
|
|
897
|
+
throw new Error("Expected JSON model topology, received ArrayBuffer.");
|
|
898
|
+
return {
|
|
899
|
+
dateSaved: /* @__PURE__ */ new Date(),
|
|
900
|
+
modelTopologyType: "JSON",
|
|
901
|
+
modelTopologyBytes: s.modelTopology == null ? 0 : Ie(JSON.stringify(s.modelTopology)),
|
|
902
|
+
weightSpecsBytes: s.weightSpecs == null ? 0 : Ie(JSON.stringify(s.weightSpecs)),
|
|
903
|
+
weightDataBytes: s.weightData == null ? 0 : new H(s.weightData).byteLength
|
|
904
|
+
};
|
|
905
|
+
}
|
|
906
|
+
function Hn(s) {
|
|
907
|
+
const e = [];
|
|
908
|
+
for (const t of s)
|
|
909
|
+
e.push(...t.weights);
|
|
910
|
+
return e;
|
|
911
|
+
}
|
|
912
|
+
class v {
|
|
913
|
+
constructor() {
|
|
914
|
+
this.saveRouters = [], this.loadRouters = [];
|
|
915
|
+
}
|
|
916
|
+
static getInstance() {
|
|
917
|
+
return v.instance == null && (v.instance = new v()), v.instance;
|
|
918
|
+
}
|
|
919
|
+
/**
|
|
920
|
+
* Register a save-handler router.
|
|
921
|
+
*
|
|
922
|
+
* @param saveRouter A function that maps a URL-like string onto an instance
|
|
923
|
+
* of `IOHandler` with the `save` method defined or `null`.
|
|
924
|
+
*/
|
|
925
|
+
static registerSaveRouter(e) {
|
|
926
|
+
v.getInstance().saveRouters.push(e);
|
|
927
|
+
}
|
|
928
|
+
/**
|
|
929
|
+
* Register a load-handler router.
|
|
930
|
+
*
|
|
931
|
+
* @param loadRouter A function that maps a URL-like string onto an instance
|
|
932
|
+
* of `IOHandler` with the `load` method defined or `null`.
|
|
933
|
+
*/
|
|
934
|
+
static registerLoadRouter(e) {
|
|
935
|
+
v.getInstance().loadRouters.push(e);
|
|
936
|
+
}
|
|
937
|
+
/**
|
|
938
|
+
* Look up IOHandler for saving, given a URL-like string.
|
|
939
|
+
*
|
|
940
|
+
* @param url
|
|
941
|
+
* @returns If only one match is found, an instance of IOHandler with the
|
|
942
|
+
* `save` method defined. If no match is found, `null`.
|
|
943
|
+
* @throws Error, if more than one match is found.
|
|
944
|
+
*/
|
|
945
|
+
static getSaveHandlers(e) {
|
|
946
|
+
return v.getHandlers(e, "save");
|
|
947
|
+
}
|
|
948
|
+
/**
|
|
949
|
+
* Look up IOHandler for loading, given a URL-like string.
|
|
950
|
+
*
|
|
951
|
+
* @param url
|
|
952
|
+
* @param loadOptions Optional, custom load options.
|
|
953
|
+
* @returns All valid handlers for `url`, given the currently registered
|
|
954
|
+
* handler routers.
|
|
955
|
+
*/
|
|
956
|
+
static getLoadHandlers(e, t) {
|
|
957
|
+
return v.getHandlers(e, "load", t);
|
|
958
|
+
}
|
|
959
|
+
static getHandlers(e, t, n) {
|
|
960
|
+
const r = [];
|
|
961
|
+
return (t === "load" ? v.getInstance().loadRouters : v.getInstance().saveRouters).forEach((i) => {
|
|
962
|
+
const o = i(e, n);
|
|
963
|
+
o !== null && r.push(o);
|
|
964
|
+
}), r;
|
|
965
|
+
}
|
|
966
|
+
}
|
|
967
|
+
const Xn = (s) => v.getSaveHandlers(s);
|
|
968
|
+
const ie = "tensorflowjs", oe = 1, O = "models_store", $ = "model_info_store";
|
|
969
|
+
function Oe() {
|
|
970
|
+
if (!B().getBool("IS_BROWSER"))
|
|
971
|
+
throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");
|
|
972
|
+
const s = typeof window > "u" ? self : window, e = s.indexedDB || s.mozIndexedDB || s.webkitIndexedDB || s.msIndexedDB || s.shimIndexedDB;
|
|
973
|
+
if (e == null)
|
|
974
|
+
throw new Error("The current browser does not appear to support IndexedDB.");
|
|
975
|
+
return e;
|
|
976
|
+
}
|
|
977
|
+
function ce(s) {
|
|
978
|
+
const e = s.result;
|
|
979
|
+
e.createObjectStore(O, { keyPath: "modelPath" }), e.createObjectStore($, { keyPath: "modelPath" });
|
|
980
|
+
}
|
|
981
|
+
class z {
|
|
982
|
+
constructor(e) {
|
|
983
|
+
if (this.indexedDB = Oe(), e == null || !e)
|
|
984
|
+
throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
|
|
985
|
+
this.modelPath = e;
|
|
986
|
+
}
|
|
987
|
+
async save(e) {
|
|
988
|
+
if (e.modelTopology instanceof ArrayBuffer)
|
|
989
|
+
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
|
|
990
|
+
return this.databaseAction(this.modelPath, e);
|
|
991
|
+
}
|
|
992
|
+
async load() {
|
|
993
|
+
return this.databaseAction(this.modelPath);
|
|
994
|
+
}
|
|
995
|
+
/**
|
|
996
|
+
* Perform database action to put model artifacts into or read model artifacts
|
|
997
|
+
* from IndexedDB object store.
|
|
998
|
+
*
|
|
999
|
+
* Whether the action is put or get depends on whether `modelArtifacts` is
|
|
1000
|
+
* specified. If it is specified, the action will be put; otherwise the action
|
|
1001
|
+
* will be get.
|
|
1002
|
+
*
|
|
1003
|
+
* @param modelPath A unique string path for the model.
|
|
1004
|
+
* @param modelArtifacts If specified, it will be the model artifacts to be
|
|
1005
|
+
* stored in IndexedDB.
|
|
1006
|
+
* @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`
|
|
1007
|
+
* of `ModelArtifacts`, if the action is get.
|
|
1008
|
+
*/
|
|
1009
|
+
databaseAction(e, t) {
|
|
1010
|
+
return new Promise((n, r) => {
|
|
1011
|
+
const a = this.indexedDB.open(ie, oe);
|
|
1012
|
+
a.onupgradeneeded = () => ce(a), a.onsuccess = () => {
|
|
1013
|
+
const i = a.result;
|
|
1014
|
+
if (t == null) {
|
|
1015
|
+
const o = i.transaction(O, "readonly"), l = o.objectStore(O).get(this.modelPath);
|
|
1016
|
+
l.onsuccess = () => {
|
|
1017
|
+
if (l.result == null)
|
|
1018
|
+
return i.close(), r(new Error(`Cannot find model with path '${this.modelPath}' in IndexedDB.`));
|
|
1019
|
+
n(l.result.modelArtifacts);
|
|
1020
|
+
}, l.onerror = (h) => (i.close(), r(l.error)), o.oncomplete = () => i.close();
|
|
1021
|
+
} else {
|
|
1022
|
+
t.weightData = H.join(t.weightData);
|
|
1023
|
+
const o = _e(t), c = i.transaction($, "readwrite");
|
|
1024
|
+
let l = c.objectStore($), h;
|
|
1025
|
+
try {
|
|
1026
|
+
h = l.put({ modelPath: this.modelPath, modelArtifactsInfo: o });
|
|
1027
|
+
} catch (d) {
|
|
1028
|
+
return r(d);
|
|
1029
|
+
}
|
|
1030
|
+
let m;
|
|
1031
|
+
h.onsuccess = () => {
|
|
1032
|
+
m = i.transaction(O, "readwrite");
|
|
1033
|
+
const d = m.objectStore(O);
|
|
1034
|
+
let p;
|
|
1035
|
+
try {
|
|
1036
|
+
p = d.put({
|
|
1037
|
+
modelPath: this.modelPath,
|
|
1038
|
+
modelArtifacts: t,
|
|
1039
|
+
modelArtifactsInfo: o
|
|
1040
|
+
});
|
|
1041
|
+
} catch (b) {
|
|
1042
|
+
return r(b);
|
|
1043
|
+
}
|
|
1044
|
+
p.onsuccess = () => n({ modelArtifactsInfo: o }), p.onerror = (b) => {
|
|
1045
|
+
l = c.objectStore($);
|
|
1046
|
+
const w = l.delete(this.modelPath);
|
|
1047
|
+
w.onsuccess = () => (i.close(), r(p.error)), w.onerror = (I) => (i.close(), r(p.error));
|
|
1048
|
+
};
|
|
1049
|
+
}, h.onerror = (d) => (i.close(), r(h.error)), c.oncomplete = () => {
|
|
1050
|
+
m == null ? i.close() : m.oncomplete = () => i.close();
|
|
1051
|
+
};
|
|
1052
|
+
}
|
|
1053
|
+
}, a.onerror = (i) => r(a.error);
|
|
1054
|
+
});
|
|
1055
|
+
}
|
|
1056
|
+
}
|
|
1057
|
+
z.URL_SCHEME = "indexeddb://";
|
|
1058
|
+
const Le = (s) => B().getBool("IS_BROWSER") && !Array.isArray(s) && s.startsWith(z.URL_SCHEME) ? $t(s.slice(z.URL_SCHEME.length)) : null;
|
|
1059
|
+
v.registerSaveRouter(Le);
|
|
1060
|
+
v.registerLoadRouter(Le);
|
|
1061
|
+
function $t(s) {
|
|
1062
|
+
return new z(s);
|
|
1063
|
+
}
|
|
1064
|
+
function _t(s) {
|
|
1065
|
+
return s.startsWith(z.URL_SCHEME) ? s.slice(z.URL_SCHEME.length) : s;
|
|
1066
|
+
}
|
|
1067
|
+
class Ot {
|
|
1068
|
+
constructor() {
|
|
1069
|
+
this.indexedDB = Oe();
|
|
1070
|
+
}
|
|
1071
|
+
async listModels() {
|
|
1072
|
+
return new Promise((e, t) => {
|
|
1073
|
+
const n = this.indexedDB.open(ie, oe);
|
|
1074
|
+
n.onupgradeneeded = () => ce(n), n.onsuccess = () => {
|
|
1075
|
+
const r = n.result, a = r.transaction($, "readonly"), o = a.objectStore($).getAll();
|
|
1076
|
+
o.onsuccess = () => {
|
|
1077
|
+
const c = {};
|
|
1078
|
+
for (const l of o.result)
|
|
1079
|
+
c[l.modelPath] = l.modelArtifactsInfo;
|
|
1080
|
+
e(c);
|
|
1081
|
+
}, o.onerror = (c) => (r.close(), t(o.error)), a.oncomplete = () => r.close();
|
|
1082
|
+
}, n.onerror = (r) => t(n.error);
|
|
1083
|
+
});
|
|
1084
|
+
}
|
|
1085
|
+
async removeModel(e) {
|
|
1086
|
+
return e = _t(e), new Promise((t, n) => {
|
|
1087
|
+
const r = this.indexedDB.open(ie, oe);
|
|
1088
|
+
r.onupgradeneeded = () => ce(r), r.onsuccess = () => {
|
|
1089
|
+
const a = r.result, i = a.transaction($, "readwrite"), o = i.objectStore($), c = o.get(e);
|
|
1090
|
+
let l;
|
|
1091
|
+
c.onsuccess = () => {
|
|
1092
|
+
if (c.result == null)
|
|
1093
|
+
return a.close(), n(new Error(`Cannot find model with path '${e}' in IndexedDB.`));
|
|
1094
|
+
{
|
|
1095
|
+
const h = o.delete(e), m = () => {
|
|
1096
|
+
l = a.transaction(O, "readwrite");
|
|
1097
|
+
const p = l.objectStore(O).delete(e);
|
|
1098
|
+
p.onsuccess = () => t(c.result.modelArtifactsInfo), p.onerror = (b) => n(c.error);
|
|
1099
|
+
};
|
|
1100
|
+
h.onsuccess = m, h.onerror = (d) => (m(), a.close(), n(c.error));
|
|
1101
|
+
}
|
|
1102
|
+
}, c.onerror = (h) => (a.close(), n(c.error)), i.oncomplete = () => {
|
|
1103
|
+
l == null ? a.close() : l.oncomplete = () => a.close();
|
|
1104
|
+
};
|
|
1105
|
+
}, r.onerror = (a) => n(r.error);
|
|
1106
|
+
});
|
|
1107
|
+
}
|
|
1108
|
+
}
|
|
1109
|
+
const C = "/", K = "tensorflowjs_models", ze = "info", Lt = "model_topology", zt = "weight_specs", Pt = "weight_data", Ut = "model_metadata";
|
|
1110
|
+
function Pe(s) {
|
|
1111
|
+
return {
|
|
1112
|
+
info: [K, s, ze].join(C),
|
|
1113
|
+
topology: [K, s, Lt].join(C),
|
|
1114
|
+
weightSpecs: [K, s, zt].join(C),
|
|
1115
|
+
weightData: [K, s, Pt].join(C),
|
|
1116
|
+
modelMetadata: [K, s, Ut].join(C)
|
|
1117
|
+
};
|
|
1118
|
+
}
|
|
1119
|
+
function Ue(s) {
|
|
1120
|
+
for (const e of Object.values(s))
|
|
1121
|
+
window.localStorage.removeItem(e);
|
|
1122
|
+
}
|
|
1123
|
+
function Gt(s) {
|
|
1124
|
+
const e = s.split(C);
|
|
1125
|
+
if (e.length < 3)
|
|
1126
|
+
throw new Error(`Invalid key format: ${s}`);
|
|
1127
|
+
return e.slice(1, e.length - 1).join(C);
|
|
1128
|
+
}
|
|
1129
|
+
function Wt(s) {
|
|
1130
|
+
return s.startsWith(P.URL_SCHEME) ? s.slice(P.URL_SCHEME.length) : s;
|
|
1131
|
+
}
|
|
1132
|
+
class P {
|
|
1133
|
+
constructor(e) {
|
|
1134
|
+
if (!B().getBool("IS_BROWSER") || typeof window > "u" || typeof window.localStorage > "u")
|
|
1135
|
+
throw new Error("The current environment does not support local storage.");
|
|
1136
|
+
if (this.LS = window.localStorage, e == null || !e)
|
|
1137
|
+
throw new Error("For local storage, modelPath must not be null, undefined or empty.");
|
|
1138
|
+
this.modelPath = e, this.keys = Pe(this.modelPath);
|
|
1139
|
+
}
|
|
1140
|
+
/**
|
|
1141
|
+
* Save model artifacts to browser local storage.
|
|
1142
|
+
*
|
|
1143
|
+
* See the documentation to `browserLocalStorage` for details on the saved
|
|
1144
|
+
* artifacts.
|
|
1145
|
+
*
|
|
1146
|
+
* @param modelArtifacts The model artifacts to be stored.
|
|
1147
|
+
* @returns An instance of SaveResult.
|
|
1148
|
+
*/
|
|
1149
|
+
async save(e) {
|
|
1150
|
+
if (e.modelTopology instanceof ArrayBuffer)
|
|
1151
|
+
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
|
|
1152
|
+
{
|
|
1153
|
+
const t = JSON.stringify(e.modelTopology), n = JSON.stringify(e.weightSpecs), r = _e(e), a = H.join(e.weightData);
|
|
1154
|
+
try {
|
|
1155
|
+
this.LS.setItem(this.keys.info, JSON.stringify(r)), this.LS.setItem(this.keys.topology, t), this.LS.setItem(this.keys.weightSpecs, n), this.LS.setItem(this.keys.weightData, Dt(a));
|
|
1156
|
+
const i = {
|
|
1157
|
+
format: e.format,
|
|
1158
|
+
generatedBy: e.generatedBy,
|
|
1159
|
+
convertedBy: e.convertedBy,
|
|
1160
|
+
signature: e.signature != null ? e.signature : void 0,
|
|
1161
|
+
userDefinedMetadata: e.userDefinedMetadata != null ? e.userDefinedMetadata : void 0,
|
|
1162
|
+
modelInitializer: e.modelInitializer != null ? e.modelInitializer : void 0,
|
|
1163
|
+
initializerSignature: e.initializerSignature != null ? e.initializerSignature : void 0,
|
|
1164
|
+
trainingConfig: e.trainingConfig != null ? e.trainingConfig : void 0
|
|
1165
|
+
};
|
|
1166
|
+
return this.LS.setItem(this.keys.modelMetadata, JSON.stringify(i)), { modelArtifactsInfo: r };
|
|
1167
|
+
} catch {
|
|
1168
|
+
throw Ue(this.keys), new Error(`Failed to save model '${this.modelPath}' to local storage: size quota being exceeded is a possible cause of this failure: modelTopologyBytes=${r.modelTopologyBytes}, weightSpecsBytes=${r.weightSpecsBytes}, weightDataBytes=${r.weightDataBytes}.`);
|
|
1169
|
+
}
|
|
1170
|
+
}
|
|
1171
|
+
}
|
|
1172
|
+
/**
|
|
1173
|
+
* Load a model from local storage.
|
|
1174
|
+
*
|
|
1175
|
+
* See the documentation to `browserLocalStorage` for details on the saved
|
|
1176
|
+
* artifacts.
|
|
1177
|
+
*
|
|
1178
|
+
* @returns The loaded model (if loading succeeds).
|
|
1179
|
+
*/
|
|
1180
|
+
async load() {
|
|
1181
|
+
const e = JSON.parse(this.LS.getItem(this.keys.info));
|
|
1182
|
+
if (e == null)
|
|
1183
|
+
throw new Error(`In local storage, there is no model with name '${this.modelPath}'`);
|
|
1184
|
+
if (e.modelTopologyType !== "JSON")
|
|
1185
|
+
throw new Error("BrowserLocalStorage does not support loading non-JSON model topology yet.");
|
|
1186
|
+
const t = {}, n = JSON.parse(this.LS.getItem(this.keys.topology));
|
|
1187
|
+
if (n == null)
|
|
1188
|
+
throw new Error(`In local storage, the topology of model '${this.modelPath}' is missing.`);
|
|
1189
|
+
t.modelTopology = n;
|
|
1190
|
+
const r = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
|
|
1191
|
+
if (r == null)
|
|
1192
|
+
throw new Error(`In local storage, the weight specs of model '${this.modelPath}' are missing.`);
|
|
1193
|
+
t.weightSpecs = r;
|
|
1194
|
+
const a = this.LS.getItem(this.keys.modelMetadata);
|
|
1195
|
+
if (a != null) {
|
|
1196
|
+
const o = JSON.parse(a);
|
|
1197
|
+
t.format = o.format, t.generatedBy = o.generatedBy, t.convertedBy = o.convertedBy, o.signature != null && (t.signature = o.signature), o.userDefinedMetadata != null && (t.userDefinedMetadata = o.userDefinedMetadata), o.modelInitializer != null && (t.modelInitializer = o.modelInitializer), o.initializerSignature != null && (t.initializerSignature = o.initializerSignature), o.trainingConfig != null && (t.trainingConfig = o.trainingConfig);
|
|
1198
|
+
}
|
|
1199
|
+
const i = this.LS.getItem(this.keys.weightData);
|
|
1200
|
+
if (i == null)
|
|
1201
|
+
throw new Error(`In local storage, the binary weight values of model '${this.modelPath}' are missing.`);
|
|
1202
|
+
return t.weightData = Ct(i), t;
|
|
1203
|
+
}
|
|
1204
|
+
}
|
|
1205
|
+
P.URL_SCHEME = "localstorage://";
|
|
1206
|
+
const Ge = (s) => B().getBool("IS_BROWSER") && !Array.isArray(s) && s.startsWith(P.URL_SCHEME) ? Kt(s.slice(P.URL_SCHEME.length)) : null;
|
|
1207
|
+
v.registerSaveRouter(Ge);
|
|
1208
|
+
v.registerLoadRouter(Ge);
|
|
1209
|
+
function Kt(s) {
|
|
1210
|
+
return new P(s);
|
|
1211
|
+
}
|
|
1212
|
+
class jt {
|
|
1213
|
+
constructor() {
|
|
1214
|
+
g(B().getBool("IS_BROWSER"), () => "Current environment is not a web browser"), g(typeof window > "u" || typeof window.localStorage < "u", () => "Current browser does not appear to support localStorage"), this.LS = window.localStorage;
|
|
1215
|
+
}
|
|
1216
|
+
async listModels() {
|
|
1217
|
+
const e = {}, t = K + C, n = C + ze;
|
|
1218
|
+
for (let r = 0; r < this.LS.length; ++r) {
|
|
1219
|
+
const a = this.LS.key(r);
|
|
1220
|
+
if (a.startsWith(t) && a.endsWith(n)) {
|
|
1221
|
+
const i = Gt(a);
|
|
1222
|
+
e[i] = JSON.parse(this.LS.getItem(a));
|
|
1223
|
+
}
|
|
1224
|
+
}
|
|
1225
|
+
return e;
|
|
1226
|
+
}
|
|
1227
|
+
async removeModel(e) {
|
|
1228
|
+
e = Wt(e);
|
|
1229
|
+
const t = Pe(e);
|
|
1230
|
+
if (this.LS.getItem(t.info) == null)
|
|
1231
|
+
throw new Error(`Cannot find model at path '${e}'`);
|
|
1232
|
+
const n = JSON.parse(this.LS.getItem(t.info));
|
|
1233
|
+
return Ue(t), n;
|
|
1234
|
+
}
|
|
1235
|
+
}
|
|
1236
|
+
const ve = "://";
|
|
1237
|
+
class N {
|
|
1238
|
+
constructor() {
|
|
1239
|
+
this.managers = {};
|
|
1240
|
+
}
|
|
1241
|
+
static getInstance() {
|
|
1242
|
+
return N.instance == null && (N.instance = new N()), N.instance;
|
|
1243
|
+
}
|
|
1244
|
+
/**
|
|
1245
|
+
* Register a save-handler router.
|
|
1246
|
+
*
|
|
1247
|
+
* @param saveRouter A function that maps a URL-like string onto an instance
|
|
1248
|
+
* of `IOHandler` with the `save` method defined or `null`.
|
|
1249
|
+
*/
|
|
1250
|
+
static registerManager(e, t) {
|
|
1251
|
+
g(e != null, () => "scheme must not be undefined or null."), e.endsWith(ve) && (e = e.slice(0, e.indexOf(ve))), g(e.length > 0, () => "scheme must not be an empty string.");
|
|
1252
|
+
const n = N.getInstance();
|
|
1253
|
+
g(n.managers[e] == null, () => `A model store manager is already registered for scheme '${e}'.`), n.managers[e] = t;
|
|
1254
|
+
}
|
|
1255
|
+
static getManager(e) {
|
|
1256
|
+
const t = N.getInstance().managers[e];
|
|
1257
|
+
if (t == null)
|
|
1258
|
+
throw new Error(`Cannot find model manager for scheme '${e}'`);
|
|
1259
|
+
return t;
|
|
1260
|
+
}
|
|
1261
|
+
static getSchemes() {
|
|
1262
|
+
return Object.keys(N.getInstance().managers);
|
|
1263
|
+
}
|
|
1264
|
+
}
|
|
1265
|
+
class qt {
|
|
1266
|
+
constructor() {
|
|
1267
|
+
this.messageName = "setTimeoutCustom", this.functionRefs = [], this.handledMessageCount = 0, this.hasEventListener = !1;
|
|
1268
|
+
}
|
|
1269
|
+
fetch(e, t) {
|
|
1270
|
+
return fetch(e, t);
|
|
1271
|
+
}
|
|
1272
|
+
now() {
|
|
1273
|
+
return performance.now();
|
|
1274
|
+
}
|
|
1275
|
+
encode(e, t) {
|
|
1276
|
+
if (t !== "utf-8" && t !== "utf8")
|
|
1277
|
+
throw new Error(`Browser's encoder only supports utf-8, but got ${t}`);
|
|
1278
|
+
return this.textEncoder == null && (this.textEncoder = new TextEncoder()), this.textEncoder.encode(e);
|
|
1279
|
+
}
|
|
1280
|
+
decode(e, t) {
|
|
1281
|
+
return new TextDecoder(t).decode(e);
|
|
1282
|
+
}
|
|
1283
|
+
// If the setTimeout nesting level is greater than 5 and timeout is less
|
|
1284
|
+
// than 4ms, timeout will be clamped to 4ms, which hurts the perf.
|
|
1285
|
+
// Interleaving window.postMessage and setTimeout will trick the browser and
|
|
1286
|
+
// avoid the clamp.
|
|
1287
|
+
setTimeoutCustom(e, t) {
|
|
1288
|
+
if (typeof window > "u" || !B().getBool("USE_SETTIMEOUTCUSTOM")) {
|
|
1289
|
+
setTimeout(e, t);
|
|
1290
|
+
return;
|
|
1291
|
+
}
|
|
1292
|
+
this.functionRefs.push(e), setTimeout(() => {
|
|
1293
|
+
window.postMessage({ name: this.messageName, index: this.functionRefs.length - 1 }, "*");
|
|
1294
|
+
}, t), this.hasEventListener || (this.hasEventListener = !0, window.addEventListener("message", (n) => {
|
|
1295
|
+
if (n.source === window && n.data.name === this.messageName) {
|
|
1296
|
+
n.stopPropagation();
|
|
1297
|
+
const r = this.functionRefs[n.data.index];
|
|
1298
|
+
r(), this.handledMessageCount++, this.handledMessageCount === this.functionRefs.length && (this.functionRefs = [], this.handledMessageCount = 0);
|
|
1299
|
+
}
|
|
1300
|
+
}, !0));
|
|
1301
|
+
}
|
|
1302
|
+
isTypedArray(e) {
|
|
1303
|
+
return yt(e);
|
|
1304
|
+
}
|
|
1305
|
+
}
|
|
1306
|
+
if (B().get("IS_BROWSER")) {
|
|
1307
|
+
B().setPlatform("browser", new qt());
|
|
1308
|
+
try {
|
|
1309
|
+
N.registerManager(P.URL_SCHEME, new jt());
|
|
1310
|
+
} catch {
|
|
1311
|
+
}
|
|
1312
|
+
try {
|
|
1313
|
+
N.registerManager(z.URL_SCHEME, new Ot());
|
|
1314
|
+
} catch {
|
|
1315
|
+
}
|
|
1316
|
+
}
|
|
1317
|
+
const Vt = {
|
|
1318
|
+
// tslint:disable-next-line:no-require-imports
|
|
1319
|
+
importFetch: () => require("node-fetch")
|
|
1320
|
+
};
|
|
1321
|
+
let se;
|
|
1322
|
+
class Ht {
|
|
1323
|
+
constructor() {
|
|
1324
|
+
this.util = require("util"), this.textEncoder = new this.util.TextEncoder();
|
|
1325
|
+
}
|
|
1326
|
+
fetch(e, t) {
|
|
1327
|
+
return B().global.fetch != null ? B().global.fetch(e, t) : (se == null && (se = Vt.importFetch()), se(e, t));
|
|
1328
|
+
}
|
|
1329
|
+
now() {
|
|
1330
|
+
const e = Q.hrtime();
|
|
1331
|
+
return e[0] * 1e3 + e[1] / 1e6;
|
|
1332
|
+
}
|
|
1333
|
+
encode(e, t) {
|
|
1334
|
+
if (t !== "utf-8" && t !== "utf8")
|
|
1335
|
+
throw new Error(`Node built-in encoder only supports utf-8, but got ${t}`);
|
|
1336
|
+
return this.textEncoder.encode(e);
|
|
1337
|
+
}
|
|
1338
|
+
decode(e, t) {
|
|
1339
|
+
return e.length === 0 ? "" : new this.util.TextDecoder(t).decode(e);
|
|
1340
|
+
}
|
|
1341
|
+
isTypedArray(e) {
|
|
1342
|
+
return this.util.types.isFloat32Array(e) || this.util.types.isInt32Array(e) || this.util.types.isUint8Array(e) || this.util.types.isUint8ClampedArray(e);
|
|
1343
|
+
}
|
|
1344
|
+
}
|
|
1345
|
+
B().get("IS_NODE") && !B().get("IS_BROWSER") && B().setPlatform("node", new Ht());
|
|
1346
|
+
function Xt(s, e = "float32", t) {
|
|
1347
|
+
return e = e || "float32", he(s), new bt(s, e, t);
|
|
1348
|
+
}
|
|
1349
|
+
function Jt(s, e) {
|
|
1350
|
+
const t = S(s, "x", "cast");
|
|
1351
|
+
if (!wt(e))
|
|
1352
|
+
throw new Error(`Failed to cast to unknown dtype ${e}`);
|
|
1353
|
+
if (e === "string" && t.dtype !== "string" || e !== "string" && t.dtype === "string")
|
|
1354
|
+
throw new Error("Only strings can be casted to strings");
|
|
1355
|
+
const n = { x: t }, r = { dtype: e };
|
|
1356
|
+
return u.runKernel(Ae, n, r);
|
|
1357
|
+
}
|
|
1358
|
+
const le = /* @__PURE__ */ M({ cast_: Jt });
|
|
1359
|
+
function Yt(s) {
|
|
1360
|
+
const t = { x: S(s, "x", "clone", "string_or_numeric") };
|
|
1361
|
+
return u.runKernel(Ee, t);
|
|
1362
|
+
}
|
|
1363
|
+
const Zt = /* @__PURE__ */ M({ clone_: Yt });
|
|
1364
|
+
function Qt(s, e = !1) {
|
|
1365
|
+
console.log(s.toString(e));
|
|
1366
|
+
}
|
|
1367
|
+
Re();
|
|
1368
|
+
const en = {
|
|
1369
|
+
buffer: Xt,
|
|
1370
|
+
cast: le,
|
|
1371
|
+
clone: Zt,
|
|
1372
|
+
print: Qt
|
|
1373
|
+
};
|
|
1374
|
+
St(en);
|
|
1375
|
+
function tn(s, e) {
|
|
1376
|
+
let t = S(s, "a", "add"), n = S(e, "b", "add");
|
|
1377
|
+
[t, n] = G(t, n);
|
|
1378
|
+
const r = { a: t, b: n };
|
|
1379
|
+
return u.runKernel(Me, r);
|
|
1380
|
+
}
|
|
1381
|
+
const y = /* @__PURE__ */ M({ add_: tn });
|
|
1382
|
+
function nn(s, e) {
|
|
1383
|
+
let t = S(s, "a", "floorDiv"), n = S(e, "b", "floorDiv");
|
|
1384
|
+
[t, n] = G(t, n);
|
|
1385
|
+
const r = { a: t, b: n };
|
|
1386
|
+
return u.runKernel(Xe, r);
|
|
1387
|
+
}
|
|
1388
|
+
const sn = /* @__PURE__ */ M({ floorDiv_: nn });
|
|
1389
|
+
function rn(s, e) {
|
|
1390
|
+
let t = S(s, "a", "div"), n = S(e, "b", "div");
|
|
1391
|
+
if ([t, n] = G(t, n), t.dtype === "int32" && n.dtype === "int32")
|
|
1392
|
+
return sn(t, n);
|
|
1393
|
+
const r = { a: t, b: n }, a = {};
|
|
1394
|
+
return u.runKernel(Je, r, a);
|
|
1395
|
+
}
|
|
1396
|
+
const x = /* @__PURE__ */ M({ div_: rn });
|
|
1397
|
+
function an(s, e) {
|
|
1398
|
+
let t = S(s, "a", "mul"), n = S(e, "b", "mul");
|
|
1399
|
+
[t, n] = G(t, n);
|
|
1400
|
+
const r = { a: t, b: n };
|
|
1401
|
+
return u.runKernel(Ye, r);
|
|
1402
|
+
}
|
|
1403
|
+
const f = /* @__PURE__ */ M({ mul_: an });
|
|
1404
|
+
function on(s) {
|
|
1405
|
+
const e = S(s, "x", "abs");
|
|
1406
|
+
if (e.dtype === "complex64") {
|
|
1407
|
+
const t = { x: e };
|
|
1408
|
+
return u.runKernel(Ze, t);
|
|
1409
|
+
} else {
|
|
1410
|
+
const t = { x: e };
|
|
1411
|
+
return u.runKernel(Qe, t);
|
|
1412
|
+
}
|
|
1413
|
+
}
|
|
1414
|
+
const cn = /* @__PURE__ */ M({ abs_: on });
|
|
1415
|
+
function ln(s, e, t) {
|
|
1416
|
+
he(s), t = t || ue(e);
|
|
1417
|
+
const n = { shape: s, value: e, dtype: t };
|
|
1418
|
+
return u.runKernel(et, {}, n);
|
|
1419
|
+
}
|
|
1420
|
+
function Jn(s, e) {
|
|
1421
|
+
const t = s.length, n = [];
|
|
1422
|
+
for (let r = 0; r < t; r++) {
|
|
1423
|
+
const a = t - 1 - r, i = s[a] || 1;
|
|
1424
|
+
(e[e.length - 1 - r] || 1) > 1 && i === 1 && n.unshift(a);
|
|
1425
|
+
}
|
|
1426
|
+
return n;
|
|
1427
|
+
}
|
|
1428
|
+
function Yn(s, e) {
|
|
1429
|
+
const t = [];
|
|
1430
|
+
for (let n = 0; n < e.length; n++) {
|
|
1431
|
+
const r = s[s.length - n - 1], a = e.length - n - 1, i = e[a];
|
|
1432
|
+
(r == null || r === 1 && i > 1) && t.unshift(a);
|
|
1433
|
+
}
|
|
1434
|
+
return t;
|
|
1435
|
+
}
|
|
1436
|
+
function un(s, e) {
|
|
1437
|
+
const t = Math.max(s.length, e.length), n = new Array(t);
|
|
1438
|
+
for (let r = 0; r < t; r++) {
|
|
1439
|
+
let a = s[s.length - r - 1];
|
|
1440
|
+
a == null && (a = 1);
|
|
1441
|
+
let i = e[e.length - r - 1];
|
|
1442
|
+
if (i == null && (i = 1), a === 1)
|
|
1443
|
+
n[t - r - 1] = i;
|
|
1444
|
+
else if (i === 1)
|
|
1445
|
+
n[t - r - 1] = a;
|
|
1446
|
+
else if (a !== i) {
|
|
1447
|
+
const o = `Operands could not be broadcast together with shapes ${s} and ${e}.`;
|
|
1448
|
+
throw Error(o);
|
|
1449
|
+
} else
|
|
1450
|
+
n[t - r - 1] = a;
|
|
1451
|
+
}
|
|
1452
|
+
return n;
|
|
1453
|
+
}
|
|
1454
|
+
function hn(s) {
|
|
1455
|
+
const t = { x: S(s, "x", "zerosLike") };
|
|
1456
|
+
return u.runKernel(tt, t);
|
|
1457
|
+
}
|
|
1458
|
+
const F = /* @__PURE__ */ M({ zerosLike_: hn });
|
|
1459
|
+
function dn(s, e) {
|
|
1460
|
+
let t = S(s, "base", "pow"), n = S(e, "exp", "pow");
|
|
1461
|
+
[t, n] = G(t, n);
|
|
1462
|
+
const r = { a: t, b: n };
|
|
1463
|
+
return u.runKernel(nt, r);
|
|
1464
|
+
}
|
|
1465
|
+
const Be = /* @__PURE__ */ M({ pow_: dn });
|
|
1466
|
+
function U(s, e) {
|
|
1467
|
+
if ((D(s) && e !== "string" || Array.isArray(s)) && e !== "complex64")
|
|
1468
|
+
throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");
|
|
1469
|
+
if (e === "string" && D(s) && !(s instanceof Uint8Array))
|
|
1470
|
+
throw new Error("When making a scalar from encoded string, the value must be `Uint8Array`.");
|
|
1471
|
+
return Mt(s, [], [], e);
|
|
1472
|
+
}
|
|
1473
|
+
function fn(s) {
|
|
1474
|
+
const t = { x: S(s, "x", "sqrt", "float32") };
|
|
1475
|
+
return u.runKernel(st, t);
|
|
1476
|
+
}
|
|
1477
|
+
const V = /* @__PURE__ */ M({ sqrt_: fn });
|
|
1478
|
+
function mn(s) {
|
|
1479
|
+
const e = S(s, "x", "square"), t = {};
|
|
1480
|
+
return u.runKernel("Square", { x: e }, t);
|
|
1481
|
+
}
|
|
1482
|
+
const L = /* @__PURE__ */ M({ square_: mn });
|
|
1483
|
+
function Zn(s) {
|
|
1484
|
+
return g(Y(s), () => "The f passed in grads(f) must be a function"), (e, t) => {
|
|
1485
|
+
g(Array.isArray(e), () => "The args passed in grads(f)(args) must be an array of `Tensor`s or `TensorLike`s");
|
|
1486
|
+
const n = Tt(e, "args", "tf.grads", "string_or_numeric"), r = t != null ? S(t, "dy", "tf.grads") : null;
|
|
1487
|
+
return u.tidy(() => {
|
|
1488
|
+
const { value: a, grads: i } = u.gradients(() => s(...n), n, r);
|
|
1489
|
+
return r != null && Ce(a.shape, r.shape, "The shape of dy passed in grads(f)([x1,...], dy) must match the shape returned by f([x1,...])"), We(i), i;
|
|
1490
|
+
});
|
|
1491
|
+
};
|
|
1492
|
+
}
|
|
1493
|
+
function Qn(s) {
|
|
1494
|
+
return g(Y(s), () => "The f passed in valueAndGrads(f) must be a function"), (e, t) => {
|
|
1495
|
+
g(Array.isArray(e) && e.every((r) => r instanceof R), () => "The args passed in valueAndGrads(f)(args) must be array of tensors"), g(t == null || t instanceof R, () => "The dy passed in valueAndGrads(f)(args, dy) must be a tensor");
|
|
1496
|
+
const n = u.gradients(() => s(...e), e, t);
|
|
1497
|
+
return t != null && Ce(n.value.shape, t.shape, "The shape of dy passed in valueAndGrads(f)([x1,...], dy) must match the shape returned by f([x1,...])"), We(n.grads), n;
|
|
1498
|
+
};
|
|
1499
|
+
}
|
|
1500
|
+
function gn(s, e) {
|
|
1501
|
+
g(Y(s), () => "The f passed in variableGrads(f) must be a function"), g(e == null || Array.isArray(e) && e.every((l) => l instanceof re), () => "The varList passed in variableGrads(f, varList) must be an array of variables");
|
|
1502
|
+
const t = e != null;
|
|
1503
|
+
if (!t) {
|
|
1504
|
+
e = [];
|
|
1505
|
+
for (const l in u.registeredVariables)
|
|
1506
|
+
e.push(u.registeredVariables[l]);
|
|
1507
|
+
}
|
|
1508
|
+
const n = t ? e.filter((l) => !l.trainable) : null, r = e.length;
|
|
1509
|
+
e = e.filter((l) => l.trainable), g(e.length > 0, () => `variableGrads() expects at least one of the input variables to be trainable, but none of the ${r} variables is trainable.`);
|
|
1510
|
+
const a = !0, { value: i, grads: o } = u.gradients(s, e, null, a);
|
|
1511
|
+
g(o.some((l) => l != null), () => "Cannot find a connection between any variable and the result of the loss function y=f(x). Please make sure the operations that use variables are inside the function f passed to minimize()."), g(i.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it returned a rank-${i.rank} tensor`);
|
|
1512
|
+
const c = {};
|
|
1513
|
+
return e.forEach((l, h) => {
|
|
1514
|
+
o[h] != null && (c[l.name] = o[h]);
|
|
1515
|
+
}), n?.forEach((l) => c[l.name] = null), { value: i, grads: c };
|
|
1516
|
+
}
|
|
1517
|
+
function es(s) {
|
|
1518
|
+
return u.customGrad(s);
|
|
1519
|
+
}
|
|
1520
|
+
function We(s) {
|
|
1521
|
+
if (s.filter((t) => t == null).length > 0)
|
|
1522
|
+
throw new Error(`Cannot compute gradient of y=f(x) with respect to x. Make sure that
|
|
1523
|
+
the f you passed encloses all operations that lead from x to y.`);
|
|
1524
|
+
}
|
|
1525
|
+
function pn(s, e) {
|
|
1526
|
+
let t = S(s, "a", "sub"), n = S(e, "b", "sub");
|
|
1527
|
+
[t, n] = G(t, n);
|
|
1528
|
+
const r = { a: t, b: n };
|
|
1529
|
+
return u.runKernel(rt, r);
|
|
1530
|
+
}
|
|
1531
|
+
const j = /* @__PURE__ */ M({ sub_: pn });
|
|
1532
|
+
function yn(s, e) {
|
|
1533
|
+
let t = S(s, "a", "maximum"), n = S(e, "b", "maximum");
|
|
1534
|
+
[t, n] = G(t, n), t.dtype === "bool" && (t = le(t, "int32"), n = le(n, "int32")), un(t.shape, n.shape);
|
|
1535
|
+
const r = { a: t, b: n };
|
|
1536
|
+
return u.runKernel(at, r);
|
|
1537
|
+
}
|
|
1538
|
+
const bn = /* @__PURE__ */ M({ maximum_: yn });
|
|
1539
|
+
const wn = /* @__PURE__ */ new Map(), Sn = /* @__PURE__ */ new Map();
|
|
1540
|
+
class kn {
|
|
1541
|
+
/**
|
|
1542
|
+
* Return the class name for this class to use in serialization contexts.
|
|
1543
|
+
*
|
|
1544
|
+
* Generally speaking this will be the same thing that constructor.name
|
|
1545
|
+
* would have returned. However, the class name needs to be robust
|
|
1546
|
+
* against minification for serialization/deserialization to work properly.
|
|
1547
|
+
*
|
|
1548
|
+
* There's also places such as initializers.VarianceScaling, where
|
|
1549
|
+
* implementation details between different languages led to different
|
|
1550
|
+
* class hierarchies and a non-leaf node is used for serialization purposes.
|
|
1551
|
+
*/
|
|
1552
|
+
getClassName() {
|
|
1553
|
+
return this.constructor.className;
|
|
1554
|
+
}
|
|
1555
|
+
/**
|
|
1556
|
+
* Creates an instance of T from a ConfigDict.
|
|
1557
|
+
*
|
|
1558
|
+
* This works for most descendants of serializable. A few need to
|
|
1559
|
+
* provide special handling.
|
|
1560
|
+
* @param cls A Constructor for the class to instantiate.
|
|
1561
|
+
* @param config The Configuration for the object.
|
|
1562
|
+
*/
|
|
1563
|
+
/** @nocollapse */
|
|
1564
|
+
static fromConfig(e, t) {
|
|
1565
|
+
return new e(t);
|
|
1566
|
+
}
|
|
1567
|
+
}
|
|
1568
|
+
class _ {
|
|
1569
|
+
constructor() {
|
|
1570
|
+
this.classNameMap = {};
|
|
1571
|
+
}
|
|
1572
|
+
/**
|
|
1573
|
+
* Returns the singleton instance of the map.
|
|
1574
|
+
*/
|
|
1575
|
+
static getMap() {
|
|
1576
|
+
return _.instance == null && (_.instance = new _()), _.instance;
|
|
1577
|
+
}
|
|
1578
|
+
/**
|
|
1579
|
+
* Registers the class as serializable.
|
|
1580
|
+
*/
|
|
1581
|
+
static register(e) {
|
|
1582
|
+
_.getMap().classNameMap[e.className] = [e, e.fromConfig];
|
|
1583
|
+
}
|
|
1584
|
+
}
|
|
1585
|
+
function In(s, e, t) {
|
|
1586
|
+
g(s.className != null, () => "Class being registered does not have the static className property defined."), g(typeof s.className == "string", () => "className is required to be a string, but got type " + typeof s.className), g(s.className.length > 0, () => "Class being registered has an empty-string as its className, which is disallowed."), typeof e > "u" && (e = "Custom"), typeof t > "u" && (t = s.className);
|
|
1587
|
+
const n = t, r = e + ">" + n;
|
|
1588
|
+
return _.register(s), wn.set(r, s), Sn.set(s, r), s;
|
|
1589
|
+
}
|
|
1590
|
+
class W extends kn {
|
|
1591
|
+
/**
|
|
1592
|
+
* Executes `f()` and minimizes the scalar output of `f()` by computing
|
|
1593
|
+
* gradients of y with respect to the list of trainable variables provided by
|
|
1594
|
+
* `varList`. If no list is provided, it defaults to all trainable variables.
|
|
1595
|
+
*
|
|
1596
|
+
* @param f The function to execute and whose output to minimize.
|
|
1597
|
+
* @param returnCost Whether to return the scalar cost value produced by
|
|
1598
|
+
* executing `f()`.
|
|
1599
|
+
* @param varList An optional list of variables to update. If specified, only
|
|
1600
|
+
* the trainable variables in varList will be updated by minimize. Defaults to
|
|
1601
|
+
* all trainable variables.
|
|
1602
|
+
*
|
|
1603
|
+
* @doc {heading: 'Training', subheading: 'Optimizers'}
|
|
1604
|
+
*/
|
|
1605
|
+
minimize(e, t = !1, n) {
|
|
1606
|
+
const { value: r, grads: a } = this.computeGradients(e, n);
|
|
1607
|
+
if (n != null) {
|
|
1608
|
+
const i = n.map((o) => ({ name: o.name, tensor: a[o.name] }));
|
|
1609
|
+
this.applyGradients(i);
|
|
1610
|
+
} else
|
|
1611
|
+
this.applyGradients(a);
|
|
1612
|
+
return A(a), t ? r : (r.dispose(), null);
|
|
1613
|
+
}
|
|
1614
|
+
/**
|
|
1615
|
+
* The number of iterations that this optimizer instance has been invoked for.
|
|
1616
|
+
*/
|
|
1617
|
+
get iterations() {
|
|
1618
|
+
return this.iterations_ == null && (this.iterations_ = 0), this.iterations_;
|
|
1619
|
+
}
|
|
1620
|
+
incrementIterations() {
|
|
1621
|
+
this.iterations_ = this.iterations + 1;
|
|
1622
|
+
}
|
|
1623
|
+
/**
|
|
1624
|
+
* Executes f() and computes the gradient of the scalar output of f() with
|
|
1625
|
+
* respect to the list of trainable variables provided by `varList`. If no
|
|
1626
|
+
* list is provided, it defaults to all trainable variables.
|
|
1627
|
+
*
|
|
1628
|
+
* @param f The function to execute and whose output to use for computing
|
|
1629
|
+
* gradients with respect to variables.
|
|
1630
|
+
* @param varList An optional list of variables to compute gradients with
|
|
1631
|
+
* respect to. If specified, only the trainable variables in varList will have
|
|
1632
|
+
* gradients computed with respect to. Defaults to all trainable variables.
|
|
1633
|
+
*
|
|
1634
|
+
* @doc {heading: 'Training', subheading: 'Optimizers'}
|
|
1635
|
+
*/
|
|
1636
|
+
computeGradients(e, t) {
|
|
1637
|
+
return gn(e, t);
|
|
1638
|
+
}
|
|
1639
|
+
/**
|
|
1640
|
+
* Dispose the variables (if any) owned by this optimizer instance.
|
|
1641
|
+
*/
|
|
1642
|
+
dispose() {
|
|
1643
|
+
this.iterations_ != null && A(this.iterations_);
|
|
1644
|
+
}
|
|
1645
|
+
async saveIterations() {
|
|
1646
|
+
return this.iterations_ == null && (this.iterations_ = 0), {
|
|
1647
|
+
name: "iter",
|
|
1648
|
+
// TODO(cais): Use 'int64' type when available.
|
|
1649
|
+
tensor: U(this.iterations_, "int32")
|
|
1650
|
+
};
|
|
1651
|
+
}
|
|
1652
|
+
async getWeights() {
|
|
1653
|
+
throw new Error("getWeights() is not implemented for this optimizer yet.");
|
|
1654
|
+
}
|
|
1655
|
+
async setWeights(e) {
|
|
1656
|
+
throw new Error(`setWeights() is not implemented for this optimizer class ${this.getClassName()}`);
|
|
1657
|
+
}
|
|
1658
|
+
/**
|
|
1659
|
+
* Extract the first element of the weight values and set it
|
|
1660
|
+
* as the iterations counter variable of this instance of optimizer.
|
|
1661
|
+
*
|
|
1662
|
+
* @param weightValues
|
|
1663
|
+
* @returns Weight values with the first element consumed and excluded.
|
|
1664
|
+
*/
|
|
1665
|
+
async extractIterations(e) {
|
|
1666
|
+
return this.iterations_ = (await e[0].tensor.data())[0], e.slice(1);
|
|
1667
|
+
}
|
|
1668
|
+
}
|
|
1669
|
+
Object.defineProperty(W, Symbol.hasInstance, {
|
|
1670
|
+
value: (s) => s.minimize != null && s.computeGradients != null && s.applyGradients != null
|
|
1671
|
+
});
|
|
1672
|
+
class vn extends W {
|
|
1673
|
+
/** @nocollapse */
|
|
1674
|
+
static get className() {
|
|
1675
|
+
return "Adadelta";
|
|
1676
|
+
}
|
|
1677
|
+
constructor(e, t, n = null) {
|
|
1678
|
+
super(), this.learningRate = e, this.rho = t, this.epsilon = n, this.accumulatedGrads = [], this.accumulatedUpdates = [], n == null && (this.epsilon = u.backend.epsilon());
|
|
1679
|
+
}
|
|
1680
|
+
applyGradients(e) {
|
|
1681
|
+
(Array.isArray(e) ? e.map((n) => n.name) : Object.keys(e)).forEach((n, r) => {
|
|
1682
|
+
const a = u.registeredVariables[n], i = !1;
|
|
1683
|
+
this.accumulatedGrads[r] == null && (this.accumulatedGrads[r] = {
|
|
1684
|
+
originalName: `${n}/accum_grad`,
|
|
1685
|
+
variable: k(() => F(a).variable(i))
|
|
1686
|
+
}), this.accumulatedUpdates[r] == null && (this.accumulatedUpdates[r] = {
|
|
1687
|
+
originalName: `${n}/accum_var`,
|
|
1688
|
+
variable: k(() => F(a).variable(i))
|
|
1689
|
+
});
|
|
1690
|
+
const o = Array.isArray(e) ? e[r].tensor : e[n];
|
|
1691
|
+
if (o == null)
|
|
1692
|
+
return;
|
|
1693
|
+
const c = this.accumulatedGrads[r].variable, l = this.accumulatedUpdates[r].variable;
|
|
1694
|
+
k(() => {
|
|
1695
|
+
const h = y(f(c, this.rho), f(L(o), 1 - this.rho)), m = f(x(V(y(l, this.epsilon)), V(y(c, this.epsilon))), o), d = y(f(l, this.rho), f(L(m), 1 - this.rho));
|
|
1696
|
+
c.assign(h), l.assign(d);
|
|
1697
|
+
const p = y(f(m, -this.learningRate), a);
|
|
1698
|
+
a.assign(p);
|
|
1699
|
+
});
|
|
1700
|
+
}), this.incrementIterations();
|
|
1701
|
+
}
|
|
1702
|
+
dispose() {
|
|
1703
|
+
this.accumulatedUpdates != null && (A(this.accumulatedGrads.map((e) => e.variable)), A(this.accumulatedUpdates.map((e) => e.variable)));
|
|
1704
|
+
}
|
|
1705
|
+
async getWeights() {
|
|
1706
|
+
const e = [...this.accumulatedGrads, ...this.accumulatedUpdates];
|
|
1707
|
+
return [await this.saveIterations()].concat(e.map((t) => ({ name: t.originalName, tensor: t.variable })));
|
|
1708
|
+
}
|
|
1709
|
+
async setWeights(e) {
|
|
1710
|
+
e = await this.extractIterations(e);
|
|
1711
|
+
const t = e.length / 2, n = !1;
|
|
1712
|
+
this.accumulatedGrads = e.slice(0, t).map((r) => ({
|
|
1713
|
+
originalName: r.name,
|
|
1714
|
+
variable: r.tensor.variable(n)
|
|
1715
|
+
})), this.accumulatedUpdates = e.slice(t, t * 2).map((r) => ({
|
|
1716
|
+
originalName: r.name,
|
|
1717
|
+
variable: r.tensor.variable(n)
|
|
1718
|
+
}));
|
|
1719
|
+
}
|
|
1720
|
+
getConfig() {
|
|
1721
|
+
return {
|
|
1722
|
+
learningRate: this.learningRate,
|
|
1723
|
+
rho: this.rho,
|
|
1724
|
+
epsilon: this.epsilon
|
|
1725
|
+
};
|
|
1726
|
+
}
|
|
1727
|
+
/** @nocollapse */
|
|
1728
|
+
static fromConfig(e, t) {
|
|
1729
|
+
return new e(t.learningRate, t.rho, t.epsilon);
|
|
1730
|
+
}
|
|
1731
|
+
}
|
|
1732
|
+
class Bn extends W {
|
|
1733
|
+
/** @nocollapse */
|
|
1734
|
+
static get className() {
|
|
1735
|
+
return "Adagrad";
|
|
1736
|
+
}
|
|
1737
|
+
constructor(e, t = 0.1) {
|
|
1738
|
+
super(), this.learningRate = e, this.initialAccumulatorValue = t, this.accumulatedGrads = [];
|
|
1739
|
+
}
|
|
1740
|
+
applyGradients(e) {
|
|
1741
|
+
(Array.isArray(e) ? e.map((n) => n.name) : Object.keys(e)).forEach((n, r) => {
|
|
1742
|
+
const a = u.registeredVariables[n];
|
|
1743
|
+
this.accumulatedGrads[r] == null && (this.accumulatedGrads[r] = {
|
|
1744
|
+
originalName: `${n}/accumulator`,
|
|
1745
|
+
variable: k(() => ln(a.shape, this.initialAccumulatorValue).variable(!1))
|
|
1746
|
+
});
|
|
1747
|
+
const i = Array.isArray(e) ? e[r].tensor : e[n];
|
|
1748
|
+
if (i == null)
|
|
1749
|
+
return;
|
|
1750
|
+
const o = this.accumulatedGrads[r].variable;
|
|
1751
|
+
k(() => {
|
|
1752
|
+
const c = y(o, L(i));
|
|
1753
|
+
o.assign(c);
|
|
1754
|
+
const l = y(f(x(i, V(y(c, u.backend.epsilon()))), -this.learningRate), a);
|
|
1755
|
+
a.assign(l);
|
|
1756
|
+
});
|
|
1757
|
+
}), this.incrementIterations();
|
|
1758
|
+
}
|
|
1759
|
+
dispose() {
|
|
1760
|
+
this.accumulatedGrads != null && A(this.accumulatedGrads.map((e) => e.variable));
|
|
1761
|
+
}
|
|
1762
|
+
async getWeights() {
|
|
1763
|
+
return [await this.saveIterations()].concat(this.accumulatedGrads.map((e) => ({ name: e.originalName, tensor: e.variable })));
|
|
1764
|
+
}
|
|
1765
|
+
async setWeights(e) {
|
|
1766
|
+
e = await this.extractIterations(e);
|
|
1767
|
+
const t = !1;
|
|
1768
|
+
this.accumulatedGrads = e.map((n) => ({ originalName: n.name, variable: n.tensor.variable(t) }));
|
|
1769
|
+
}
|
|
1770
|
+
getConfig() {
|
|
1771
|
+
return {
|
|
1772
|
+
learningRate: this.learningRate,
|
|
1773
|
+
initialAccumulatorValue: this.initialAccumulatorValue
|
|
1774
|
+
};
|
|
1775
|
+
}
|
|
1776
|
+
/** @nocollapse */
|
|
1777
|
+
static fromConfig(e, t) {
|
|
1778
|
+
return new e(t.learningRate, t.initialAccumulatorValue);
|
|
1779
|
+
}
|
|
1780
|
+
}
|
|
1781
|
+
class En extends W {
|
|
1782
|
+
/** @nocollapse */
|
|
1783
|
+
static get className() {
|
|
1784
|
+
return "Adam";
|
|
1785
|
+
}
|
|
1786
|
+
constructor(e, t, n, r = null) {
|
|
1787
|
+
super(), this.learningRate = e, this.beta1 = t, this.beta2 = n, this.epsilon = r, this.accumulatedFirstMoment = [], this.accumulatedSecondMoment = [], k(() => {
|
|
1788
|
+
this.accBeta1 = U(t).variable(), this.accBeta2 = U(n).variable();
|
|
1789
|
+
}), r == null && (this.epsilon = u.backend.epsilon());
|
|
1790
|
+
}
|
|
1791
|
+
applyGradients(e) {
|
|
1792
|
+
const t = Array.isArray(e) ? e.map((n) => n.name) : Object.keys(e);
|
|
1793
|
+
k(() => {
|
|
1794
|
+
const n = j(1, this.accBeta1), r = j(1, this.accBeta2);
|
|
1795
|
+
t.forEach((a, i) => {
|
|
1796
|
+
const o = u.registeredVariables[a], c = !1;
|
|
1797
|
+
this.accumulatedFirstMoment[i] == null && (this.accumulatedFirstMoment[i] = {
|
|
1798
|
+
originalName: `${a}/m`,
|
|
1799
|
+
variable: k(() => F(o).variable(c))
|
|
1800
|
+
}), this.accumulatedSecondMoment[i] == null && (this.accumulatedSecondMoment[i] = {
|
|
1801
|
+
originalName: `${a}/v`,
|
|
1802
|
+
variable: k(() => F(o).variable(c))
|
|
1803
|
+
});
|
|
1804
|
+
const l = Array.isArray(e) ? e[i].tensor : e[a];
|
|
1805
|
+
if (l == null)
|
|
1806
|
+
return;
|
|
1807
|
+
const h = this.accumulatedFirstMoment[i].variable, m = this.accumulatedSecondMoment[i].variable, d = y(f(h, this.beta1), f(l, 1 - this.beta1)), p = y(f(m, this.beta2), f(L(l), 1 - this.beta2)), b = x(d, n), w = x(p, r);
|
|
1808
|
+
h.assign(d), m.assign(p);
|
|
1809
|
+
const I = y(f(x(b, y(V(w), this.epsilon)), -this.learningRate), o);
|
|
1810
|
+
o.assign(I);
|
|
1811
|
+
}), this.accBeta1.assign(f(this.accBeta1, this.beta1)), this.accBeta2.assign(f(this.accBeta2, this.beta2));
|
|
1812
|
+
}), this.incrementIterations();
|
|
1813
|
+
}
|
|
1814
|
+
dispose() {
|
|
1815
|
+
this.accBeta1.dispose(), this.accBeta2.dispose(), this.accumulatedFirstMoment != null && A(this.accumulatedFirstMoment.map((e) => e.variable)), this.accumulatedSecondMoment != null && A(this.accumulatedSecondMoment.map((e) => e.variable));
|
|
1816
|
+
}
|
|
1817
|
+
async getWeights() {
|
|
1818
|
+
const e = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment];
|
|
1819
|
+
return [await this.saveIterations()].concat(e.map((t) => ({ name: t.originalName, tensor: t.variable })));
|
|
1820
|
+
}
|
|
1821
|
+
async setWeights(e) {
|
|
1822
|
+
e = await this.extractIterations(e), k(() => {
|
|
1823
|
+
this.accBeta1.assign(Be(this.beta1, this.iterations_ + 1)), this.accBeta2.assign(Be(this.beta2, this.iterations_ + 1));
|
|
1824
|
+
});
|
|
1825
|
+
const t = e.length / 2, n = !1;
|
|
1826
|
+
this.accumulatedFirstMoment = e.slice(0, t).map((r) => ({
|
|
1827
|
+
originalName: r.name,
|
|
1828
|
+
variable: r.tensor.variable(n)
|
|
1829
|
+
})), this.accumulatedSecondMoment = e.slice(t, t * 2).map((r) => ({
|
|
1830
|
+
originalName: r.name,
|
|
1831
|
+
variable: r.tensor.variable(n)
|
|
1832
|
+
}));
|
|
1833
|
+
}
|
|
1834
|
+
getConfig() {
|
|
1835
|
+
return {
|
|
1836
|
+
learningRate: this.learningRate,
|
|
1837
|
+
beta1: this.beta1,
|
|
1838
|
+
beta2: this.beta2,
|
|
1839
|
+
epsilon: this.epsilon
|
|
1840
|
+
};
|
|
1841
|
+
}
|
|
1842
|
+
/** @nocollapse */
|
|
1843
|
+
static fromConfig(e, t) {
|
|
1844
|
+
return new e(t.learningRate, t.beta1, t.beta2, t.epsilon);
|
|
1845
|
+
}
|
|
1846
|
+
}
|
|
1847
|
+
class Tn extends W {
|
|
1848
|
+
/** @nocollapse */
|
|
1849
|
+
static get className() {
|
|
1850
|
+
return "Adamax";
|
|
1851
|
+
}
|
|
1852
|
+
constructor(e, t, n, r = null, a = 0) {
|
|
1853
|
+
super(), this.learningRate = e, this.beta1 = t, this.beta2 = n, this.epsilon = r, this.decay = a, this.accumulatedFirstMoment = [], this.accumulatedWeightedInfNorm = [], k(() => {
|
|
1854
|
+
this.iteration = U(0).variable(), this.accBeta1 = U(t).variable();
|
|
1855
|
+
}), r == null && (this.epsilon = u.backend.epsilon());
|
|
1856
|
+
}
|
|
1857
|
+
applyGradients(e) {
|
|
1858
|
+
const t = Array.isArray(e) ? e.map((n) => n.name) : Object.keys(e);
|
|
1859
|
+
k(() => {
|
|
1860
|
+
const n = j(1, this.accBeta1), r = x(-this.learningRate, y(f(this.iteration, this.decay), 1));
|
|
1861
|
+
t.forEach((a, i) => {
|
|
1862
|
+
const o = u.registeredVariables[a], c = !1;
|
|
1863
|
+
this.accumulatedFirstMoment[i] == null && (this.accumulatedFirstMoment[i] = {
|
|
1864
|
+
originalName: `${a}/m`,
|
|
1865
|
+
variable: F(o).variable(c)
|
|
1866
|
+
}), this.accumulatedWeightedInfNorm[i] == null && (this.accumulatedWeightedInfNorm[i] = {
|
|
1867
|
+
originalName: `${a}/v`,
|
|
1868
|
+
variable: F(o).variable(c)
|
|
1869
|
+
});
|
|
1870
|
+
const l = Array.isArray(e) ? e[i].tensor : e[a];
|
|
1871
|
+
if (l == null)
|
|
1872
|
+
return;
|
|
1873
|
+
const h = this.accumulatedFirstMoment[i].variable, m = this.accumulatedWeightedInfNorm[i].variable, d = y(f(h, this.beta1), f(l, 1 - this.beta1)), p = f(m, this.beta2), b = cn(l), w = bn(p, b);
|
|
1874
|
+
h.assign(d), m.assign(w);
|
|
1875
|
+
const I = y(f(x(r, n), x(d, y(w, this.epsilon))), o);
|
|
1876
|
+
o.assign(I);
|
|
1877
|
+
}), this.iteration.assign(y(this.iteration, 1)), this.accBeta1.assign(f(this.accBeta1, this.beta1));
|
|
1878
|
+
}), this.incrementIterations();
|
|
1879
|
+
}
|
|
1880
|
+
dispose() {
|
|
1881
|
+
this.accBeta1.dispose(), this.iteration.dispose(), this.accumulatedFirstMoment != null && A(this.accumulatedFirstMoment.map((e) => e.variable)), this.accumulatedWeightedInfNorm != null && A(this.accumulatedWeightedInfNorm.map((e) => e.variable));
|
|
1882
|
+
}
|
|
1883
|
+
async getWeights() {
|
|
1884
|
+
throw new Error("getWeights() is not implemented for Adamax yet.");
|
|
1885
|
+
}
|
|
1886
|
+
async setWeights(e) {
|
|
1887
|
+
throw new Error("setWeights() is not implemented for Adamax yet.");
|
|
1888
|
+
}
|
|
1889
|
+
getConfig() {
|
|
1890
|
+
return {
|
|
1891
|
+
learningRate: this.learningRate,
|
|
1892
|
+
beta1: this.beta1,
|
|
1893
|
+
beta2: this.beta2,
|
|
1894
|
+
epsilon: this.epsilon,
|
|
1895
|
+
decay: this.decay
|
|
1896
|
+
};
|
|
1897
|
+
}
|
|
1898
|
+
/** @nocollapse */
|
|
1899
|
+
static fromConfig(e, t) {
|
|
1900
|
+
return new e(t.learningRate, t.beta1, t.beta2, t.epsilon, t.decay);
|
|
1901
|
+
}
|
|
1902
|
+
}
|
|
1903
|
+
class Ke extends W {
|
|
1904
|
+
/** @nocollapse */
|
|
1905
|
+
static get className() {
|
|
1906
|
+
return "SGD";
|
|
1907
|
+
}
|
|
1908
|
+
constructor(e) {
|
|
1909
|
+
super(), this.learningRate = e, this.setLearningRate(e);
|
|
1910
|
+
}
|
|
1911
|
+
applyGradients(e) {
|
|
1912
|
+
(Array.isArray(e) ? e.map((n) => n.name) : Object.keys(e)).forEach((n, r) => {
|
|
1913
|
+
const a = Array.isArray(e) ? e[r].tensor : e[n];
|
|
1914
|
+
if (a == null)
|
|
1915
|
+
return;
|
|
1916
|
+
const i = u.registeredVariables[n];
|
|
1917
|
+
k(() => {
|
|
1918
|
+
const o = y(f(this.c, a), i);
|
|
1919
|
+
i.assign(o);
|
|
1920
|
+
});
|
|
1921
|
+
}), this.incrementIterations();
|
|
1922
|
+
}
|
|
1923
|
+
/**
|
|
1924
|
+
* Sets the learning rate of the optimizer.
|
|
1925
|
+
*/
|
|
1926
|
+
setLearningRate(e) {
|
|
1927
|
+
this.learningRate = e, this.c != null && this.c.dispose(), this.c = xt(U(-e));
|
|
1928
|
+
}
|
|
1929
|
+
dispose() {
|
|
1930
|
+
this.c.dispose();
|
|
1931
|
+
}
|
|
1932
|
+
async getWeights() {
|
|
1933
|
+
return [await this.saveIterations()];
|
|
1934
|
+
}
|
|
1935
|
+
async setWeights(e) {
|
|
1936
|
+
if (e = await this.extractIterations(e), e.length !== 0)
|
|
1937
|
+
throw new Error("SGD optimizer does not have settable weights.");
|
|
1938
|
+
}
|
|
1939
|
+
getConfig() {
|
|
1940
|
+
return { learningRate: this.learningRate };
|
|
1941
|
+
}
|
|
1942
|
+
/** @nocollapse */
|
|
1943
|
+
static fromConfig(e, t) {
|
|
1944
|
+
return new e(t.learningRate);
|
|
1945
|
+
}
|
|
1946
|
+
}
|
|
1947
|
+
class An extends Ke {
|
|
1948
|
+
/** @nocollapse */
|
|
1949
|
+
// Name matters for Python compatibility.
|
|
1950
|
+
static get className() {
|
|
1951
|
+
return "Momentum";
|
|
1952
|
+
}
|
|
1953
|
+
constructor(e, t, n = !1) {
|
|
1954
|
+
super(e), this.learningRate = e, this.momentum = t, this.useNesterov = n, this.accumulations = [], this.m = U(this.momentum);
|
|
1955
|
+
}
|
|
1956
|
+
applyGradients(e) {
|
|
1957
|
+
(Array.isArray(e) ? e.map((n) => n.name) : Object.keys(e)).forEach((n, r) => {
|
|
1958
|
+
const a = u.registeredVariables[n];
|
|
1959
|
+
this.accumulations[r] == null && (this.accumulations[r] = {
|
|
1960
|
+
originalName: `${n}/momentum`,
|
|
1961
|
+
variable: k(() => F(a).variable(!1))
|
|
1962
|
+
});
|
|
1963
|
+
const i = this.accumulations[r].variable, o = Array.isArray(e) ? e[r].tensor : e[n];
|
|
1964
|
+
o != null && k(() => {
|
|
1965
|
+
let c;
|
|
1966
|
+
const l = y(f(this.m, i), o);
|
|
1967
|
+
this.useNesterov ? c = y(f(this.c, y(o, f(l, this.m))), a) : c = y(f(this.c, l), a), i.assign(l), a.assign(c);
|
|
1968
|
+
});
|
|
1969
|
+
}), this.incrementIterations();
|
|
1970
|
+
}
|
|
1971
|
+
dispose() {
|
|
1972
|
+
this.m.dispose(), this.accumulations != null && A(this.accumulations.map((e) => e.variable));
|
|
1973
|
+
}
|
|
1974
|
+
/**
|
|
1975
|
+
* Sets the momentum of the optimizer.
|
|
1976
|
+
*
|
|
1977
|
+
* @param momentum
|
|
1978
|
+
*/
|
|
1979
|
+
setMomentum(e) {
|
|
1980
|
+
this.momentum = e;
|
|
1981
|
+
}
|
|
1982
|
+
async getWeights() {
|
|
1983
|
+
return [await this.saveIterations()].concat(this.accumulations.map((e) => ({ name: e.originalName, tensor: e.variable })));
|
|
1984
|
+
}
|
|
1985
|
+
async setWeights(e) {
|
|
1986
|
+
e = await this.extractIterations(e);
|
|
1987
|
+
const t = !1;
|
|
1988
|
+
this.accumulations = e.map((n) => ({ originalName: n.name, variable: n.tensor.variable(t) }));
|
|
1989
|
+
}
|
|
1990
|
+
getConfig() {
|
|
1991
|
+
return {
|
|
1992
|
+
learningRate: this.learningRate,
|
|
1993
|
+
momentum: this.momentum,
|
|
1994
|
+
useNesterov: this.useNesterov
|
|
1995
|
+
};
|
|
1996
|
+
}
|
|
1997
|
+
/** @nocollapse */
|
|
1998
|
+
static fromConfig(e, t) {
|
|
1999
|
+
return new e(t.learningRate, t.momentum, t.useNesterov);
|
|
2000
|
+
}
|
|
2001
|
+
}
|
|
2002
|
+
class Mn extends W {
|
|
2003
|
+
/** @nocollapse */
|
|
2004
|
+
static get className() {
|
|
2005
|
+
return "RMSProp";
|
|
2006
|
+
}
|
|
2007
|
+
constructor(e, t = 0.9, n = 0, r = null, a = !1) {
|
|
2008
|
+
if (super(), this.learningRate = e, this.decay = t, this.momentum = n, this.epsilon = r, this.accumulatedMeanSquares = [], this.accumulatedMoments = [], this.accumulatedMeanGrads = [], this.centered = a, r == null && (this.epsilon = u.backend.epsilon()), e == null)
|
|
2009
|
+
throw new Error("learningRate for RMSPropOptimizer must be defined.");
|
|
2010
|
+
}
|
|
2011
|
+
applyGradients(e) {
|
|
2012
|
+
(Array.isArray(e) ? e.map((n) => n.name) : Object.keys(e)).forEach((n, r) => {
|
|
2013
|
+
const a = u.registeredVariables[n], i = !1;
|
|
2014
|
+
this.accumulatedMeanSquares[r] == null && (this.accumulatedMeanSquares[r] = {
|
|
2015
|
+
originalName: `${n}/rms`,
|
|
2016
|
+
variable: k(() => F(a).variable(i))
|
|
2017
|
+
}), this.accumulatedMoments[r] == null && (this.accumulatedMoments[r] = {
|
|
2018
|
+
originalName: `${n}/momentum`,
|
|
2019
|
+
variable: k(() => F(a).variable(i))
|
|
2020
|
+
}), this.accumulatedMeanGrads[r] == null && this.centered && (this.accumulatedMeanGrads[r] = {
|
|
2021
|
+
originalName: `${n}/mg`,
|
|
2022
|
+
variable: k(() => F(a).variable(i))
|
|
2023
|
+
});
|
|
2024
|
+
const o = Array.isArray(e) ? e[r].tensor : e[n];
|
|
2025
|
+
if (o == null)
|
|
2026
|
+
return;
|
|
2027
|
+
const c = this.accumulatedMeanSquares[r].variable, l = this.accumulatedMoments[r].variable;
|
|
2028
|
+
k(() => {
|
|
2029
|
+
const h = y(f(c, this.decay), f(L(o), 1 - this.decay));
|
|
2030
|
+
if (this.centered) {
|
|
2031
|
+
const m = this.accumulatedMeanGrads[r].variable, d = y(f(m, this.decay), f(o, 1 - this.decay)), p = x(f(o, this.learningRate), V(j(h, y(L(d), this.epsilon)))), b = y(f(l, this.momentum), p);
|
|
2032
|
+
c.assign(h), m.assign(d), l.assign(b);
|
|
2033
|
+
const w = j(a, b);
|
|
2034
|
+
a.assign(w);
|
|
2035
|
+
} else {
|
|
2036
|
+
const m = y(f(c, this.decay), f(L(o), 1 - this.decay)), d = y(f(l, this.momentum), x(f(o, this.learningRate), V(y(m, this.epsilon))));
|
|
2037
|
+
c.assign(m), l.assign(d);
|
|
2038
|
+
const p = j(a, d);
|
|
2039
|
+
a.assign(p);
|
|
2040
|
+
}
|
|
2041
|
+
});
|
|
2042
|
+
}), this.incrementIterations();
|
|
2043
|
+
}
|
|
2044
|
+
dispose() {
|
|
2045
|
+
this.accumulatedMeanSquares != null && A(this.accumulatedMeanSquares.map((e) => e.variable)), this.accumulatedMeanGrads != null && this.centered && A(this.accumulatedMeanGrads.map((e) => e.variable)), this.accumulatedMoments != null && A(this.accumulatedMoments.map((e) => e.variable));
|
|
2046
|
+
}
|
|
2047
|
+
async getWeights() {
|
|
2048
|
+
const e = [...this.accumulatedMeanSquares, ...this.accumulatedMoments];
|
|
2049
|
+
return this.centered && e.push(...this.accumulatedMeanGrads), [await this.saveIterations()].concat(e.map((t) => ({ name: t.originalName, tensor: t.variable })));
|
|
2050
|
+
}
|
|
2051
|
+
async setWeights(e) {
|
|
2052
|
+
e = await this.extractIterations(e);
|
|
2053
|
+
const t = this.centered ? e.length / 3 : e.length / 2, n = !1;
|
|
2054
|
+
this.accumulatedMeanSquares = e.slice(0, t).map((r) => ({
|
|
2055
|
+
originalName: r.name,
|
|
2056
|
+
variable: r.tensor.variable(n)
|
|
2057
|
+
})), this.accumulatedMoments = e.slice(t, t * 2).map((r) => ({
|
|
2058
|
+
originalName: r.name,
|
|
2059
|
+
variable: r.tensor.variable(n)
|
|
2060
|
+
})), this.centered && (this.accumulatedMeanGrads = e.slice(t * 2, t * 3).map((r) => ({
|
|
2061
|
+
originalName: r.name,
|
|
2062
|
+
variable: r.tensor.variable(n)
|
|
2063
|
+
})));
|
|
2064
|
+
}
|
|
2065
|
+
getConfig() {
|
|
2066
|
+
return {
|
|
2067
|
+
learningRate: this.learningRate,
|
|
2068
|
+
decay: this.decay,
|
|
2069
|
+
momentum: this.momentum,
|
|
2070
|
+
epsilon: this.epsilon,
|
|
2071
|
+
centered: this.centered
|
|
2072
|
+
};
|
|
2073
|
+
}
|
|
2074
|
+
/** @nocollapse */
|
|
2075
|
+
static fromConfig(e, t) {
|
|
2076
|
+
return new e(t.learningRate, t.decay, t.momentum, t.epsilon, t.centered);
|
|
2077
|
+
}
|
|
2078
|
+
}
|
|
2079
|
+
const Nn = [
|
|
2080
|
+
vn,
|
|
2081
|
+
Bn,
|
|
2082
|
+
En,
|
|
2083
|
+
Tn,
|
|
2084
|
+
An,
|
|
2085
|
+
Mn,
|
|
2086
|
+
Ke
|
|
2087
|
+
];
|
|
2088
|
+
function xn() {
|
|
2089
|
+
for (const s of Nn)
|
|
2090
|
+
In(s);
|
|
2091
|
+
}
|
|
2092
|
+
xn();
|
|
2093
|
+
export {
|
|
2094
|
+
bn as $,
|
|
2095
|
+
M as A,
|
|
2096
|
+
S as B,
|
|
2097
|
+
Tt as C,
|
|
2098
|
+
Zt as D,
|
|
2099
|
+
u as E,
|
|
2100
|
+
Zn as F,
|
|
2101
|
+
le as G,
|
|
2102
|
+
_n as H,
|
|
2103
|
+
It as I,
|
|
2104
|
+
Xt as J,
|
|
2105
|
+
v as K,
|
|
2106
|
+
H as L,
|
|
2107
|
+
qn as M,
|
|
2108
|
+
_e as N,
|
|
2109
|
+
W as O,
|
|
2110
|
+
Vn as P,
|
|
2111
|
+
Hn as Q,
|
|
2112
|
+
Wn as R,
|
|
2113
|
+
cn as S,
|
|
2114
|
+
L as T,
|
|
2115
|
+
Ke as U,
|
|
2116
|
+
An as V,
|
|
2117
|
+
Mn as W,
|
|
2118
|
+
En as X,
|
|
2119
|
+
vn as Y,
|
|
2120
|
+
Tn as Z,
|
|
2121
|
+
Bn as _,
|
|
2122
|
+
Ln as a,
|
|
2123
|
+
In as a0,
|
|
2124
|
+
kn as a1,
|
|
2125
|
+
_ as a2,
|
|
2126
|
+
Xn as a3,
|
|
2127
|
+
Kn as a4,
|
|
2128
|
+
jn as a5,
|
|
2129
|
+
sn as a6,
|
|
2130
|
+
$n as a7,
|
|
2131
|
+
Bt as a8,
|
|
2132
|
+
U as b,
|
|
2133
|
+
j as c,
|
|
2134
|
+
A as d,
|
|
2135
|
+
On as e,
|
|
2136
|
+
Gn as f,
|
|
2137
|
+
Un as g,
|
|
2138
|
+
Jn as h,
|
|
2139
|
+
ln as i,
|
|
2140
|
+
un as j,
|
|
2141
|
+
xt as k,
|
|
2142
|
+
x as l,
|
|
2143
|
+
f as m,
|
|
2144
|
+
y as n,
|
|
2145
|
+
V as o,
|
|
2146
|
+
Be as p,
|
|
2147
|
+
Yn as q,
|
|
2148
|
+
Pn as r,
|
|
2149
|
+
zn as s,
|
|
2150
|
+
k as t,
|
|
2151
|
+
Qn as u,
|
|
2152
|
+
gn as v,
|
|
2153
|
+
es as w,
|
|
2154
|
+
Et as x,
|
|
2155
|
+
Mt as y,
|
|
2156
|
+
F as z
|
|
2157
|
+
};
|