@genai-fi/nanogpt 0.7.2 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +36 -4
- package/dist/Generator.js +183 -69
- package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-N8TpOMYv.js} +14 -14
- package/dist/{Reshape-DvudQDvJ.js → Reshape-B-lWQRnF.js} +1 -1
- package/dist/{Reshape-DH5srBP0.js → Reshape-Bo8HzP8V.js} +5 -5
- package/dist/TeachableLLM.d.ts +6 -6
- package/dist/TeachableLLM.js +51 -50
- package/dist/Trainer.d.ts +19 -3
- package/dist/Trainer.js +71 -28
- package/dist/{axis_util-BzbKo31C.js → axis_util-DubwyOhW.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-TE7aTPhZ.js → backend_util-BJ-_jSeK.js} +46 -46
- package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-BYfCp5iL.js} +2 -2
- package/dist/{concat-CsxrgovM.js → concat-BmDqqFsa.js} +1 -1
- package/dist/{dataset-CtdBYwjo.js → dataset-CJmEGu6D.js} +5 -5
- package/dist/{dropout-DYs5QFGQ.js → dropout-sx0sjVAT.js} +8 -8
- package/dist/exports_initializers-DAKM8UO9.js +16 -0
- package/dist/{gather-CMMy2KEG.js → gather-C1siEkdp.js} +1 -1
- package/dist/{gelu-C-dPj6Ku.js → gelu-Bd3UBBxg.js} +1 -1
- package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-TFLxaLkw.js} +26 -26
- package/dist/{index-CLthM0TO.js → index-BaPo_0H8.js} +185 -185
- package/dist/{index-BoWRt-10.js → index-CUQrfsw_.js} +266 -265
- package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-P9aFa232.js} +9 -9
- package/dist/layers/BaseLayer.d.ts +8 -13
- package/dist/layers/BaseLayer.js +25 -13
- package/dist/layers/CausalSelfAttention.d.ts +3 -2
- package/dist/layers/CausalSelfAttention.js +28 -28
- package/dist/layers/MLP.d.ts +3 -2
- package/dist/layers/MLP.js +16 -20
- package/dist/layers/PositionEmbedding.d.ts +9 -0
- package/dist/layers/PositionEmbedding.js +45 -0
- package/dist/layers/RMSNorm.d.ts +3 -2
- package/dist/layers/RMSNorm.js +6 -6
- package/dist/layers/RoPECache.d.ts +1 -1
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.d.ts +3 -2
- package/dist/layers/TiedEmbedding.js +29 -7
- package/dist/layers/TransformerBlock.d.ts +3 -2
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/load.d.ts +2 -2
- package/dist/loader/loadHF.d.ts +2 -2
- package/dist/loader/loadTransformers.d.ts +4 -2
- package/dist/loader/loadTransformers.js +10 -9
- package/dist/loader/newZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.js +42 -51
- package/dist/loader/save.d.ts +8 -0
- package/dist/loader/save.js +62 -0
- package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C142qZqY.js} +14 -14
- package/dist/main.d.ts +5 -4
- package/dist/main.js +22 -18
- package/dist/{mat_mul-8m8pfdcx.js → mat_mul-DMkduNJu.js} +1 -1
- package/dist/{max-Ddnnb5xe.js → max-B3JOcNGb.js} +1 -1
- package/dist/mod-uUuj4gSb.js +27 -0
- package/dist/models/NanoGPTV1.d.ts +15 -0
- package/dist/models/NanoGPTV1.js +71 -0
- package/dist/{config.d.ts → models/config.d.ts} +1 -0
- package/dist/{config.js → models/config.js} +1 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +14 -0
- package/dist/models/model.d.ts +26 -0
- package/dist/models/model.js +68 -0
- package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-Cm2gw-c8.js} +1 -1
- package/dist/{ones-Dj0SDhHf.js → ones-ZdgQGBCP.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +3 -3
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +11 -11
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +10 -10
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +3 -3
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +5 -5
- package/dist/ops/webgpu/qkv.js +3 -3
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/{ops-BFGCx8Ri.js → ops-C_1K_-35.js} +103 -103
- package/dist/{random_width-sZORGo5k.js → random_width-D8Pwy_na.js} +136 -136
- package/dist/{range-CRuAh-gd.js → range-LVHrSLdi.js} +1 -1
- package/dist/{reciprocal-BvGAyKyu.js → reciprocal-CaR9e67G.js} +1 -1
- package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-DUshvVWP.js} +2026 -2049
- package/dist/{reshape-CdBq1WJ6.js → reshape-DEfQGSin.js} +1 -1
- package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-CUPPNLaA.js} +1 -1
- package/dist/{selu_util-BJEXVvjX.js → selu_util-8vv5JxQV.js} +3 -3
- package/dist/{shared-B8ztnyEk.js → shared-CkNorDcU.js} +83 -83
- package/dist/{shared-wS99K7_n.js → shared-D1elLckx.js} +1 -1
- package/dist/{sin-BeA3tsEd.js → sin-D2CKKmyR.js} +1 -1
- package/dist/{slice-BiOsknYS.js → slice-BnyE-M_7.js} +1 -1
- package/dist/{softmax-Bv_6lyMX.js → softmax-DLoZWYBx.js} +1 -1
- package/dist/{split-B-dikLRw.js → split-By_n4TKP.js} +1 -1
- package/dist/{stack-B17UN2nn.js → stack-DkdFLq37.js} +1 -1
- package/dist/{sum-66ew2byf.js → sum-l_0SqM4h.js} +3 -3
- package/dist/{tensor-JwS7ZYY6.js → tensor-BAQdLqoU.js} +1 -1
- package/dist/{tensor2d-wxPAnDQy.js → tensor2d-BHy261cI.js} +1 -1
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/Evaluator.d.ts +2 -2
- package/dist/training/FullTrainer.d.ts +16 -3
- package/dist/training/FullTrainer.js +91 -53
- package/dist/training/Trainer.d.ts +25 -3
- package/dist/training/Trainer.js +39 -47
- package/dist/training/sparseCrossEntropy.js +9 -9
- package/dist/utilities/dummy.d.ts +4 -4
- package/dist/utilities/dummy.js +13 -13
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/parameters.d.ts +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BuddVFLa.js → variable-C9hihzDB.js} +1 -1
- package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-dFEVbDPL.js} +1 -1
- package/dist/{webgpu_util-D____QpY.js → webgpu_util-DLImlSc6.js} +27 -27
- package/dist/{zeros--BdLQ3oG.js → zeros-VZ72lWXM.js} +1 -1
- package/package.json +2 -3
- package/dist/NanoGPTModel.d.ts +0 -52
- package/dist/NanoGPTModel.js +0 -203
- package/dist/TiedEmbedding-BxOerUmB.js +0 -43
- package/dist/utilities/generate.d.ts +0 -3
- package/dist/utilities/generate.js +0 -22
- package/dist/utilities/save.d.ts +0 -9
- package/dist/utilities/save.js +0 -61
|
@@ -1,25 +1,25 @@
|
|
|
1
|
-
import { B as T, C as I, E as O,
|
|
2
|
-
import { k as ke, c as Nt, o as ze, s as er, b as nr, d as Wu, m as sr, t as In, l as ir, v as os, a as Gu, S as Pu, p as Uu, w as as, x as rr, y as
|
|
3
|
-
import { n as wt, w as ne, a as
|
|
4
|
-
import { r as N } from "./reshape-
|
|
5
|
-
import { s as W } from "./sum-
|
|
6
|
-
import { m as ct } from "./mat_mul-
|
|
7
|
-
import { s as Qt } from "./split-
|
|
8
|
-
import { s as
|
|
9
|
-
import { e as Hn, g as lr, h as cs, c as Xu } from "./axis_util-
|
|
10
|
-
import { a as se, e as ie, l as Yu } from "./log_sum_exp-
|
|
11
|
-
import { s as Dn } from "./stack-
|
|
12
|
-
import { o as xe } from "./ones-
|
|
13
|
-
import { s as Dt } from "./slice-
|
|
14
|
-
import { M as Qu, f as ur, r as tc, d as ec, a as $n } from "./dropout-
|
|
15
|
-
import { z as vt } from "./zeros
|
|
16
|
-
import { c as pe } from "./concat-
|
|
17
|
-
import { g as cr } from "./gather-
|
|
18
|
-
import { s as hr } from "./softmax-
|
|
19
|
-
import { m as Ee } from "./max-
|
|
20
|
-
import { t as nc } from "./tensor-
|
|
21
|
-
import { r as sc } from "./range-
|
|
22
|
-
import { v as ic } from "./variable-
|
|
1
|
+
import { B as T, C as I, E as O, bG as Oa, bH as Ma, bI as Ci, n as b, bJ as Ii, $ as L, bK as Di, bL as $i, bM as Ti, bN as zi, h as Ei, bO as Li, bP as Fi, bQ as Oi, bR as Mi, bS as _a, bT as _i, bU as Ra, bV as Ri, bW as Ba, bX as Bi, N as Ge, l as kt, bs as Wa, bY as Wi, bZ as Gi, b_ as Pi, I as Pe, c as V, a as w, b$ as Ga, c0 as Ui, c1 as Vi, c2 as ji, p as ce, aW as pt, bx as Pa, c3 as Ki, c4 as Hi, c5 as qi, c6 as Zi, c7 as Ji, bz as Xi, c8 as Yi, c9 as Qi, L as Ua, an as Va, as as ja, ca as tr, cb as Ka, q as z, cc as Ps, cd as Ha, ce as qa, j as pn, cf as Us, cg as Za, ch as Ja, ci as Xa, cj as Ya, ck as Qa, cl as tl, cm as el, bh as nl, cn as sl, w as he, b as et, o as U, co as il, bo as rl, ay as ht, cp as ol, K as Q, cq as al, cr as ll, cs as ul, ct as cl, cu as hl, cv as pl, cw as dl, cx as fl, Q as ml, cy as gl, bl, br as yl, cz as wl, H as kl, cA as xl, ab as Nl, cB as vl, cC as Al, cD as Sl, au as Cl, cE as Il, ag as Dl, aX as $l, bt as Tl, am as zl, bu as El, G as Ll, aZ as Fl, ao as Ol, cF as Ml, cG as _l, cH as Rl, ap as Bl, ah as Wl, cI as Gl, cJ as Pl, cK as Ul, X as Vl, bv as jl, cL as Kl, cM as Hl, aU as ql, b3 as Zl, cN as Jl, Y as Xl, bw as Yl, b1 as Ql, P as tu, cO as eu, x as rs, ar as nu, by as su, aB as iu, Z as ru, aw as ou, av as au, U as lu, be as uu, cP as cu, bf as hu, cQ as pu, b5 as du, aS as fu, at as mu, cR as gu, ac as bu, _ as yu, S as wu, V as ku, bB as xu, cS as Nu, bC as vu, ax as Au, bE as Su, a0 as Cu, cT as Iu, W as Du, b7 as $u, b6 as Tu, cU as Oe, cV as zu, i as Eu, al as Vs, cW as Lu, t as x, aV as $e, cX as S, cY as He, cZ as qe, ae as Vt, d as Z, af as Fu, c_ as js, k as Zt, F as Ou, T as Te, O as Mu, c$ as _u, m as Ks, d0 as Ru, d1 as Hs, d2 as Bu } from "./index-CUQrfsw_.js";
|
|
2
|
+
import { k as ke, c as Nt, o as ze, s as er, b as nr, d as Wu, m as sr, t as In, l as ir, v as os, a as Gu, S as Pu, p as Uu, w as as, x as rr, y as Ze, z as Vu, A as ju } from "./selu_util-8vv5JxQV.js";
|
|
3
|
+
import { n as wt, w as ne, a as Je, g as Xe, b as ls, t as K, c as Ce, d as Xt, e as Ku, u as dn, f as ye, h as fn, i as Hu, l as qu, s as us, m as or, j as qt, k as Zu } from "./ops-C_1K_-35.js";
|
|
4
|
+
import { r as N } from "./reshape-DEfQGSin.js";
|
|
5
|
+
import { s as W } from "./sum-l_0SqM4h.js";
|
|
6
|
+
import { m as ct } from "./mat_mul-DMkduNJu.js";
|
|
7
|
+
import { s as Qt } from "./split-By_n4TKP.js";
|
|
8
|
+
import { s as Ju, c as ar } from "./sin-D2CKKmyR.js";
|
|
9
|
+
import { e as Hn, g as lr, h as cs, c as Xu } from "./axis_util-DubwyOhW.js";
|
|
10
|
+
import { a as se, e as ie, l as Yu } from "./log_sum_exp-C142qZqY.js";
|
|
11
|
+
import { s as Dn } from "./stack-DkdFLq37.js";
|
|
12
|
+
import { o as xe } from "./ones-ZdgQGBCP.js";
|
|
13
|
+
import { s as Dt } from "./slice-BnyE-M_7.js";
|
|
14
|
+
import { M as Qu, f as ur, r as tc, d as ec, a as $n } from "./dropout-sx0sjVAT.js";
|
|
15
|
+
import { z as vt } from "./zeros-VZ72lWXM.js";
|
|
16
|
+
import { c as pe } from "./concat-BmDqqFsa.js";
|
|
17
|
+
import { g as cr } from "./gather-C1siEkdp.js";
|
|
18
|
+
import { s as hr } from "./softmax-DLoZWYBx.js";
|
|
19
|
+
import { m as Ee } from "./max-B3JOcNGb.js";
|
|
20
|
+
import { t as nc } from "./tensor-BAQdLqoU.js";
|
|
21
|
+
import { r as sc } from "./range-LVHrSLdi.js";
|
|
22
|
+
import { v as ic } from "./variable-C9hihzDB.js";
|
|
23
23
|
/**
|
|
24
24
|
* @license
|
|
25
25
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -444,7 +444,7 @@ function Kc(n, t = 0, e = !1, s = !1) {
|
|
|
444
444
|
const r = { x: I(n, "x", "cumprod") }, o = { axis: t, exclusive: e, reverse: s };
|
|
445
445
|
return O.runKernel(Ra, r, o);
|
|
446
446
|
}
|
|
447
|
-
const
|
|
447
|
+
const Zs = /* @__PURE__ */ T({ cumprod_: Kc });
|
|
448
448
|
/**
|
|
449
449
|
* @license
|
|
450
450
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -482,13 +482,13 @@ const qc = /* @__PURE__ */ T({ cumsum_: Hc });
|
|
|
482
482
|
* limitations under the License.
|
|
483
483
|
* =============================================================================
|
|
484
484
|
*/
|
|
485
|
-
function
|
|
485
|
+
function Zc(n, t, e, s = !1) {
|
|
486
486
|
const i = I(n, "x", "denseBincount"), r = I(t, "weights", "denseBincount");
|
|
487
487
|
b(i.dtype === "int32", () => `Error in denseBincount: input dtype must be int32, but got ${i.dtype}`), b(i.rank <= 2, () => `Error in denseBincount: input must be at most rank 2, but got rank ${i.rank}.`), b(e >= 0, () => `size must be non-negative, but got ${e}.`), b(r.size === i.size || r.size === 0, () => `Error in denseBincount: weights must have the same shape as x or 0-length, but got x shape: ${i.shape}, weights shape: ${r.shape}.`);
|
|
488
488
|
const o = { x: i, weights: r }, a = { size: e, binaryOutput: s };
|
|
489
489
|
return O.runKernel(Ba, o, a);
|
|
490
490
|
}
|
|
491
|
-
const
|
|
491
|
+
const Js = /* @__PURE__ */ T({ denseBincount_: Zc });
|
|
492
492
|
/**
|
|
493
493
|
* @license
|
|
494
494
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -505,7 +505,7 @@ const Zs = /* @__PURE__ */ T({ denseBincount_: Jc });
|
|
|
505
505
|
* limitations under the License.
|
|
506
506
|
* =============================================================================
|
|
507
507
|
*/
|
|
508
|
-
function
|
|
508
|
+
function Jc(n, t, e, s, i = "NHWC", r = [1, 1], o) {
|
|
509
509
|
const a = I(n, "x", "depthwiseConv2d", "float32"), l = I(t, "filter", "depthwiseConv2d", "float32");
|
|
510
510
|
let u = a, c = !1;
|
|
511
511
|
a.rank === 3 && (c = !0, u = N(a, [1, a.shape[0], a.shape[1], a.shape[2]])), b(u.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got rank ${u.rank}.`), b(l.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ${l.rank}.`);
|
|
@@ -514,7 +514,7 @@ function Zc(n, t, e, s, i = "NHWC", r = [1, 1], o) {
|
|
|
514
514
|
const p = { x: u, filter: l }, f = { strides: e, pad: s, dataFormat: i, dilations: r, dimRoundingMode: o }, y = O.runKernel(Bi, p, f);
|
|
515
515
|
return c ? N(y, [y.shape[1], y.shape[2], y.shape[3]]) : y;
|
|
516
516
|
}
|
|
517
|
-
const dr = /* @__PURE__ */ T({ depthwiseConv2d_:
|
|
517
|
+
const dr = /* @__PURE__ */ T({ depthwiseConv2d_: Jc });
|
|
518
518
|
/**
|
|
519
519
|
* @license
|
|
520
520
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -858,7 +858,7 @@ function yh(n, t, e) {
|
|
|
858
858
|
const s = I(n, "x", "spaceToBatchND");
|
|
859
859
|
b(s.rank >= 1 + t.length, () => `input rank ${s.rank} should be > than [blockShape] ${t.length}`), b(e.length === t.length, () => `paddings.shape[0] ${e.length} must be equal to [blockShape] ${t.length}`), b(s.shape.reduce((o, a, l) => l > 0 && l <= t.length ? o && (a + e[l - 1][0] + e[l - 1][1]) % t[l - 1] === 0 : o, !0), () => `input spatial dimensions ${s.shape.slice(1)} with paddings ${e.toString()} must be divisible by blockShapes ${t.toString()}`);
|
|
860
860
|
const i = { x: s }, r = { blockShape: t, paddings: e };
|
|
861
|
-
return O.runKernel(
|
|
861
|
+
return O.runKernel(Zi, i, r);
|
|
862
862
|
}
|
|
863
863
|
const wh = /* @__PURE__ */ T({ spaceToBatchND_: yh });
|
|
864
864
|
/**
|
|
@@ -879,7 +879,7 @@ const wh = /* @__PURE__ */ T({ spaceToBatchND_: yh });
|
|
|
879
879
|
*/
|
|
880
880
|
function kh(n, t) {
|
|
881
881
|
const s = { x: I(n, "x", "reverse") }, i = { dims: t };
|
|
882
|
-
return O.runKernel(
|
|
882
|
+
return O.runKernel(Ji, s, i);
|
|
883
883
|
}
|
|
884
884
|
const gn = /* @__PURE__ */ T({ reverse_: kh });
|
|
885
885
|
/**
|
|
@@ -1361,7 +1361,7 @@ class Vh {
|
|
|
1361
1361
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1362
1362
|
*/
|
|
1363
1363
|
static sgd(t) {
|
|
1364
|
-
return new
|
|
1364
|
+
return new Za(t);
|
|
1365
1365
|
}
|
|
1366
1366
|
/**
|
|
1367
1367
|
* Constructs a `tf.MomentumOptimizer` that uses momentum gradient
|
|
@@ -1379,7 +1379,7 @@ class Vh {
|
|
|
1379
1379
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1380
1380
|
*/
|
|
1381
1381
|
static momentum(t, e, s = !1) {
|
|
1382
|
-
return new
|
|
1382
|
+
return new Ja(t, e, s);
|
|
1383
1383
|
}
|
|
1384
1384
|
/**
|
|
1385
1385
|
* Constructs a `tf.RMSPropOptimizer` that uses RMSProp gradient
|
|
@@ -1605,7 +1605,7 @@ const qh = {
|
|
|
1605
1605
|
* limitations under the License.
|
|
1606
1606
|
* =============================================================================
|
|
1607
1607
|
*/
|
|
1608
|
-
const
|
|
1608
|
+
const Zh = {
|
|
1609
1609
|
kernelName: rl,
|
|
1610
1610
|
inputsToSave: ["a", "b"],
|
|
1611
1611
|
gradFunc: (n, t) => {
|
|
@@ -1637,7 +1637,7 @@ const Jh = {
|
|
|
1637
1637
|
* limitations under the License.
|
|
1638
1638
|
* =============================================================================
|
|
1639
1639
|
*/
|
|
1640
|
-
const
|
|
1640
|
+
const Jh = {
|
|
1641
1641
|
kernelName: ol,
|
|
1642
1642
|
saveAllInputs: !0,
|
|
1643
1643
|
gradFunc: (n, t) => {
|
|
@@ -2092,7 +2092,7 @@ const mp = {
|
|
|
2092
2092
|
gradFunc: (n, t, e) => {
|
|
2093
2093
|
const [s] = t, { clipValueMin: i, clipValueMax: r } = e;
|
|
2094
2094
|
return {
|
|
2095
|
-
x: () => ne(
|
|
2095
|
+
x: () => ne(Je(Xe(s, i), ls(s, r)), n, Q(n))
|
|
2096
2096
|
};
|
|
2097
2097
|
}
|
|
2098
2098
|
};
|
|
@@ -2270,7 +2270,7 @@ const vp = {
|
|
|
2270
2270
|
inputsToSave: ["x"],
|
|
2271
2271
|
gradFunc: (n, t) => {
|
|
2272
2272
|
const [e] = t;
|
|
2273
|
-
return { x: () => w(wt(
|
|
2273
|
+
return { x: () => w(wt(Ju(L(e, "float32"))), n) };
|
|
2274
2274
|
}
|
|
2275
2275
|
};
|
|
2276
2276
|
/**
|
|
@@ -2973,7 +2973,7 @@ const ti = {
|
|
|
2973
2973
|
* limitations under the License.
|
|
2974
2974
|
* =============================================================================
|
|
2975
2975
|
*/
|
|
2976
|
-
const
|
|
2976
|
+
const Zp = {
|
|
2977
2977
|
kernelName: jl,
|
|
2978
2978
|
inputsToSave: ["a", "b"],
|
|
2979
2979
|
gradFunc: (n, t) => {
|
|
@@ -2997,7 +2997,7 @@ const Jp = {
|
|
|
2997
2997
|
* limitations under the License.
|
|
2998
2998
|
* =============================================================================
|
|
2999
2999
|
*/
|
|
3000
|
-
function
|
|
3000
|
+
function Jp(n, t, e, s, i, r, o) {
|
|
3001
3001
|
const a = I(n, "dy", "maxPool3dGrad"), l = I(t, "input", "maxPool3dGrad"), u = I(e, "output", "maxPool3dGrad");
|
|
3002
3002
|
let c = a, h = l, p = u, f = !1;
|
|
3003
3003
|
l.rank === 4 && (f = !0, c = N(a, [1, a.shape[0], a.shape[1], a.shape[2], a.shape[3]]), h = N(l, [
|
|
@@ -3016,7 +3016,7 @@ function Zp(n, t, e, s, i, r, o) {
|
|
|
3016
3016
|
const y = { dy: c, input: h, output: p }, g = { filterSize: s, strides: i, pad: r, dimRoundingMode: o }, m = O.runKernel(Kl, y, g);
|
|
3017
3017
|
return f ? N(m, [m.shape[1], m.shape[2], m.shape[3], m.shape[4]]) : m;
|
|
3018
3018
|
}
|
|
3019
|
-
const Xp = /* @__PURE__ */ T({ maxPool3dGrad_:
|
|
3019
|
+
const Xp = /* @__PURE__ */ T({ maxPool3dGrad_: Jp });
|
|
3020
3020
|
/**
|
|
3021
3021
|
* @license
|
|
3022
3022
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3169,7 +3169,7 @@ const sd = {
|
|
|
3169
3169
|
* =============================================================================
|
|
3170
3170
|
*/
|
|
3171
3171
|
const id = {
|
|
3172
|
-
kernelName:
|
|
3172
|
+
kernelName: Zl,
|
|
3173
3173
|
inputsToSave: ["a", "b"],
|
|
3174
3174
|
gradFunc: (n, t) => {
|
|
3175
3175
|
const [e, s] = t;
|
|
@@ -3193,7 +3193,7 @@ const id = {
|
|
|
3193
3193
|
* =============================================================================
|
|
3194
3194
|
*/
|
|
3195
3195
|
const rd = {
|
|
3196
|
-
kernelName:
|
|
3196
|
+
kernelName: Jl,
|
|
3197
3197
|
inputsToSave: ["x"],
|
|
3198
3198
|
gradFunc: (n, t, e) => {
|
|
3199
3199
|
const s = t[0], { paddings: i } = e, r = i.map((o) => o[0]);
|
|
@@ -3457,7 +3457,7 @@ const dd = {
|
|
|
3457
3457
|
function fd(n, t, e) {
|
|
3458
3458
|
const s = n.shape.slice();
|
|
3459
3459
|
s[e] = 1;
|
|
3460
|
-
const i = N(t, s), r =
|
|
3460
|
+
const i = N(t, s), r = Zs(n, e, !0, !1), o = Zs(n, e, !0, !0), a = w(r, o);
|
|
3461
3461
|
return w(i, a);
|
|
3462
3462
|
}
|
|
3463
3463
|
function md(n, t, e) {
|
|
@@ -3683,7 +3683,7 @@ const vd = {
|
|
|
3683
3683
|
* =============================================================================
|
|
3684
3684
|
*/
|
|
3685
3685
|
const Ad = {
|
|
3686
|
-
kernelName:
|
|
3686
|
+
kernelName: Ji,
|
|
3687
3687
|
gradFunc: (n, t, e) => {
|
|
3688
3688
|
const { dims: s } = e, i = ce(s, n.shape);
|
|
3689
3689
|
return { x: () => gn(n, i) };
|
|
@@ -3977,7 +3977,7 @@ const Od = {
|
|
|
3977
3977
|
* =============================================================================
|
|
3978
3978
|
*/
|
|
3979
3979
|
const ni = {
|
|
3980
|
-
kernelName:
|
|
3980
|
+
kernelName: Zi,
|
|
3981
3981
|
gradFunc: (n, t, e) => {
|
|
3982
3982
|
const { blockShape: s, paddings: i } = e;
|
|
3983
3983
|
return { x: () => mc(n, s, i) };
|
|
@@ -4337,7 +4337,7 @@ function qd(n, t) {
|
|
|
4337
4337
|
const r = s.rank - i.rank;
|
|
4338
4338
|
for (let a = 0; a < r; ++a)
|
|
4339
4339
|
i = ye(i, a + 1);
|
|
4340
|
-
i =
|
|
4340
|
+
i = Je(i, xe(s.shape, "bool"));
|
|
4341
4341
|
const o = Q(s);
|
|
4342
4342
|
return ne(i, s, o);
|
|
4343
4343
|
}
|
|
@@ -4357,7 +4357,7 @@ function qd(n, t) {
|
|
|
4357
4357
|
* limitations under the License.
|
|
4358
4358
|
* =============================================================================
|
|
4359
4359
|
*/
|
|
4360
|
-
const
|
|
4360
|
+
const Zd = {
|
|
4361
4361
|
kernelName: zu,
|
|
4362
4362
|
gradFunc: (n) => ({ x: () => Q(n) })
|
|
4363
4363
|
};
|
|
@@ -4377,12 +4377,12 @@ const Jd = {
|
|
|
4377
4377
|
* limitations under the License.
|
|
4378
4378
|
* =============================================================================
|
|
4379
4379
|
*/
|
|
4380
|
-
const
|
|
4380
|
+
const Jd = [
|
|
4381
4381
|
br,
|
|
4382
4382
|
Hh,
|
|
4383
4383
|
qh,
|
|
4384
|
-
Jh,
|
|
4385
4384
|
Zh,
|
|
4385
|
+
Jh,
|
|
4386
4386
|
Xh,
|
|
4387
4387
|
Yh,
|
|
4388
4388
|
Qh,
|
|
@@ -4430,7 +4430,7 @@ const Zd = [
|
|
|
4430
4430
|
qp,
|
|
4431
4431
|
ti,
|
|
4432
4432
|
ti,
|
|
4433
|
-
|
|
4433
|
+
Zp,
|
|
4434
4434
|
Yp,
|
|
4435
4435
|
ed,
|
|
4436
4436
|
nd,
|
|
@@ -4482,9 +4482,9 @@ const Zd = [
|
|
|
4482
4482
|
jd,
|
|
4483
4483
|
Kd,
|
|
4484
4484
|
Hd,
|
|
4485
|
-
|
|
4485
|
+
Zd
|
|
4486
4486
|
];
|
|
4487
|
-
for (const n of
|
|
4487
|
+
for (const n of Jd)
|
|
4488
4488
|
Eu(n);
|
|
4489
4489
|
/**
|
|
4490
4490
|
* @license
|
|
@@ -4624,15 +4624,15 @@ function ks(n) {
|
|
|
4624
4624
|
const t = {};
|
|
4625
4625
|
return t.className = n.getClassName(), t.config = n.getConfig(), t;
|
|
4626
4626
|
}
|
|
4627
|
-
function
|
|
4627
|
+
function Zn(n) {
|
|
4628
4628
|
if (!(n == null || typeof n != "object"))
|
|
4629
4629
|
if (Array.isArray(n))
|
|
4630
|
-
n.forEach((t) =>
|
|
4630
|
+
n.forEach((t) => Zn(t));
|
|
4631
4631
|
else {
|
|
4632
4632
|
const t = Object.keys(n);
|
|
4633
4633
|
for (const e of t) {
|
|
4634
4634
|
const s = n[e];
|
|
4635
|
-
s != null && typeof s == "object" && (!Array.isArray(s) && s.type === "ndarray" && typeof s.value == "number" ? n[e] = s.value :
|
|
4635
|
+
s != null && typeof s == "object" && (!Array.isArray(s) && s.type === "ndarray" && typeof s.value == "number" ? n[e] = s.value : Zn(s));
|
|
4636
4636
|
}
|
|
4637
4637
|
}
|
|
4638
4638
|
}
|
|
@@ -4671,7 +4671,7 @@ function Ye(n, t = {}, e = {}, s = "object", i = !1) {
|
|
|
4671
4671
|
const h = Object.assign({}, Ct);
|
|
4672
4672
|
for (const f of Object.keys(e))
|
|
4673
4673
|
Ct[f] = e[f];
|
|
4674
|
-
|
|
4674
|
+
Zn(r.config);
|
|
4675
4675
|
const p = l(a, r.config, e, i);
|
|
4676
4676
|
return Ct = Object.assign({}, h), p;
|
|
4677
4677
|
} else {
|
|
@@ -4894,7 +4894,7 @@ function hf(n, t) {
|
|
|
4894
4894
|
if (n.shape.length !== 2)
|
|
4895
4895
|
throw new d(`repeat() expects a rank-2 tensor, but received a rank-${n.shape.length} tensor.`);
|
|
4896
4896
|
const e = Qe(n, 1);
|
|
4897
|
-
return
|
|
4897
|
+
return Jn(e, [1, t, 1]);
|
|
4898
4898
|
});
|
|
4899
4899
|
}
|
|
4900
4900
|
function pf(n) {
|
|
@@ -5017,7 +5017,7 @@ function oi(n, t) {
|
|
|
5017
5017
|
throw new d(`concatAlongFirstAxis() received an unsupported tensor rank: ${n.rank}`);
|
|
5018
5018
|
}
|
|
5019
5019
|
}
|
|
5020
|
-
function
|
|
5020
|
+
function Jn(n, t) {
|
|
5021
5021
|
if (Array.isArray(t) || (t = [t]), n.rank !== t.length)
|
|
5022
5022
|
throw new d(`The length of input n (${t.length}) does not match the number of dimensions in input x (${n.rank})`);
|
|
5023
5023
|
return Ce(n, t);
|
|
@@ -5443,7 +5443,7 @@ function li(n, t = {}) {
|
|
|
5443
5443
|
function Y(n) {
|
|
5444
5444
|
return ks(n);
|
|
5445
5445
|
}
|
|
5446
|
-
function
|
|
5446
|
+
function J(n) {
|
|
5447
5447
|
if (typeof n == "string") {
|
|
5448
5448
|
const t = n in ai ? ai[n] : n;
|
|
5449
5449
|
if (t === "GlorotNormal")
|
|
@@ -6119,7 +6119,7 @@ class G extends He {
|
|
|
6119
6119
|
addWeight(t, e, s, i, r, o, a, l) {
|
|
6120
6120
|
if (this._addedWeightNames.indexOf(t) !== -1)
|
|
6121
6121
|
throw new d(`Duplicate weight name ${t} for layer ${this.name}`);
|
|
6122
|
-
this._addedWeightNames.push(t), s == null && (s = "float32"), this.fastWeightInitDuringBuild && (i = l != null ? l() :
|
|
6122
|
+
this._addedWeightNames.push(t), s == null && (s = "float32"), this.fastWeightInitDuringBuild && (i = l != null ? l() : J("zeros"));
|
|
6123
6123
|
const u = i.apply(e, s), c = new Nf(u, s, t, o, a);
|
|
6124
6124
|
return u.dispose(), r != null && this.addLoss(() => r.apply(c.read())), o == null && (o = !0), o ? this._trainableWeights.push(c) : this._nonTrainableWeights.push(c), c;
|
|
6125
6125
|
}
|
|
@@ -6533,7 +6533,7 @@ class Yt {
|
|
|
6533
6533
|
}
|
|
6534
6534
|
/** Dispose all mask Tensors held by this object. */
|
|
6535
6535
|
disposeMasks() {
|
|
6536
|
-
this.id2Mask != null &&
|
|
6536
|
+
this.id2Mask != null && Z(this.id2Mask);
|
|
6537
6537
|
}
|
|
6538
6538
|
}
|
|
6539
6539
|
const kn = new wr(), xn = new wr();
|
|
@@ -6572,7 +6572,7 @@ function Be(n, t, e, s) {
|
|
|
6572
6572
|
const M = a.indexOf(E[F].name);
|
|
6573
6573
|
M !== -1 && (l[M] = D[F]);
|
|
6574
6574
|
}
|
|
6575
|
-
i ||
|
|
6575
|
+
i || Z(C);
|
|
6576
6576
|
}
|
|
6577
6577
|
return f.disposeMasks(), r ? l : l[0];
|
|
6578
6578
|
}
|
|
@@ -6707,7 +6707,7 @@ Rr.className = "UnitNorm";
|
|
|
6707
6707
|
S(Rr);
|
|
6708
6708
|
class Br extends sn {
|
|
6709
6709
|
apply(t) {
|
|
6710
|
-
return
|
|
6710
|
+
return Ze(t);
|
|
6711
6711
|
}
|
|
6712
6712
|
}
|
|
6713
6713
|
Br.className = "NonNeg";
|
|
@@ -6777,7 +6777,7 @@ async function fe(n) {
|
|
|
6777
6777
|
const i = await Promise.all(t);
|
|
6778
6778
|
for (let r = 0; r < i.length; ++r)
|
|
6779
6779
|
n[e[r]] = i[r][0];
|
|
6780
|
-
|
|
6780
|
+
Z(s);
|
|
6781
6781
|
}
|
|
6782
6782
|
}
|
|
6783
6783
|
function Gr(n) {
|
|
@@ -6943,7 +6943,7 @@ class Bf extends Ue {
|
|
|
6943
6943
|
for (const s of this.params.metrics)
|
|
6944
6944
|
this.totals[s] != null && (typeof this.totals[s] == "number" ? e[s] = this.totals[s] / this.seen : x(() => {
|
|
6945
6945
|
const i = w(U(1, this.seen), this.totals[s]);
|
|
6946
|
-
e[s] = i, this.totals[s].dispose(),
|
|
6946
|
+
e[s] = i, this.totals[s].dispose(), Zt(e[s]);
|
|
6947
6947
|
}));
|
|
6948
6948
|
}
|
|
6949
6949
|
}
|
|
@@ -7087,7 +7087,7 @@ function Ur(n, t, e, s, i, r, o, a, l) {
|
|
|
7087
7087
|
* https://opensource.org/licenses/MIT.
|
|
7088
7088
|
* =============================================================================
|
|
7089
7089
|
*/
|
|
7090
|
-
function
|
|
7090
|
+
function Jt(n, t = {}, e = !1) {
|
|
7091
7091
|
return Ye(n, qe.getMap().classNameMap, t, "layer", e);
|
|
7092
7092
|
}
|
|
7093
7093
|
/**
|
|
@@ -7171,7 +7171,7 @@ function Hf(n, t) {
|
|
|
7171
7171
|
if (!Vt(n.shape, t.shape))
|
|
7172
7172
|
throw new d(`logits and labels must have the same shape, but got shapes ${JSON.stringify(n.shape)} and ${JSON.stringify(t.shape)}`);
|
|
7173
7173
|
return x(() => {
|
|
7174
|
-
const e =
|
|
7174
|
+
const e = Ze(t), s = wt($e(t));
|
|
7175
7175
|
return z(V(e, w(t, n)), eh(ie(s)));
|
|
7176
7176
|
});
|
|
7177
7177
|
}
|
|
@@ -7187,7 +7187,7 @@ function qf(n, t) {
|
|
|
7187
7187
|
return W(w(n, se(U(e, s))), -1);
|
|
7188
7188
|
});
|
|
7189
7189
|
}
|
|
7190
|
-
function
|
|
7190
|
+
function Zf(n, t) {
|
|
7191
7191
|
return x(() => {
|
|
7192
7192
|
const e = se(z(st(), t));
|
|
7193
7193
|
return nt(V(t, w(n, e)), -1);
|
|
@@ -7212,7 +7212,7 @@ const An = {
|
|
|
7212
7212
|
sparseCategoricalCrossentropy: vn,
|
|
7213
7213
|
binaryCrossentropy: Fn,
|
|
7214
7214
|
kullbackLeiblerDivergence: qf,
|
|
7215
|
-
poisson:
|
|
7215
|
+
poisson: Zf,
|
|
7216
7216
|
cosineProximity: Vr
|
|
7217
7217
|
};
|
|
7218
7218
|
function Un(n) {
|
|
@@ -7242,15 +7242,15 @@ function jr(n, t) {
|
|
|
7242
7242
|
function Kr(n, t) {
|
|
7243
7243
|
return x(() => Pt(re(mn(n, -1), mn(t, -1)), "float32"));
|
|
7244
7244
|
}
|
|
7245
|
-
function
|
|
7246
|
-
return x(() => L(W(
|
|
7245
|
+
function Jf(n, t) {
|
|
7246
|
+
return x(() => L(W(Je(re(n, 1), re(t, 1))), "float32"));
|
|
7247
7247
|
}
|
|
7248
7248
|
function Xf(n, t) {
|
|
7249
|
-
return x(() => L(W(
|
|
7249
|
+
return x(() => L(W(Je(re(n, 0), re(t, 1))), "float32"));
|
|
7250
7250
|
}
|
|
7251
7251
|
function Yf(n, t) {
|
|
7252
7252
|
return x(() => {
|
|
7253
|
-
const e =
|
|
7253
|
+
const e = Jf(n, t), s = Xf(n, t), i = z(e, s);
|
|
7254
7254
|
return L(ne(Xt(i, 0), U(e, i), 0), "float32");
|
|
7255
7255
|
});
|
|
7256
7256
|
}
|
|
@@ -7485,7 +7485,7 @@ function fm(n, t, e, s) {
|
|
|
7485
7485
|
* https://opensource.org/licenses/MIT.
|
|
7486
7486
|
* =============================================================================
|
|
7487
7487
|
*/
|
|
7488
|
-
function
|
|
7488
|
+
function Zr(n, t, e) {
|
|
7489
7489
|
return (n === "inboundNodes" || n === "outputLayers" || n === "inputLayers") && t === 0 && typeof e == "string";
|
|
7490
7490
|
}
|
|
7491
7491
|
function es(n, t) {
|
|
@@ -7499,7 +7499,7 @@ function es(n, t) {
|
|
|
7499
7499
|
const e = [], s = n.length;
|
|
7500
7500
|
for (let i = 0; i < s; ++i) {
|
|
7501
7501
|
const r = n[i];
|
|
7502
|
-
|
|
7502
|
+
Zr(t, i, r) ? e.push(r) : e.push(es(r, t));
|
|
7503
7503
|
}
|
|
7504
7504
|
return e;
|
|
7505
7505
|
} else {
|
|
@@ -7527,7 +7527,7 @@ function ns(n, t) {
|
|
|
7527
7527
|
const e = [], s = n.length;
|
|
7528
7528
|
for (let i = 0; i < s; ++i) {
|
|
7529
7529
|
const r = n[i];
|
|
7530
|
-
|
|
7530
|
+
Zr(t, i, r) ? e.push(r) : e.push(ns(r, t));
|
|
7531
7531
|
}
|
|
7532
7532
|
return e;
|
|
7533
7533
|
} else {
|
|
@@ -7540,7 +7540,7 @@ function ns(n, t) {
|
|
|
7540
7540
|
}
|
|
7541
7541
|
}
|
|
7542
7542
|
/** @license See the LICENSE file. */
|
|
7543
|
-
const
|
|
7543
|
+
const Jr = "4.22.0";
|
|
7544
7544
|
/**
|
|
7545
7545
|
* @license
|
|
7546
7546
|
* Copyright 2018 Google LLC
|
|
@@ -7804,7 +7804,7 @@ class Lt extends G {
|
|
|
7804
7804
|
*/
|
|
7805
7805
|
updatedConfig() {
|
|
7806
7806
|
const t = this.getConfig(), e = {};
|
|
7807
|
-
return e.className = this.getClassName(), e.config = t, e.kerasVersion = `tfjs-layers ${
|
|
7807
|
+
return e.className = this.getClassName(), e.config = t, e.kerasVersion = `tfjs-layers ${Jr}`, e.backend = "TensorFlow.js", e;
|
|
7808
7808
|
}
|
|
7809
7809
|
/**
|
|
7810
7810
|
* Returns a JSON string containing the network configuration.
|
|
@@ -8100,7 +8100,7 @@ class Lt extends G {
|
|
|
8100
8100
|
k.length > 0 && m.apply(yt(k), C);
|
|
8101
8101
|
}
|
|
8102
8102
|
function u(m) {
|
|
8103
|
-
const A = m.name, k =
|
|
8103
|
+
const A = m.name, k = Jt(m, e.customObjects != null ? e.customObjects : {});
|
|
8104
8104
|
k.setFastWeightInitDuringBuild(i), r[A] = k, m.inboundNodes.forEach((v) => {
|
|
8105
8105
|
if (!(v instanceof Array))
|
|
8106
8106
|
throw new d(`Corrupted configuration, expected array for nodeData: ${v}`);
|
|
@@ -8208,7 +8208,7 @@ async function Yr(n, t, e, s) {
|
|
|
8208
8208
|
} else
|
|
8209
8209
|
throw new Error(`Unexpected rank of target (y) tensor (${n.rank}) during handling of class weights. The rank is expected to be 1 or 2.`);
|
|
8210
8210
|
}), r = Array.from(await i.data());
|
|
8211
|
-
|
|
8211
|
+
Z(i);
|
|
8212
8212
|
const o = [];
|
|
8213
8213
|
return r.forEach((a) => {
|
|
8214
8214
|
if (e[a] == null)
|
|
@@ -8319,10 +8319,10 @@ async function km(n, t, e) {
|
|
|
8319
8319
|
R.push(await Yr(D[P], null, M[P]));
|
|
8320
8320
|
}
|
|
8321
8321
|
const E = v.concat(D).concat(R), F = a(E);
|
|
8322
|
-
|
|
8322
|
+
Z(E);
|
|
8323
8323
|
for (let M = 0; M < l.length; ++M) {
|
|
8324
8324
|
const P = l[M], ut = F[M];
|
|
8325
|
-
$[P] = ut,
|
|
8325
|
+
$[P] = ut, Zt(ut);
|
|
8326
8326
|
}
|
|
8327
8327
|
await p.onBatchEnd(k, $), Gr($), k++, A++;
|
|
8328
8328
|
}
|
|
@@ -8373,15 +8373,15 @@ async function vm(n, t, e) {
|
|
|
8373
8373
|
if (r = x(() => {
|
|
8374
8374
|
if (u.value) {
|
|
8375
8375
|
const { xs: c, ys: h } = Qr(n, u.value), p = c.concat(h), f = x(() => i(p));
|
|
8376
|
-
if (
|
|
8376
|
+
if (Z(p), l === 0)
|
|
8377
8377
|
for (let g = 0; g < f.length; ++g)
|
|
8378
8378
|
r.push(et(0));
|
|
8379
8379
|
const y = p[0].shape[0];
|
|
8380
8380
|
for (let g = 0; g < f.length; ++g) {
|
|
8381
8381
|
const m = f[g], A = r[g];
|
|
8382
|
-
r[g] = x(() => z(r[g], w(y, m))), l > 0 &&
|
|
8382
|
+
r[g] = x(() => z(r[g], w(y, m))), l > 0 && Z(A);
|
|
8383
8383
|
}
|
|
8384
|
-
|
|
8384
|
+
Z(f), a += y, ++l;
|
|
8385
8385
|
}
|
|
8386
8386
|
return r;
|
|
8387
8387
|
}), u.done) {
|
|
@@ -8391,7 +8391,7 @@ async function vm(n, t, e) {
|
|
|
8391
8391
|
}
|
|
8392
8392
|
for (let u = 0; u < r.length; ++u) {
|
|
8393
8393
|
const c = r[u];
|
|
8394
|
-
r[u] = U(r[u], a),
|
|
8394
|
+
r[u] = U(r[u], a), Z(c);
|
|
8395
8395
|
}
|
|
8396
8396
|
return yt(r);
|
|
8397
8397
|
}
|
|
@@ -9089,7 +9089,7 @@ class Ie extends Lt {
|
|
|
9089
9089
|
const A = this.metricsTensors[g][0], k = this.metricsTensors[g][1];
|
|
9090
9090
|
m = nt(A(i[k], f[k]));
|
|
9091
9091
|
}
|
|
9092
|
-
|
|
9092
|
+
Zt(m), o.push(m);
|
|
9093
9093
|
}
|
|
9094
9094
|
return y = nt(y), this.calculateLosses().forEach((g) => {
|
|
9095
9095
|
y = z(y, g);
|
|
@@ -9194,7 +9194,7 @@ class Ie extends Lt {
|
|
|
9194
9194
|
const R = Pr(s.callbacks, s.yieldEvery);
|
|
9195
9195
|
return await this.fitLoop(C, k, v, f, s.epochs, s.verbose, R, D, A, s.shuffle, $, s.initialEpoch, null, null);
|
|
9196
9196
|
} finally {
|
|
9197
|
-
this.isTraining = !1, Et(i, t), Et(r, e), Et(o, t), Et(a, e), Et(c, l), Et(h, u), p != null &&
|
|
9197
|
+
this.isTraining = !1, Et(i, t), Et(r, e), Et(o, t), Et(a, e), Et(c, l), Et(h, u), p != null && Z(p);
|
|
9198
9198
|
}
|
|
9199
9199
|
}
|
|
9200
9200
|
/**
|
|
@@ -9252,13 +9252,13 @@ class Ie extends Lt {
|
|
|
9252
9252
|
const ft = ss(e, ut), mt = t(ft);
|
|
9253
9253
|
for (let at = 0; at < s.length; ++at) {
|
|
9254
9254
|
const bt = s[at], gt = mt[at];
|
|
9255
|
-
F[bt] = gt,
|
|
9255
|
+
F[bt] = gt, Zt(gt);
|
|
9256
9256
|
}
|
|
9257
9257
|
if (E === R.length - 1 && g) {
|
|
9258
9258
|
const at = this.testLoop(l, u, i);
|
|
9259
9259
|
for (let bt = 0; bt < s.length; ++bt) {
|
|
9260
9260
|
const gt = s[bt], St = at[bt];
|
|
9261
|
-
|
|
9261
|
+
Zt(St), D["val_" + gt] = St;
|
|
9262
9262
|
}
|
|
9263
9263
|
}
|
|
9264
9264
|
}), await k.onBatchEnd(E, F), Gr(F), this.stopTraining_)
|
|
@@ -9326,7 +9326,7 @@ class Ie extends Lt {
|
|
|
9326
9326
|
const c = await u.data();
|
|
9327
9327
|
l.push(c[0]);
|
|
9328
9328
|
}
|
|
9329
|
-
return
|
|
9329
|
+
return Z(a), Et(s[0], t), Et(s[1], e), yt(l);
|
|
9330
9330
|
}
|
|
9331
9331
|
/**
|
|
9332
9332
|
* Extract weight values of the model.
|
|
@@ -9443,7 +9443,7 @@ class Ie extends Lt {
|
|
|
9443
9443
|
throw new Error("Loading loss_weights is not supported yet.");
|
|
9444
9444
|
if (t.sample_weight_mode != null)
|
|
9445
9445
|
throw new Error("Loading sample_weight_mode is not supported yet.");
|
|
9446
|
-
const e = es(t.optimizer_config), s =
|
|
9446
|
+
const e = es(t.optimizer_config), s = Jt(e);
|
|
9447
9447
|
let i;
|
|
9448
9448
|
if (typeof t.loss == "string")
|
|
9449
9449
|
i = me(t.loss);
|
|
@@ -9559,7 +9559,7 @@ class Ie extends Lt {
|
|
|
9559
9559
|
const s = await Hs(this.getNamedWeights(e)), a = {
|
|
9560
9560
|
modelTopology: this.toJSON(null, !1),
|
|
9561
9561
|
format: Dm,
|
|
9562
|
-
generatedBy: `TensorFlow.js tfjs-layers v${
|
|
9562
|
+
generatedBy: `TensorFlow.js tfjs-layers v${Jr}`,
|
|
9563
9563
|
convertedBy: null
|
|
9564
9564
|
};
|
|
9565
9565
|
if ((e == null ? !1 : e.includeOptimizer) && this.optimizer != null) {
|
|
@@ -10049,7 +10049,7 @@ class je extends Ie {
|
|
|
10049
10049
|
if (!(a instanceof je))
|
|
10050
10050
|
throw new B(`Sequential.fromConfig called on non-Sequential input: ${a}`);
|
|
10051
10051
|
for (const l of r) {
|
|
10052
|
-
const c =
|
|
10052
|
+
const c = Jt(l, void 0, i);
|
|
10053
10053
|
i && c.setFastWeightInitDuringBuild(!0), a.add(c);
|
|
10054
10054
|
}
|
|
10055
10055
|
return a;
|
|
@@ -10142,14 +10142,14 @@ so.className = "selu";
|
|
|
10142
10142
|
S(so);
|
|
10143
10143
|
class io extends dt {
|
|
10144
10144
|
apply(t) {
|
|
10145
|
-
return
|
|
10145
|
+
return Ze(t);
|
|
10146
10146
|
}
|
|
10147
10147
|
}
|
|
10148
10148
|
io.className = "relu";
|
|
10149
10149
|
S(io);
|
|
10150
10150
|
class ro extends dt {
|
|
10151
10151
|
apply(t) {
|
|
10152
|
-
return x(() => or(6,
|
|
10152
|
+
return x(() => or(6, Ze(t)));
|
|
10153
10153
|
}
|
|
10154
10154
|
}
|
|
10155
10155
|
ro.className = "relu6";
|
|
@@ -10378,7 +10378,7 @@ class ko extends G {
|
|
|
10378
10378
|
}
|
|
10379
10379
|
call(t, e) {
|
|
10380
10380
|
t = _(t);
|
|
10381
|
-
let s =
|
|
10381
|
+
let s = Ze(t);
|
|
10382
10382
|
return this.maxValue != null && (s = Tt(s, 0, this.maxValue)), s;
|
|
10383
10383
|
}
|
|
10384
10384
|
computeOutputShape(t) {
|
|
@@ -10411,7 +10411,7 @@ xo.className = "LeakyReLU";
|
|
|
10411
10411
|
S(xo);
|
|
10412
10412
|
class No extends G {
|
|
10413
10413
|
constructor(t) {
|
|
10414
|
-
if (super(t ?? {}), this.DEFAULT_ALPHA_INITIALIZER = "zeros", t == null && (t = {}), this.supportsMasking = !0, this.alphaInitializer =
|
|
10414
|
+
if (super(t ?? {}), this.DEFAULT_ALPHA_INITIALIZER = "zeros", t == null && (t = {}), this.supportsMasking = !0, this.alphaInitializer = J(t.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER), this.alphaRegularizer = X(t.alphaRegularizer), this.alphaConstraint = ot(t.alphaConstraint), t.sharedAxes == null)
|
|
10415
10415
|
this.sharedAxes = null;
|
|
10416
10416
|
else if (Array.isArray(t.sharedAxes))
|
|
10417
10417
|
this.sharedAxes = t.sharedAxes;
|
|
@@ -10619,7 +10619,7 @@ class On extends G {
|
|
|
10619
10619
|
constructor(t, e) {
|
|
10620
10620
|
if (super(e), this.bias = null, this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_BIAS_INITIALIZER = "zeros", On.verifyArgs(e), this.rank = t, lt(this.rank, "rank"), this.rank !== 1 && this.rank !== 2 && this.rank !== 3)
|
|
10621
10621
|
throw new B(`Convolution layer for rank other than 1, 2, or 3 (${this.rank}) is not implemented yet.`);
|
|
10622
|
-
if (this.kernelSize = De(e.kernelSize, t, "kernelSize"), this.strides = De(e.strides == null ? 1 : e.strides, t, "strides"), this.padding = e.padding == null ? "valid" : e.padding, At(this.padding), this.dataFormat = e.dataFormat == null ? "channelsLast" : e.dataFormat, tt(this.dataFormat), this.activation = le(e.activation), this.useBias = e.useBias == null ? !0 : e.useBias, this.biasInitializer =
|
|
10622
|
+
if (this.kernelSize = De(e.kernelSize, t, "kernelSize"), this.strides = De(e.strides == null ? 1 : e.strides, t, "strides"), this.padding = e.padding == null ? "valid" : e.padding, At(this.padding), this.dataFormat = e.dataFormat == null ? "channelsLast" : e.dataFormat, tt(this.dataFormat), this.activation = le(e.activation), this.useBias = e.useBias == null ? !0 : e.useBias, this.biasInitializer = J(e.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.biasConstraint = ot(e.biasConstraint), this.biasRegularizer = X(e.biasRegularizer), this.activityRegularizer = X(e.activityRegularizer), this.dilationRate = De(e.dilationRate == null ? 1 : e.dilationRate, t, "dilationRate"), this.rank === 1 && Array.isArray(this.dilationRate) && this.dilationRate.length !== 1)
|
|
10623
10623
|
throw new d(`dilationRate must be a number or an array of a single number for 1D convolution, but received ${JSON.stringify(this.dilationRate)}`);
|
|
10624
10624
|
if (this.rank === 2) {
|
|
10625
10625
|
if (typeof this.dilationRate == "number")
|
|
@@ -10656,7 +10656,7 @@ class On extends G {
|
|
|
10656
10656
|
}
|
|
10657
10657
|
class Me extends On {
|
|
10658
10658
|
constructor(t, e) {
|
|
10659
|
-
super(t, e), this.kernel = null, Me.verifyArgs(e), this.filters = e.filters, lt(this.filters, "filters"), this.kernelInitializer =
|
|
10659
|
+
super(t, e), this.kernel = null, Me.verifyArgs(e), this.filters = e.filters, lt(this.filters, "filters"), this.kernelInitializer = J(e.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.kernelConstraint = ot(e.kernelConstraint), this.kernelRegularizer = X(e.kernelRegularizer);
|
|
10660
10660
|
}
|
|
10661
10661
|
build(t) {
|
|
10662
10662
|
t = j(t);
|
|
@@ -10835,7 +10835,7 @@ class $o extends Me {
|
|
|
10835
10835
|
throw new d("Fields kernelInitializer, kernelRegularizer and kernelConstraint are invalid for SeparableConv2D. Use depthwiseInitializer, depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, pointwiseRegularizer and pointwiseConstraint instead.");
|
|
10836
10836
|
if (e.padding != null && e.padding !== "same" && e.padding !== "valid")
|
|
10837
10837
|
throw new d(`SeparableConv${this.rank}D supports only padding modes: 'same' and 'valid', but received ${JSON.stringify(e.padding)}`);
|
|
10838
|
-
this.depthMultiplier = e.depthMultiplier == null ? 1 : e.depthMultiplier, this.depthwiseInitializer =
|
|
10838
|
+
this.depthMultiplier = e.depthMultiplier == null ? 1 : e.depthMultiplier, this.depthwiseInitializer = J(e.depthwiseInitializer || this.DEFAULT_DEPTHWISE_INITIALIZER), this.depthwiseRegularizer = X(e.depthwiseRegularizer), this.depthwiseConstraint = ot(e.depthwiseConstraint), this.pointwiseInitializer = J(e.depthwiseInitializer || this.DEFAULT_POINTWISE_INITIALIZER), this.pointwiseRegularizer = X(e.pointwiseRegularizer), this.pointwiseConstraint = ot(e.pointwiseConstraint);
|
|
10839
10839
|
}
|
|
10840
10840
|
build(t) {
|
|
10841
10841
|
if (t = j(t), t.length < this.rank + 2)
|
|
@@ -10985,7 +10985,7 @@ function Em(n, t, e = [1, 1], s = "valid", i, r) {
|
|
|
10985
10985
|
}
|
|
10986
10986
|
class Lo extends On {
|
|
10987
10987
|
constructor(t) {
|
|
10988
|
-
super(2, t), this.depthwiseKernel = null, this.depthMultiplier = t.depthMultiplier == null ? 1 : t.depthMultiplier, this.depthwiseInitializer =
|
|
10988
|
+
super(2, t), this.depthwiseKernel = null, this.depthMultiplier = t.depthMultiplier == null ? 1 : t.depthMultiplier, this.depthwiseInitializer = J(t.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.depthwiseConstraint = ot(t.depthwiseConstraint), this.depthwiseRegularizer = X(t.depthwiseRegularizer);
|
|
10989
10989
|
}
|
|
10990
10990
|
build(t) {
|
|
10991
10991
|
if (t = j(t), t.length < 4)
|
|
@@ -11179,11 +11179,11 @@ class de extends G {
|
|
|
11179
11179
|
if (this.states_ == null)
|
|
11180
11180
|
Array.isArray(this.cell.stateSize) ? this.states_ = this.cell.stateSize.map((i) => vt([s, i])) : this.states_ = [vt([s, this.cell.stateSize])];
|
|
11181
11181
|
else if (t == null)
|
|
11182
|
-
|
|
11182
|
+
Z(this.states_), this.keptStates != null && (Z(this.keptStates), this.keptStates = []), Array.isArray(this.cell.stateSize) ? this.states_ = this.cell.stateSize.map((i) => vt([s, i])) : this.states_[0] = vt([s, this.cell.stateSize]);
|
|
11183
11183
|
else {
|
|
11184
11184
|
if (Array.isArray(t) || (t = [t]), t.length !== this.states_.length)
|
|
11185
11185
|
throw new d(`Layer ${this.name} expects ${this.states_.length} state(s), but it received ${t.length} state value(s). Input received: ${t}`);
|
|
11186
|
-
e === !0 ? this.keptStates.push(this.states_.slice()) :
|
|
11186
|
+
e === !0 ? this.keptStates.push(this.states_.slice()) : Z(this.states_);
|
|
11187
11187
|
for (let i = 0; i < this.states_.length; ++i) {
|
|
11188
11188
|
const r = t[i], o = Array.isArray(this.cell.stateSize) ? this.cell.stateSize[i] : this.cell.stateSize, a = [s, o];
|
|
11189
11189
|
if (!Vt(r.shape, a))
|
|
@@ -11191,7 +11191,7 @@ class de extends G {
|
|
|
11191
11191
|
this.states_[i] = r;
|
|
11192
11192
|
}
|
|
11193
11193
|
}
|
|
11194
|
-
this.states_ = this.states_.map((i) =>
|
|
11194
|
+
this.states_ = this.states_.map((i) => Zt(i.clone()));
|
|
11195
11195
|
});
|
|
11196
11196
|
}
|
|
11197
11197
|
apply(t, e) {
|
|
@@ -11236,7 +11236,7 @@ class de extends G {
|
|
|
11236
11236
|
getInitialState(t) {
|
|
11237
11237
|
return x(() => {
|
|
11238
11238
|
let e = vt(t.shape);
|
|
11239
|
-
return e = W(e, [1, 2]), e = Qe(e), Array.isArray(this.cell.stateSize) ? this.cell.stateSize.map((s) => s > 1 ?
|
|
11239
|
+
return e = W(e, [1, 2]), e = Qe(e), Array.isArray(this.cell.stateSize) ? this.cell.stateSize.map((s) => s > 1 ? Jn(e, [1, s]) : e) : this.cell.stateSize > 1 ? [Jn(e, [1, this.cell.stateSize])] : [e];
|
|
11240
11240
|
});
|
|
11241
11241
|
}
|
|
11242
11242
|
get trainableWeights() {
|
|
@@ -11265,7 +11265,7 @@ class de extends G {
|
|
|
11265
11265
|
}
|
|
11266
11266
|
/** @nocollapse */
|
|
11267
11267
|
static fromConfig(t, e, s = {}) {
|
|
11268
|
-
const i = e.cell, r =
|
|
11268
|
+
const i = e.cell, r = Jt(i, s);
|
|
11269
11269
|
return new t(Object.assign(e, { cell: r }));
|
|
11270
11270
|
}
|
|
11271
11271
|
}
|
|
@@ -11275,7 +11275,7 @@ class _n extends G {
|
|
|
11275
11275
|
}
|
|
11276
11276
|
class Ms extends _n {
|
|
11277
11277
|
constructor(t) {
|
|
11278
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer =
|
|
11278
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = ot(t.kernelConstraint), this.recurrentConstraint = ot(t.recurrentConstraint), this.biasConstraint = ot(t.biasConstraint), this.dropout = Fe([1, oe([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Fe([
|
|
11279
11279
|
1,
|
|
11280
11280
|
oe([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
11281
11281
|
]), this.dropoutFunc = t.dropoutFunc, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -11343,7 +11343,7 @@ class Mo extends de {
|
|
|
11343
11343
|
}
|
|
11344
11344
|
call(t, e) {
|
|
11345
11345
|
return x(() => {
|
|
11346
|
-
this.cell.dropoutMask != null && (
|
|
11346
|
+
this.cell.dropoutMask != null && (Z(this.cell.dropoutMask), this.cell.dropoutMask = null), this.cell.recurrentDropoutMask != null && (Z(this.cell.recurrentDropoutMask), this.cell.recurrentDropoutMask = null);
|
|
11347
11347
|
const s = e == null ? null : e.mask, i = e == null ? null : e.training, r = e == null ? null : e.initialState;
|
|
11348
11348
|
return super.call(t, { mask: s, training: i, initialState: r });
|
|
11349
11349
|
});
|
|
@@ -11359,7 +11359,7 @@ class _s extends _n {
|
|
|
11359
11359
|
constructor(t) {
|
|
11360
11360
|
if (super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", t.resetAfter)
|
|
11361
11361
|
throw new d("GRUCell does not support reset_after parameter set to true.");
|
|
11362
|
-
this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = le(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer =
|
|
11362
|
+
this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = le(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = ot(t.kernelConstraint), this.recurrentConstraint = ot(t.recurrentConstraint), this.biasConstraint = ot(t.biasConstraint), this.dropout = Fe([1, oe([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Fe([
|
|
11363
11363
|
1,
|
|
11364
11364
|
oe([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
11365
11365
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -11433,7 +11433,7 @@ class _o extends de {
|
|
|
11433
11433
|
}
|
|
11434
11434
|
call(t, e) {
|
|
11435
11435
|
return x(() => {
|
|
11436
|
-
this.cell.dropoutMask != null && (
|
|
11436
|
+
this.cell.dropoutMask != null && (Z(this.cell.dropoutMask), this.cell.dropoutMask = null), this.cell.recurrentDropoutMask != null && (Z(this.cell.recurrentDropoutMask), this.cell.recurrentDropoutMask = null);
|
|
11437
11437
|
const s = e == null ? null : e.mask, i = e == null ? null : e.training, r = e == null ? null : e.initialState;
|
|
11438
11438
|
return super.call(t, { mask: s, training: i, initialState: r });
|
|
11439
11439
|
});
|
|
@@ -11447,7 +11447,7 @@ _o.className = "GRU";
|
|
|
11447
11447
|
S(_o);
|
|
11448
11448
|
class Rn extends _n {
|
|
11449
11449
|
constructor(t) {
|
|
11450
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = le(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer =
|
|
11450
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = le(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = ot(t.kernelConstraint), this.recurrentConstraint = ot(t.recurrentConstraint), this.biasConstraint = ot(t.biasConstraint), this.dropout = Fe([1, oe([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Fe([
|
|
11451
11451
|
1,
|
|
11452
11452
|
oe([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
11453
11453
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = [this.units, this.units], this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -11538,7 +11538,7 @@ class Ro extends de {
|
|
|
11538
11538
|
}
|
|
11539
11539
|
call(t, e) {
|
|
11540
11540
|
return x(() => {
|
|
11541
|
-
this.cell.dropoutMask != null && (
|
|
11541
|
+
this.cell.dropoutMask != null && (Z(this.cell.dropoutMask), this.cell.dropoutMask = null), this.cell.recurrentDropoutMask != null && (Z(this.cell.recurrentDropoutMask), this.cell.recurrentDropoutMask = null);
|
|
11542
11542
|
const s = e == null ? null : e.mask, i = e == null ? null : e.training, r = e == null ? null : e.initialState;
|
|
11543
11543
|
return super.call(t, { mask: s, training: i, initialState: r });
|
|
11544
11544
|
});
|
|
@@ -11600,7 +11600,7 @@ class Rs extends _n {
|
|
|
11600
11600
|
static fromConfig(t, e, s = {}) {
|
|
11601
11601
|
const i = [];
|
|
11602
11602
|
for (const r of e.cells)
|
|
11603
|
-
i.push(
|
|
11603
|
+
i.push(Jt(r, s));
|
|
11604
11604
|
return new t({ cells: i });
|
|
11605
11605
|
}
|
|
11606
11606
|
get trainableWeights() {
|
|
@@ -11654,7 +11654,7 @@ Rs.className = "StackedRNNCells";
|
|
|
11654
11654
|
S(Rs);
|
|
11655
11655
|
function ue(n) {
|
|
11656
11656
|
const { ones: t, rate: e, training: s = !1, count: i = 1, dropoutFunc: r } = n, o = () => r != null ? r(t(), e) : Dr(t(), e), a = () => en(o, t, s);
|
|
11657
|
-
return !i || i <= 1 ?
|
|
11657
|
+
return !i || i <= 1 ? Zt(a().clone()) : Array(i).fill(void 0).map(a).map((u) => Zt(u.clone()));
|
|
11658
11658
|
}
|
|
11659
11659
|
/**
|
|
11660
11660
|
* @license
|
|
@@ -11683,7 +11683,7 @@ class Bo extends de {
|
|
|
11683
11683
|
}
|
|
11684
11684
|
call(t, e) {
|
|
11685
11685
|
return x(() => {
|
|
11686
|
-
if (this.cell.dropoutMask != null && (
|
|
11686
|
+
if (this.cell.dropoutMask != null && (Z(this.cell.dropoutMask), this.cell.dropoutMask = null), this.cell.recurrentDropoutMask != null && (Z(this.cell.recurrentDropoutMask), this.cell.recurrentDropoutMask = null), e && e.constants)
|
|
11687
11687
|
throw new d("ConvRNN2D cell does not support constants");
|
|
11688
11688
|
const s = e == null ? null : e.mask, i = e == null ? null : e.training, r = e == null ? null : e.initialState;
|
|
11689
11689
|
return super.call(t, { mask: s, training: i, initialState: r });
|
|
@@ -11709,11 +11709,11 @@ class Bo extends de {
|
|
|
11709
11709
|
if (this.getStates() == null)
|
|
11710
11710
|
Array.isArray(this.cell.stateSize) ? this.states_ = this.cell.stateSize.map(() => vt(r)) : this.states_ = [vt(r)];
|
|
11711
11711
|
else if (t == null)
|
|
11712
|
-
|
|
11712
|
+
Z(this.states_), this.keptStates != null && (Z(this.keptStates), this.keptStates = []), Array.isArray(this.cell.stateSize) ? this.states_ = this.cell.stateSize.map(() => vt(r)) : this.states_[0] = vt(r);
|
|
11713
11713
|
else {
|
|
11714
11714
|
if (Array.isArray(t) || (t = [t]), t.length !== this.states_.length)
|
|
11715
11715
|
throw new d(`Layer ${this.name} expects ${this.states_.length} state(s), but it received ${t.length} state value(s). Input received: ${t}`);
|
|
11716
|
-
e ? this.keptStates.push(this.states_.slice()) :
|
|
11716
|
+
e ? this.keptStates.push(this.states_.slice()) : Z(this.states_);
|
|
11717
11717
|
for (let a = 0; a < this.states_.length; ++a) {
|
|
11718
11718
|
const l = t[a], u = r;
|
|
11719
11719
|
if (!Vt(l.shape, u))
|
|
@@ -11721,7 +11721,7 @@ class Bo extends de {
|
|
|
11721
11721
|
this.states_[a] = l;
|
|
11722
11722
|
}
|
|
11723
11723
|
}
|
|
11724
|
-
this.states_ = this.states_.map((a) =>
|
|
11724
|
+
this.states_ = this.states_.map((a) => Zt(a.clone()));
|
|
11725
11725
|
});
|
|
11726
11726
|
}
|
|
11727
11727
|
computeSingleOutputShape(t) {
|
|
@@ -11891,7 +11891,7 @@ class Po extends G {
|
|
|
11891
11891
|
let e = null;
|
|
11892
11892
|
t.batchSize != null && (e = t.batchSize), this.batchInputShape = [e, t.inputDim];
|
|
11893
11893
|
}
|
|
11894
|
-
this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation), t.useBias != null && (this.useBias = t.useBias), this.kernelInitializer =
|
|
11894
|
+
this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation), t.useBias != null && (this.useBias = t.useBias), this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelConstraint = ot(t.kernelConstraint), this.biasConstraint = ot(t.biasConstraint), this.kernelRegularizer = X(t.kernelRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.activityRegularizer = X(t.activityRegularizer), this.supportsMasking = !0, this.inputSpec = [{ minNDim: 2 }];
|
|
11895
11895
|
}
|
|
11896
11896
|
build(t) {
|
|
11897
11897
|
t = j(t);
|
|
@@ -12133,13 +12133,13 @@ S(qo);
|
|
|
12133
12133
|
* https://opensource.org/licenses/MIT.
|
|
12134
12134
|
* =============================================================================
|
|
12135
12135
|
*/
|
|
12136
|
-
class
|
|
12136
|
+
class Zo extends G {
|
|
12137
12137
|
constructor(t) {
|
|
12138
12138
|
if (super(t), this.embeddings = null, this.DEFAULT_EMBEDDINGS_INITIALIZER = "randomUniform", t.batchInputShape == null && t.inputShape == null) {
|
|
12139
12139
|
let e = null;
|
|
12140
12140
|
t.batchSize != null && (e = t.batchSize), t.inputLength == null ? this.batchInputShape = [e, null] : this.batchInputShape = [e].concat(H(t.inputLength));
|
|
12141
12141
|
}
|
|
12142
|
-
this.inputDim = t.inputDim, lt(this.inputDim, "inputDim"), this.outputDim = t.outputDim, lt(this.outputDim, "outputDim"), this.embeddingsInitializer =
|
|
12142
|
+
this.inputDim = t.inputDim, lt(this.inputDim, "inputDim"), this.outputDim = t.outputDim, lt(this.outputDim, "outputDim"), this.embeddingsInitializer = J(t.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER), this.embeddingsRegularizer = X(t.embeddingsRegularizer), this.activityRegularizer = X(t.activityRegularizer), this.embeddingsConstraint = ot(t.embeddingsConstraint), this.maskZero = t.maskZero, this.supportsMasking = t.maskZero, this.inputLength = t.inputLength;
|
|
12143
12143
|
}
|
|
12144
12144
|
build(t) {
|
|
12145
12145
|
this.embeddings = this.addWeight("embeddings", [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, !0, this.embeddingsConstraint), this.built = !0;
|
|
@@ -12191,8 +12191,8 @@ class Jo extends G {
|
|
|
12191
12191
|
return Object.assign(t, e), t;
|
|
12192
12192
|
}
|
|
12193
12193
|
}
|
|
12194
|
-
|
|
12195
|
-
S(
|
|
12194
|
+
Zo.className = "Embedding";
|
|
12195
|
+
S(Zo);
|
|
12196
12196
|
/**
|
|
12197
12197
|
* @license
|
|
12198
12198
|
* Copyright 2018 Google LLC
|
|
@@ -12335,12 +12335,12 @@ class ve extends G {
|
|
|
12335
12335
|
e = e.map((i) => i == null ? i : ye(i, 0));
|
|
12336
12336
|
let s = e[0];
|
|
12337
12337
|
for (let i = 1; i < e.length - 1; ++i)
|
|
12338
|
-
s =
|
|
12338
|
+
s = Je(s, e[i]);
|
|
12339
12339
|
return s;
|
|
12340
12340
|
});
|
|
12341
12341
|
}
|
|
12342
12342
|
}
|
|
12343
|
-
class
|
|
12343
|
+
class Jo extends ve {
|
|
12344
12344
|
constructor(t) {
|
|
12345
12345
|
super(t);
|
|
12346
12346
|
}
|
|
@@ -12353,8 +12353,8 @@ class Zo extends ve {
|
|
|
12353
12353
|
});
|
|
12354
12354
|
}
|
|
12355
12355
|
}
|
|
12356
|
-
|
|
12357
|
-
S(
|
|
12356
|
+
Jo.className = "Add";
|
|
12357
|
+
S(Jo);
|
|
12358
12358
|
class Xo extends ve {
|
|
12359
12359
|
constructor(t) {
|
|
12360
12360
|
super(t);
|
|
@@ -12722,7 +12722,7 @@ function _m(n, t, e, s, i = 1e-3) {
|
|
|
12722
12722
|
}
|
|
12723
12723
|
class oa extends G {
|
|
12724
12724
|
constructor(t) {
|
|
12725
|
-
t == null && (t = {}), super(t), this.supportsMasking = !0, this.axis = t.axis == null ? -1 : t.axis, this.momentum = t.momentum == null ? 0.99 : t.momentum, this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer =
|
|
12725
|
+
t == null && (t = {}), super(t), this.supportsMasking = !0, this.axis = t.axis == null ? -1 : t.axis, this.momentum = t.momentum == null ? 0.99 : t.momentum, this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = J(t.betaInitializer || "zeros"), this.gammaInitializer = J(t.gammaInitializer || "ones"), this.movingMeanInitializer = J(t.movingMeanInitializer || "zeros"), this.movingVarianceInitializer = J(t.movingVarianceInitializer || "ones"), this.betaConstraint = ot(t.betaConstraint), this.gammaConstraint = ot(t.gammaConstraint), this.betaRegularizer = X(t.betaRegularizer), this.gammaRegularizer = X(t.gammaRegularizer);
|
|
12726
12726
|
}
|
|
12727
12727
|
build(t) {
|
|
12728
12728
|
t = j(t);
|
|
@@ -12793,7 +12793,7 @@ class aa extends G {
|
|
|
12793
12793
|
throw new Error(`Expected axis to be an array of integers, but received ${JSON.stringify(this.axis)}`);
|
|
12794
12794
|
} else
|
|
12795
12795
|
throw new Error(`Expected axis to be an integer or an array of integers, but received ${JSON.stringify(this.axis)}`);
|
|
12796
|
-
this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer =
|
|
12796
|
+
this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = J(t.betaInitializer || "zeros"), this.gammaInitializer = J(t.gammaInitializer || "ones"), this.betaRegularizer = X(t.betaRegularizer), this.gammaRegularizer = X(t.gammaRegularizer), this.supportsMasking = !0;
|
|
12797
12797
|
}
|
|
12798
12798
|
build(t) {
|
|
12799
12799
|
t = j(t);
|
|
@@ -13229,7 +13229,7 @@ class Sa extends G {
|
|
|
13229
13229
|
}
|
|
13230
13230
|
/** @nocollapse */
|
|
13231
13231
|
static fromConfig(t, e, s = {}) {
|
|
13232
|
-
const i = e.layer, r =
|
|
13232
|
+
const i = e.layer, r = Jt(i, s);
|
|
13233
13233
|
delete e.layer;
|
|
13234
13234
|
const o = { layer: r };
|
|
13235
13235
|
return Object.assign(o, e), new t(o);
|
|
@@ -13275,9 +13275,9 @@ class Ia extends Sa {
|
|
|
13275
13275
|
constructor(t) {
|
|
13276
13276
|
super(t);
|
|
13277
13277
|
const e = t.layer.getConfig(), s = {};
|
|
13278
|
-
s.className = t.layer.getClassName(), s.config = e, this.forwardLayer =
|
|
13278
|
+
s.className = t.layer.getClassName(), s.config = e, this.forwardLayer = Jt(s), e.goBackwards = e.goBackwards !== !0;
|
|
13279
13279
|
const i = {};
|
|
13280
|
-
if (i.className = t.layer.getClassName(), i.config = e, this.backwardLayer =
|
|
13280
|
+
if (i.className = t.layer.getClassName(), i.config = e, this.backwardLayer = Jt(i), this.forwardLayer.name = "forward_" + this.forwardLayer.name, this.backwardLayer.name = "backward_" + this.backwardLayer.name, this.mergeMode = t.mergeMode === void 0 ? Wm : t.mergeMode, Bm(this.mergeMode), t.weights)
|
|
13281
13281
|
throw new B("weights support is not implemented for Bidirectional layer yet.");
|
|
13282
13282
|
this._stateful = t.layer.stateful, this.returnSequences = t.layer.returnSequences, this.returnState = t.layer.returnState, this.supportsMasking = !0, this._trainable = !0, this.inputSpec = t.layer.inputSpec, this.numConstants = null;
|
|
13283
13283
|
}
|
|
@@ -13382,7 +13382,7 @@ class Ia extends Sa {
|
|
|
13382
13382
|
}
|
|
13383
13383
|
/** @nocollapse */
|
|
13384
13384
|
static fromConfig(t, e) {
|
|
13385
|
-
const s =
|
|
13385
|
+
const s = Jt(e.layer);
|
|
13386
13386
|
if (delete e.layer, e.numConstants != null)
|
|
13387
13387
|
throw new B("Deserialization of a Bidirectional layer with numConstants present is not supported yet.");
|
|
13388
13388
|
const i = e;
|
|
@@ -13490,7 +13490,7 @@ function Um(n, t, e, s) {
|
|
|
13490
13490
|
throw new d(`When outputMode is not int, maximum output rank is 2 Received outputMode ${t} and input shape ${r} which would result in output rank ${i.rank}.`);
|
|
13491
13491
|
const o = ["multiHot", "oneHot"].includes(t), a = i;
|
|
13492
13492
|
let l;
|
|
13493
|
-
if (typeof s < "u" && t === "count" ? l =
|
|
13493
|
+
if (typeof s < "u" && t === "count" ? l = Js(a, s, e, o) : l = Js(a, [], e, o), t !== "tfIdf")
|
|
13494
13494
|
return l;
|
|
13495
13495
|
if (s)
|
|
13496
13496
|
return w(l, s);
|
|
@@ -13529,7 +13529,7 @@ class Ta extends G {
|
|
|
13529
13529
|
Received countWeights=${e.countWeights}`);
|
|
13530
13530
|
s = _(e.countWeights);
|
|
13531
13531
|
}
|
|
13532
|
-
const i = Ee(t), r =
|
|
13532
|
+
const i = Ee(t), r = Zu(t), o = Xt(this.numTokens, i).bufferSync().get(0), a = Xe(r, 0).bufferSync().get(0);
|
|
13533
13533
|
if (!(o && a))
|
|
13534
13534
|
throw new d(`Input values must be between 0 < values <= numTokens with numTokens=${this.numTokens}`);
|
|
13535
13535
|
return Um(t, this.outputMode, this.numTokens, s);
|
|
@@ -13699,7 +13699,7 @@ export {
|
|
|
13699
13699
|
mh as B,
|
|
13700
13700
|
Ot as C,
|
|
13701
13701
|
Ws as D,
|
|
13702
|
-
|
|
13702
|
+
Zo as E,
|
|
13703
13703
|
Nh as F,
|
|
13704
13704
|
Ah as G,
|
|
13705
13705
|
Ch as H,
|
|
@@ -13727,7 +13727,7 @@ export {
|
|
|
13727
13727
|
jc as q,
|
|
13728
13728
|
gn as r,
|
|
13729
13729
|
fs as s,
|
|
13730
|
-
|
|
13730
|
+
Zs as t,
|
|
13731
13731
|
qc as u,
|
|
13732
13732
|
dr as v,
|
|
13733
13733
|
Qc as w,
|