@genai-fi/nanogpt 0.17.4 → 0.18.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +2 -15
- package/dist/Generator.js +45 -34
- package/dist/{RealDiv-CGwv0liw.js → RealDiv-ioj6Z-ox.js} +9 -9
- package/dist/{Reshape-BW__R4mZ.js → Reshape-BZC-ebeR.js} +7 -7
- package/dist/{Reshape-CPBkTIH2.js → Reshape-pwprEaej.js} +1 -1
- package/dist/TeachableLLM.d.ts +3 -8
- package/dist/TeachableLLM.js +61 -44
- package/dist/Trainer.d.ts +6 -4
- package/dist/Trainer.js +107 -92
- package/dist/{axis_util-GTVlo58H.js → axis_util-QWWgLjut.js} +1 -1
- package/dist/backend.js +2 -2
- package/dist/{backend_util-GaFarB78.js → backend_util-qwSFfxYx.js} +21 -21
- package/dist/{backend_webgpu-BqASlsbV.js → backend_webgpu-DI2wXEC2.js} +8 -8
- package/dist/{broadcast_to-eS93CCN_.js → broadcast_to-C_EJTVTZ.js} +2 -2
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +5 -5
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/matMulGelu.js +2 -2
- package/dist/checks/normRMS.js +6 -6
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.js +6 -6
- package/dist/checks/qkv.js +2 -2
- package/dist/checks/rope.js +2 -2
- package/dist/{clip_by_value-DDA7rrcT.js → clip_by_value-CLAD4h_I.js} +1 -1
- package/dist/complex-3DpPEG9B.js +11 -0
- package/dist/{concat-CAQpCret.js → concat-Dqk7Xk7h.js} +5 -5
- package/dist/{concat_util-D18dJ4fD.js → concat_util-C1Mxe27t.js} +1 -1
- package/dist/{dataset-CGGp1z9P.js → dataset-DlqAN81i.js} +3 -3
- package/dist/{dropout_util--NxWuYg2.js → dropout_util-N0z8Os-K.js} +1 -1
- package/dist/{expand_dims-Bkd1YD5x.js → expand_dims-D0rBtgT1.js} +4 -4
- package/dist/{exports_initializers-CYzKLjN7.js → exports_initializers-DIOZQt_L.js} +1 -1
- package/dist/{floor-BQtb-Azg.js → floor-CymuCmTO.js} +1 -1
- package/dist/{gather-qIqEqaGn.js → gather-DEyjXNb1.js} +1 -1
- package/dist/{gelu-B220X1Go.js → gelu-DpTCC3eB.js} +1 -1
- package/dist/{gpgpu_math-BwvV12df.js → gpgpu_math-3bCb5ooU.js} +25 -25
- package/dist/{index-CjOWnMXP.js → index-BQvB7LCC.js} +15 -15
- package/dist/{index-CUXkjxiT.js → index-DSGwv2Yx.js} +33 -33
- package/dist/inference/types.d.ts +16 -0
- package/dist/inference/types.js +1 -0
- package/dist/{kernel_funcs_utils-pq0CK9co.js → kernel_funcs_utils-DGqzNlHT.js} +6 -6
- package/dist/layers/BaseLayer.js +4 -4
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/LoRA.js +4 -4
- package/dist/layers/MLP.js +4 -4
- package/dist/layers/PositionEmbedding.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +6 -6
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/layers/WeightStore.js +2 -2
- package/dist/loader/load.d.ts +2 -8
- package/dist/loader/loadTransformers.d.ts +2 -8
- package/dist/loader/loadTransformers.js +13 -11
- package/dist/loader/newZipLoad.d.ts +2 -8
- package/dist/loader/newZipLoad.js +25 -10
- package/dist/loader/oldZipLoad.js +13 -13
- package/dist/loader/save.d.ts +9 -2
- package/dist/loader/save.js +64 -55
- package/dist/loader/types.d.ts +29 -1
- package/dist/main.d.ts +2 -0
- package/dist/main.js +45 -43
- package/dist/{matMul16-BcVC_E62.js → matMul16-BIT70Vya.js} +3 -3
- package/dist/{matMulGelu-JNLZqKQp.js → matMulGelu-CsZnh18H.js} +18 -18
- package/dist/mat_mul-DP86qZtZ.js +11 -0
- package/dist/mod-BXjLYwvM.js +11 -0
- package/dist/models/NanoGPTV1.js +2 -2
- package/dist/models/NanoGPTV2.js +2 -2
- package/dist/models/model.d.ts +3 -2
- package/dist/models/model.js +13 -13
- package/dist/{not_equal-hurPF26l.js → not_equal-CkQKkKZy.js} +15 -15
- package/dist/{ones-BytntneX.js → ones-DbVB5N58.js} +3 -3
- package/dist/ops/adamAdjust.js +3 -3
- package/dist/ops/adamMoments.js +3 -3
- package/dist/ops/add16.js +1 -1
- package/dist/ops/appendCache.js +6 -6
- package/dist/ops/attentionMask.js +3 -3
- package/dist/ops/concat16.js +3 -3
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +5 -5
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +6 -6
- package/dist/ops/cpu/fusedSoftmax.js +4 -4
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +4 -4
- package/dist/ops/cpu/matMul16.js +2 -2
- package/dist/ops/cpu/matMulGelu.js +7 -7
- package/dist/ops/cpu/matMulMul.js +2 -2
- package/dist/ops/cpu/mulDropout.js +5 -5
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +5 -5
- package/dist/ops/dot16.js +2 -2
- package/dist/ops/dropout.js +6 -6
- package/dist/ops/dropout16.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/globalNorm.js +7 -7
- package/dist/ops/grads/add16.js +1 -1
- package/dist/ops/grads/attentionMask.js +2 -2
- package/dist/ops/grads/dropout16.js +1 -1
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMul16.js +3 -3
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/mul16.js +1 -1
- package/dist/ops/grads/normRMS.js +7 -7
- package/dist/ops/grads/pack16.js +3 -3
- package/dist/ops/grads/qkv.js +11 -11
- package/dist/ops/grads/rope.js +2 -2
- package/dist/ops/grads/softmax16.js +1 -1
- package/dist/ops/grads/unpack16.js +2 -2
- package/dist/ops/matMul16.js +3 -3
- package/dist/ops/matMulGelu.js +6 -6
- package/dist/ops/matMulMul.js +3 -3
- package/dist/ops/mul16.js +1 -1
- package/dist/ops/mulDrop.js +3 -3
- package/dist/ops/normRMS.js +4 -4
- package/dist/ops/pack16.js +2 -2
- package/dist/ops/qkv.js +3 -3
- package/dist/ops/reshape16.js +6 -6
- package/dist/ops/rope.js +2 -2
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.js +2 -2
- package/dist/ops/softmax16.js +1 -1
- package/dist/ops/sub16.js +1 -1
- package/dist/ops/sum16.js +6 -6
- package/dist/ops/transpose16.js +3 -3
- package/dist/ops/unpack16.js +2 -2
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/dropout16.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +7 -7
- package/dist/ops/webgl/gatherSub.js +3 -3
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMul16.js +13 -13
- package/dist/ops/webgl/matMulGelu.js +4 -4
- package/dist/ops/webgl/matMulMul.js +2 -2
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +2 -2
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/add16.js +6 -6
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +2 -2
- package/dist/ops/webgpu/attentionMask32_program.js +2 -2
- package/dist/ops/webgpu/clipScale.js +7 -7
- package/dist/ops/webgpu/concat16.js +5 -5
- package/dist/ops/webgpu/dropout16.js +6 -6
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +8 -8
- package/dist/ops/webgpu/matMul16.js +16 -16
- package/dist/ops/webgpu/matMul16_program.js +2 -2
- package/dist/ops/webgpu/mul16.js +5 -5
- package/dist/ops/webgpu/norm2.js +1 -1
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +4 -4
- package/dist/ops/webgpu/pack16.js +4 -4
- package/dist/ops/webgpu/pack16_program.js +2 -2
- package/dist/ops/webgpu/qkv.js +2 -2
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/slice16.js +4 -4
- package/dist/ops/webgpu/softmax16.js +4 -4
- package/dist/ops/webgpu/softmax16_program.js +2 -2
- package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
- package/dist/ops/webgpu/softmax16grad.js +4 -4
- package/dist/ops/webgpu/sub16.js +6 -6
- package/dist/ops/webgpu/sum16.js +3 -3
- package/dist/ops/webgpu/transpose16.js +8 -8
- package/dist/ops/webgpu/transpose16_program.js +2 -2
- package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
- package/dist/ops/webgpu/unpack16.js +3 -3
- package/dist/ops/webgpu/utils/binary_op.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +5 -5
- package/dist/{ops-CsXeTq1P.js → ops-CURIZSVt.js} +100 -100
- package/dist/{pack16-bqltoUlR.js → pack16-WlOSOuZA.js} +2 -2
- package/dist/patches/webgpu_backend.js +6 -6
- package/dist/patches/webgpu_base.js +1 -1
- package/dist/patches/webgpu_program.js +2 -2
- package/dist/{random_normal-IBRrha8a.js → random_normal-CIm8lk2-.js} +1 -1
- package/dist/{random_width-DN5ZtQkM.js → random_width-B_fVXhGx.js} +131 -131
- package/dist/{range-C-CjF-LI.js → range-BDxO73mk.js} +1 -1
- package/dist/{readers-iz5u3HBo.js → readers-17HLdxVM.js} +2 -2
- package/dist/relu-DTvZKBsZ.js +9 -0
- package/dist/{reshape-BDOuCSNW.js → reshape-BIN71H3p.js} +1 -1
- package/dist/{resize_nearest_neighbor-BojqlfRe.js → resize_nearest_neighbor-C6_0dAnK.js} +41 -41
- package/dist/{rope-0j_f1TPm.js → rope-CC5RjmKU.js} +4 -4
- package/dist/{scatter_nd_util-ByNJaL6I.js → scatter_nd_util-C-x73Cj6.js} +1 -1
- package/dist/{segment_util-Dasb2Zaf.js → segment_util-4zuHV5IG.js} +2 -2
- package/dist/{selu_util-BLhIqRkw.js → selu_util-BXdhy_W6.js} +5 -5
- package/dist/{shared-CagdqkLh.js → shared-DRWDyk9w.js} +6 -6
- package/dist/{shared-3agzAqQ_.js → shared-zTaJ5siv.js} +1 -1
- package/dist/slice-BvItlgXu.js +12 -0
- package/dist/{slice_util-CC35pLmT.js → slice_util-DPY56GzQ.js} +5 -5
- package/dist/{softmax-D4q1LJN7.js → softmax-BLGJqdwx.js} +1 -1
- package/dist/split-BN9LkEgS.js +9 -0
- package/dist/{squeeze-ho4wLUek.js → squeeze-O_YWJpw_.js} +2 -2
- package/dist/{stack-DudVrtmG.js → stack-z6QE7kmP.js} +1 -1
- package/dist/{step-BTxPtq1r.js → step-DQY6_ABw.js} +4 -4
- package/dist/{sum-BpiwSWvg.js → sum-D39FeU5h.js} +3 -3
- package/dist/{tensor-BWFldCso.js → tensor-D8e0Gd7c.js} +1 -1
- package/dist/{tensor1d-LMGMIUlr.js → tensor1d-BMl0eZYV.js} +1 -1
- package/dist/{tensor2d-BnXMKScO.js → tensor2d-DTtQ1QcT.js} +1 -1
- package/dist/{tensor4d-C6UCG_u8.js → tensor4d-Dj4rDssL.js} +1 -1
- package/dist/{tfjs_backend-BGnG-ppu.js → tfjs_backend-Bk3PmK91.js} +65 -65
- package/dist/{tile-CFy-xTO6.js → tile-CsWlVKKz.js} +1 -1
- package/dist/tokeniser/BaseTokeniser.d.ts +4 -1
- package/dist/tokeniser/BaseTokeniser.js +21 -5
- package/dist/tokeniser/CharTokeniser.d.ts +1 -1
- package/dist/tokeniser/CharTokeniser.js +62 -50
- package/dist/tokeniser/bpe.d.ts +1 -1
- package/dist/tokeniser/bpe.js +41 -35
- package/dist/tokeniser/type.d.ts +3 -1
- package/dist/training/AdamW.d.ts +3 -0
- package/dist/training/AdamW.js +59 -30
- package/dist/training/BasicTrainer.d.ts +1 -0
- package/dist/training/BasicTrainer.js +112 -92
- package/dist/training/DatasetBuilder.js +3 -3
- package/dist/training/Evaluator.js +2 -2
- package/dist/training/LRScheduler.d.ts +1 -0
- package/dist/training/LRScheduler.js +18 -12
- package/dist/training/PreTrainer.js +3 -3
- package/dist/training/SFTDatasetBuilder.js +3 -3
- package/dist/training/SFTTrainer.js +1 -1
- package/dist/training/orthoGrad.js +1 -1
- package/dist/training/sparseCrossEntropy.js +30 -30
- package/dist/training/types.d.ts +5 -3
- package/dist/training/validation.js +13 -13
- package/dist/{transpose-9kRxIXWR.js → transpose-Qxz-4os3.js} +7 -7
- package/dist/{unsorted_segment_sum-DJvk5xnh.js → unsorted_segment_sum-BfFVV9Zm.js} +20 -20
- package/dist/utilities/datasetID.d.ts +2 -0
- package/dist/utilities/datasetID.js +21 -0
- package/dist/utilities/dummy.js +6 -6
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.js +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.js +5 -5
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-Ck482e3n.js → variable-SSATClyt.js} +1 -1
- package/dist/{webgpu_program-B4HmApL1.js → webgpu_program-CbjdYLYk.js} +1 -1
- package/dist/{webgpu_util-DYlGSwOJ.js → webgpu_util-DuofJBMo.js} +7 -7
- package/dist/{zeros-DvZpK8s6.js → zeros-Bw0puq_w.js} +2 -2
- package/dist/{zeros_like-CWjDdwr-.js → zeros_like-rOHr54NY.js} +69 -69
- package/package.json +3 -3
- package/dist/complex-DI35Q-gW.js +0 -11
- package/dist/mat_mul-DhG0Newp.js +0 -11
- package/dist/mod-CSdCpRjf.js +0 -11
- package/dist/relu-J_X6MUzx.js +0 -9
- package/dist/slice-BzS11Qh0.js +0 -12
- package/dist/split-C2Sj255c.js +0 -9
|
@@ -1,37 +1,37 @@
|
|
|
1
|
-
import { o as tt,
|
|
2
|
-
import { c as te, k as Gs, m as Tn, b as ou, t as hs, a as Ps, s as Us, l as lu, p as uu, e as cu } from "./step-
|
|
3
|
-
import { n as pt, t as P } from "./transpose-
|
|
4
|
-
import { r as C } from "./reshape-
|
|
5
|
-
import { s as _ } from "./sum-
|
|
6
|
-
import { m as Mt } from "./mat_mul-
|
|
7
|
-
import { j as Vs, o as Xe, G as pi, b as hu, s as pu, D as du, t as fu, A as mu, F as gu, e as Ht, q as En, r as Ye, l as bu, c as di, p as yu, z as fi, d as wu, x as ku, a as js, i as Qe, y as vt, B as Nu, E as Hs, w as xu, v as vu, n as Su, C as Au, k as Iu, u as Cu, h as Ln, g as Du, m as zu, f as Tu } from "./unsorted_segment_sum-
|
|
8
|
-
import { b as Re, c as _e, l as
|
|
9
|
-
import { a as st, w as
|
|
10
|
-
import { s as Vt } from "./split-
|
|
11
|
-
import { e as Es, a as gi, g as
|
|
12
|
-
import { t as Ce } from "./tile-
|
|
13
|
-
import { s as ps } from "./stack-
|
|
14
|
-
import { o as he } from "./ones-
|
|
15
|
-
import { s as me } from "./slice-
|
|
16
|
-
import { f as bi } from "./floor-
|
|
17
|
-
import { z as ft } from "./zeros-
|
|
18
|
-
import { s as Ou, b as Ru, c as _u, g as Bu, a as Wu, S as Gu } from "./selu_util-
|
|
19
|
-
import { p as Pu } from "./slice_util-
|
|
20
|
-
import { c as Zs } from "./concat-
|
|
21
|
-
import { e as ue } from "./expand_dims-
|
|
22
|
-
import { g as Uu } from "./gather-
|
|
23
|
-
import { V as d, N as B, r as ds, c as Js, a as et, b as ge, e as Be, s as Xs, g as yi, f as wi, t as Ot, R as zt, h as ht, A as Gt, i as V, n as oe, j as nt, k as Vu, l as ju, m as We, o as Pt, p as Tt, q as ae, u as jt, v as
|
|
24
|
-
import { s as tn } from "./squeeze-
|
|
25
|
-
import { t as Ls } from "./tensor1d-
|
|
26
|
-
import { r as Pe } from "./relu-
|
|
27
|
-
import { c as At } from "./clip_by_value-
|
|
28
|
-
import { s as Si } from "./softmax-
|
|
29
|
-
import { r as ms } from "./dropout_util
|
|
30
|
-
import { e as ec, l as sc, i as Rt } from "./ops-
|
|
31
|
-
import { t as nc } from "./tensor-
|
|
32
|
-
import { r as ic } from "./range-
|
|
1
|
+
import { o as tt, n as R, v as S, E as X, cy as li, J as L, cz as ui, cA as Oa, cB as Ra, cC as ci, ah as Yt, j as ut, a1 as U, q as _a, a5 as Ba, cD as Wa, h as z, z as Ga, _ as Lt, a2 as An, cE as In, cF as Pa, cG as Ua, cH as Va, cI as ja, cJ as Ha, cK as Ka, cL as qa, cM as Za, cN as Ja, bP as Xa, m as f, c7 as Ya, i as Qt, b as Q, d as W, c8 as Qa, bV as to, $ as lt, cO as eo, bo as so, a3 as Y, c9 as no, ca as io, cb as ro, cd as ao, cc as oo, ce as lo, cP as uo, cQ as co, bp as ho, B as po, br as fo, cR as mo, bS as go, bY as bo, C as yo, cS as wo, x as ko, bt as No, bu as xo, cT as vo, bv as So, bw as Ao, by as Io, bz as Co, cg as Do, cU as zo, cV as To, aG as Eo, cW as Lo, bB as $o, aN as Fo, y as Mo, bZ as Oo, F as Ro, b_ as _o, bs as Bo, G as Wo, b0 as Go, aU as Po, ch as Uo, ci as Vo, cj as jo, aH as Ho, b3 as Ko, aO as qo, cX as Zo, cY as Jo, ck as Xo, aM as Yo, b$ as Qo, cZ as tl, c_ as el, bE as sl, aP as nl, N as hi, aZ as il, b5 as rl, cm as al, M as ol, c0 as ll, ap as ul, bF as cl, bG as hl, P as pl, bH as dl, c$ as fl, p as Ws, aI as ml, c1 as gl, aX as bl, cn as yl, aJ as wl, A as kl, R as Nl, b9 as xl, d0 as vl, ba as Sl, d1 as Al, bJ as Il, b6 as Cl, bK as Dl, aL as zl, bL as Tl, aF as El, co as Ll, bM as $l, bN as Fl, S as Ml, D as Ol, bC as Rl, bI as _l, H as Bl, c3 as Wl, d2 as Gl, b7 as Pl, aK as Ul, c5 as Vl, K as jl, cs as Hl, bq as Kl, T as ql, ar as Zl, b8 as Jl, bO as Xl, cx as Ne, d3 as Yl, c as Ql, d4 as v, t as y, d5 as Me, d6 as Oe, ac as $t, l as K, ab as tu, d7 as Cn, k as _t, a_ as ze, f as eu, u as su, ag as we, O as nu, d8 as iu, Y as Dn, d9 as ru, da as zn, db as au } from "./index-DSGwv2Yx.js";
|
|
2
|
+
import { c as te, k as Gs, m as Tn, b as ou, t as hs, a as Ps, s as Us, l as lu, p as uu, e as cu } from "./step-DQY6_ABw.js";
|
|
3
|
+
import { n as pt, t as P } from "./transpose-Qxz-4os3.js";
|
|
4
|
+
import { r as C } from "./reshape-BIN71H3p.js";
|
|
5
|
+
import { s as _ } from "./sum-D39FeU5h.js";
|
|
6
|
+
import { m as Mt } from "./mat_mul-DP86qZtZ.js";
|
|
7
|
+
import { j as Vs, o as Xe, G as pi, b as hu, s as pu, D as du, t as fu, A as mu, F as gu, e as Ht, q as En, r as Ye, l as bu, c as di, p as yu, z as fi, d as wu, x as ku, a as js, i as Qe, y as vt, B as Nu, E as Hs, w as xu, v as vu, n as Su, C as Au, k as Iu, u as Cu, h as Ln, g as Du, m as zu, f as Tu } from "./unsorted_segment_sum-BfFVV9Zm.js";
|
|
8
|
+
import { b as Re, c as _e, l as Ks, g as Wt, a as Eu, u as ts, d as Lu, m as mi, h as $u } from "./resize_nearest_neighbor-C6_0dAnK.js";
|
|
9
|
+
import { a as st, w as Kt, e as xe, b as le, m as Te, l as Fu, n as Ts } from "./not_equal-CkQKkKZy.js";
|
|
10
|
+
import { s as Vt } from "./split-BN9LkEgS.js";
|
|
11
|
+
import { e as Es, a as gi, g as qs, c as Mu } from "./axis_util-QWWgLjut.js";
|
|
12
|
+
import { t as Ce } from "./tile-CsWlVKKz.js";
|
|
13
|
+
import { s as ps } from "./stack-z6QE7kmP.js";
|
|
14
|
+
import { o as he } from "./ones-DbVB5N58.js";
|
|
15
|
+
import { s as me } from "./slice-BvItlgXu.js";
|
|
16
|
+
import { f as bi } from "./floor-CymuCmTO.js";
|
|
17
|
+
import { z as ft } from "./zeros-Bw0puq_w.js";
|
|
18
|
+
import { s as Ou, b as Ru, c as _u, g as Bu, a as Wu, S as Gu } from "./selu_util-BXdhy_W6.js";
|
|
19
|
+
import { p as Pu } from "./slice_util-DPY56GzQ.js";
|
|
20
|
+
import { c as Zs } from "./concat-Dqk7Xk7h.js";
|
|
21
|
+
import { e as ue } from "./expand_dims-D0rBtgT1.js";
|
|
22
|
+
import { g as Uu } from "./gather-DEyjXNb1.js";
|
|
23
|
+
import { V as d, N as B, r as ds, c as Js, a as et, b as ge, e as Be, s as Xs, g as yi, f as wi, t as Ot, R as zt, h as ht, A as Gt, i as V, n as oe, j as nt, k as Vu, l as ju, m as We, o as Pt, p as Tt, q as ae, u as jt, v as Ke, w as ce, x as Hu, y as es, z as fs, B as ki, C as St, D as $n, E as Ku, F as qu, G as Zu, H as qt, I as Ju, J as Ys, K as It, L as qe, M as Xu, O as ot, P as Ni, Q as mt, S as ee, T as Fn, U as ke, d as Et, W as Mn, X as Ge, Y as xi, Z as Qs, _ as Yu, $ as Qu, a0 as vi, a1 as tc } from "./tfjs_backend-Bk3PmK91.js";
|
|
24
|
+
import { s as tn } from "./squeeze-O_YWJpw_.js";
|
|
25
|
+
import { t as Ls } from "./tensor1d-BMl0eZYV.js";
|
|
26
|
+
import { r as Pe } from "./relu-DTvZKBsZ.js";
|
|
27
|
+
import { c as At } from "./clip_by_value-CLAD4h_I.js";
|
|
28
|
+
import { s as Si } from "./softmax-BLGJqdwx.js";
|
|
29
|
+
import { r as ms } from "./dropout_util-N0z8Os-K.js";
|
|
30
|
+
import { e as ec, l as sc, i as Rt } from "./ops-CURIZSVt.js";
|
|
31
|
+
import { t as nc } from "./tensor-D8e0Gd7c.js";
|
|
32
|
+
import { r as ic } from "./range-BDxO73mk.js";
|
|
33
33
|
import { M as rc } from "./rand_util-CZ7yLoUm.js";
|
|
34
|
-
import { v as ac } from "./variable-
|
|
34
|
+
import { v as ac } from "./variable-SSATClyt.js";
|
|
35
35
|
function oc(n, t, e, s, i, r = "NDHWC") {
|
|
36
36
|
const a = R(n, "x", "avgPool3d", "float32");
|
|
37
37
|
let o = a, l = !1;
|
|
@@ -310,7 +310,7 @@ class Lc {
|
|
|
310
310
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
311
311
|
*/
|
|
312
312
|
static adam(t = 1e-3, e = 0.9, s = 0.999, i = null) {
|
|
313
|
-
return new
|
|
313
|
+
return new Ka(t, e, s, i);
|
|
314
314
|
}
|
|
315
315
|
/**
|
|
316
316
|
* Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
|
|
@@ -325,7 +325,7 @@ class Lc {
|
|
|
325
325
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
326
326
|
*/
|
|
327
327
|
static adadelta(t = 1e-3, e = 0.95, s = null) {
|
|
328
|
-
return new
|
|
328
|
+
return new qa(t, e, s);
|
|
329
329
|
}
|
|
330
330
|
/**
|
|
331
331
|
* Constructs a `tf.AdamaxOptimizer` that uses the Adamax algorithm.
|
|
@@ -512,14 +512,14 @@ function Hc(n, t, e, s, i, r) {
|
|
|
512
512
|
const h = { dy: l, input: u }, p = { filterSize: e, strides: s, pad: i, dimRoundingMode: r }, w = X.runKernel(uo, h, p);
|
|
513
513
|
return c ? C(w, [w.shape[1], w.shape[2], w.shape[3], w.shape[4]]) : w;
|
|
514
514
|
}
|
|
515
|
-
const
|
|
516
|
-
const
|
|
515
|
+
const Kc = /* @__PURE__ */ tt({ avgPool3dGrad_: Hc });
|
|
516
|
+
const qc = {
|
|
517
517
|
kernelName: li,
|
|
518
518
|
inputsToSave: ["x"],
|
|
519
519
|
gradFunc: (n, t, e) => {
|
|
520
520
|
const [s] = t, { filterSize: i, strides: r, pad: a, dimRoundingMode: o } = e;
|
|
521
521
|
return {
|
|
522
|
-
x: () =>
|
|
522
|
+
x: () => Kc(n, s, i, r, a, o)
|
|
523
523
|
};
|
|
524
524
|
}
|
|
525
525
|
};
|
|
@@ -601,7 +601,7 @@ const nh = {
|
|
|
601
601
|
gradFunc: (n, t, e) => {
|
|
602
602
|
const [s] = t, { clipValueMin: i, clipValueMax: r } = e;
|
|
603
603
|
return {
|
|
604
|
-
x: () =>
|
|
604
|
+
x: () => Kt(Re(_e(s, i), Ks(s, r)), n, Y(n))
|
|
605
605
|
};
|
|
606
606
|
}
|
|
607
607
|
};
|
|
@@ -821,7 +821,7 @@ const Sh = {
|
|
|
821
821
|
b
|
|
822
822
|
]), I = C(h, N), D = C(c, [w]), M = _n([[g], m, A]), T = P(I, M);
|
|
823
823
|
let E = gu(T, D, u.shape[o]);
|
|
824
|
-
const F =
|
|
824
|
+
const F = qs(M);
|
|
825
825
|
return E = P(E, F), E;
|
|
826
826
|
};
|
|
827
827
|
if (a === 1) {
|
|
@@ -873,11 +873,11 @@ const Th = {
|
|
|
873
873
|
inputsToSave: ["x"],
|
|
874
874
|
gradFunc: (n, t, e) => {
|
|
875
875
|
const [s] = t, { alpha: i } = e, r = Wt(s, 0);
|
|
876
|
-
return { x: () =>
|
|
876
|
+
return { x: () => Kt(r, n, f(n, i)) };
|
|
877
877
|
}
|
|
878
878
|
};
|
|
879
879
|
const Eh = {
|
|
880
|
-
kernelName:
|
|
880
|
+
kernelName: Ko,
|
|
881
881
|
inputsToSave: ["x"],
|
|
882
882
|
gradFunc: (n, t) => {
|
|
883
883
|
const [e] = t;
|
|
@@ -885,7 +885,7 @@ const Eh = {
|
|
|
885
885
|
}
|
|
886
886
|
};
|
|
887
887
|
const Lh = {
|
|
888
|
-
kernelName:
|
|
888
|
+
kernelName: qo,
|
|
889
889
|
inputsToSave: ["x"],
|
|
890
890
|
gradFunc: (n, t) => {
|
|
891
891
|
const [e] = t;
|
|
@@ -1026,10 +1026,10 @@ const Hh = {
|
|
|
1026
1026
|
inputsToSave: ["a", "b"],
|
|
1027
1027
|
gradFunc: (n, t) => {
|
|
1028
1028
|
const [e, s] = t;
|
|
1029
|
-
return { a: () => f(n, L(
|
|
1029
|
+
return { a: () => f(n, L(Ks(e, s), "float32")), b: () => f(n, L(Wt(e, s), "float32")) };
|
|
1030
1030
|
}
|
|
1031
1031
|
};
|
|
1032
|
-
const
|
|
1032
|
+
const Kh = {
|
|
1033
1033
|
kernelName: al,
|
|
1034
1034
|
inputsToSave: ["x"],
|
|
1035
1035
|
gradFunc: (n, t, e) => {
|
|
@@ -1037,7 +1037,7 @@ const qh = {
|
|
|
1037
1037
|
return { x: () => me(n, r, s.shape) };
|
|
1038
1038
|
}
|
|
1039
1039
|
};
|
|
1040
|
-
const
|
|
1040
|
+
const qh = {
|
|
1041
1041
|
kernelName: ol,
|
|
1042
1042
|
inputsToSave: ["a", "b"],
|
|
1043
1043
|
gradFunc: (n, t) => {
|
|
@@ -1109,7 +1109,7 @@ const tp = {
|
|
|
1109
1109
|
const p = lt(r.shape, o);
|
|
1110
1110
|
return p.length > 0 && (h = _(h, p)), C(h, r.shape);
|
|
1111
1111
|
}, b: () => {
|
|
1112
|
-
const c = Wt(r, 0), h =
|
|
1112
|
+
const c = Wt(r, 0), h = Kt(c, le(r), Y(r));
|
|
1113
1113
|
let p = f(n, f(i, h));
|
|
1114
1114
|
const w = lt(a.shape, o);
|
|
1115
1115
|
return w.length > 0 && (p = _(p, w)), C(p, a.shape);
|
|
@@ -1122,9 +1122,9 @@ const ep = {
|
|
|
1122
1122
|
gradFunc: (n, t) => {
|
|
1123
1123
|
const [e, s] = t, i = Wt(e, 0);
|
|
1124
1124
|
return {
|
|
1125
|
-
x: () =>
|
|
1125
|
+
x: () => Kt(i, n, f(n, s)),
|
|
1126
1126
|
alpha: () => {
|
|
1127
|
-
let r =
|
|
1127
|
+
let r = Kt(i, Y(n), f(n, e));
|
|
1128
1128
|
const a = lt(s.shape, n.shape);
|
|
1129
1129
|
return a.length > 0 && (r = _(r, a)), C(r, s.shape);
|
|
1130
1130
|
}
|
|
@@ -1146,7 +1146,7 @@ function np(n, t, e) {
|
|
|
1146
1146
|
const c = a.reshape(o);
|
|
1147
1147
|
let h = sp(c, t, i);
|
|
1148
1148
|
if (h = h.reshape(a.shape), r != null) {
|
|
1149
|
-
const p =
|
|
1149
|
+
const p = qs(r);
|
|
1150
1150
|
h = P(h, p);
|
|
1151
1151
|
}
|
|
1152
1152
|
return h;
|
|
@@ -1189,7 +1189,7 @@ const op = {
|
|
|
1189
1189
|
kernelName: wl,
|
|
1190
1190
|
inputsToSave: ["x"],
|
|
1191
1191
|
gradFunc: (n, t) => {
|
|
1192
|
-
const [e] = t, s = f(
|
|
1192
|
+
const [e] = t, s = f(Ks(e, 6), Ps(e));
|
|
1193
1193
|
return { x: () => f(n, L(s, "float32")) };
|
|
1194
1194
|
}
|
|
1195
1195
|
};
|
|
@@ -1272,7 +1272,7 @@ const gp = {
|
|
|
1272
1272
|
return {
|
|
1273
1273
|
x: () => {
|
|
1274
1274
|
const s = Wt(e, Q(0)), i = Q(Wu), r = Q(Gu), a = f(n, r), o = f(f(n, i), xe(L(e, "float32")));
|
|
1275
|
-
return
|
|
1275
|
+
return Kt(s, a, o);
|
|
1276
1276
|
}
|
|
1277
1277
|
};
|
|
1278
1278
|
}
|
|
@@ -1412,7 +1412,7 @@ const Tp = {
|
|
|
1412
1412
|
}
|
|
1413
1413
|
};
|
|
1414
1414
|
const Ep = {
|
|
1415
|
-
kernelName:
|
|
1415
|
+
kernelName: Kl,
|
|
1416
1416
|
outputsToSave: [!0],
|
|
1417
1417
|
gradFunc: (n, t) => {
|
|
1418
1418
|
const [e] = t;
|
|
@@ -1420,7 +1420,7 @@ const Ep = {
|
|
|
1420
1420
|
}
|
|
1421
1421
|
};
|
|
1422
1422
|
const Lp = {
|
|
1423
|
-
kernelName:
|
|
1423
|
+
kernelName: ql,
|
|
1424
1424
|
inputsToSave: ["x"],
|
|
1425
1425
|
gradFunc: (n, t, e) => {
|
|
1426
1426
|
const [s] = t, { reps: i } = e;
|
|
@@ -1461,7 +1461,7 @@ const Lp = {
|
|
|
1461
1461
|
const $p = {
|
|
1462
1462
|
kernelName: Zl,
|
|
1463
1463
|
gradFunc: (n, t, e) => {
|
|
1464
|
-
const s = e, { perm: i } = s, r =
|
|
1464
|
+
const s = e, { perm: i } = s, r = qs(i);
|
|
1465
1465
|
return { x: () => P(n, r) };
|
|
1466
1466
|
}
|
|
1467
1467
|
};
|
|
@@ -1488,7 +1488,7 @@ function Op(n, t) {
|
|
|
1488
1488
|
i = ue(i, o + 1);
|
|
1489
1489
|
i = Re(i, he(s.shape, "bool"));
|
|
1490
1490
|
const a = Y(s);
|
|
1491
|
-
return
|
|
1491
|
+
return Kt(i, s, a);
|
|
1492
1492
|
}
|
|
1493
1493
|
const Rp = {
|
|
1494
1494
|
kernelName: Yl,
|
|
@@ -1507,7 +1507,7 @@ const _p = [
|
|
|
1507
1507
|
Uc,
|
|
1508
1508
|
Vc,
|
|
1509
1509
|
jc,
|
|
1510
|
-
|
|
1510
|
+
qc,
|
|
1511
1511
|
Xc,
|
|
1512
1512
|
Yc,
|
|
1513
1513
|
Qc,
|
|
@@ -1553,8 +1553,8 @@ const _p = [
|
|
|
1553
1553
|
Vh,
|
|
1554
1554
|
jh,
|
|
1555
1555
|
Hh,
|
|
1556
|
-
qh,
|
|
1557
1556
|
Kh,
|
|
1557
|
+
qh,
|
|
1558
1558
|
Zh,
|
|
1559
1559
|
Jh,
|
|
1560
1560
|
Xh,
|
|
@@ -1971,7 +1971,7 @@ function Vn(n, t = {}) {
|
|
|
1971
1971
|
function J(n) {
|
|
1972
1972
|
return Xs(n);
|
|
1973
1973
|
}
|
|
1974
|
-
function
|
|
1974
|
+
function q(n) {
|
|
1975
1975
|
if (typeof n == "string") {
|
|
1976
1976
|
const t = n in Un ? Un[n] : n;
|
|
1977
1977
|
if (t === "GlorotNormal")
|
|
@@ -2113,10 +2113,10 @@ class Ft {
|
|
|
2113
2113
|
this.dtype = t, this.shape = e, this.sourceLayer = s, this.inputs = i, this.callArgs = r, this.outputTensorIndex = o, this.id = Ti(), a != null && (this.originalName = yi(a), this.name = wi(this.originalName)), this.rank = e.length;
|
|
2114
2114
|
}
|
|
2115
2115
|
}
|
|
2116
|
-
let
|
|
2116
|
+
let Kp = 0;
|
|
2117
2117
|
class bs {
|
|
2118
2118
|
constructor(t, e) {
|
|
2119
|
-
this.callArgs = e, this.id =
|
|
2119
|
+
this.callArgs = e, this.id = Kp++, this.outboundLayer = t.outboundLayer, this.inboundLayers = t.inboundLayers, this.nodeIndices = t.nodeIndices, this.tensorIndices = t.tensorIndices, this.inputTensors = t.inputTensors, this.outputTensors = t.outputTensors, this.inputMasks = t.inputMasks, this.outputMasks = t.outputMasks, this.inputShapes = t.inputShapes, this.outputShapes = t.outputShapes;
|
|
2120
2120
|
for (const s of t.inboundLayers)
|
|
2121
2121
|
s?.outboundNodes.push(this);
|
|
2122
2122
|
t.outboundLayer.inboundNodes.push(this);
|
|
@@ -2133,10 +2133,10 @@ class bs {
|
|
|
2133
2133
|
};
|
|
2134
2134
|
}
|
|
2135
2135
|
}
|
|
2136
|
-
let
|
|
2136
|
+
let qp = 0;
|
|
2137
2137
|
class O extends Me {
|
|
2138
2138
|
constructor(t = {}) {
|
|
2139
|
-
super(), this._callHook = null, this._addedWeightNames = [], this._stateful = !1, this.id =
|
|
2139
|
+
super(), this._callHook = null, this._addedWeightNames = [], this._stateful = !1, this.id = qp++, this.activityRegularizer = null, this.inputSpec = null, this.supportsMasking = !1, this._trainableWeights = [], this._nonTrainableWeights = [], this._losses = [], this._updates = [], this._built = !1, this.inboundNodes = [], this.outboundNodes = [];
|
|
2140
2140
|
let e = t.name;
|
|
2141
2141
|
if (!e) {
|
|
2142
2142
|
const s = this.getClassName();
|
|
@@ -2611,7 +2611,7 @@ class O extends Me {
|
|
|
2611
2611
|
addWeight(t, e, s, i, r, a, o, l) {
|
|
2612
2612
|
if (this._addedWeightNames.indexOf(t) !== -1)
|
|
2613
2613
|
throw new d(`Duplicate weight name ${t} for layer ${this.name}`);
|
|
2614
|
-
this._addedWeightNames.push(t), s == null && (s = "float32"), this.fastWeightInitDuringBuild && (i = l != null ? l() :
|
|
2614
|
+
this._addedWeightNames.push(t), s == null && (s = "float32"), this.fastWeightInitDuringBuild && (i = l != null ? l() : q("zeros"));
|
|
2615
2615
|
const u = i.apply(e, s), c = new jp(u, s, t, a, o);
|
|
2616
2616
|
return u.dispose(), r != null && this.addLoss(() => r.apply(c.read())), a == null && (a = !0), a ? this._trainableWeights.push(c) : this._nonTrainableWeights.push(c), c;
|
|
2617
2617
|
}
|
|
@@ -3007,7 +3007,7 @@ class Ut {
|
|
|
3007
3007
|
}
|
|
3008
3008
|
/** Dispose all mask Tensors held by this object. */
|
|
3009
3009
|
disposeMasks() {
|
|
3010
|
-
this.id2Mask != null &&
|
|
3010
|
+
this.id2Mask != null && K(this.id2Mask);
|
|
3011
3011
|
}
|
|
3012
3012
|
}
|
|
3013
3013
|
const is = new zi(), rs = new zi();
|
|
@@ -3046,7 +3046,7 @@ function De(n, t, e, s) {
|
|
|
3046
3046
|
const F = o.indexOf(T[E].name);
|
|
3047
3047
|
F !== -1 && (l[F] = I[E]);
|
|
3048
3048
|
}
|
|
3049
|
-
i ||
|
|
3049
|
+
i || K(A);
|
|
3050
3050
|
}
|
|
3051
3051
|
return w.disposeMasks(), r ? l : l[0];
|
|
3052
3052
|
}
|
|
@@ -3182,7 +3182,7 @@ class Pi extends Ve {
|
|
|
3182
3182
|
}
|
|
3183
3183
|
Pi.className = "MinMaxNorm";
|
|
3184
3184
|
v(Pi);
|
|
3185
|
-
const
|
|
3185
|
+
const Kn = {
|
|
3186
3186
|
maxNorm: "MaxNorm",
|
|
3187
3187
|
minMaxNorm: "MinMaxNorm",
|
|
3188
3188
|
nonNeg: "NonNeg",
|
|
@@ -3191,16 +3191,16 @@ const qn = {
|
|
|
3191
3191
|
function rt(n) {
|
|
3192
3192
|
return Xs(n);
|
|
3193
3193
|
}
|
|
3194
|
-
function
|
|
3194
|
+
function qn(n, t = {}) {
|
|
3195
3195
|
return Be(n, Oe.getMap().classNameMap, t, "constraint");
|
|
3196
3196
|
}
|
|
3197
3197
|
function at(n) {
|
|
3198
3198
|
if (n == null)
|
|
3199
3199
|
return null;
|
|
3200
3200
|
if (typeof n == "string") {
|
|
3201
|
-
const e = { className: n in
|
|
3202
|
-
return
|
|
3203
|
-
} else return n instanceof Ve ? n :
|
|
3201
|
+
const e = { className: n in Kn ? Kn[n] : n, config: {} };
|
|
3202
|
+
return qn(e);
|
|
3203
|
+
} else return n instanceof Ve ? n : qn(n);
|
|
3204
3204
|
}
|
|
3205
3205
|
async function re(n) {
|
|
3206
3206
|
if (n == null)
|
|
@@ -3217,7 +3217,7 @@ async function re(n) {
|
|
|
3217
3217
|
const i = await Promise.all(t);
|
|
3218
3218
|
for (let r = 0; r < i.length; ++r)
|
|
3219
3219
|
n[e[r]] = i[r][0];
|
|
3220
|
-
|
|
3220
|
+
K(s);
|
|
3221
3221
|
}
|
|
3222
3222
|
}
|
|
3223
3223
|
function Ui(n) {
|
|
@@ -3637,13 +3637,13 @@ function Is(n) {
|
|
|
3637
3637
|
} else
|
|
3638
3638
|
return n;
|
|
3639
3639
|
}
|
|
3640
|
-
function
|
|
3640
|
+
function Ki(n, t) {
|
|
3641
3641
|
return y(() => {
|
|
3642
3642
|
const e = f(0.5, vt(t)), s = Tt(Wt(t, e), n.dtype);
|
|
3643
3643
|
return st(Ht(n, s), -1);
|
|
3644
3644
|
});
|
|
3645
3645
|
}
|
|
3646
|
-
function
|
|
3646
|
+
function qi(n, t) {
|
|
3647
3647
|
return y(() => Tt(Ht(Qe(n, -1), Qe(t, -1)), "float32"));
|
|
3648
3648
|
}
|
|
3649
3649
|
function wd(n, t) {
|
|
@@ -3655,7 +3655,7 @@ function kd(n, t) {
|
|
|
3655
3655
|
function Nd(n, t) {
|
|
3656
3656
|
return y(() => {
|
|
3657
3657
|
const e = wd(n, t), s = kd(n, t), i = z(e, s);
|
|
3658
|
-
return L(
|
|
3658
|
+
return L(Kt(Wt(i, 0), W(e, i), 0), "float32");
|
|
3659
3659
|
});
|
|
3660
3660
|
}
|
|
3661
3661
|
function xd(n, t) {
|
|
@@ -3665,8 +3665,8 @@ function vd(n, t) {
|
|
|
3665
3665
|
return n.rank === t.rank && (n = tn(n, [n.rank - 1])), t = Qe(t, -1), t.dtype !== n.dtype && (t = L(t, n.dtype)), L(Ht(n, t), "float32");
|
|
3666
3666
|
}
|
|
3667
3667
|
const Sd = ys, Ad = ys, Id = dn, Cd = dn, Dd = fn, zd = fn, Zi = Le, Td = Hi, Ji = os, us = {
|
|
3668
|
-
binaryAccuracy:
|
|
3669
|
-
categoricalAccuracy:
|
|
3668
|
+
binaryAccuracy: Ki,
|
|
3669
|
+
categoricalAccuracy: qi,
|
|
3670
3670
|
precision: Nd,
|
|
3671
3671
|
categoricalCrossentropy: Zi,
|
|
3672
3672
|
sparseCategoricalCrossentropy: Ji,
|
|
@@ -3982,7 +3982,7 @@ class Nt extends O {
|
|
|
3982
3982
|
const A = i[m];
|
|
3983
3983
|
A in w || (w[A] = []), w[A].push(r[m]);
|
|
3984
3984
|
}
|
|
3985
|
-
let k = Object.keys(w).map((m) => parseInt(m, 10)).sort(
|
|
3985
|
+
let k = Object.keys(w).map((m) => parseInt(m, 10)).sort(Ke);
|
|
3986
3986
|
this.layers = [];
|
|
3987
3987
|
for (const m of k) {
|
|
3988
3988
|
const A = w[m];
|
|
@@ -3993,7 +3993,7 @@ class Nt extends O {
|
|
|
3993
3993
|
for (const N of A)
|
|
3994
3994
|
N instanceof Nt && this.internalContainerRefs.push(N), this.layers.push(N);
|
|
3995
3995
|
}
|
|
3996
|
-
this.layersByDepth = w, k = Object.keys(p).map((m) => parseInt(m, 10)).sort(
|
|
3996
|
+
this.layersByDepth = w, k = Object.keys(p).map((m) => parseInt(m, 10)).sort(Ke);
|
|
3997
3997
|
const g = this.inputs.slice(), b = [];
|
|
3998
3998
|
for (const m of k)
|
|
3999
3999
|
for (const A of p[m]) {
|
|
@@ -4236,7 +4236,7 @@ class Nt extends O {
|
|
|
4236
4236
|
const l = this.inputLayers[o], u = e[o], c = l.name + "_0_0";
|
|
4237
4237
|
s[c] = u;
|
|
4238
4238
|
}
|
|
4239
|
-
const i = Object.keys(this.nodesByDepth).map((o) => parseInt(o, 10)).sort(
|
|
4239
|
+
const i = Object.keys(this.nodesByDepth).map((o) => parseInt(o, 10)).sort(Ke);
|
|
4240
4240
|
if (i.length > 1)
|
|
4241
4241
|
for (const o of i) {
|
|
4242
4242
|
const l = this.nodesByDepth[o];
|
|
@@ -4284,7 +4284,7 @@ class Nt extends O {
|
|
|
4284
4284
|
const u = this.inputs[l], c = t[l], h = e[l];
|
|
4285
4285
|
s[u.id] = [c, h];
|
|
4286
4286
|
}
|
|
4287
|
-
const i = Object.keys(this.nodesByDepth).map((l) => parseInt(l, 10)).sort(
|
|
4287
|
+
const i = Object.keys(this.nodesByDepth).map((l) => parseInt(l, 10)).sort(Ke);
|
|
4288
4288
|
for (const l of i) {
|
|
4289
4289
|
const u = this.nodesByDepth[l];
|
|
4290
4290
|
for (const c of u) {
|
|
@@ -4557,7 +4557,7 @@ async function tr(n, t, e, s) {
|
|
|
4557
4557
|
} else
|
|
4558
4558
|
throw new Error(`Unexpected rank of target (y) tensor (${n.rank}) during handling of class weights. The rank is expected to be 1 or 2.`);
|
|
4559
4559
|
}), r = Array.from(await i.data());
|
|
4560
|
-
|
|
4560
|
+
K(i);
|
|
4561
4561
|
const a = [];
|
|
4562
4562
|
return r.forEach((o) => {
|
|
4563
4563
|
if (e[o] == null)
|
|
@@ -4659,7 +4659,7 @@ async function Ud(n, t, e) {
|
|
|
4659
4659
|
M.push(await tr(I[H], null, F[H]));
|
|
4660
4660
|
}
|
|
4661
4661
|
const T = N.concat(I).concat(M), E = o(T);
|
|
4662
|
-
|
|
4662
|
+
K(T);
|
|
4663
4663
|
for (let F = 0; F < l.length; ++F) {
|
|
4664
4664
|
const H = l[F], gt = E[F];
|
|
4665
4665
|
D[H] = gt, _t(gt);
|
|
@@ -4713,15 +4713,15 @@ async function Hd(n, t, e) {
|
|
|
4713
4713
|
if (r = y(() => {
|
|
4714
4714
|
if (u.value) {
|
|
4715
4715
|
const { xs: c, ys: h } = er(n, u.value), p = c.concat(h), w = y(() => i(p));
|
|
4716
|
-
if (
|
|
4716
|
+
if (K(p), l === 0)
|
|
4717
4717
|
for (let g = 0; g < w.length; ++g)
|
|
4718
4718
|
r.push(Q(0));
|
|
4719
4719
|
const k = p[0].shape[0];
|
|
4720
4720
|
for (let g = 0; g < w.length; ++g) {
|
|
4721
4721
|
const b = w[g], x = r[g];
|
|
4722
|
-
r[g] = y(() => z(r[g], f(k, b))), l > 0 &&
|
|
4722
|
+
r[g] = y(() => z(r[g], f(k, b))), l > 0 && K(x);
|
|
4723
4723
|
}
|
|
4724
|
-
|
|
4724
|
+
K(w), o += k, ++l;
|
|
4725
4725
|
}
|
|
4726
4726
|
return r;
|
|
4727
4727
|
}), u.done) {
|
|
@@ -4731,7 +4731,7 @@ async function Hd(n, t, e) {
|
|
|
4731
4731
|
}
|
|
4732
4732
|
for (let u = 0; u < r.length; ++u) {
|
|
4733
4733
|
const c = r[u];
|
|
4734
|
-
r[u] = W(r[u], o),
|
|
4734
|
+
r[u] = W(r[u], o), K(c);
|
|
4735
4735
|
}
|
|
4736
4736
|
return ht(r);
|
|
4737
4737
|
}
|
|
@@ -4795,14 +4795,14 @@ function kt(n, t) {
|
|
|
4795
4795
|
i.isDisposed || i.dispose();
|
|
4796
4796
|
});
|
|
4797
4797
|
}
|
|
4798
|
-
function
|
|
4798
|
+
function Kd(n) {
|
|
4799
4799
|
return n instanceof we;
|
|
4800
4800
|
}
|
|
4801
4801
|
function Bs(n) {
|
|
4802
4802
|
return Array.isArray(n);
|
|
4803
4803
|
}
|
|
4804
4804
|
function ti(n) {
|
|
4805
|
-
return !
|
|
4805
|
+
return !Kd(n) && !Bs(n);
|
|
4806
4806
|
}
|
|
4807
4807
|
function ei(n, t, e, s = !0, i = "") {
|
|
4808
4808
|
if (t == null || t.length === 0) {
|
|
@@ -4859,7 +4859,7 @@ function ei(n, t, e, s = !0, i = "") {
|
|
|
4859
4859
|
}
|
|
4860
4860
|
return r;
|
|
4861
4861
|
}
|
|
4862
|
-
function
|
|
4862
|
+
function qd(n, t, e) {
|
|
4863
4863
|
const s = jt(n.map((r) => r.shape[0]));
|
|
4864
4864
|
s.sort();
|
|
4865
4865
|
const i = jt(t.map((r) => r.shape[0]));
|
|
@@ -5048,7 +5048,7 @@ class be extends Nt {
|
|
|
5048
5048
|
for (const k of u) {
|
|
5049
5049
|
if (typeof k == "string" && ["accuracy", "acc", "crossentropy", "ce"].indexOf(k) !== -1) {
|
|
5050
5050
|
const b = this.internalOutputShapes[a];
|
|
5051
|
-
b[b.length - 1] === 1 || this.lossFunctions[a] === ws ? ["accuracy", "acc"].indexOf(k) !== -1 ? p =
|
|
5051
|
+
b[b.length - 1] === 1 || this.lossFunctions[a] === ws ? ["accuracy", "acc"].indexOf(k) !== -1 ? p = Ki : ["crossentropy", "ce"].indexOf(k) !== -1 && (p = xd) : this.lossFunctions[a] === os ? ["accuracy", "acc"].indexOf(k) !== -1 ? p = vd : ["crossentropy", "ce"].indexOf(k) !== -1 && (p = Ji) : ["accuracy", "acc"].indexOf(k) !== -1 ? p = qi : ["crossentropy", "ce"].indexOf(k) !== -1 && (p = Zi);
|
|
5052
5052
|
let x;
|
|
5053
5053
|
["accuracy", "acc"].indexOf(k) !== -1 ? x = "acc" : ["crossentropy", "ce"].indexOf(k) !== -1 && (x = "ce"), w = p, h = "" + x;
|
|
5054
5054
|
} else
|
|
@@ -5312,7 +5312,7 @@ class be extends Nt {
|
|
|
5312
5312
|
const o = this.feedOutputShapes[a];
|
|
5313
5313
|
this.feedLossFns[a] === os ? r.push(o.slice(0, o.length - 1).concat([1])) : r.push(o);
|
|
5314
5314
|
}
|
|
5315
|
-
if (t = ei(t, this.feedInputNames, this.feedInputShapes, !1, "input"), e = ei(e, this.feedOutputNames, r, !1, "target"),
|
|
5315
|
+
if (t = ei(t, this.feedInputNames, this.feedInputShapes, !1, "input"), e = ei(e, this.feedOutputNames, r, !1, "target"), qd(t, e), Zd(e, this.feedLossFns, this.feedOutputShapes), this.stateful && i != null && i > 0 && t[0].shape[0] % i !== 0)
|
|
5316
5316
|
throw new d(`In a stateful network, you should only pass inputs with a number of samples that is divisible by the batch size ${i}. Found: ${t[0].shape[0]} sample(s).`);
|
|
5317
5317
|
return [t, e];
|
|
5318
5318
|
}
|
|
@@ -5516,7 +5516,7 @@ class be extends Nt {
|
|
|
5516
5516
|
const M = Vi(s.callbacks, s.yieldEvery);
|
|
5517
5517
|
return await this.fitLoop(A, m, N, w, s.epochs, s.verbose, M, I, x, s.shuffle, D, s.initialEpoch, null, null);
|
|
5518
5518
|
} finally {
|
|
5519
|
-
this.isTraining = !1, kt(i, t), kt(r, e), kt(a, t), kt(o, e), kt(c, l), kt(h, u), p != null &&
|
|
5519
|
+
this.isTraining = !1, kt(i, t), kt(r, e), kt(a, t), kt(o, e), kt(c, l), kt(h, u), p != null && K(p);
|
|
5520
5520
|
}
|
|
5521
5521
|
}
|
|
5522
5522
|
/**
|
|
@@ -5648,7 +5648,7 @@ class be extends Nt {
|
|
|
5648
5648
|
const c = await u.data();
|
|
5649
5649
|
l.push(c[0]);
|
|
5650
5650
|
}
|
|
5651
|
-
return
|
|
5651
|
+
return K(o), kt(s[0], t), kt(s[1], e), ht(l);
|
|
5652
5652
|
}
|
|
5653
5653
|
/**
|
|
5654
5654
|
* Extract weight values of the model.
|
|
@@ -6432,7 +6432,7 @@ class ir extends ct {
|
|
|
6432
6432
|
* @return Output of the ELU activation.
|
|
6433
6433
|
*/
|
|
6434
6434
|
apply(t, e = 1) {
|
|
6435
|
-
return
|
|
6435
|
+
return Ku(t, e);
|
|
6436
6436
|
}
|
|
6437
6437
|
}
|
|
6438
6438
|
ir.className = "elu";
|
|
@@ -6474,7 +6474,7 @@ ur.className = "sigmoid";
|
|
|
6474
6474
|
v(ur);
|
|
6475
6475
|
class cr extends ct {
|
|
6476
6476
|
apply(t) {
|
|
6477
|
-
return
|
|
6477
|
+
return qu(t);
|
|
6478
6478
|
}
|
|
6479
6479
|
}
|
|
6480
6480
|
cr.className = "hardSigmoid";
|
|
@@ -6697,7 +6697,7 @@ xr.className = "LeakyReLU";
|
|
|
6697
6697
|
v(xr);
|
|
6698
6698
|
class vr extends O {
|
|
6699
6699
|
constructor(t) {
|
|
6700
|
-
if (super(t ?? {}), this.DEFAULT_ALPHA_INITIALIZER = "zeros", t == null && (t = {}), this.supportsMasking = !0, this.alphaInitializer =
|
|
6700
|
+
if (super(t ?? {}), this.DEFAULT_ALPHA_INITIALIZER = "zeros", t == null && (t = {}), this.supportsMasking = !0, this.alphaInitializer = q(t.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER), this.alphaRegularizer = Z(t.alphaRegularizer), this.alphaConstraint = at(t.alphaConstraint), t.sharedAxes == null)
|
|
6701
6701
|
this.sharedAxes = null;
|
|
6702
6702
|
else if (Array.isArray(t.sharedAxes))
|
|
6703
6703
|
this.sharedAxes = t.sharedAxes;
|
|
@@ -6823,7 +6823,7 @@ function Dt(n, t, e, s) {
|
|
|
6823
6823
|
if (n == null)
|
|
6824
6824
|
return null;
|
|
6825
6825
|
if (s === "valid")
|
|
6826
|
-
n = n * t +
|
|
6826
|
+
n = n * t + qt([e - t, 0]);
|
|
6827
6827
|
else if (s === "same")
|
|
6828
6828
|
n = n * t;
|
|
6829
6829
|
else
|
|
@@ -6887,7 +6887,7 @@ class ks extends O {
|
|
|
6887
6887
|
constructor(t, e) {
|
|
6888
6888
|
if (super(e), this.bias = null, this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_BIAS_INITIALIZER = "zeros", ks.verifyArgs(e), this.rank = t, ot(this.rank, "rank"), this.rank !== 1 && this.rank !== 2 && this.rank !== 3)
|
|
6889
6889
|
throw new B(`Convolution layer for rank other than 1, 2, or 3 (${this.rank}) is not implemented yet.`);
|
|
6890
|
-
if (this.kernelSize = ye(e.kernelSize, t, "kernelSize"), this.strides = ye(e.strides == null ? 1 : e.strides, t, "strides"), this.padding = e.padding == null ? "valid" : e.padding, mt(this.padding), this.dataFormat = e.dataFormat == null ? "channelsLast" : e.dataFormat, et(this.dataFormat), this.activation = Jt(e.activation), this.useBias = e.useBias == null ? !0 : e.useBias, this.biasInitializer =
|
|
6890
|
+
if (this.kernelSize = ye(e.kernelSize, t, "kernelSize"), this.strides = ye(e.strides == null ? 1 : e.strides, t, "strides"), this.padding = e.padding == null ? "valid" : e.padding, mt(this.padding), this.dataFormat = e.dataFormat == null ? "channelsLast" : e.dataFormat, et(this.dataFormat), this.activation = Jt(e.activation), this.useBias = e.useBias == null ? !0 : e.useBias, this.biasInitializer = q(e.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.biasConstraint = at(e.biasConstraint), this.biasRegularizer = Z(e.biasRegularizer), this.activityRegularizer = Z(e.activityRegularizer), this.dilationRate = ye(e.dilationRate == null ? 1 : e.dilationRate, t, "dilationRate"), this.rank === 1 && Array.isArray(this.dilationRate) && this.dilationRate.length !== 1)
|
|
6891
6891
|
throw new d(`dilationRate must be a number or an array of a single number for 1D convolution, but received ${JSON.stringify(this.dilationRate)}`);
|
|
6892
6892
|
if (this.rank === 2) {
|
|
6893
6893
|
if (typeof this.dilationRate == "number")
|
|
@@ -6924,7 +6924,7 @@ class ks extends O {
|
|
|
6924
6924
|
}
|
|
6925
6925
|
class ve extends ks {
|
|
6926
6926
|
constructor(t, e) {
|
|
6927
|
-
super(t, e), this.kernel = null, ve.verifyArgs(e), this.filters = e.filters, ot(this.filters, "filters"), this.kernelInitializer =
|
|
6927
|
+
super(t, e), this.kernel = null, ve.verifyArgs(e), this.filters = e.filters, ot(this.filters, "filters"), this.kernelInitializer = q(e.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.kernelConstraint = at(e.kernelConstraint), this.kernelRegularizer = Z(e.kernelRegularizer);
|
|
6928
6928
|
}
|
|
6929
6929
|
build(t) {
|
|
6930
6930
|
t = G(t);
|
|
@@ -7103,7 +7103,7 @@ class Tr extends ve {
|
|
|
7103
7103
|
throw new d("Fields kernelInitializer, kernelRegularizer and kernelConstraint are invalid for SeparableConv2D. Use depthwiseInitializer, depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, pointwiseRegularizer and pointwiseConstraint instead.");
|
|
7104
7104
|
if (e.padding != null && e.padding !== "same" && e.padding !== "valid")
|
|
7105
7105
|
throw new d(`SeparableConv${this.rank}D supports only padding modes: 'same' and 'valid', but received ${JSON.stringify(e.padding)}`);
|
|
7106
|
-
this.depthMultiplier = e.depthMultiplier == null ? 1 : e.depthMultiplier, this.depthwiseInitializer =
|
|
7106
|
+
this.depthMultiplier = e.depthMultiplier == null ? 1 : e.depthMultiplier, this.depthwiseInitializer = q(e.depthwiseInitializer || this.DEFAULT_DEPTHWISE_INITIALIZER), this.depthwiseRegularizer = Z(e.depthwiseRegularizer), this.depthwiseConstraint = at(e.depthwiseConstraint), this.pointwiseInitializer = q(e.depthwiseInitializer || this.DEFAULT_POINTWISE_INITIALIZER), this.pointwiseRegularizer = Z(e.pointwiseRegularizer), this.pointwiseConstraint = at(e.pointwiseConstraint);
|
|
7107
7107
|
}
|
|
7108
7108
|
build(t) {
|
|
7109
7109
|
if (t = G(t), t.length < this.rank + 2)
|
|
@@ -7178,11 +7178,11 @@ class Lr extends O {
|
|
|
7178
7178
|
call(t, e) {
|
|
7179
7179
|
return y(() => {
|
|
7180
7180
|
if (t = $(t), this.dataFormat === "channelsLast") {
|
|
7181
|
-
const s =
|
|
7182
|
-
return
|
|
7181
|
+
const s = qe(t, this.cropping[0][0], t.shape[1] - this.cropping[0][0] - this.cropping[0][1], 2);
|
|
7182
|
+
return qe(s, this.cropping[1][0], t.shape[2] - this.cropping[1][1] - this.cropping[1][0], 3);
|
|
7183
7183
|
} else {
|
|
7184
|
-
const s =
|
|
7185
|
-
return
|
|
7184
|
+
const s = qe(t, this.cropping[0][0], t.shape[2] - this.cropping[0][0] - this.cropping[0][1], 3);
|
|
7185
|
+
return qe(s, this.cropping[1][0], t.shape[3] - this.cropping[1][1] - this.cropping[1][0], 4);
|
|
7186
7186
|
}
|
|
7187
7187
|
});
|
|
7188
7188
|
}
|
|
@@ -7244,7 +7244,7 @@ function ef(n, t, e = [1, 1], s = "valid", i, r) {
|
|
|
7244
7244
|
}
|
|
7245
7245
|
class Fr extends ks {
|
|
7246
7246
|
constructor(t) {
|
|
7247
|
-
super(2, t), this.depthwiseKernel = null, this.depthMultiplier = t.depthMultiplier == null ? 1 : t.depthMultiplier, this.depthwiseInitializer =
|
|
7247
|
+
super(2, t), this.depthwiseKernel = null, this.depthMultiplier = t.depthMultiplier == null ? 1 : t.depthMultiplier, this.depthwiseInitializer = q(t.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.depthwiseConstraint = at(t.depthwiseConstraint), this.depthwiseRegularizer = Z(t.depthwiseRegularizer);
|
|
7248
7248
|
}
|
|
7249
7249
|
build(t) {
|
|
7250
7250
|
if (t = G(t), t.length < 4)
|
|
@@ -7429,11 +7429,11 @@ class se extends O {
|
|
|
7429
7429
|
if (this.states_ == null)
|
|
7430
7430
|
Array.isArray(this.cell.stateSize) ? this.states_ = this.cell.stateSize.map((i) => ft([s, i])) : this.states_ = [ft([s, this.cell.stateSize])];
|
|
7431
7431
|
else if (t == null)
|
|
7432
|
-
|
|
7432
|
+
K(this.states_), this.keptStates != null && (K(this.keptStates), this.keptStates = []), Array.isArray(this.cell.stateSize) ? this.states_ = this.cell.stateSize.map((i) => ft([s, i])) : this.states_[0] = ft([s, this.cell.stateSize]);
|
|
7433
7433
|
else {
|
|
7434
7434
|
if (Array.isArray(t) || (t = [t]), t.length !== this.states_.length)
|
|
7435
7435
|
throw new d(`Layer ${this.name} expects ${this.states_.length} state(s), but it received ${t.length} state value(s). Input received: ${t}`);
|
|
7436
|
-
e === !0 ? this.keptStates.push(this.states_.slice()) :
|
|
7436
|
+
e === !0 ? this.keptStates.push(this.states_.slice()) : K(this.states_);
|
|
7437
7437
|
for (let i = 0; i < this.states_.length; ++i) {
|
|
7438
7438
|
const r = t[i], a = Array.isArray(this.cell.stateSize) ? this.cell.stateSize[i] : this.cell.stateSize, o = [s, a];
|
|
7439
7439
|
if (!$t(r.shape, o))
|
|
@@ -7525,9 +7525,9 @@ class xs extends O {
|
|
|
7525
7525
|
}
|
|
7526
7526
|
class bn extends xs {
|
|
7527
7527
|
constructor(t) {
|
|
7528
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Jt(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer =
|
|
7528
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Jt(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = q(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = q(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = q(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = Z(t.kernelRegularizer), this.recurrentRegularizer = Z(t.recurrentRegularizer), this.biasRegularizer = Z(t.biasRegularizer), this.kernelConstraint = at(t.kernelConstraint), this.recurrentConstraint = at(t.recurrentConstraint), this.biasConstraint = at(t.biasConstraint), this.dropout = ke([1, qt([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = ke([
|
|
7529
7529
|
1,
|
|
7530
|
-
|
|
7530
|
+
qt([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
7531
7531
|
]), this.dropoutFunc = t.dropoutFunc, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
7532
7532
|
}
|
|
7533
7533
|
build(t) {
|
|
@@ -7593,7 +7593,7 @@ class Rr extends se {
|
|
|
7593
7593
|
}
|
|
7594
7594
|
call(t, e) {
|
|
7595
7595
|
return y(() => {
|
|
7596
|
-
this.cell.dropoutMask != null && (
|
|
7596
|
+
this.cell.dropoutMask != null && (K(this.cell.dropoutMask), this.cell.dropoutMask = null), this.cell.recurrentDropoutMask != null && (K(this.cell.recurrentDropoutMask), this.cell.recurrentDropoutMask = null);
|
|
7597
7597
|
const s = e == null ? null : e.mask, i = e == null ? null : e.training, r = e == null ? null : e.initialState;
|
|
7598
7598
|
return super.call(t, { mask: s, training: i, initialState: r });
|
|
7599
7599
|
});
|
|
@@ -7609,9 +7609,9 @@ class yn extends xs {
|
|
|
7609
7609
|
constructor(t) {
|
|
7610
7610
|
if (super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", t.resetAfter)
|
|
7611
7611
|
throw new d("GRUCell does not support reset_after parameter set to true.");
|
|
7612
|
-
this.units = t.units, ot(this.units, "units"), this.activation = Jt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Jt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer =
|
|
7612
|
+
this.units = t.units, ot(this.units, "units"), this.activation = Jt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Jt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = q(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = q(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = q(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = Z(t.kernelRegularizer), this.recurrentRegularizer = Z(t.recurrentRegularizer), this.biasRegularizer = Z(t.biasRegularizer), this.kernelConstraint = at(t.kernelConstraint), this.recurrentConstraint = at(t.recurrentConstraint), this.biasConstraint = at(t.biasConstraint), this.dropout = ke([1, qt([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = ke([
|
|
7613
7613
|
1,
|
|
7614
|
-
|
|
7614
|
+
qt([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
7615
7615
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
7616
7616
|
}
|
|
7617
7617
|
build(t) {
|
|
@@ -7683,7 +7683,7 @@ class _r extends se {
|
|
|
7683
7683
|
}
|
|
7684
7684
|
call(t, e) {
|
|
7685
7685
|
return y(() => {
|
|
7686
|
-
this.cell.dropoutMask != null && (
|
|
7686
|
+
this.cell.dropoutMask != null && (K(this.cell.dropoutMask), this.cell.dropoutMask = null), this.cell.recurrentDropoutMask != null && (K(this.cell.recurrentDropoutMask), this.cell.recurrentDropoutMask = null);
|
|
7687
7687
|
const s = e == null ? null : e.mask, i = e == null ? null : e.training, r = e == null ? null : e.initialState;
|
|
7688
7688
|
return super.call(t, { mask: s, training: i, initialState: r });
|
|
7689
7689
|
});
|
|
@@ -7697,9 +7697,9 @@ _r.className = "GRU";
|
|
|
7697
7697
|
v(_r);
|
|
7698
7698
|
class vs extends xs {
|
|
7699
7699
|
constructor(t) {
|
|
7700
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Jt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Jt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer =
|
|
7700
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Jt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Jt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = q(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = q(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = q(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer = Z(t.kernelRegularizer), this.recurrentRegularizer = Z(t.recurrentRegularizer), this.biasRegularizer = Z(t.biasRegularizer), this.kernelConstraint = at(t.kernelConstraint), this.recurrentConstraint = at(t.recurrentConstraint), this.biasConstraint = at(t.biasConstraint), this.dropout = ke([1, qt([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = ke([
|
|
7701
7701
|
1,
|
|
7702
|
-
|
|
7702
|
+
qt([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
7703
7703
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = [this.units, this.units], this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
7704
7704
|
}
|
|
7705
7705
|
build(t) {
|
|
@@ -7788,7 +7788,7 @@ class Br extends se {
|
|
|
7788
7788
|
}
|
|
7789
7789
|
call(t, e) {
|
|
7790
7790
|
return y(() => {
|
|
7791
|
-
this.cell.dropoutMask != null && (
|
|
7791
|
+
this.cell.dropoutMask != null && (K(this.cell.dropoutMask), this.cell.dropoutMask = null), this.cell.recurrentDropoutMask != null && (K(this.cell.recurrentDropoutMask), this.cell.recurrentDropoutMask = null);
|
|
7792
7792
|
const s = e == null ? null : e.mask, i = e == null ? null : e.training, r = e == null ? null : e.initialState;
|
|
7793
7793
|
return super.call(t, { mask: s, training: i, initialState: r });
|
|
7794
7794
|
});
|
|
@@ -7924,7 +7924,7 @@ class Wr extends se {
|
|
|
7924
7924
|
}
|
|
7925
7925
|
call(t, e) {
|
|
7926
7926
|
return y(() => {
|
|
7927
|
-
if (this.cell.dropoutMask != null && (
|
|
7927
|
+
if (this.cell.dropoutMask != null && (K(this.cell.dropoutMask), this.cell.dropoutMask = null), this.cell.recurrentDropoutMask != null && (K(this.cell.recurrentDropoutMask), this.cell.recurrentDropoutMask = null), e && e.constants)
|
|
7928
7928
|
throw new d("ConvRNN2D cell does not support constants");
|
|
7929
7929
|
const s = e == null ? null : e.mask, i = e == null ? null : e.training, r = e == null ? null : e.initialState;
|
|
7930
7930
|
return super.call(t, { mask: s, training: i, initialState: r });
|
|
@@ -7950,11 +7950,11 @@ class Wr extends se {
|
|
|
7950
7950
|
if (this.getStates() == null)
|
|
7951
7951
|
Array.isArray(this.cell.stateSize) ? this.states_ = this.cell.stateSize.map(() => ft(r)) : this.states_ = [ft(r)];
|
|
7952
7952
|
else if (t == null)
|
|
7953
|
-
|
|
7953
|
+
K(this.states_), this.keptStates != null && (K(this.keptStates), this.keptStates = []), Array.isArray(this.cell.stateSize) ? this.states_ = this.cell.stateSize.map(() => ft(r)) : this.states_[0] = ft(r);
|
|
7954
7954
|
else {
|
|
7955
7955
|
if (Array.isArray(t) || (t = [t]), t.length !== this.states_.length)
|
|
7956
7956
|
throw new d(`Layer ${this.name} expects ${this.states_.length} state(s), but it received ${t.length} state value(s). Input received: ${t}`);
|
|
7957
|
-
e ? this.keptStates.push(this.states_.slice()) :
|
|
7957
|
+
e ? this.keptStates.push(this.states_.slice()) : K(this.states_);
|
|
7958
7958
|
for (let o = 0; o < this.states_.length; ++o) {
|
|
7959
7959
|
const l = t[o], u = r;
|
|
7960
7960
|
if (!$t(l.shape, u))
|
|
@@ -8123,7 +8123,7 @@ class Ur extends O {
|
|
|
8123
8123
|
let e = null;
|
|
8124
8124
|
t.batchSize != null && (e = t.batchSize), this.batchInputShape = [e, t.inputDim];
|
|
8125
8125
|
}
|
|
8126
|
-
this.units = t.units, ot(this.units, "units"), this.activation = Jt(t.activation), t.useBias != null && (this.useBias = t.useBias), this.kernelInitializer =
|
|
8126
|
+
this.units = t.units, ot(this.units, "units"), this.activation = Jt(t.activation), t.useBias != null && (this.useBias = t.useBias), this.kernelInitializer = q(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.biasInitializer = q(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelConstraint = at(t.kernelConstraint), this.biasConstraint = at(t.biasConstraint), this.kernelRegularizer = Z(t.kernelRegularizer), this.biasRegularizer = Z(t.biasRegularizer), this.activityRegularizer = Z(t.activityRegularizer), this.supportsMasking = !0, this.inputSpec = [{ minNDim: 2 }];
|
|
8127
8127
|
}
|
|
8128
8128
|
build(t) {
|
|
8129
8129
|
t = G(t);
|
|
@@ -8231,7 +8231,7 @@ class Hr extends O {
|
|
|
8231
8231
|
}
|
|
8232
8232
|
Hr.className = "RepeatVector";
|
|
8233
8233
|
v(Hr);
|
|
8234
|
-
class
|
|
8234
|
+
class Kr extends O {
|
|
8235
8235
|
constructor(t) {
|
|
8236
8236
|
super(t), this.targetShape = t.targetShape;
|
|
8237
8237
|
for (let e = 0; e < this.targetShape.length; ++e)
|
|
@@ -8299,9 +8299,9 @@ class qr extends O {
|
|
|
8299
8299
|
return Object.assign(t, e), t;
|
|
8300
8300
|
}
|
|
8301
8301
|
}
|
|
8302
|
-
|
|
8303
|
-
v(
|
|
8304
|
-
class
|
|
8302
|
+
Kr.className = "Reshape";
|
|
8303
|
+
v(Kr);
|
|
8304
|
+
class qr extends O {
|
|
8305
8305
|
constructor(t) {
|
|
8306
8306
|
if (super(t), t.dims == null)
|
|
8307
8307
|
throw new Error("Required configuration field `dims` is missing during Permute constructor call.");
|
|
@@ -8329,8 +8329,8 @@ class Kr extends O {
|
|
|
8329
8329
|
return Object.assign(t, e), t;
|
|
8330
8330
|
}
|
|
8331
8331
|
}
|
|
8332
|
-
|
|
8333
|
-
v(
|
|
8332
|
+
qr.className = "Permute";
|
|
8333
|
+
v(qr);
|
|
8334
8334
|
class Zr extends O {
|
|
8335
8335
|
constructor(t) {
|
|
8336
8336
|
super(t ?? {}), this.supportsMasking = !0, t != null ? this.maskValue = t.maskValue == null ? 0 : t.maskValue : this.maskValue = 0;
|
|
@@ -8362,7 +8362,7 @@ class Jr extends O {
|
|
|
8362
8362
|
let e = null;
|
|
8363
8363
|
t.batchSize != null && (e = t.batchSize), t.inputLength == null ? this.batchInputShape = [e, null] : this.batchInputShape = [e].concat(V(t.inputLength));
|
|
8364
8364
|
}
|
|
8365
|
-
this.inputDim = t.inputDim, ot(this.inputDim, "inputDim"), this.outputDim = t.outputDim, ot(this.outputDim, "outputDim"), this.embeddingsInitializer =
|
|
8365
|
+
this.inputDim = t.inputDim, ot(this.inputDim, "inputDim"), this.outputDim = t.outputDim, ot(this.outputDim, "outputDim"), this.embeddingsInitializer = q(t.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER), this.embeddingsRegularizer = Z(t.embeddingsRegularizer), this.activityRegularizer = Z(t.activityRegularizer), this.embeddingsConstraint = at(t.embeddingsConstraint), this.maskZero = t.maskZero, this.supportsMasking = t.maskZero, this.inputLength = t.inputLength;
|
|
8366
8366
|
}
|
|
8367
8367
|
build(t) {
|
|
8368
8368
|
this.embeddings = this.addWeight("embeddings", [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, !0, this.embeddingsConstraint), this.built = !0;
|
|
@@ -8482,7 +8482,7 @@ class pe extends O {
|
|
|
8482
8482
|
if (t = t, this.reshapeRequired) {
|
|
8483
8483
|
const s = [], i = t.map((r) => r.rank);
|
|
8484
8484
|
if (i.indexOf(null) === -1) {
|
|
8485
|
-
const r =
|
|
8485
|
+
const r = qt(i);
|
|
8486
8486
|
for (let a of t) {
|
|
8487
8487
|
const o = a.rank;
|
|
8488
8488
|
for (let l = 0; l < r - o; ++l)
|
|
@@ -8918,7 +8918,7 @@ function of(n, t, e, s, i = 1e-3) {
|
|
|
8918
8918
|
}
|
|
8919
8919
|
class oa extends O {
|
|
8920
8920
|
constructor(t) {
|
|
8921
|
-
t == null && (t = {}), super(t), this.supportsMasking = !0, this.axis = t.axis == null ? -1 : t.axis, this.momentum = t.momentum == null ? 0.99 : t.momentum, this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer =
|
|
8921
|
+
t == null && (t = {}), super(t), this.supportsMasking = !0, this.axis = t.axis == null ? -1 : t.axis, this.momentum = t.momentum == null ? 0.99 : t.momentum, this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = q(t.betaInitializer || "zeros"), this.gammaInitializer = q(t.gammaInitializer || "ones"), this.movingMeanInitializer = q(t.movingMeanInitializer || "zeros"), this.movingVarianceInitializer = q(t.movingVarianceInitializer || "ones"), this.betaConstraint = at(t.betaConstraint), this.gammaConstraint = at(t.gammaConstraint), this.betaRegularizer = Z(t.betaRegularizer), this.gammaRegularizer = Z(t.gammaRegularizer);
|
|
8922
8922
|
}
|
|
8923
8923
|
build(t) {
|
|
8924
8924
|
t = G(t);
|
|
@@ -8987,7 +8987,7 @@ class la extends O {
|
|
|
8987
8987
|
throw new Error(`Expected axis to be an array of integers, but received ${JSON.stringify(this.axis)}`);
|
|
8988
8988
|
} else
|
|
8989
8989
|
throw new Error(`Expected axis to be an integer or an array of integers, but received ${JSON.stringify(this.axis)}`);
|
|
8990
|
-
this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer =
|
|
8990
|
+
this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = q(t.betaInitializer || "zeros"), this.gammaInitializer = q(t.gammaInitializer || "ones"), this.betaRegularizer = Z(t.betaRegularizer), this.gammaRegularizer = Z(t.gammaRegularizer), this.supportsMasking = !0;
|
|
8991
8991
|
}
|
|
8992
8992
|
build(t) {
|
|
8993
8993
|
t = G(t);
|