@genai-fi/nanogpt 0.4.1 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +3 -3
- package/dist/NanoGPTModel.js +83 -70
- package/dist/TeachableLLM.js +1 -1
- package/dist/{random_width-CMHmdbSu.js → TiedEmbedding-CnJ1bx4q.js} +760 -719
- package/dist/{axis_util-DeydwOoC.js → axis_util-BgTGy5w8.js} +1 -1
- package/dist/{concat-DS_qH7MI.js → concat-CuRsVY-K.js} +1 -1
- package/dist/dropout-DfDdklfL.js +193 -0
- package/dist/{gather-BUmJIS8n.js → gather-ZYRWhmXR.js} +1 -1
- package/dist/gelu-CnCt17Lk.js +26 -0
- package/dist/{index-XjBAhiFO.js → index-C4JCoBvj.js} +61 -61
- package/dist/kernel_funcs_utils-CAd1h9X1.js +388 -0
- package/dist/layers/CausalSelfAttention.js +71 -70
- package/dist/layers/MLP.d.ts +3 -1
- package/dist/layers/MLP.js +93 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +6 -46
- package/dist/layers/TransformerBlock.js +2 -2
- package/dist/{log_sum_exp-DJPkVZZn.js → log_sum_exp-BswFnwOb.js} +5 -5
- package/dist/main.js +1 -1
- package/dist/{mat_mul-CKwFEV1Q.js → mat_mul-415y5Qn2.js} +1 -1
- package/dist/{max-DJvEiCAJ.js → max-CP_9O2Yd.js} +1 -1
- package/dist/{moments-CrWRPcR3.js → moments-CjeIaVdp.js} +3 -3
- package/dist/{norm-BzY929B_.js → norm-CZM380I3.js} +5 -5
- package/dist/{ones-BO01zpJG.js → ones-Bf3YR48P.js} +2 -2
- package/dist/ops/appendCache.js +1 -1
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +13 -9
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +3 -3
- package/dist/ops/cpu/gelu.d.ts +1 -0
- package/dist/ops/cpu/gelu.js +40 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +4 -4
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.d.ts +3 -0
- package/dist/ops/gelu.js +8 -0
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.d.ts +2 -0
- package/dist/ops/grads/gelu.js +5 -0
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +19 -18
- package/dist/ops/webgl/fusedSoftmax.js +483 -782
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.d.ts +2 -0
- package/dist/ops/webgl/gelu.js +50 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{range-DQMNzBWs.js → range-9AzeApCc.js} +1 -1
- package/dist/{reshape-DFzh97Sc.js → reshape-Boe4DuIO.js} +1 -1
- package/dist/{sin-BYM-U4Ut.js → sin-KmhiDuMa.js} +1 -1
- package/dist/{slice_util-CnVNPQI-.js → slice_util-19zDNNSn.js} +2 -2
- package/dist/{softmax-4DOn6cPq.js → softmax-Cujsg4ay.js} +1 -1
- package/dist/{split-CkbeVdF8.js → split-DbcNm1-i.js} +1 -1
- package/dist/{stack-DaIMO5iX.js → stack-D1YjmgKN.js} +1 -1
- package/dist/{sum-C6u3xMi3.js → sum-R28pucR5.js} +1 -1
- package/dist/{tensor-Cu1fU7H7.js → tensor-BVeHdl7V.js} +1 -1
- package/dist/{tensor2d-D0CKdG6B.js → tensor2d-DqFGNs_K.js} +1 -1
- package/dist/{tfjs_backend-Bzl2SrRo.js → tfjs_backend-Cug-PH75.js} +826 -1015
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +3 -3
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +5 -5
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BS4AKqNU.js → variable-LJT9Ld63.js} +1 -1
- package/dist/{zeros-CmJFiC84.js → zeros-dnQxFgAD.js} +1 -1
- package/package.json +1 -1
- package/dist/MLP-KHhikThU.js +0 -83
|
@@ -1,25 +1,26 @@
|
|
|
1
|
-
import { o as F, h as D, E as M, bb as
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import {
|
|
6
|
-
import {
|
|
7
|
-
import {
|
|
8
|
-
import {
|
|
9
|
-
import {
|
|
10
|
-
import { s as
|
|
11
|
-
import {
|
|
12
|
-
import {
|
|
13
|
-
import {
|
|
14
|
-
import {
|
|
15
|
-
import {
|
|
16
|
-
import {
|
|
17
|
-
import {
|
|
18
|
-
import { m as
|
|
19
|
-
import {
|
|
20
|
-
import {
|
|
21
|
-
import {
|
|
22
|
-
import {
|
|
1
|
+
import { o as F, h as D, E as M, bb as fo, bc as mo, bd as mi, j as k, b8 as Mn, be as gi, x as L, bf as bi, bg as yi, bh as wi, bi as ki, bj as xi, bk as Ni, bl as vi, bm as go, bn as Si, bo, bp as Ai, bq as yo, br as Ci, p as Hn, al as wt, bs as wo, bt as Ii, bu as Di, bv as zi, c as On, s as V, b as w, bw as ko, bx as Ti, by as $i, bz as xo, bA as Ei, bB as Li, bC as Fi, bD as Mi, bE as Oi, bF as Ri, bG as _i, bH as Bi, k as No, aa as vo, bI as Wi, bJ as So, a5 as $, bK as Ls, bL as Ao, bM as Co, bN as Io, bO as Do, bP as zo, A as To, bQ as $o, bR as Eo, bS as Lo, bT as S, t as x, f as tt, n as Gi, bU as Be, bV as We, ab as Ft, a as Z, af as Fo, bW as Mo, bX as Oo, Z as ct, a0 as ee, ad as P, bY as Ro, bZ as _o, aN as lt, b_ as Bo, z as Q, b$ as Wo, c0 as Go, c1 as Po, c2 as Uo, c3 as Vo, c4 as jo, c5 as Ko, c6 as Ho, B as qo, c7 as Zo, c8 as Jo, c9 as Xo, aq as Yo, ca as Qo, C as tl, Y as he, cb as el, Q as nl, cc as sl, cd as il, ce as rl, at as al, cf as ol, a3 as ll, au as ul, cg as cl, a9 as hl, ch as pl, G as dl, aw as fl, ci as ml, cj as gl, ck as bl, cl as yl, ay as wl, a4 as kl, cm as xl, cn as Nl, co as vl, M as Sl, cp as Al, cq as Cl, cr as Il, X as Dl, _ as zl, aD as Tl, cs as $l, a6 as El, ct as Ll, aB as Fl, P as Ml, cu as Ol, O as qn, aE as Rl, cv as _l, cw as Bl, N as Wl, aH as Gl, aG as Pl, q as Ul, aV as Vl, cx as jl, aW as Kl, cy as Hl, aI as ql, ar as Zl, an as Jl, cz as Xl, T as Yl, ao as Ql, S as tu, u as eu, cA as nu, cB as su, cC as iu, aK as ru, cD as au, y as ou, cE as lu, a1 as uu, aM as cu, aL as hu, cF as Ie, cG as pu, g as du, cH as Fs, F as Bt, $ as Fe, D as fu, w as mu, ac as xe, cI as gu, cJ as bu, m as Ms, cK as yu, cL as Os, cM as wu } from "./index-C4JCoBvj.js";
|
|
2
|
+
import { s as ku, a as xu, g as Nu, b as vu, V as d, N as G, r as bn, e as Su, l as Au, c as Zn, f as et, h as ye, i as Ge, j as Jn, k as Pi, m as Ui, t as Rt, R as Et, n as ht, A as Pt, o as K, p as le, q as Xn, u as pt, w as Ht, v as Pe, x as Ue, y as Yn, z as j, B as Ee, C as Gt, D as Cu, E as be, F as en, G as Qn, H as ue, I as Ct, J as nt, K as Ve, L as Iu, M as Du, O as je, P as Ut, Q as Lt, S as ts, T as oe, U as jt, W as Xe, X as ce, Y as zu, Z as Rn, _ as nn, $ as yn, a0 as Vi, a1 as It, a2 as Rs, a3 as Tu, a4 as ji, a5 as $u, a6 as Eu, a7 as Lu, a8 as Fu, a9 as Mu, aa as qt, ab as Ou, ac as es, ad as zt, ae as Ye, af as Ru, ag as _t, ah as ot, ai as Ki, aj as gt, ak as ne, al as _s, am as Ne, d as St, an as Bs, ao as Ke, ap as Hi, aq as ns, ar as _u, as as Bu, at as qi, au as Wu } from "./tfjs_backend-Cug-PH75.js";
|
|
3
|
+
import { M as Gu, a as wn, f as Zi } from "./dropout-DfDdklfL.js";
|
|
4
|
+
import { z as mt } from "./zeros-dnQxFgAD.js";
|
|
5
|
+
import { o as pe } from "./ones-Bf3YR48P.js";
|
|
6
|
+
import { v as Ji } from "./variable-LJT9Ld63.js";
|
|
7
|
+
import { r as A } from "./reshape-Boe4DuIO.js";
|
|
8
|
+
import { s as B } from "./sum-R28pucR5.js";
|
|
9
|
+
import { m as Ot } from "./mat_mul-415y5Qn2.js";
|
|
10
|
+
import { s as Kt } from "./split-DbcNm1-i.js";
|
|
11
|
+
import { s as Pu, c as Xi } from "./sin-KmhiDuMa.js";
|
|
12
|
+
import { g as Yi, d as ss, e as Ws, c as Uu } from "./axis_util-BgTGy5w8.js";
|
|
13
|
+
import { a as Zt, e as Jt, l as Vu } from "./log_sum_exp-BswFnwOb.js";
|
|
14
|
+
import { s as kn } from "./stack-D1YjmgKN.js";
|
|
15
|
+
import { p as ju } from "./slice_util-19zDNNSn.js";
|
|
16
|
+
import { c as is } from "./concat-CuRsVY-K.js";
|
|
17
|
+
import { g as Qi } from "./gather-ZYRWhmXR.js";
|
|
18
|
+
import { m as at, a as rs } from "./moments-CjeIaVdp.js";
|
|
19
|
+
import { s as tr } from "./softmax-Cujsg4ay.js";
|
|
20
|
+
import { m as ve } from "./max-CP_9O2Yd.js";
|
|
21
|
+
import { t as Ku } from "./tensor-BVeHdl7V.js";
|
|
22
|
+
import { r as Hu } from "./range-9AzeApCc.js";
|
|
23
|
+
import { m as qu } from "./norm-CZM380I3.js";
|
|
23
24
|
/**
|
|
24
25
|
* @license
|
|
25
26
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -38,7 +39,7 @@ import { v as qu } from "./variable-BS4AKqNU.js";
|
|
|
38
39
|
*/
|
|
39
40
|
function Zu(s, t = null, e = !1) {
|
|
40
41
|
const i = { x: D(s, "x", "all", "bool") }, r = { axis: t, keepDims: e };
|
|
41
|
-
return M.runKernel(
|
|
42
|
+
return M.runKernel(fo, i, r);
|
|
42
43
|
}
|
|
43
44
|
const Ju = /* @__PURE__ */ F({ all_: Zu });
|
|
44
45
|
/**
|
|
@@ -59,7 +60,7 @@ const Ju = /* @__PURE__ */ F({ all_: Zu });
|
|
|
59
60
|
*/
|
|
60
61
|
function Xu(s, t = null, e = !1) {
|
|
61
62
|
const i = { x: D(s, "x", "any", "bool") }, r = { axis: t, keepDims: e };
|
|
62
|
-
return M.runKernel(
|
|
63
|
+
return M.runKernel(mo, i, r);
|
|
63
64
|
}
|
|
64
65
|
const Gs = /* @__PURE__ */ F({ any_: Xu });
|
|
65
66
|
/**
|
|
@@ -473,13 +474,13 @@ function Cc(s, t, e, n, i) {
|
|
|
473
474
|
t.rank === 4 && (o = !0, a = A(t, [1, t.shape[0], t.shape[1], t.shape[2], t.shape[3]]), r = [1, s[0], s[1], s[2], s[3]]);
|
|
474
475
|
const l = r[4], u = a.shape[4];
|
|
475
476
|
k(r.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ${r.length}.`), k(a.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got rank ${a.rank}`), k(e.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got rank ${e.rank}`), k(l === e.shape[3], () => `Error in conv3dDerInput: depth of input (${l}) must match input depth for filter ${e.shape[3]}.`), k(u === e.shape[4], () => `Error in conv3dDerInput: depth of output (${u}) must match output depth for filter ${e.shape[4]}.`);
|
|
476
|
-
const c = { dy: a, filter: e }, h = { pad: i, strides: n, inputShape: r }, p = M.runKernel(
|
|
477
|
+
const c = { dy: a, filter: e }, h = { pad: i, strides: n, inputShape: r }, p = M.runKernel(go, c, h);
|
|
477
478
|
return o ? A(p, [p.shape[1], p.shape[2], p.shape[3], p.shape[4]]) : p;
|
|
478
479
|
}
|
|
479
|
-
const
|
|
480
|
+
const er = /* @__PURE__ */ F({ conv3DBackpropInput_: Cc });
|
|
480
481
|
function Ic(s, t, e, n, i) {
|
|
481
482
|
const r = D(s, "x", "conv3dTranspose"), a = D(t, "filter", "conv3dTranspose");
|
|
482
|
-
return
|
|
483
|
+
return er(e, r, a, n, i);
|
|
483
484
|
}
|
|
484
485
|
const Dc = /* @__PURE__ */ F({ conv3dTranspose_: Ic });
|
|
485
486
|
/**
|
|
@@ -521,7 +522,7 @@ const Tc = /* @__PURE__ */ F({ cosh_: zc });
|
|
|
521
522
|
*/
|
|
522
523
|
function $c(s, t = 0, e = !1, n = !1) {
|
|
523
524
|
const r = { x: D(s, "x", "cumprod") }, a = { axis: t, exclusive: e, reverse: n };
|
|
524
|
-
return M.runKernel(
|
|
525
|
+
return M.runKernel(bo, r, a);
|
|
525
526
|
}
|
|
526
527
|
const Ps = /* @__PURE__ */ F({ cumprod_: $c });
|
|
527
528
|
/**
|
|
@@ -565,7 +566,7 @@ function Fc(s, t, e, n = !1) {
|
|
|
565
566
|
const i = D(s, "x", "denseBincount"), r = D(t, "weights", "denseBincount");
|
|
566
567
|
k(i.dtype === "int32", () => `Error in denseBincount: input dtype must be int32, but got ${i.dtype}`), k(i.rank <= 2, () => `Error in denseBincount: input must be at most rank 2, but got rank ${i.rank}.`), k(e >= 0, () => `size must be non-negative, but got ${e}.`), k(r.size === i.size || r.size === 0, () => `Error in denseBincount: weights must have the same shape as x or 0-length, but got x shape: ${i.shape}, weights shape: ${r.shape}.`);
|
|
567
568
|
const a = { x: i, weights: r }, o = { size: e, binaryOutput: n };
|
|
568
|
-
return M.runKernel(
|
|
569
|
+
return M.runKernel(yo, a, o);
|
|
569
570
|
}
|
|
570
571
|
const Us = /* @__PURE__ */ F({ denseBincount_: Fc });
|
|
571
572
|
/**
|
|
@@ -593,7 +594,7 @@ function Mc(s, t, e, n, i = "NHWC", r = [1, 1], a) {
|
|
|
593
594
|
const p = { x: u, filter: l }, f = { strides: e, pad: n, dataFormat: i, dilations: r, dimRoundingMode: a }, g = M.runKernel(Ci, p, f);
|
|
594
595
|
return c ? A(g, [g.shape[1], g.shape[2], g.shape[3]]) : g;
|
|
595
596
|
}
|
|
596
|
-
const
|
|
597
|
+
const nr = /* @__PURE__ */ F({ depthwiseConv2d_: Mc });
|
|
597
598
|
/**
|
|
598
599
|
* @license
|
|
599
600
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -614,7 +615,7 @@ function Oc(s, t) {
|
|
|
614
615
|
let e = D(s, "a", "equal", "string_or_numeric"), n = D(t, "b", "equal", "string_or_numeric");
|
|
615
616
|
[e, n] = Hn(e, n), wt(e.shape, n.shape);
|
|
616
617
|
const i = { a: e, b: n };
|
|
617
|
-
return M.runKernel(
|
|
618
|
+
return M.runKernel(wo, i);
|
|
618
619
|
}
|
|
619
620
|
const Xt = /* @__PURE__ */ F({ equal_: Oc });
|
|
620
621
|
/**
|
|
@@ -729,7 +730,7 @@ const Uc = /* @__PURE__ */ F({ logSoftmax_: Pc });
|
|
|
729
730
|
*/
|
|
730
731
|
function Vc(s) {
|
|
731
732
|
const e = { x: D(s, "x", "logicalNot", "bool") };
|
|
732
|
-
return M.runKernel(
|
|
733
|
+
return M.runKernel(ko, e);
|
|
733
734
|
}
|
|
734
735
|
const jc = /* @__PURE__ */ F({ logicalNot_: Vc });
|
|
735
736
|
/**
|
|
@@ -800,7 +801,7 @@ function Jc(s, t) {
|
|
|
800
801
|
let e = D(s, "a", "notEqual", "string_or_numeric"), n = D(t, "b", "notEqual", "string_or_numeric");
|
|
801
802
|
[e, n] = Hn(e, n), wt(e.shape, n.shape);
|
|
802
803
|
const i = { a: e, b: n };
|
|
803
|
-
return M.runKernel(
|
|
804
|
+
return M.runKernel(xo, i);
|
|
804
805
|
}
|
|
805
806
|
const Bn = /* @__PURE__ */ F({ notEqual_: Jc });
|
|
806
807
|
/**
|
|
@@ -846,7 +847,7 @@ function Qc(s) {
|
|
|
846
847
|
const e = { x: D(s, "x", "onesLike") };
|
|
847
848
|
return M.runKernel(Li, e);
|
|
848
849
|
}
|
|
849
|
-
const
|
|
850
|
+
const Dt = /* @__PURE__ */ F({ onesLike_: Qc });
|
|
850
851
|
/**
|
|
851
852
|
* @license
|
|
852
853
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -870,7 +871,7 @@ function th(s, t, e = 0) {
|
|
|
870
871
|
const i = { paddings: t, constantValue: e }, r = { x: n };
|
|
871
872
|
return M.runKernel(Fi, r, i);
|
|
872
873
|
}
|
|
873
|
-
const
|
|
874
|
+
const sr = /* @__PURE__ */ F({ pad_: th });
|
|
874
875
|
/**
|
|
875
876
|
* @license
|
|
876
877
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -965,7 +966,7 @@ function lh(s, t, e, n, i, r = [1, 1], a = "NHWC") {
|
|
|
965
966
|
k(c.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got rank ${c.rank}.`), k(l.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but got rank ${l.rank}.`), k(u.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but got rank ${l.rank}.`), k(u.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter must be 1, but got ${u.shape[0]}.`), k(u.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise filter must be 1, but got ${u.shape[1]}.`);
|
|
966
967
|
const p = l.shape[2], f = l.shape[3];
|
|
967
968
|
k(u.shape[2] === p * f, () => `Error in separableConv2d: the third dimension of pointwise filter must be ${p * f}, but got ${u.shape[2]}.`);
|
|
968
|
-
const g =
|
|
969
|
+
const g = nr(c, l, n, i, a, r), m = Ce(g, u, 1, "valid", a);
|
|
969
970
|
return h ? A(m, [m.shape[1], m.shape[2], m.shape[3]]) : m;
|
|
970
971
|
}
|
|
971
972
|
const uh = /* @__PURE__ */ F({ separableConv2d_: lh });
|
|
@@ -1007,14 +1008,14 @@ const hh = /* @__PURE__ */ F({ sinh_: ch });
|
|
|
1007
1008
|
* =============================================================================
|
|
1008
1009
|
*/
|
|
1009
1010
|
function ph(s, t = 0, e = 1, n, i) {
|
|
1010
|
-
if (
|
|
1011
|
+
if (No(s), n != null && n === "bool")
|
|
1011
1012
|
throw new Error("Unsupported data type $ { dtype }");
|
|
1012
|
-
const r = new
|
|
1013
|
+
const r = new Gu(t, e, n, !0, i), a = vo(s, n);
|
|
1013
1014
|
for (let o = 0; o < a.values.length; o++)
|
|
1014
1015
|
a.values[o] = r.nextValue();
|
|
1015
1016
|
return a.toTensor();
|
|
1016
1017
|
}
|
|
1017
|
-
const
|
|
1018
|
+
const ir = /* @__PURE__ */ F({ truncatedNormal_: ph });
|
|
1018
1019
|
/**
|
|
1019
1020
|
* @license
|
|
1020
1021
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1062,7 +1063,7 @@ function mh(s, t, e, n, i, r = "NHWC", a) {
|
|
|
1062
1063
|
const u = r === "NHWC" ? o.shape[3] : o.shape[1], c = r === "NHWC" ? l.shape[3] : l.shape[1];
|
|
1063
1064
|
k(u === e[2], () => `Error in conv2dDerFilter: depth of input ${u}) must match input depth in filter (${e[2]}.`), k(c === e[3], () => `Error in conv2dDerFilter: depth of dy (${c}) must match output depth for filter (${e[3]}).`), ft("conv2dDerFilter", i, a);
|
|
1064
1065
|
const h = { x: o, dy: l }, p = { strides: n, pad: i, dataFormat: r, dimRoundingMode: a, filterShape: e };
|
|
1065
|
-
return M.runKernel(
|
|
1066
|
+
return M.runKernel(So, h, p);
|
|
1066
1067
|
}
|
|
1067
1068
|
const cs = /* @__PURE__ */ F({ conv2DBackpropFilter_: mh });
|
|
1068
1069
|
/**
|
|
@@ -1082,10 +1083,10 @@ const cs = /* @__PURE__ */ F({ conv2DBackpropFilter_: mh });
|
|
|
1082
1083
|
* =============================================================================
|
|
1083
1084
|
*/
|
|
1084
1085
|
function gh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilations: r = [1, 1], dimRoundingMode: a, bias: o, activation: l = "linear", preluActivationWeights: u, leakyreluAlpha: c }) {
|
|
1085
|
-
if (l = l || "linear",
|
|
1086
|
+
if (l = l || "linear", ku(M.state.gradientDepth, l) === !1) {
|
|
1086
1087
|
k(i === "NHWC", () => `Error in fused conv2d: got dataFormat of ${i} but only NHWC is currently supported for the case of gradient depth is 0 and the activation is not linear.`);
|
|
1087
1088
|
let z = Ce(s, t, e, n, i, r, a);
|
|
1088
|
-
return o != null && (z = $(z, o)),
|
|
1089
|
+
return o != null && (z = $(z, o)), xu(z, l, u, c);
|
|
1089
1090
|
}
|
|
1090
1091
|
const h = D(s, "x", "conv2d", "float32"), p = D(t, "filter", "conv2d", "float32");
|
|
1091
1092
|
let f = h, g = !1;
|
|
@@ -1111,12 +1112,12 @@ function gh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilat
|
|
|
1111
1112
|
}
|
|
1112
1113
|
const C = (z, _) => {
|
|
1113
1114
|
k(i === "NHWC", () => `Error in gradient of fused conv2D: got dataFormat of ${i} but only NHWC is currently supported.`);
|
|
1114
|
-
const [T, E, R, q] = _, bt =
|
|
1115
|
+
const [T, E, R, q] = _, bt = Nu(z, R, l);
|
|
1115
1116
|
k(Se(r), () => `Error in gradient of fused conv2D: dilation rates greater than 1 are not yet supported in gradients. Got dilations '${r}'`);
|
|
1116
1117
|
const ie = ls(E.shape, bt, T, e, n), re = cs(E, bt, T.shape, e, n), xt = [ie, re];
|
|
1117
1118
|
if (q != null) {
|
|
1118
|
-
const
|
|
1119
|
-
xt.push(
|
|
1119
|
+
const Tt = vu(q, bt);
|
|
1120
|
+
xt.push(Tt);
|
|
1120
1121
|
}
|
|
1121
1122
|
return xt;
|
|
1122
1123
|
}, N = {
|
|
@@ -1167,7 +1168,7 @@ function yh(s, t, e, n, i, r = [1, 1], a) {
|
|
|
1167
1168
|
let l = t;
|
|
1168
1169
|
l.rank === 3 && (l = A(t, [1, t.shape[0], t.shape[1], t.shape[2]]));
|
|
1169
1170
|
const u = { x: o, dy: l }, c = { strides: n, pad: i, dimRoundingMode: a, dilations: r, filterShape: e };
|
|
1170
|
-
return M.runKernel(
|
|
1171
|
+
return M.runKernel(Ao, u, c);
|
|
1171
1172
|
}
|
|
1172
1173
|
const wh = F({ depthwiseConv2dNativeBackpropFilter_: yh });
|
|
1173
1174
|
/**
|
|
@@ -1191,7 +1192,7 @@ function kh(s, t, e, n, i, r = [1, 1], a) {
|
|
|
1191
1192
|
t.rank === 3 && (l = !0, o = A(t, [1, t.shape[0], t.shape[1], t.shape[2]]));
|
|
1192
1193
|
const u = { dy: o, filter: e }, c = { strides: n, pad: i, dimRoundingMode: a, dilations: r, inputShape: s }, h = (
|
|
1193
1194
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
1194
|
-
M.runKernel(
|
|
1195
|
+
M.runKernel(Co, u, c)
|
|
1195
1196
|
);
|
|
1196
1197
|
return l ? A(h, [h.shape[1], h.shape[2], h.shape[3]]) : h;
|
|
1197
1198
|
}
|
|
@@ -1251,7 +1252,7 @@ class Nh {
|
|
|
1251
1252
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1252
1253
|
*/
|
|
1253
1254
|
static sgd(t) {
|
|
1254
|
-
return new
|
|
1255
|
+
return new Io(t);
|
|
1255
1256
|
}
|
|
1256
1257
|
/**
|
|
1257
1258
|
* Constructs a `tf.MomentumOptimizer` that uses momentum gradient
|
|
@@ -1269,7 +1270,7 @@ class Nh {
|
|
|
1269
1270
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1270
1271
|
*/
|
|
1271
1272
|
static momentum(t, e, n = !1) {
|
|
1272
|
-
return new
|
|
1273
|
+
return new Do(t, e, n);
|
|
1273
1274
|
}
|
|
1274
1275
|
/**
|
|
1275
1276
|
* Constructs a `tf.RMSPropOptimizer` that uses RMSProp gradient
|
|
@@ -1292,7 +1293,7 @@ class Nh {
|
|
|
1292
1293
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1293
1294
|
*/
|
|
1294
1295
|
static rmsprop(t, e = 0.9, n = 0, i = null, r = !1) {
|
|
1295
|
-
return new
|
|
1296
|
+
return new zo(t, e, n, i, r);
|
|
1296
1297
|
}
|
|
1297
1298
|
/**
|
|
1298
1299
|
* Constructs a `tf.AdamOptimizer` that uses the Adam algorithm.
|
|
@@ -1307,7 +1308,7 @@ class Nh {
|
|
|
1307
1308
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1308
1309
|
*/
|
|
1309
1310
|
static adam(t = 1e-3, e = 0.9, n = 0.999, i = null) {
|
|
1310
|
-
return new
|
|
1311
|
+
return new To(t, e, n, i);
|
|
1311
1312
|
}
|
|
1312
1313
|
/**
|
|
1313
1314
|
* Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
|
|
@@ -1322,7 +1323,7 @@ class Nh {
|
|
|
1322
1323
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1323
1324
|
*/
|
|
1324
1325
|
static adadelta(t = 1e-3, e = 0.95, n = null) {
|
|
1325
|
-
return new
|
|
1326
|
+
return new $o(t, e, n);
|
|
1326
1327
|
}
|
|
1327
1328
|
/**
|
|
1328
1329
|
* Constructs a `tf.AdamaxOptimizer` that uses the Adamax algorithm.
|
|
@@ -1338,7 +1339,7 @@ class Nh {
|
|
|
1338
1339
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1339
1340
|
*/
|
|
1340
1341
|
static adamax(t = 2e-3, e = 0.9, n = 0.999, i = null, r = 0) {
|
|
1341
|
-
return new
|
|
1342
|
+
return new Eo(t, e, n, i, r);
|
|
1342
1343
|
}
|
|
1343
1344
|
/**
|
|
1344
1345
|
* Constructs a `tf.AdagradOptimizer` that uses the Adagrad algorithm.
|
|
@@ -1357,7 +1358,7 @@ class Nh {
|
|
|
1357
1358
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1358
1359
|
*/
|
|
1359
1360
|
static adagrad(t, e = 0.1) {
|
|
1360
|
-
return new
|
|
1361
|
+
return new Lo(t, e);
|
|
1361
1362
|
}
|
|
1362
1363
|
}
|
|
1363
1364
|
/**
|
|
@@ -1423,7 +1424,7 @@ const Ah = 1.7580993408473768, Ch = 1.0507009873554805;
|
|
|
1423
1424
|
* https://opensource.org/licenses/MIT.
|
|
1424
1425
|
* =============================================================================
|
|
1425
1426
|
*/
|
|
1426
|
-
class
|
|
1427
|
+
class rr {
|
|
1427
1428
|
constructor(t) {
|
|
1428
1429
|
this.maxEntries = t || 100, this.cache = /* @__PURE__ */ new Map();
|
|
1429
1430
|
}
|
|
@@ -1478,7 +1479,7 @@ class sr {
|
|
|
1478
1479
|
* =============================================================================
|
|
1479
1480
|
*/
|
|
1480
1481
|
let Ih = 0;
|
|
1481
|
-
function
|
|
1482
|
+
function ar() {
|
|
1482
1483
|
return Ih++;
|
|
1483
1484
|
}
|
|
1484
1485
|
const Qe = {};
|
|
@@ -1518,13 +1519,13 @@ class kt extends Be {
|
|
|
1518
1519
|
return {};
|
|
1519
1520
|
}
|
|
1520
1521
|
}
|
|
1521
|
-
class
|
|
1522
|
+
class or extends kt {
|
|
1522
1523
|
apply(t, e) {
|
|
1523
1524
|
return mt(t, e);
|
|
1524
1525
|
}
|
|
1525
1526
|
}
|
|
1526
|
-
|
|
1527
|
-
S(
|
|
1527
|
+
or.className = "Zeros";
|
|
1528
|
+
S(or);
|
|
1528
1529
|
class hs extends kt {
|
|
1529
1530
|
apply(t, e) {
|
|
1530
1531
|
return pe(t, e);
|
|
@@ -1532,7 +1533,7 @@ class hs extends kt {
|
|
|
1532
1533
|
}
|
|
1533
1534
|
hs.className = "Ones";
|
|
1534
1535
|
S(hs);
|
|
1535
|
-
class
|
|
1536
|
+
class lr extends kt {
|
|
1536
1537
|
constructor(t) {
|
|
1537
1538
|
if (super(), typeof t != "object")
|
|
1538
1539
|
throw new d(`Expected argument of type ConstantConfig but got ${t}`);
|
|
@@ -1549,21 +1550,21 @@ class ar extends kt {
|
|
|
1549
1550
|
};
|
|
1550
1551
|
}
|
|
1551
1552
|
}
|
|
1552
|
-
|
|
1553
|
-
S(
|
|
1554
|
-
class
|
|
1553
|
+
lr.className = "Constant";
|
|
1554
|
+
S(lr);
|
|
1555
|
+
class ur extends kt {
|
|
1555
1556
|
constructor(t) {
|
|
1556
1557
|
super(), this.DEFAULT_MINVAL = -0.05, this.DEFAULT_MAXVAL = 0.05, this.minval = t.minval || this.DEFAULT_MINVAL, this.maxval = t.maxval || this.DEFAULT_MAXVAL, this.seed = t.seed;
|
|
1557
1558
|
}
|
|
1558
1559
|
apply(t, e) {
|
|
1559
|
-
return
|
|
1560
|
+
return wn(t, this.minval, this.maxval, e, this.seed);
|
|
1560
1561
|
}
|
|
1561
1562
|
getConfig() {
|
|
1562
1563
|
return { minval: this.minval, maxval: this.maxval, seed: this.seed };
|
|
1563
1564
|
}
|
|
1564
1565
|
}
|
|
1565
|
-
|
|
1566
|
-
S(
|
|
1566
|
+
ur.className = "RandomUniform";
|
|
1567
|
+
S(ur);
|
|
1567
1568
|
class ps extends kt {
|
|
1568
1569
|
constructor(t) {
|
|
1569
1570
|
super(), this.DEFAULT_MEAN = 0, this.DEFAULT_STDDEV = 0.05, this.mean = t.mean || this.DEFAULT_MEAN, this.stddev = t.stddev || this.DEFAULT_STDDEV, this.seed = t.seed;
|
|
@@ -1571,7 +1572,7 @@ class ps extends kt {
|
|
|
1571
1572
|
apply(t, e) {
|
|
1572
1573
|
if (e = e || "float32", e !== "float32" && e !== "int32")
|
|
1573
1574
|
throw new G(`randomNormal does not support dType ${e}.`);
|
|
1574
|
-
return
|
|
1575
|
+
return bn(t, this.mean, this.stddev, e, this.seed);
|
|
1575
1576
|
}
|
|
1576
1577
|
getConfig() {
|
|
1577
1578
|
return { mean: this.mean, stddev: this.stddev, seed: this.seed };
|
|
@@ -1579,22 +1580,22 @@ class ps extends kt {
|
|
|
1579
1580
|
}
|
|
1580
1581
|
ps.className = "RandomNormal";
|
|
1581
1582
|
S(ps);
|
|
1582
|
-
class
|
|
1583
|
+
class cr extends kt {
|
|
1583
1584
|
constructor(t) {
|
|
1584
1585
|
super(), this.DEFAULT_MEAN = 0, this.DEFAULT_STDDEV = 0.05, this.mean = t.mean || this.DEFAULT_MEAN, this.stddev = t.stddev || this.DEFAULT_STDDEV, this.seed = t.seed;
|
|
1585
1586
|
}
|
|
1586
1587
|
apply(t, e) {
|
|
1587
1588
|
if (e = e || "float32", e !== "float32" && e !== "int32")
|
|
1588
1589
|
throw new G(`truncatedNormal does not support dType ${e}.`);
|
|
1589
|
-
return
|
|
1590
|
+
return ir(t, this.mean, this.stddev, e, this.seed);
|
|
1590
1591
|
}
|
|
1591
1592
|
getConfig() {
|
|
1592
1593
|
return { mean: this.mean, stddev: this.stddev, seed: this.seed };
|
|
1593
1594
|
}
|
|
1594
1595
|
}
|
|
1595
|
-
|
|
1596
|
-
S(
|
|
1597
|
-
class
|
|
1596
|
+
cr.className = "TruncatedNormal";
|
|
1597
|
+
S(cr);
|
|
1598
|
+
class hr extends kt {
|
|
1598
1599
|
constructor(t) {
|
|
1599
1600
|
super(), this.gain = t.gain != null ? t.gain : 1;
|
|
1600
1601
|
}
|
|
@@ -1602,15 +1603,15 @@ class ur extends kt {
|
|
|
1602
1603
|
return x(() => {
|
|
1603
1604
|
if (t.length !== 2 || t[0] !== t[1])
|
|
1604
1605
|
throw new d("Identity matrix initializer can only be used for 2D square matrices.");
|
|
1605
|
-
return w(this.gain,
|
|
1606
|
+
return w(this.gain, Su(t[0]));
|
|
1606
1607
|
});
|
|
1607
1608
|
}
|
|
1608
1609
|
getConfig() {
|
|
1609
1610
|
return { gain: this.gain };
|
|
1610
1611
|
}
|
|
1611
1612
|
}
|
|
1612
|
-
|
|
1613
|
-
S(
|
|
1613
|
+
hr.className = "Identity";
|
|
1614
|
+
S(hr);
|
|
1614
1615
|
function Eh(s, t = "channelsLast") {
|
|
1615
1616
|
let e, n;
|
|
1616
1617
|
if (et(t), s.length === 2)
|
|
@@ -1646,10 +1647,10 @@ class dt extends kt {
|
|
|
1646
1647
|
const o = Math.sqrt(a);
|
|
1647
1648
|
if (e = e || "float32", e !== "float32" && e !== "int32")
|
|
1648
1649
|
throw new G(`${this.getClassName()} does not support dType ${e}.`);
|
|
1649
|
-
return
|
|
1650
|
+
return ir(t, 0, o, e, this.seed);
|
|
1650
1651
|
} else {
|
|
1651
1652
|
const o = Math.sqrt(3 * a);
|
|
1652
|
-
return
|
|
1653
|
+
return wn(t, -o, o, e, this.seed);
|
|
1653
1654
|
}
|
|
1654
1655
|
}
|
|
1655
1656
|
getConfig() {
|
|
@@ -1767,7 +1768,7 @@ class ys extends dt {
|
|
|
1767
1768
|
}
|
|
1768
1769
|
ys.className = "LeCunUniform";
|
|
1769
1770
|
S(ys);
|
|
1770
|
-
class
|
|
1771
|
+
class pr extends kt {
|
|
1771
1772
|
constructor(t) {
|
|
1772
1773
|
super(), this.DEFAULT_GAIN = 1, this.ELEMENTS_WARN_SLOW = 2e3, this.gain = t.gain == null ? this.DEFAULT_GAIN : t.gain, this.seed = t.seed;
|
|
1773
1774
|
}
|
|
@@ -1780,7 +1781,7 @@ class cr extends kt {
|
|
|
1780
1781
|
e = e;
|
|
1781
1782
|
const n = Gi(t.slice(0, -1)), i = t[t.length - 1], r = n * i;
|
|
1782
1783
|
r > this.ELEMENTS_WARN_SLOW && console.warn(`Orthogonal initializer is being called on a matrix with more than ${this.ELEMENTS_WARN_SLOW} (${r}) elements: Slowness may result.`);
|
|
1783
|
-
const a = [Math.max(i, n), Math.min(i, n)], o =
|
|
1784
|
+
const a = [Math.max(i, n), Math.min(i, n)], o = bn(a, 0, 1, e, this.seed), l = Au.qr(o, !1);
|
|
1784
1785
|
let u = l[0];
|
|
1785
1786
|
const h = l[1].flatten().stridedSlice([0], [Math.min(i, n) * Math.min(i, n)], [Math.min(i, n) + 1]);
|
|
1786
1787
|
return u = w(u, h.sign()), n < i && (u = u.transpose()), w(tt(this.gain), u.reshape(t));
|
|
@@ -1793,8 +1794,8 @@ class cr extends kt {
|
|
|
1793
1794
|
};
|
|
1794
1795
|
}
|
|
1795
1796
|
}
|
|
1796
|
-
|
|
1797
|
-
S(
|
|
1797
|
+
pr.className = "Orthogonal";
|
|
1798
|
+
S(pr);
|
|
1798
1799
|
const Vs = {
|
|
1799
1800
|
constant: "Constant",
|
|
1800
1801
|
glorotNormal: "GlorotNormal",
|
|
@@ -1913,7 +1914,7 @@ class Lh {
|
|
|
1913
1914
|
* @throws ValueError if `name` is `null` or `undefined`.
|
|
1914
1915
|
*/
|
|
1915
1916
|
constructor(t, e = "float32", n = Ks, i = !0, r = null) {
|
|
1916
|
-
this.dtype = e ?? "float32", this.shape = t.shape, this.id =
|
|
1917
|
+
this.dtype = e ?? "float32", this.shape = t.shape, this.id = ar(), n = n ?? Ks, this.originalName = Pi(n), this.name = Ui(this.originalName), this.trainable_ = i, this.constraint = r, this.val = Ji(t, this.trainable_, this.name, this.dtype);
|
|
1917
1918
|
}
|
|
1918
1919
|
/**
|
|
1919
1920
|
* Get a snapshot of the Variable's value.
|
|
@@ -1993,7 +1994,7 @@ class Mt {
|
|
|
1993
1994
|
* returned by apply().
|
|
1994
1995
|
*/
|
|
1995
1996
|
constructor(t, e, n, i, r, a, o) {
|
|
1996
|
-
this.dtype = t, this.shape = e, this.sourceLayer = n, this.inputs = i, this.callArgs = r, this.outputTensorIndex = o, this.id =
|
|
1997
|
+
this.dtype = t, this.shape = e, this.sourceLayer = n, this.inputs = i, this.callArgs = r, this.outputTensorIndex = o, this.id = ar(), a != null && (this.originalName = Pi(a), this.name = Ui(this.originalName)), this.rank = e.length;
|
|
1997
1998
|
}
|
|
1998
1999
|
}
|
|
1999
2000
|
let Mh = 0;
|
|
@@ -2060,7 +2061,7 @@ class W extends Be {
|
|
|
2060
2061
|
*/
|
|
2061
2062
|
getNodeAtIndex(t, e) {
|
|
2062
2063
|
if (this.inboundNodes.length === 0)
|
|
2063
|
-
throw new
|
|
2064
|
+
throw new Et(`The layer has never been called and thus has no defined ${e}.`);
|
|
2064
2065
|
if (this.inboundNodes.length <= t)
|
|
2065
2066
|
throw new d(`Asked to get ${e} at node ${t}, but the layer has only ${this.inboundNodes.length} inbound nodes.`);
|
|
2066
2067
|
return this.inboundNodes[t];
|
|
@@ -2419,7 +2420,7 @@ class W extends Be {
|
|
|
2419
2420
|
*/
|
|
2420
2421
|
countParams() {
|
|
2421
2422
|
if (!this.built)
|
|
2422
|
-
throw new
|
|
2423
|
+
throw new Et(`You tried to call countParams() on ${this.name}, but the layer is not built yet. Build it first by calling build(batchInputShape).`);
|
|
2423
2424
|
return un(this.weights);
|
|
2424
2425
|
}
|
|
2425
2426
|
/**
|
|
@@ -2688,7 +2689,7 @@ function Rh(s) {
|
|
|
2688
2689
|
function _h(s) {
|
|
2689
2690
|
return "float32";
|
|
2690
2691
|
}
|
|
2691
|
-
function
|
|
2692
|
+
function dr(s, t, e) {
|
|
2692
2693
|
if ((t == null || e != null && e > 0) && (t = s.sourceLayer, e = s.nodeIndex), t.inboundNodes.length === 0)
|
|
2693
2694
|
return [s];
|
|
2694
2695
|
{
|
|
@@ -2698,7 +2699,7 @@ function hr(s, t, e) {
|
|
|
2698
2699
|
{
|
|
2699
2700
|
const i = [];
|
|
2700
2701
|
for (let r = 0; r < n.inboundLayers.length; r++) {
|
|
2701
|
-
const a = n.inputTensors[r], o = n.inboundLayers[r], l = n.nodeIndices[r], u =
|
|
2702
|
+
const a = n.inputTensors[r], o = n.inboundLayers[r], l = n.nodeIndices[r], u = dr(a, o, l);
|
|
2702
2703
|
for (const c of u)
|
|
2703
2704
|
i.indexOf(c) === -1 && i.push(c);
|
|
2704
2705
|
}
|
|
@@ -2911,7 +2912,7 @@ class Vt {
|
|
|
2911
2912
|
this.id2Mask != null && Z(this.id2Mask);
|
|
2912
2913
|
}
|
|
2913
2914
|
}
|
|
2914
|
-
const cn = new
|
|
2915
|
+
const cn = new rr(), hn = new rr();
|
|
2915
2916
|
function Uh(s) {
|
|
2916
2917
|
cn?.setMaxEntries(s), hn?.setMaxEntries(s);
|
|
2917
2918
|
}
|
|
@@ -3032,7 +3033,7 @@ function Kh(s) {
|
|
|
3032
3033
|
* limitations under the License.
|
|
3033
3034
|
* =============================================================================
|
|
3034
3035
|
*/
|
|
3035
|
-
const Hh =
|
|
3036
|
+
const Hh = Fo();
|
|
3036
3037
|
Hh.registerFlag("TOPOLOGICAL_SORT_CACHE_MAX_ENTRIES", () => 100, Uh);
|
|
3037
3038
|
/**
|
|
3038
3039
|
* @license
|
|
@@ -3050,8 +3051,8 @@ Hh.registerFlag("TOPOLOGICAL_SORT_CACHE_MAX_ENTRIES", () => 100, Uh);
|
|
|
3050
3051
|
* limitations under the License.
|
|
3051
3052
|
* =============================================================================
|
|
3052
3053
|
*/
|
|
3053
|
-
const
|
|
3054
|
-
kernelName:
|
|
3054
|
+
const fr = {
|
|
3055
|
+
kernelName: Mo,
|
|
3055
3056
|
inputsToSave: ["x"],
|
|
3056
3057
|
gradFunc: (s, t) => {
|
|
3057
3058
|
const [e] = t;
|
|
@@ -3075,7 +3076,7 @@ const pr = {
|
|
|
3075
3076
|
* =============================================================================
|
|
3076
3077
|
*/
|
|
3077
3078
|
const qh = {
|
|
3078
|
-
kernelName:
|
|
3079
|
+
kernelName: Oo,
|
|
3079
3080
|
inputsToSave: ["x"],
|
|
3080
3081
|
gradFunc: (s, t) => {
|
|
3081
3082
|
const [e] = t;
|
|
@@ -3104,7 +3105,7 @@ const qh = {
|
|
|
3104
3105
|
* =============================================================================
|
|
3105
3106
|
*/
|
|
3106
3107
|
const Zh = {
|
|
3107
|
-
kernelName:
|
|
3108
|
+
kernelName: Ro,
|
|
3108
3109
|
inputsToSave: ["x"],
|
|
3109
3110
|
gradFunc: (s, t) => {
|
|
3110
3111
|
const [e] = t;
|
|
@@ -3133,7 +3134,7 @@ const Zh = {
|
|
|
3133
3134
|
* =============================================================================
|
|
3134
3135
|
*/
|
|
3135
3136
|
const Jh = {
|
|
3136
|
-
kernelName:
|
|
3137
|
+
kernelName: _o,
|
|
3137
3138
|
inputsToSave: ["a", "b"],
|
|
3138
3139
|
gradFunc: (s, t) => {
|
|
3139
3140
|
const [e, n] = t, i = wt(e.shape, n.shape);
|
|
@@ -3165,7 +3166,7 @@ const Jh = {
|
|
|
3165
3166
|
* =============================================================================
|
|
3166
3167
|
*/
|
|
3167
3168
|
const Xh = {
|
|
3168
|
-
kernelName:
|
|
3169
|
+
kernelName: Bo,
|
|
3169
3170
|
saveAllInputs: !0,
|
|
3170
3171
|
gradFunc: (s, t) => {
|
|
3171
3172
|
const e = {};
|
|
@@ -3215,7 +3216,7 @@ const Yh = {
|
|
|
3215
3216
|
* =============================================================================
|
|
3216
3217
|
*/
|
|
3217
3218
|
const Qh = {
|
|
3218
|
-
kernelName:
|
|
3219
|
+
kernelName: Wo,
|
|
3219
3220
|
inputsToSave: ["x"],
|
|
3220
3221
|
gradFunc: (s, t) => {
|
|
3221
3222
|
const [e] = t;
|
|
@@ -3239,7 +3240,7 @@ const Qh = {
|
|
|
3239
3240
|
* =============================================================================
|
|
3240
3241
|
*/
|
|
3241
3242
|
const tp = {
|
|
3242
|
-
kernelName:
|
|
3243
|
+
kernelName: Go,
|
|
3243
3244
|
inputsToSave: ["x"],
|
|
3244
3245
|
gradFunc: (s, t) => {
|
|
3245
3246
|
const [e] = t;
|
|
@@ -3263,7 +3264,7 @@ const tp = {
|
|
|
3263
3264
|
* =============================================================================
|
|
3264
3265
|
*/
|
|
3265
3266
|
const ep = {
|
|
3266
|
-
kernelName:
|
|
3267
|
+
kernelName: Po,
|
|
3267
3268
|
inputsToSave: ["x"],
|
|
3268
3269
|
gradFunc: (s, t) => {
|
|
3269
3270
|
const [e] = t;
|
|
@@ -3292,7 +3293,7 @@ const ep = {
|
|
|
3292
3293
|
* =============================================================================
|
|
3293
3294
|
*/
|
|
3294
3295
|
const np = {
|
|
3295
|
-
kernelName:
|
|
3296
|
+
kernelName: Uo,
|
|
3296
3297
|
inputsToSave: ["a", "b"],
|
|
3297
3298
|
gradFunc: (s, t) => {
|
|
3298
3299
|
const [e, n] = t, i = wt(e.shape, n.shape);
|
|
@@ -3326,7 +3327,7 @@ const np = {
|
|
|
3326
3327
|
* =============================================================================
|
|
3327
3328
|
*/
|
|
3328
3329
|
const sp = {
|
|
3329
|
-
kernelName:
|
|
3330
|
+
kernelName: Vo,
|
|
3330
3331
|
inputsToSave: ["x"],
|
|
3331
3332
|
gradFunc: (s, t) => {
|
|
3332
3333
|
const [e] = t;
|
|
@@ -3350,7 +3351,7 @@ const sp = {
|
|
|
3350
3351
|
* =============================================================================
|
|
3351
3352
|
*/
|
|
3352
3353
|
const ip = {
|
|
3353
|
-
kernelName:
|
|
3354
|
+
kernelName: jo,
|
|
3354
3355
|
inputsToSave: ["x"],
|
|
3355
3356
|
gradFunc: (s, t) => {
|
|
3356
3357
|
const [e] = t;
|
|
@@ -3383,7 +3384,7 @@ function rp(s, t, e, n, i, r) {
|
|
|
3383
3384
|
o.shape[2],
|
|
3384
3385
|
o.shape[3]
|
|
3385
3386
|
])), k(l.rank === 5, () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ${l.rank}.`), k(u.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ${u.rank}.`), ft("avgPool3dGrad", i, r);
|
|
3386
|
-
const h = { dy: l, input: u }, p = { filterSize: e, strides: n, pad: i, dimRoundingMode: r }, f = M.runKernel(
|
|
3387
|
+
const h = { dy: l, input: u }, p = { filterSize: e, strides: n, pad: i, dimRoundingMode: r }, f = M.runKernel(Ko, h, p);
|
|
3387
3388
|
return c ? A(f, [f.shape[1], f.shape[2], f.shape[3], f.shape[4]]) : f;
|
|
3388
3389
|
}
|
|
3389
3390
|
const ap = /* @__PURE__ */ F({ avgPool3dGrad_: rp });
|
|
@@ -3434,7 +3435,7 @@ function lp(s, t, e, n, i) {
|
|
|
3434
3435
|
k(a.rank === r.rank, () => `Rank of input (${a.rank}) does not match rank of dy (${r.rank})`);
|
|
3435
3436
|
let o = a, l = r, u = !1;
|
|
3436
3437
|
a.rank === 3 && (u = !0, o = A(a, [1, a.shape[0], a.shape[1], a.shape[2]]), l = A(r, [1, r.shape[0], r.shape[1], r.shape[2]])), k(l.rank === 4, () => `Error in avgPoolGrad: dy must be rank 4 but got rank ${l.rank}.`), k(o.rank === 4, () => `Error in avgPoolGrad: input must be rank 4 but got rank ${o.rank}.`);
|
|
3437
|
-
const c = { dy: l, input: o }, h = { filterSize: e, strides: n, pad: i }, p = M.runKernel(
|
|
3438
|
+
const c = { dy: l, input: o }, h = { filterSize: e, strides: n, pad: i }, p = M.runKernel(Ho, c, h);
|
|
3438
3439
|
return u ? A(p, [p.shape[1], p.shape[2], p.shape[3]]) : p;
|
|
3439
3440
|
}
|
|
3440
3441
|
const up = /* @__PURE__ */ F({ avgPoolGrad_: lp });
|
|
@@ -3479,7 +3480,7 @@ const cp = {
|
|
|
3479
3480
|
* =============================================================================
|
|
3480
3481
|
*/
|
|
3481
3482
|
const hp = {
|
|
3482
|
-
kernelName:
|
|
3483
|
+
kernelName: qo,
|
|
3483
3484
|
inputsToSave: ["a", "b"],
|
|
3484
3485
|
gradFunc: (s, t, e) => {
|
|
3485
3486
|
const [n, i] = t, { transposeA: r, transposeB: a } = e;
|
|
@@ -3538,7 +3539,7 @@ const pp = {
|
|
|
3538
3539
|
* =============================================================================
|
|
3539
3540
|
*/
|
|
3540
3541
|
const dp = {
|
|
3541
|
-
kernelName:
|
|
3542
|
+
kernelName: Zo,
|
|
3542
3543
|
gradFunc: (s, t, e) => {
|
|
3543
3544
|
const n = e, i = n.inputShape, r = n.shape, a = Array.from(r);
|
|
3544
3545
|
for (let l = i.length - 1; l >= 0; l--)
|
|
@@ -3574,7 +3575,7 @@ const dp = {
|
|
|
3574
3575
|
* =============================================================================
|
|
3575
3576
|
*/
|
|
3576
3577
|
const fp = {
|
|
3577
|
-
kernelName:
|
|
3578
|
+
kernelName: Jo,
|
|
3578
3579
|
gradFunc: (s) => ({ x: () => s.clone() })
|
|
3579
3580
|
};
|
|
3580
3581
|
/**
|
|
@@ -3594,7 +3595,7 @@ const fp = {
|
|
|
3594
3595
|
* =============================================================================
|
|
3595
3596
|
*/
|
|
3596
3597
|
const mp = {
|
|
3597
|
-
kernelName:
|
|
3598
|
+
kernelName: Xo,
|
|
3598
3599
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
3599
3600
|
};
|
|
3600
3601
|
/**
|
|
@@ -3614,7 +3615,7 @@ const mp = {
|
|
|
3614
3615
|
* =============================================================================
|
|
3615
3616
|
*/
|
|
3616
3617
|
const gp = {
|
|
3617
|
-
kernelName:
|
|
3618
|
+
kernelName: Yo,
|
|
3618
3619
|
inputsToSave: ["x"],
|
|
3619
3620
|
gradFunc: (s, t, e) => {
|
|
3620
3621
|
const [n] = t, { clipValueMin: i, clipValueMax: r } = e;
|
|
@@ -3640,9 +3641,9 @@ const gp = {
|
|
|
3640
3641
|
* =============================================================================
|
|
3641
3642
|
*/
|
|
3642
3643
|
const bp = {
|
|
3643
|
-
kernelName:
|
|
3644
|
+
kernelName: Qo,
|
|
3644
3645
|
inputsToSave: ["x"],
|
|
3645
|
-
gradFunc:
|
|
3646
|
+
gradFunc: fr.gradFunc
|
|
3646
3647
|
};
|
|
3647
3648
|
/**
|
|
3648
3649
|
* @license
|
|
@@ -3661,7 +3662,7 @@ const bp = {
|
|
|
3661
3662
|
* =============================================================================
|
|
3662
3663
|
*/
|
|
3663
3664
|
const yp = {
|
|
3664
|
-
kernelName:
|
|
3665
|
+
kernelName: tl,
|
|
3665
3666
|
saveAllInputs: !0,
|
|
3666
3667
|
gradFunc: (s, t, e) => {
|
|
3667
3668
|
const n = t.map((l) => l.shape), { axis: i } = e, r = he(i, t[0].shape)[0], a = n.map((l) => l[r]);
|
|
@@ -3744,7 +3745,7 @@ function xp(s, t, e, n, i) {
|
|
|
3744
3745
|
let a = t;
|
|
3745
3746
|
a.rank === 4 && (a = A(t, [1, t.shape[0], t.shape[1], t.shape[2], t.shape[3]])), k(r.rank === 5, () => `Error in conv3dDerFilter: input must be rank 5, but got shape ${r.shape}.`), k(a.rank === 5, () => `Error in conv3dDerFilter: dy must be rank 5, but got shape ${a.shape}.`), k(e.length === 5, () => `Error in conv3dDerFilter: filterShape must be length 5, but got ${e}.`), k(r.shape[4] === e[3], () => `Error in conv3dDerFilter: depth of input ${r.shape[4]}) must match input depth in filter (${e[3]}.`), k(a.shape[4] === e[4], () => `Error in conv3dDerFilter: depth of dy (${a.shape[4]}) must match output depth for filter (${e[4]}).`);
|
|
3746
3747
|
const o = { x: r, dy: a }, l = { strides: n, pad: i, filterShape: e };
|
|
3747
|
-
return M.runKernel(
|
|
3748
|
+
return M.runKernel(el, o, l);
|
|
3748
3749
|
}
|
|
3749
3750
|
const Np = /* @__PURE__ */ F({ conv3DBackpropFilter_: xp });
|
|
3750
3751
|
/**
|
|
@@ -3771,7 +3772,7 @@ const vp = {
|
|
|
3771
3772
|
k(Se(n), () => `Error in gradient of conv3D: dilation rates greater than 1 are not yet supported in gradients. Got dilations '${n}'`);
|
|
3772
3773
|
const [a, o] = t;
|
|
3773
3774
|
return {
|
|
3774
|
-
x: () =>
|
|
3775
|
+
x: () => er(a.shape, s, o, i, r),
|
|
3775
3776
|
filter: () => Np(a, s, o.shape, i, r)
|
|
3776
3777
|
};
|
|
3777
3778
|
}
|
|
@@ -3793,11 +3794,11 @@ const vp = {
|
|
|
3793
3794
|
* =============================================================================
|
|
3794
3795
|
*/
|
|
3795
3796
|
const Sp = {
|
|
3796
|
-
kernelName:
|
|
3797
|
+
kernelName: nl,
|
|
3797
3798
|
inputsToSave: ["x"],
|
|
3798
3799
|
gradFunc: (s, t) => {
|
|
3799
3800
|
const [e] = t;
|
|
3800
|
-
return { x: () => w(pt(
|
|
3801
|
+
return { x: () => w(pt(Pu(L(e, "float32"))), s) };
|
|
3801
3802
|
}
|
|
3802
3803
|
};
|
|
3803
3804
|
/**
|
|
@@ -3847,7 +3848,7 @@ const Cp = {
|
|
|
3847
3848
|
const [n] = t, { axis: i, exclusive: r, reverse: a } = e;
|
|
3848
3849
|
return {
|
|
3849
3850
|
x: () => {
|
|
3850
|
-
const o =
|
|
3851
|
+
const o = Yi([i], n.rank);
|
|
3851
3852
|
let l = Lc(s, i, r, !a);
|
|
3852
3853
|
return o != null && (l = j(l, o)), l;
|
|
3853
3854
|
}
|
|
@@ -3900,13 +3901,13 @@ const Ip = {
|
|
|
3900
3901
|
* =============================================================================
|
|
3901
3902
|
*/
|
|
3902
3903
|
const Dp = {
|
|
3903
|
-
kernelName:
|
|
3904
|
+
kernelName: sl,
|
|
3904
3905
|
inputsToSave: ["x", "filter"],
|
|
3905
3906
|
gradFunc: (s, t, e) => {
|
|
3906
3907
|
const [n, i] = t, r = { x: n, filter: i, dy: s }, a = { x: n, filter: i, dy: s };
|
|
3907
3908
|
return {
|
|
3908
|
-
x: () => M.runKernel(
|
|
3909
|
-
filter: () => M.runKernel(
|
|
3909
|
+
x: () => M.runKernel(rl, r, e),
|
|
3910
|
+
filter: () => M.runKernel(il, a, e)
|
|
3910
3911
|
};
|
|
3911
3912
|
}
|
|
3912
3913
|
};
|
|
@@ -3927,11 +3928,11 @@ const Dp = {
|
|
|
3927
3928
|
* =============================================================================
|
|
3928
3929
|
*/
|
|
3929
3930
|
const zp = {
|
|
3930
|
-
kernelName:
|
|
3931
|
+
kernelName: al,
|
|
3931
3932
|
outputsToSave: [!0],
|
|
3932
3933
|
gradFunc: (s, t) => {
|
|
3933
3934
|
const [e] = t, n = { dy: s, y: e };
|
|
3934
|
-
return { x: () => M.runKernel(
|
|
3935
|
+
return { x: () => M.runKernel(ol, n) };
|
|
3935
3936
|
}
|
|
3936
3937
|
};
|
|
3937
3938
|
/**
|
|
@@ -3975,7 +3976,7 @@ const Tp = {
|
|
|
3975
3976
|
* =============================================================================
|
|
3976
3977
|
*/
|
|
3977
3978
|
const $p = {
|
|
3978
|
-
kernelName:
|
|
3979
|
+
kernelName: ll,
|
|
3979
3980
|
outputsToSave: [!0],
|
|
3980
3981
|
gradFunc: (s, t) => {
|
|
3981
3982
|
const [e] = t;
|
|
@@ -3999,7 +4000,7 @@ const $p = {
|
|
|
3999
4000
|
* =============================================================================
|
|
4000
4001
|
*/
|
|
4001
4002
|
const Ep = {
|
|
4002
|
-
kernelName:
|
|
4003
|
+
kernelName: ul,
|
|
4003
4004
|
inputsToSave: ["input"],
|
|
4004
4005
|
gradFunc: (s, t) => {
|
|
4005
4006
|
const [e] = t;
|
|
@@ -4023,7 +4024,7 @@ const Ep = {
|
|
|
4023
4024
|
* =============================================================================
|
|
4024
4025
|
*/
|
|
4025
4026
|
const Lp = {
|
|
4026
|
-
kernelName:
|
|
4027
|
+
kernelName: cl,
|
|
4027
4028
|
inputsToSave: ["x"],
|
|
4028
4029
|
gradFunc: (s, t) => {
|
|
4029
4030
|
const [e] = t;
|
|
@@ -4047,7 +4048,7 @@ const Lp = {
|
|
|
4047
4048
|
* =============================================================================
|
|
4048
4049
|
*/
|
|
4049
4050
|
const Fp = {
|
|
4050
|
-
kernelName:
|
|
4051
|
+
kernelName: hl,
|
|
4051
4052
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4052
4053
|
};
|
|
4053
4054
|
/**
|
|
@@ -4067,7 +4068,7 @@ const Fp = {
|
|
|
4067
4068
|
* =============================================================================
|
|
4068
4069
|
*/
|
|
4069
4070
|
const Mp = {
|
|
4070
|
-
kernelName:
|
|
4071
|
+
kernelName: pl,
|
|
4071
4072
|
inputsToSave: ["a", "b"],
|
|
4072
4073
|
gradFunc: (s, t) => {
|
|
4073
4074
|
const [e, n] = t, i = wt(e.shape, n.shape);
|
|
@@ -4149,7 +4150,7 @@ const Op = {
|
|
|
4149
4150
|
* =============================================================================
|
|
4150
4151
|
*/
|
|
4151
4152
|
const Rp = {
|
|
4152
|
-
kernelName:
|
|
4153
|
+
kernelName: dl,
|
|
4153
4154
|
inputsToSave: ["x", "indices"],
|
|
4154
4155
|
gradFunc: (s, t, e) => {
|
|
4155
4156
|
const [n, i] = t, { axis: r, batchDims: a } = e, o = he(r, n.shape)[0], l = (u, c, h) => () => {
|
|
@@ -4199,7 +4200,7 @@ function Zs(s) {
|
|
|
4199
4200
|
* =============================================================================
|
|
4200
4201
|
*/
|
|
4201
4202
|
const _p = {
|
|
4202
|
-
kernelName:
|
|
4203
|
+
kernelName: fl,
|
|
4203
4204
|
inputsToSave: ["a", "b"],
|
|
4204
4205
|
gradFunc: (s, t) => {
|
|
4205
4206
|
const [e, n] = t;
|
|
@@ -4223,7 +4224,7 @@ const _p = {
|
|
|
4223
4224
|
* =============================================================================
|
|
4224
4225
|
*/
|
|
4225
4226
|
const Bp = {
|
|
4226
|
-
kernelName:
|
|
4227
|
+
kernelName: ml,
|
|
4227
4228
|
gradFunc: (s) => ({ x: () => L(s, "float32") })
|
|
4228
4229
|
};
|
|
4229
4230
|
/**
|
|
@@ -4243,7 +4244,7 @@ const Bp = {
|
|
|
4243
4244
|
* =============================================================================
|
|
4244
4245
|
*/
|
|
4245
4246
|
const Wp = {
|
|
4246
|
-
kernelName:
|
|
4247
|
+
kernelName: gl,
|
|
4247
4248
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4248
4249
|
};
|
|
4249
4250
|
/**
|
|
@@ -4263,7 +4264,7 @@ const Wp = {
|
|
|
4263
4264
|
* =============================================================================
|
|
4264
4265
|
*/
|
|
4265
4266
|
const Gp = {
|
|
4266
|
-
kernelName:
|
|
4267
|
+
kernelName: bl,
|
|
4267
4268
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4268
4269
|
};
|
|
4269
4270
|
/**
|
|
@@ -4283,7 +4284,7 @@ const Gp = {
|
|
|
4283
4284
|
* =============================================================================
|
|
4284
4285
|
*/
|
|
4285
4286
|
const Pp = {
|
|
4286
|
-
kernelName:
|
|
4287
|
+
kernelName: yl,
|
|
4287
4288
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4288
4289
|
};
|
|
4289
4290
|
/**
|
|
@@ -4303,7 +4304,7 @@ const Pp = {
|
|
|
4303
4304
|
* =============================================================================
|
|
4304
4305
|
*/
|
|
4305
4306
|
const Up = {
|
|
4306
|
-
kernelName:
|
|
4307
|
+
kernelName: wl,
|
|
4307
4308
|
inputsToSave: ["x"],
|
|
4308
4309
|
gradFunc: (s, t, e) => {
|
|
4309
4310
|
const [n] = t, { alpha: i } = e, r = Gt(n, 0);
|
|
@@ -4351,7 +4352,7 @@ const Vp = {
|
|
|
4351
4352
|
* =============================================================================
|
|
4352
4353
|
*/
|
|
4353
4354
|
const jp = {
|
|
4354
|
-
kernelName:
|
|
4355
|
+
kernelName: kl,
|
|
4355
4356
|
inputsToSave: ["x"],
|
|
4356
4357
|
gradFunc: (s, t) => {
|
|
4357
4358
|
const [e] = t;
|
|
@@ -4375,7 +4376,7 @@ const jp = {
|
|
|
4375
4376
|
* =============================================================================
|
|
4376
4377
|
*/
|
|
4377
4378
|
const Kp = {
|
|
4378
|
-
kernelName:
|
|
4379
|
+
kernelName: xl,
|
|
4379
4380
|
inputsToSave: [],
|
|
4380
4381
|
outputsToSave: [!0],
|
|
4381
4382
|
gradFunc: (s, t, e) => {
|
|
@@ -4406,7 +4407,7 @@ const Kp = {
|
|
|
4406
4407
|
*/
|
|
4407
4408
|
function Hp(s, t, e, n = 5, i = 1, r = 1, a = 0.5) {
|
|
4408
4409
|
const o = { x: s, y: t, dy: e }, l = { depthRadius: n, bias: i, alpha: r, beta: a };
|
|
4409
|
-
return M.runKernel(
|
|
4410
|
+
return M.runKernel(Nl, o, l);
|
|
4410
4411
|
}
|
|
4411
4412
|
const qp = F({ localResponseNormalizationBackprop_: Hp });
|
|
4412
4413
|
/**
|
|
@@ -4426,7 +4427,7 @@ const qp = F({ localResponseNormalizationBackprop_: Hp });
|
|
|
4426
4427
|
* =============================================================================
|
|
4427
4428
|
*/
|
|
4428
4429
|
const Zp = {
|
|
4429
|
-
kernelName:
|
|
4430
|
+
kernelName: vl,
|
|
4430
4431
|
inputsToSave: ["x"],
|
|
4431
4432
|
outputsToSave: [!0],
|
|
4432
4433
|
gradFunc: (s, t, e) => {
|
|
@@ -4452,7 +4453,7 @@ const Zp = {
|
|
|
4452
4453
|
* limitations under the License.
|
|
4453
4454
|
* =============================================================================
|
|
4454
4455
|
*/
|
|
4455
|
-
function
|
|
4456
|
+
function mr(s, t, e, n) {
|
|
4456
4457
|
return t.rank < e.rank && (t = A(t, Ws(t.shape, n))), s.rank < e.rank && (s = A(s, Ws(s.shape, n))), {
|
|
4457
4458
|
x: () => w(s, L(Xt(e, t), s.dtype))
|
|
4458
4459
|
};
|
|
@@ -4474,11 +4475,11 @@ function dr(s, t, e, n) {
|
|
|
4474
4475
|
* =============================================================================
|
|
4475
4476
|
*/
|
|
4476
4477
|
const Js = {
|
|
4477
|
-
kernelName:
|
|
4478
|
+
kernelName: Sl,
|
|
4478
4479
|
inputsToSave: ["x"],
|
|
4479
4480
|
outputsToSave: [!0],
|
|
4480
4481
|
gradFunc: (s, t, e) => {
|
|
4481
|
-
const n = e, { reductionIndices: i } = n, r = t[0], a = t[1], o = he(i, r.shape), l =
|
|
4482
|
+
const n = e, { reductionIndices: i } = n, r = t[0], a = t[1], o = he(i, r.shape), l = mr(s, a, r, o);
|
|
4482
4483
|
return {
|
|
4483
4484
|
x: () => l.x()
|
|
4484
4485
|
};
|
|
@@ -4501,11 +4502,11 @@ const Js = {
|
|
|
4501
4502
|
* =============================================================================
|
|
4502
4503
|
*/
|
|
4503
4504
|
const Jp = {
|
|
4504
|
-
kernelName:
|
|
4505
|
+
kernelName: Al,
|
|
4505
4506
|
inputsToSave: ["a", "b"],
|
|
4506
4507
|
gradFunc: (s, t) => {
|
|
4507
4508
|
const [e, n] = t;
|
|
4508
|
-
return { a: () => w(s, L(Ue(e, n), "float32")), b: () => w(s, L(
|
|
4509
|
+
return { a: () => w(s, L(Ue(e, n), "float32")), b: () => w(s, L(Cu(e, n), "float32")) };
|
|
4509
4510
|
}
|
|
4510
4511
|
};
|
|
4511
4512
|
/**
|
|
@@ -4540,7 +4541,7 @@ function Xp(s, t, e, n, i, r, a) {
|
|
|
4540
4541
|
u.shape[2],
|
|
4541
4542
|
u.shape[3]
|
|
4542
4543
|
])), k(c.rank === 5, () => `Error in maxPool3dGrad: dy must be rank 5 but got rank ${c.rank}.`), k(h.rank === 5, () => `Error in maxPool3dGrad: input must be rank 5 but got rank ${h.rank}.`), k(p.rank === 5, () => `Error in maxPool3dGrad: output must be rank 5 but got rank ${p.rank}.`), ft("maxPool3dGrad", r, a);
|
|
4543
|
-
const g = { dy: c, input: h, output: p }, b = { filterSize: n, strides: i, pad: r, dimRoundingMode: a }, m = M.runKernel(
|
|
4544
|
+
const g = { dy: c, input: h, output: p }, b = { filterSize: n, strides: i, pad: r, dimRoundingMode: a }, m = M.runKernel(Cl, g, b);
|
|
4544
4545
|
return f ? A(m, [m.shape[1], m.shape[2], m.shape[3], m.shape[4]]) : m;
|
|
4545
4546
|
}
|
|
4546
4547
|
const Yp = /* @__PURE__ */ F({ maxPool3dGrad_: Xp });
|
|
@@ -4591,7 +4592,7 @@ function td(s, t, e, n, i, r, a) {
|
|
|
4591
4592
|
const o = D(s, "dy", "maxPoolGrad"), l = D(t, "input", "maxPoolGrad"), u = D(e, "output", "maxPoolGrad");
|
|
4592
4593
|
k(l.rank === o.rank, () => `Rank of input (${l.rank}) does not match rank of dy (${o.rank})`), k(o.rank === 4, () => `Error in maxPoolGrad: dy must be rank 4 but got rank ${o.rank}.`), k(l.rank === 4, () => `Error in maxPoolGrad: input must be rank 4 but got rank ${l.rank}.`), ft("maxPoolGrad", r, a);
|
|
4593
4594
|
const c = { dy: o, input: l, output: u }, h = { filterSize: n, strides: i, pad: r, dimRoundingMode: a };
|
|
4594
|
-
return M.runKernel(
|
|
4595
|
+
return M.runKernel(Il, c, h);
|
|
4595
4596
|
}
|
|
4596
4597
|
const ed = /* @__PURE__ */ F({ maxPoolGrad_: td });
|
|
4597
4598
|
/**
|
|
@@ -4638,10 +4639,10 @@ const nd = {
|
|
|
4638
4639
|
* =============================================================================
|
|
4639
4640
|
*/
|
|
4640
4641
|
const sd = {
|
|
4641
|
-
kernelName:
|
|
4642
|
+
kernelName: Dl,
|
|
4642
4643
|
inputsToSave: ["x"],
|
|
4643
4644
|
gradFunc: (s, t, e) => {
|
|
4644
|
-
const [n] = t, { axis: i } = e, r = he(i, n.shape), o =
|
|
4645
|
+
const [n] = t, { axis: i } = e, r = he(i, n.shape), o = Uu(n.shape, r)[1], l = Gi(o);
|
|
4645
4646
|
return { x: () => {
|
|
4646
4647
|
const c = n.shape.slice();
|
|
4647
4648
|
r.forEach((f) => {
|
|
@@ -4669,11 +4670,11 @@ const sd = {
|
|
|
4669
4670
|
* =============================================================================
|
|
4670
4671
|
*/
|
|
4671
4672
|
const id = {
|
|
4672
|
-
kernelName:
|
|
4673
|
+
kernelName: zl,
|
|
4673
4674
|
inputsToSave: ["x"],
|
|
4674
4675
|
outputsToSave: [!0],
|
|
4675
4676
|
gradFunc: (s, t, e) => {
|
|
4676
|
-
const n = e, { axis: i } = n, [r, a] = t, o = he(i, r.shape), l =
|
|
4677
|
+
const n = e, { axis: i } = n, [r, a] = t, o = he(i, r.shape), l = mr(s, a, r, o);
|
|
4677
4678
|
return {
|
|
4678
4679
|
x: () => l.x()
|
|
4679
4680
|
};
|
|
@@ -4696,7 +4697,7 @@ const id = {
|
|
|
4696
4697
|
* =============================================================================
|
|
4697
4698
|
*/
|
|
4698
4699
|
const rd = {
|
|
4699
|
-
kernelName:
|
|
4700
|
+
kernelName: Tl,
|
|
4700
4701
|
inputsToSave: ["a", "b"],
|
|
4701
4702
|
gradFunc: (s, t) => {
|
|
4702
4703
|
const [e, n] = t;
|
|
@@ -4720,7 +4721,7 @@ const rd = {
|
|
|
4720
4721
|
* =============================================================================
|
|
4721
4722
|
*/
|
|
4722
4723
|
const ad = {
|
|
4723
|
-
kernelName:
|
|
4724
|
+
kernelName: $l,
|
|
4724
4725
|
inputsToSave: ["x"],
|
|
4725
4726
|
gradFunc: (s, t, e) => {
|
|
4726
4727
|
const n = t[0], { paddings: i } = e, r = i.map((a) => a[0]);
|
|
@@ -4744,7 +4745,7 @@ const ad = {
|
|
|
4744
4745
|
* =============================================================================
|
|
4745
4746
|
*/
|
|
4746
4747
|
const od = {
|
|
4747
|
-
kernelName:
|
|
4748
|
+
kernelName: El,
|
|
4748
4749
|
inputsToSave: ["a", "b"],
|
|
4749
4750
|
gradFunc: (s, t) => {
|
|
4750
4751
|
const [e, n] = t, i = wt(e.shape, n.shape);
|
|
@@ -4752,7 +4753,7 @@ const od = {
|
|
|
4752
4753
|
const o = lt(e.shape, i);
|
|
4753
4754
|
return o.length > 0 ? A(B(s, o), e.shape) : s;
|
|
4754
4755
|
}, b: () => {
|
|
4755
|
-
const o = w(s, pt(
|
|
4756
|
+
const o = w(s, pt(Zi(P(e, n)))), l = lt(n.shape, i);
|
|
4756
4757
|
return l.length > 0 ? A(B(o, l), n.shape) : o;
|
|
4757
4758
|
} };
|
|
4758
4759
|
}
|
|
@@ -4774,7 +4775,7 @@ const od = {
|
|
|
4774
4775
|
* =============================================================================
|
|
4775
4776
|
*/
|
|
4776
4777
|
const ld = {
|
|
4777
|
-
kernelName:
|
|
4778
|
+
kernelName: Ll,
|
|
4778
4779
|
inputsToSave: ["a", "b"],
|
|
4779
4780
|
gradFunc: (s, t) => {
|
|
4780
4781
|
const [e, n] = t, i = wt(e.shape, n.shape);
|
|
@@ -4804,7 +4805,7 @@ const ld = {
|
|
|
4804
4805
|
* =============================================================================
|
|
4805
4806
|
*/
|
|
4806
4807
|
const ud = {
|
|
4807
|
-
kernelName:
|
|
4808
|
+
kernelName: Fl,
|
|
4808
4809
|
gradFunc: (s) => ({ x: () => pt(s) })
|
|
4809
4810
|
};
|
|
4810
4811
|
/**
|
|
@@ -4868,7 +4869,7 @@ const hd = {
|
|
|
4868
4869
|
* =============================================================================
|
|
4869
4870
|
*/
|
|
4870
4871
|
const pd = {
|
|
4871
|
-
kernelName:
|
|
4872
|
+
kernelName: Ml,
|
|
4872
4873
|
saveAllInputs: !0,
|
|
4873
4874
|
gradFunc: (s, t, e) => {
|
|
4874
4875
|
const { axis: n } = e;
|
|
@@ -4916,7 +4917,7 @@ const Xs = {
|
|
|
4916
4917
|
* =============================================================================
|
|
4917
4918
|
*/
|
|
4918
4919
|
const dd = {
|
|
4919
|
-
kernelName:
|
|
4920
|
+
kernelName: Ol,
|
|
4920
4921
|
inputsToSave: ["a", "b"],
|
|
4921
4922
|
outputsToSave: [!0],
|
|
4922
4923
|
gradFunc: (s, t) => {
|
|
@@ -4951,7 +4952,7 @@ const dd = {
|
|
|
4951
4952
|
* =============================================================================
|
|
4952
4953
|
*/
|
|
4953
4954
|
const fd = {
|
|
4954
|
-
kernelName:
|
|
4955
|
+
kernelName: Rl,
|
|
4955
4956
|
inputsToSave: ["x", "alpha"],
|
|
4956
4957
|
gradFunc: (s, t) => {
|
|
4957
4958
|
const [e, n] = t, i = Gt(e, 0);
|
|
@@ -4988,7 +4989,7 @@ function md(s, t, e) {
|
|
|
4988
4989
|
return w(i, o);
|
|
4989
4990
|
}
|
|
4990
4991
|
function gd(s, t, e) {
|
|
4991
|
-
const n = s.shape.length, i = n - e.length, r =
|
|
4992
|
+
const n = s.shape.length, i = n - e.length, r = Yi(e, n);
|
|
4992
4993
|
let a = s;
|
|
4993
4994
|
r != null && (a = j(s, r));
|
|
4994
4995
|
const o = a.shape.slice(), u = o.splice(n - e.length, e.length).reduce((p, f) => p * f, 1);
|
|
@@ -5002,7 +5003,7 @@ function gd(s, t, e) {
|
|
|
5002
5003
|
return h;
|
|
5003
5004
|
}
|
|
5004
5005
|
const bd = {
|
|
5005
|
-
kernelName:
|
|
5006
|
+
kernelName: _l,
|
|
5006
5007
|
inputsToSave: ["x"],
|
|
5007
5008
|
gradFunc: (s, t, e) => {
|
|
5008
5009
|
const [n] = t, { axis: i } = e;
|
|
@@ -5027,7 +5028,7 @@ const bd = {
|
|
|
5027
5028
|
* =============================================================================
|
|
5028
5029
|
*/
|
|
5029
5030
|
const yd = {
|
|
5030
|
-
kernelName:
|
|
5031
|
+
kernelName: Bl,
|
|
5031
5032
|
inputsToSave: ["a", "b"],
|
|
5032
5033
|
gradFunc: (s, t) => {
|
|
5033
5034
|
const [e, n] = t, i = wt(e.shape, n.shape);
|
|
@@ -5060,7 +5061,7 @@ const yd = {
|
|
|
5060
5061
|
* =============================================================================
|
|
5061
5062
|
*/
|
|
5062
5063
|
const wd = {
|
|
5063
|
-
kernelName:
|
|
5064
|
+
kernelName: Wl,
|
|
5064
5065
|
inputsToSave: ["x"],
|
|
5065
5066
|
gradFunc: (s, t) => {
|
|
5066
5067
|
const [e] = t;
|
|
@@ -5084,7 +5085,7 @@ const wd = {
|
|
|
5084
5085
|
* =============================================================================
|
|
5085
5086
|
*/
|
|
5086
5087
|
const kd = {
|
|
5087
|
-
kernelName:
|
|
5088
|
+
kernelName: Gl,
|
|
5088
5089
|
inputsToSave: ["x"],
|
|
5089
5090
|
gradFunc: (s, t) => {
|
|
5090
5091
|
const [e] = t, n = w(Yn(e, 6), Xn(e));
|
|
@@ -5108,7 +5109,7 @@ const kd = {
|
|
|
5108
5109
|
* =============================================================================
|
|
5109
5110
|
*/
|
|
5110
5111
|
const xd = {
|
|
5111
|
-
kernelName:
|
|
5112
|
+
kernelName: Pl,
|
|
5112
5113
|
inputsToSave: ["x"],
|
|
5113
5114
|
gradFunc: (s, t) => {
|
|
5114
5115
|
const [e] = t;
|
|
@@ -5132,7 +5133,7 @@ const xd = {
|
|
|
5132
5133
|
* =============================================================================
|
|
5133
5134
|
*/
|
|
5134
5135
|
const Nd = {
|
|
5135
|
-
kernelName:
|
|
5136
|
+
kernelName: Ul,
|
|
5136
5137
|
inputsToSave: ["x"],
|
|
5137
5138
|
gradFunc: (s, t) => {
|
|
5138
5139
|
const [e] = t;
|
|
@@ -5156,13 +5157,13 @@ const Nd = {
|
|
|
5156
5157
|
* =============================================================================
|
|
5157
5158
|
*/
|
|
5158
5159
|
const vd = {
|
|
5159
|
-
kernelName:
|
|
5160
|
+
kernelName: Vl,
|
|
5160
5161
|
inputsToSave: ["images"],
|
|
5161
5162
|
gradFunc: (s, t, e) => {
|
|
5162
5163
|
const [n] = t, i = { dy: s, images: n };
|
|
5163
5164
|
return { images: () => (
|
|
5164
5165
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
5165
|
-
M.runKernel(
|
|
5166
|
+
M.runKernel(jl, i, e)
|
|
5166
5167
|
) };
|
|
5167
5168
|
}
|
|
5168
5169
|
};
|
|
@@ -5183,13 +5184,13 @@ const vd = {
|
|
|
5183
5184
|
* =============================================================================
|
|
5184
5185
|
*/
|
|
5185
5186
|
const Sd = {
|
|
5186
|
-
kernelName:
|
|
5187
|
+
kernelName: Kl,
|
|
5187
5188
|
inputsToSave: ["images"],
|
|
5188
5189
|
gradFunc: (s, t, e) => {
|
|
5189
5190
|
const [n] = t, i = { dy: s, images: n };
|
|
5190
5191
|
return { images: () => (
|
|
5191
5192
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
5192
|
-
M.runKernel(
|
|
5193
|
+
M.runKernel(Hl, i, e)
|
|
5193
5194
|
) };
|
|
5194
5195
|
}
|
|
5195
5196
|
};
|
|
@@ -5233,7 +5234,7 @@ const Ad = {
|
|
|
5233
5234
|
* =============================================================================
|
|
5234
5235
|
*/
|
|
5235
5236
|
const Cd = {
|
|
5236
|
-
kernelName:
|
|
5237
|
+
kernelName: ql,
|
|
5237
5238
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
5238
5239
|
};
|
|
5239
5240
|
/**
|
|
@@ -5277,7 +5278,7 @@ const Id = {
|
|
|
5277
5278
|
* =============================================================================
|
|
5278
5279
|
*/
|
|
5279
5280
|
const Dd = {
|
|
5280
|
-
kernelName:
|
|
5281
|
+
kernelName: Zl,
|
|
5281
5282
|
inputsToSave: ["condition"],
|
|
5282
5283
|
gradFunc: (s, t) => {
|
|
5283
5284
|
const [e] = t;
|
|
@@ -5336,7 +5337,7 @@ const zd = {
|
|
|
5336
5337
|
* =============================================================================
|
|
5337
5338
|
*/
|
|
5338
5339
|
const Td = {
|
|
5339
|
-
kernelName:
|
|
5340
|
+
kernelName: Jl,
|
|
5340
5341
|
outputsToSave: [!0],
|
|
5341
5342
|
gradFunc: (s, t) => {
|
|
5342
5343
|
const [e] = t;
|
|
@@ -5360,7 +5361,7 @@ const Td = {
|
|
|
5360
5361
|
* =============================================================================
|
|
5361
5362
|
*/
|
|
5362
5363
|
const $d = {
|
|
5363
|
-
kernelName:
|
|
5364
|
+
kernelName: Xl,
|
|
5364
5365
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
5365
5366
|
};
|
|
5366
5367
|
/**
|
|
@@ -5380,11 +5381,11 @@ const $d = {
|
|
|
5380
5381
|
* =============================================================================
|
|
5381
5382
|
*/
|
|
5382
5383
|
const Ed = {
|
|
5383
|
-
kernelName:
|
|
5384
|
+
kernelName: Yl,
|
|
5384
5385
|
inputsToSave: ["x"],
|
|
5385
5386
|
gradFunc: (s, t) => {
|
|
5386
5387
|
const [e] = t;
|
|
5387
|
-
return { x: () => w(
|
|
5388
|
+
return { x: () => w(Xi(L(e, "float32")), s) };
|
|
5388
5389
|
}
|
|
5389
5390
|
};
|
|
5390
5391
|
/**
|
|
@@ -5428,13 +5429,13 @@ const Ld = {
|
|
|
5428
5429
|
* =============================================================================
|
|
5429
5430
|
*/
|
|
5430
5431
|
const Fd = {
|
|
5431
|
-
kernelName:
|
|
5432
|
+
kernelName: Ql,
|
|
5432
5433
|
inputsToSave: ["x"],
|
|
5433
5434
|
gradFunc: (s, t, e) => {
|
|
5434
|
-
const [n] = t, { begin: i, size: r } = e, a = n.shape, [o, l] =
|
|
5435
|
+
const [n] = t, { begin: i, size: r } = e, a = n.shape, [o, l] = ju(n, i, r), u = [];
|
|
5435
5436
|
for (let c = 0; c < s.rank; c++)
|
|
5436
5437
|
u.push([o[c], a[c] - o[c] - l[c]]);
|
|
5437
|
-
return { x: () =>
|
|
5438
|
+
return { x: () => sr(s, u) };
|
|
5438
5439
|
}
|
|
5439
5440
|
};
|
|
5440
5441
|
/**
|
|
@@ -5454,7 +5455,7 @@ const Fd = {
|
|
|
5454
5455
|
* =============================================================================
|
|
5455
5456
|
*/
|
|
5456
5457
|
const Md = {
|
|
5457
|
-
kernelName:
|
|
5458
|
+
kernelName: tu,
|
|
5458
5459
|
outputsToSave: [!0],
|
|
5459
5460
|
gradFunc: (s, t, e) => {
|
|
5460
5461
|
const [n] = t, { dim: i } = e, r = !0, a = w(s, n);
|
|
@@ -5527,7 +5528,7 @@ const Ys = {
|
|
|
5527
5528
|
* =============================================================================
|
|
5528
5529
|
*/
|
|
5529
5530
|
const Qs = {
|
|
5530
|
-
kernelName:
|
|
5531
|
+
kernelName: eu,
|
|
5531
5532
|
gradFunc: (s, t, e) => {
|
|
5532
5533
|
const { axis: n } = e;
|
|
5533
5534
|
return { x: () => is(s, n) };
|
|
@@ -5550,7 +5551,7 @@ const Qs = {
|
|
|
5550
5551
|
* =============================================================================
|
|
5551
5552
|
*/
|
|
5552
5553
|
const Rd = {
|
|
5553
|
-
kernelName:
|
|
5554
|
+
kernelName: nu,
|
|
5554
5555
|
inputsToSave: ["x"],
|
|
5555
5556
|
gradFunc: (s, t) => {
|
|
5556
5557
|
const [e] = t;
|
|
@@ -5574,7 +5575,7 @@ const Rd = {
|
|
|
5574
5575
|
* =============================================================================
|
|
5575
5576
|
*/
|
|
5576
5577
|
const _d = {
|
|
5577
|
-
kernelName:
|
|
5578
|
+
kernelName: su,
|
|
5578
5579
|
inputsToSave: ["x"],
|
|
5579
5580
|
gradFunc: (s, t) => {
|
|
5580
5581
|
const [e] = t;
|
|
@@ -5598,7 +5599,7 @@ const _d = {
|
|
|
5598
5599
|
* =============================================================================
|
|
5599
5600
|
*/
|
|
5600
5601
|
const Bd = {
|
|
5601
|
-
kernelName:
|
|
5602
|
+
kernelName: iu,
|
|
5602
5603
|
inputsToSave: ["a", "b"],
|
|
5603
5604
|
gradFunc: (s, t) => {
|
|
5604
5605
|
const [e, n] = t, i = tt(2);
|
|
@@ -5622,7 +5623,7 @@ const Bd = {
|
|
|
5622
5623
|
* =============================================================================
|
|
5623
5624
|
*/
|
|
5624
5625
|
const Wd = {
|
|
5625
|
-
kernelName:
|
|
5626
|
+
kernelName: ru,
|
|
5626
5627
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
5627
5628
|
};
|
|
5628
5629
|
/**
|
|
@@ -5642,7 +5643,7 @@ const Wd = {
|
|
|
5642
5643
|
* =============================================================================
|
|
5643
5644
|
*/
|
|
5644
5645
|
const Gd = {
|
|
5645
|
-
kernelName:
|
|
5646
|
+
kernelName: au,
|
|
5646
5647
|
inputsToSave: ["a", "b"],
|
|
5647
5648
|
gradFunc: (s, t) => {
|
|
5648
5649
|
const [e, n] = t, i = wt(e.shape, n.shape);
|
|
@@ -5674,7 +5675,7 @@ const Gd = {
|
|
|
5674
5675
|
* =============================================================================
|
|
5675
5676
|
*/
|
|
5676
5677
|
const Pd = {
|
|
5677
|
-
kernelName:
|
|
5678
|
+
kernelName: ou,
|
|
5678
5679
|
inputsToSave: ["x"],
|
|
5679
5680
|
gradFunc: (s, t, e) => {
|
|
5680
5681
|
const [n] = t, i = n.shape.slice(), { axis: r } = e;
|
|
@@ -5702,11 +5703,11 @@ const Pd = {
|
|
|
5702
5703
|
* =============================================================================
|
|
5703
5704
|
*/
|
|
5704
5705
|
const Ud = {
|
|
5705
|
-
kernelName:
|
|
5706
|
+
kernelName: lu,
|
|
5706
5707
|
inputsToSave: ["x"],
|
|
5707
5708
|
gradFunc: (s, t) => {
|
|
5708
5709
|
const [e] = t;
|
|
5709
|
-
return { x: () => P(s, ct(
|
|
5710
|
+
return { x: () => P(s, ct(Xi(e))) };
|
|
5710
5711
|
}
|
|
5711
5712
|
};
|
|
5712
5713
|
/**
|
|
@@ -5750,7 +5751,7 @@ const Vd = {
|
|
|
5750
5751
|
* =============================================================================
|
|
5751
5752
|
*/
|
|
5752
5753
|
const jd = {
|
|
5753
|
-
kernelName:
|
|
5754
|
+
kernelName: uu,
|
|
5754
5755
|
inputsToSave: ["x"],
|
|
5755
5756
|
gradFunc: (s, t, e) => {
|
|
5756
5757
|
const [n] = t, { reps: i } = e;
|
|
@@ -5805,7 +5806,7 @@ const jd = {
|
|
|
5805
5806
|
* =============================================================================
|
|
5806
5807
|
*/
|
|
5807
5808
|
const Kd = {
|
|
5808
|
-
kernelName:
|
|
5809
|
+
kernelName: cu,
|
|
5809
5810
|
gradFunc: (s, t, e) => {
|
|
5810
5811
|
const n = e, { perm: i } = n, r = ss(i);
|
|
5811
5812
|
return { x: () => j(s, r) };
|
|
@@ -5828,7 +5829,7 @@ const Kd = {
|
|
|
5828
5829
|
* =============================================================================
|
|
5829
5830
|
*/
|
|
5830
5831
|
const Hd = {
|
|
5831
|
-
kernelName:
|
|
5832
|
+
kernelName: hu,
|
|
5832
5833
|
gradFunc: (s, t, e) => {
|
|
5833
5834
|
const n = e, { axis: i } = n;
|
|
5834
5835
|
return { value: () => kn(s, i) };
|
|
@@ -5859,7 +5860,7 @@ const qd = {
|
|
|
5859
5860
|
}
|
|
5860
5861
|
};
|
|
5861
5862
|
function Zd(s, t) {
|
|
5862
|
-
const e = Ie(t, Q(t)), n =
|
|
5863
|
+
const e = Ie(t, Q(t)), n = Qi(s, e);
|
|
5863
5864
|
let i = Ue(t, tt(0, "int32"));
|
|
5864
5865
|
const r = n.rank - i.rank;
|
|
5865
5866
|
for (let o = 0; o < r; ++o)
|
|
@@ -5885,7 +5886,7 @@ function Zd(s, t) {
|
|
|
5885
5886
|
* =============================================================================
|
|
5886
5887
|
*/
|
|
5887
5888
|
const Jd = {
|
|
5888
|
-
kernelName:
|
|
5889
|
+
kernelName: pu,
|
|
5889
5890
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
5890
5891
|
};
|
|
5891
5892
|
/**
|
|
@@ -5905,7 +5906,7 @@ const Jd = {
|
|
|
5905
5906
|
* =============================================================================
|
|
5906
5907
|
*/
|
|
5907
5908
|
const Xd = [
|
|
5908
|
-
|
|
5909
|
+
fr,
|
|
5909
5910
|
qh,
|
|
5910
5911
|
Zh,
|
|
5911
5912
|
Jh,
|
|
@@ -6012,7 +6013,7 @@ const Xd = [
|
|
|
6012
6013
|
Jd
|
|
6013
6014
|
];
|
|
6014
6015
|
for (const s of Xd)
|
|
6015
|
-
|
|
6016
|
+
du(s);
|
|
6016
6017
|
/**
|
|
6017
6018
|
* @license
|
|
6018
6019
|
* Copyright 2018 Google LLC
|
|
@@ -6030,13 +6031,13 @@ class qe extends Be {
|
|
|
6030
6031
|
return {};
|
|
6031
6032
|
}
|
|
6032
6033
|
}
|
|
6033
|
-
class
|
|
6034
|
+
class gr extends qe {
|
|
6034
6035
|
constructor(t) {
|
|
6035
6036
|
super(), this.defaultMaxValue = 2, this.defaultAxis = 0, this.maxValue = t.maxValue != null ? t.maxValue : this.defaultMaxValue, this.axis = t.axis != null ? t.axis : this.defaultAxis;
|
|
6036
6037
|
}
|
|
6037
6038
|
apply(t) {
|
|
6038
6039
|
return x(() => {
|
|
6039
|
-
const e = ks(t, this.axis), n =
|
|
6040
|
+
const e = ks(t, this.axis), n = Ct(e, 0, this.maxValue);
|
|
6040
6041
|
return w(t, P(n, $(nt(), e)));
|
|
6041
6042
|
});
|
|
6042
6043
|
}
|
|
@@ -6044,9 +6045,9 @@ class fr extends qe {
|
|
|
6044
6045
|
return { maxValue: this.maxValue, axis: this.axis };
|
|
6045
6046
|
}
|
|
6046
6047
|
}
|
|
6047
|
-
|
|
6048
|
-
S(
|
|
6049
|
-
class
|
|
6048
|
+
gr.className = "MaxNorm";
|
|
6049
|
+
S(gr);
|
|
6050
|
+
class br extends qe {
|
|
6050
6051
|
constructor(t) {
|
|
6051
6052
|
super(), this.defaultAxis = 0, this.axis = t.axis != null ? t.axis : this.defaultAxis;
|
|
6052
6053
|
}
|
|
@@ -6057,22 +6058,22 @@ class mr extends qe {
|
|
|
6057
6058
|
return { axis: this.axis };
|
|
6058
6059
|
}
|
|
6059
6060
|
}
|
|
6060
|
-
|
|
6061
|
-
S(
|
|
6062
|
-
class
|
|
6061
|
+
br.className = "UnitNorm";
|
|
6062
|
+
S(br);
|
|
6063
|
+
class yr extends qe {
|
|
6063
6064
|
apply(t) {
|
|
6064
6065
|
return Ve(t);
|
|
6065
6066
|
}
|
|
6066
6067
|
}
|
|
6067
|
-
|
|
6068
|
-
S(
|
|
6069
|
-
class
|
|
6068
|
+
yr.className = "NonNeg";
|
|
6069
|
+
S(yr);
|
|
6070
|
+
class wr extends qe {
|
|
6070
6071
|
constructor(t) {
|
|
6071
6072
|
super(), this.defaultMinValue = 0, this.defaultMaxValue = 1, this.defaultRate = 1, this.defaultAxis = 0, this.minValue = t.minValue != null ? t.minValue : this.defaultMinValue, this.maxValue = t.maxValue != null ? t.maxValue : this.defaultMaxValue, this.rate = t.rate != null ? t.rate : this.defaultRate, this.axis = t.axis != null ? t.axis : this.defaultAxis;
|
|
6072
6073
|
}
|
|
6073
6074
|
apply(t) {
|
|
6074
6075
|
return x(() => {
|
|
6075
|
-
const e = ks(t, this.axis), n = $(w(this.rate,
|
|
6076
|
+
const e = ks(t, this.axis), n = $(w(this.rate, Ct(e, this.minValue, this.maxValue)), w(1 - this.rate, e));
|
|
6076
6077
|
return w(t, P(n, $(nt(), e)));
|
|
6077
6078
|
});
|
|
6078
6079
|
}
|
|
@@ -6085,8 +6086,8 @@ class br extends qe {
|
|
|
6085
6086
|
};
|
|
6086
6087
|
}
|
|
6087
6088
|
}
|
|
6088
|
-
|
|
6089
|
-
S(
|
|
6089
|
+
wr.className = "MinMaxNorm";
|
|
6090
|
+
S(wr);
|
|
6090
6091
|
const ti = {
|
|
6091
6092
|
maxNorm: "MaxNorm",
|
|
6092
6093
|
minMaxNorm: "MinMaxNorm",
|
|
@@ -6116,7 +6117,7 @@ function rt(s) {
|
|
|
6116
6117
|
* https://opensource.org/licenses/MIT.
|
|
6117
6118
|
* =============================================================================
|
|
6118
6119
|
*/
|
|
6119
|
-
function
|
|
6120
|
+
function Yd(s) {
|
|
6120
6121
|
return new ps(s);
|
|
6121
6122
|
}
|
|
6122
6123
|
/**
|
|
@@ -6146,7 +6147,7 @@ async function ae(s) {
|
|
|
6146
6147
|
Z(n);
|
|
6147
6148
|
}
|
|
6148
6149
|
}
|
|
6149
|
-
function
|
|
6150
|
+
function kr(s) {
|
|
6150
6151
|
if (s != null)
|
|
6151
6152
|
for (const t in s) {
|
|
6152
6153
|
const e = s[t];
|
|
@@ -6166,7 +6167,7 @@ var ni;
|
|
|
6166
6167
|
(function(s) {
|
|
6167
6168
|
s[s.SILENT = 0] = "SILENT", s[s.VERBOSE = 1] = "VERBOSE";
|
|
6168
6169
|
})(ni || (ni = {}));
|
|
6169
|
-
const
|
|
6170
|
+
const Qd = 125;
|
|
6170
6171
|
class Me {
|
|
6171
6172
|
constructor() {
|
|
6172
6173
|
this.validationData = null;
|
|
@@ -6196,7 +6197,7 @@ class Me {
|
|
|
6196
6197
|
setModel(t) {
|
|
6197
6198
|
}
|
|
6198
6199
|
}
|
|
6199
|
-
class
|
|
6200
|
+
class tf {
|
|
6200
6201
|
// TODO(cais): When the need arises, uncomment the following lines and
|
|
6201
6202
|
// implement the queue for time values.
|
|
6202
6203
|
// private deltaTBatch: number;
|
|
@@ -6281,7 +6282,7 @@ class Qd {
|
|
|
6281
6282
|
await e.onTrainEnd(t);
|
|
6282
6283
|
}
|
|
6283
6284
|
}
|
|
6284
|
-
class
|
|
6285
|
+
class ef extends Me {
|
|
6285
6286
|
constructor() {
|
|
6286
6287
|
super();
|
|
6287
6288
|
}
|
|
@@ -6313,7 +6314,7 @@ class tf extends Me {
|
|
|
6313
6314
|
}));
|
|
6314
6315
|
}
|
|
6315
6316
|
}
|
|
6316
|
-
class
|
|
6317
|
+
class nf extends Me {
|
|
6317
6318
|
async onTrainBegin(t) {
|
|
6318
6319
|
this.epoch = [], this.history = {};
|
|
6319
6320
|
}
|
|
@@ -6340,11 +6341,11 @@ class ef extends Me {
|
|
|
6340
6341
|
this.history[e[r]][n[r]].dispose(), this.history[e[r]][n[r]] = i[r][0];
|
|
6341
6342
|
}
|
|
6342
6343
|
}
|
|
6343
|
-
class
|
|
6344
|
+
class sf extends Me {
|
|
6344
6345
|
constructor(t, e) {
|
|
6345
|
-
if (super(), this.currentEpoch = 0, this.nowFunc = t.nowFunc, this.nextFrameFunc = t.nextFrameFunc || Sh, this.yieldEvery = e || "auto", this.yieldEvery === "auto" && (this.yieldEvery =
|
|
6346
|
+
if (super(), this.currentEpoch = 0, this.nowFunc = t.nowFunc, this.nextFrameFunc = t.nextFrameFunc || Sh, this.yieldEvery = e || "auto", this.yieldEvery === "auto" && (this.yieldEvery = Qd), this.yieldEvery === "never" && t.onYield != null)
|
|
6346
6347
|
throw new Error("yieldEvery is `never` but you provided an `onYield` callback. Either change `yieldEvery` or remove the callback");
|
|
6347
|
-
Fs(this.yieldEvery) && (this.maybeWait =
|
|
6348
|
+
Fs(this.yieldEvery) && (this.maybeWait = Iu(this.maybeWait.bind(this), this.yieldEvery, this.nowFunc)), this.trainBegin = t.onTrainBegin, this.trainEnd = t.onTrainEnd, this.epochBegin = t.onEpochBegin, this.epochEnd = t.onEpochEnd, this.batchBegin = t.onBatchBegin, this.batchEnd = t.onBatchEnd, this.yield = t.onYield;
|
|
6348
6349
|
}
|
|
6349
6350
|
async maybeWait(t, e, n) {
|
|
6350
6351
|
const i = [];
|
|
@@ -6371,8 +6372,8 @@ class nf extends Me {
|
|
|
6371
6372
|
this.trainEnd != null && (await ae(t), await this.trainEnd(t));
|
|
6372
6373
|
}
|
|
6373
6374
|
}
|
|
6374
|
-
function
|
|
6375
|
-
return s == null && (s = {}), s instanceof Me ? [s] : Array.isArray(s) && s[0] instanceof Me ? s : K(s).map((n) => new
|
|
6375
|
+
function xr(s, t) {
|
|
6376
|
+
return s == null && (s = {}), s instanceof Me ? [s] : Array.isArray(s) && s[0] instanceof Me ? s : K(s).map((n) => new sf(n, t));
|
|
6376
6377
|
}
|
|
6377
6378
|
class yt {
|
|
6378
6379
|
/**
|
|
@@ -6426,13 +6427,13 @@ class yt {
|
|
|
6426
6427
|
}
|
|
6427
6428
|
}
|
|
6428
6429
|
yt.constructors = {};
|
|
6429
|
-
function
|
|
6430
|
-
const u = new
|
|
6431
|
-
new
|
|
6430
|
+
function Nr(s, t, e, n, i, r, a, o, l) {
|
|
6431
|
+
const u = new nf(), c = [
|
|
6432
|
+
new ef(),
|
|
6432
6433
|
...yt.createCallbacks(t)
|
|
6433
6434
|
];
|
|
6434
6435
|
s != null && c.push(...s), c.push(u);
|
|
6435
|
-
const h = new
|
|
6436
|
+
const h = new tf(c);
|
|
6436
6437
|
return h.setParams({
|
|
6437
6438
|
epochs: e,
|
|
6438
6439
|
initialEpoch: n,
|
|
@@ -6468,7 +6469,7 @@ function Wt(s, t = {}, e = !1) {
|
|
|
6468
6469
|
function pn(s, t) {
|
|
6469
6470
|
return x(() => {
|
|
6470
6471
|
s.dtype !== "float32" && (s = L(s, "float32"));
|
|
6471
|
-
const e = B(je(s), t, !0), n =
|
|
6472
|
+
const e = B(je(s), t, !0), n = fu(e.shape, nt()), i = ee(Ie(e, n));
|
|
6472
6473
|
return P(s, i);
|
|
6473
6474
|
});
|
|
6474
6475
|
}
|
|
@@ -6480,35 +6481,35 @@ function xs(s, t) {
|
|
|
6480
6481
|
}
|
|
6481
6482
|
function Ns(s, t) {
|
|
6482
6483
|
return x(() => {
|
|
6483
|
-
const e = V(s, t), n =
|
|
6484
|
+
const e = V(s, t), n = Ct(Fe(s), nt(), Number.MAX_VALUE), i = Fe(P(e, n));
|
|
6484
6485
|
return w(100, at(i, -1));
|
|
6485
6486
|
});
|
|
6486
6487
|
}
|
|
6487
|
-
function
|
|
6488
|
+
function rf(s, t) {
|
|
6488
6489
|
return x(() => {
|
|
6489
|
-
const e =
|
|
6490
|
+
const e = Ct(t, nt(), Number.MAX_VALUE), n = Zt($(1, e)), i = Ct(s, nt(), Number.MAX_VALUE), r = Zt($(1, i));
|
|
6490
6491
|
return at(je(V(n, r)), -1);
|
|
6491
6492
|
});
|
|
6492
6493
|
}
|
|
6493
|
-
function
|
|
6494
|
+
function af(s, t) {
|
|
6494
6495
|
return x(() => {
|
|
6495
6496
|
const e = Ie(0, V(1, w(s, t)));
|
|
6496
6497
|
return at(je(e), -1);
|
|
6497
6498
|
});
|
|
6498
6499
|
}
|
|
6499
|
-
function
|
|
6500
|
+
function of(s, t) {
|
|
6500
6501
|
return x(() => {
|
|
6501
6502
|
const e = Ie(0, V(1, w(s, t)));
|
|
6502
6503
|
return at(e, -1);
|
|
6503
6504
|
});
|
|
6504
6505
|
}
|
|
6505
|
-
function
|
|
6506
|
+
function lf(s, t) {
|
|
6506
6507
|
return x(() => {
|
|
6507
6508
|
const e = B(w(s, t), -1), n = ve(w(V(1, s), t), -1);
|
|
6508
6509
|
return Ie(0, $(1, V(n, e)));
|
|
6509
6510
|
});
|
|
6510
6511
|
}
|
|
6511
|
-
function
|
|
6512
|
+
function uf(s, t) {
|
|
6512
6513
|
return x(() => {
|
|
6513
6514
|
const e = Math.log(2), n = V(t, s), i = V($(n, us(w(-2, n))), e);
|
|
6514
6515
|
return at(i, -1);
|
|
@@ -6517,23 +6518,23 @@ function lf(s, t) {
|
|
|
6517
6518
|
function Oe(s, t, e = !1) {
|
|
6518
6519
|
return x(() => {
|
|
6519
6520
|
if (e)
|
|
6520
|
-
t =
|
|
6521
|
+
t = tr(t);
|
|
6521
6522
|
else {
|
|
6522
6523
|
const n = B(t, t.shape.length - 1, !0);
|
|
6523
6524
|
t = P(t, n);
|
|
6524
6525
|
}
|
|
6525
|
-
return t =
|
|
6526
|
+
return t = Ct(t, nt(), 1 - nt()), pt(B(w(L(s, "float32"), Zt(t)), t.shape.length - 1));
|
|
6526
6527
|
});
|
|
6527
6528
|
}
|
|
6528
6529
|
function dn(s, t, e = !1) {
|
|
6529
6530
|
return x(() => {
|
|
6530
|
-
const n = L(
|
|
6531
|
-
t =
|
|
6531
|
+
const n = L(Zi(Du(s)), "int32");
|
|
6532
|
+
t = Ct(t, nt(), 1 - nt());
|
|
6532
6533
|
const i = t.shape, r = A(Yc(n, i[i.length - 1]), i);
|
|
6533
6534
|
return Oe(r, t, e);
|
|
6534
6535
|
});
|
|
6535
6536
|
}
|
|
6536
|
-
function
|
|
6537
|
+
function cf(s, t) {
|
|
6537
6538
|
if (!Ft(s.shape, t.shape))
|
|
6538
6539
|
throw new d(`logits and labels must have the same shape, but got shapes ${JSON.stringify(s.shape)} and ${JSON.stringify(t.shape)}`);
|
|
6539
6540
|
return x(() => {
|
|
@@ -6544,22 +6545,22 @@ function uf(s, t) {
|
|
|
6544
6545
|
function Sn(s, t) {
|
|
6545
6546
|
return x(() => {
|
|
6546
6547
|
let e;
|
|
6547
|
-
return e =
|
|
6548
|
+
return e = Ct(t, nt(), 1 - nt()), e = Zt(P(e, V(1, e))), at(cf(s, e), -1);
|
|
6548
6549
|
});
|
|
6549
6550
|
}
|
|
6550
|
-
function
|
|
6551
|
+
function hf(s, t) {
|
|
6551
6552
|
return x(() => {
|
|
6552
|
-
const e =
|
|
6553
|
+
const e = Ct(s, nt(), 1), n = Ct(t, nt(), 1);
|
|
6553
6554
|
return B(w(s, Zt(P(e, n))), -1);
|
|
6554
6555
|
});
|
|
6555
6556
|
}
|
|
6556
|
-
function
|
|
6557
|
+
function pf(s, t) {
|
|
6557
6558
|
return x(() => {
|
|
6558
6559
|
const e = Zt($(nt(), t));
|
|
6559
6560
|
return at(V(t, w(s, e)), -1);
|
|
6560
6561
|
});
|
|
6561
6562
|
}
|
|
6562
|
-
function
|
|
6563
|
+
function vr(s, t) {
|
|
6563
6564
|
return x(() => {
|
|
6564
6565
|
const e = pn(s, -1), n = pn(t, -1), i = w(e, n);
|
|
6565
6566
|
return pt(B(i, -1));
|
|
@@ -6569,17 +6570,17 @@ const fn = {
|
|
|
6569
6570
|
meanSquaredError: vn,
|
|
6570
6571
|
meanAbsoluteError: xs,
|
|
6571
6572
|
meanAbsolutePercentageError: Ns,
|
|
6572
|
-
meanSquaredLogarithmicError:
|
|
6573
|
-
squaredHinge:
|
|
6574
|
-
hinge:
|
|
6575
|
-
categoricalHinge:
|
|
6576
|
-
logcosh:
|
|
6573
|
+
meanSquaredLogarithmicError: rf,
|
|
6574
|
+
squaredHinge: af,
|
|
6575
|
+
hinge: of,
|
|
6576
|
+
categoricalHinge: lf,
|
|
6577
|
+
logcosh: uf,
|
|
6577
6578
|
categoricalCrossentropy: Oe,
|
|
6578
6579
|
sparseCategoricalCrossentropy: dn,
|
|
6579
6580
|
binaryCrossentropy: Sn,
|
|
6580
|
-
kullbackLeiblerDivergence:
|
|
6581
|
-
poisson:
|
|
6582
|
-
cosineProximity:
|
|
6581
|
+
kullbackLeiblerDivergence: hf,
|
|
6582
|
+
poisson: pf,
|
|
6583
|
+
cosineProximity: vr
|
|
6583
6584
|
};
|
|
6584
6585
|
function $n(s) {
|
|
6585
6586
|
if (typeof s == "string") {
|
|
@@ -6599,48 +6600,48 @@ function $n(s) {
|
|
|
6599
6600
|
* https://opensource.org/licenses/MIT.
|
|
6600
6601
|
* =============================================================================
|
|
6601
6602
|
*/
|
|
6602
|
-
function
|
|
6603
|
+
function Sr(s, t) {
|
|
6603
6604
|
return x(() => {
|
|
6604
|
-
const e = w(0.5,
|
|
6605
|
+
const e = w(0.5, Dt(t)), n = Lt(Gt(t, e), s.dtype);
|
|
6605
6606
|
return at(Xt(s, n), -1);
|
|
6606
6607
|
});
|
|
6607
6608
|
}
|
|
6608
|
-
function
|
|
6609
|
-
return x(() =>
|
|
6609
|
+
function Ar(s, t) {
|
|
6610
|
+
return x(() => Lt(Xt(sn(s, -1), sn(t, -1)), "float32"));
|
|
6610
6611
|
}
|
|
6611
|
-
function
|
|
6612
|
+
function df(s, t) {
|
|
6612
6613
|
return x(() => L(B(Pe(Xt(s, 1), Xt(t, 1))), "float32"));
|
|
6613
6614
|
}
|
|
6614
|
-
function
|
|
6615
|
+
function ff(s, t) {
|
|
6615
6616
|
return x(() => L(B(Pe(Xt(s, 0), Xt(t, 1))), "float32"));
|
|
6616
6617
|
}
|
|
6617
|
-
function
|
|
6618
|
+
function mf(s, t) {
|
|
6618
6619
|
return x(() => {
|
|
6619
|
-
const e =
|
|
6620
|
+
const e = df(s, t), n = ff(s, t), i = $(e, n);
|
|
6620
6621
|
return L(Ht(Gt(i, 0), P(e, i), 0), "float32");
|
|
6621
6622
|
});
|
|
6622
6623
|
}
|
|
6623
|
-
function
|
|
6624
|
+
function gf(s, t) {
|
|
6624
6625
|
return Sn(s, t);
|
|
6625
6626
|
}
|
|
6626
|
-
function
|
|
6627
|
+
function bf(s, t) {
|
|
6627
6628
|
return s.rank === t.rank && (s = ts(s, [s.rank - 1])), t = sn(t, -1), t.dtype !== s.dtype && (t = L(t, s.dtype)), L(Xt(s, t), "float32");
|
|
6628
6629
|
}
|
|
6629
|
-
const
|
|
6630
|
-
binaryAccuracy:
|
|
6631
|
-
categoricalAccuracy:
|
|
6632
|
-
precision:
|
|
6633
|
-
categoricalCrossentropy:
|
|
6634
|
-
sparseCategoricalCrossentropy:
|
|
6635
|
-
mse:
|
|
6636
|
-
MSE:
|
|
6637
|
-
mae:
|
|
6638
|
-
MAE:
|
|
6639
|
-
mape:
|
|
6640
|
-
MAPE:
|
|
6641
|
-
cosine:
|
|
6630
|
+
const yf = vn, wf = vn, kf = xs, xf = xs, Nf = Ns, vf = Ns, Cr = Oe, Sf = vr, Ir = dn, mn = {
|
|
6631
|
+
binaryAccuracy: Sr,
|
|
6632
|
+
categoricalAccuracy: Ar,
|
|
6633
|
+
precision: mf,
|
|
6634
|
+
categoricalCrossentropy: Cr,
|
|
6635
|
+
sparseCategoricalCrossentropy: Ir,
|
|
6636
|
+
mse: yf,
|
|
6637
|
+
MSE: wf,
|
|
6638
|
+
mae: kf,
|
|
6639
|
+
MAE: xf,
|
|
6640
|
+
mape: Nf,
|
|
6641
|
+
MAPE: vf,
|
|
6642
|
+
cosine: Sf
|
|
6642
6643
|
};
|
|
6643
|
-
function
|
|
6644
|
+
function Af(s) {
|
|
6644
6645
|
if (typeof s == "string" && s in mn)
|
|
6645
6646
|
return mn[s];
|
|
6646
6647
|
if (typeof s != "string" && s != null)
|
|
@@ -6676,7 +6677,7 @@ function tn(s) {
|
|
|
6676
6677
|
* https://opensource.org/licenses/MIT.
|
|
6677
6678
|
* =============================================================================
|
|
6678
6679
|
*/
|
|
6679
|
-
function
|
|
6680
|
+
function Cf(s) {
|
|
6680
6681
|
const t = {
|
|
6681
6682
|
Adagrad: () => ge.adagrad(0.01),
|
|
6682
6683
|
Adadelta: () => ge.adadelta(1, 0.95, nt()),
|
|
@@ -6738,8 +6739,8 @@ function Pn(s) {
|
|
|
6738
6739
|
* https://opensource.org/licenses/MIT.
|
|
6739
6740
|
* =============================================================================
|
|
6740
6741
|
*/
|
|
6741
|
-
function
|
|
6742
|
-
const i =
|
|
6742
|
+
function If(s, t, e, n = console.log) {
|
|
6743
|
+
const i = zf(s), r = ["Layer (type)", "Input Shape", "Output shape", "Param #"];
|
|
6743
6744
|
i ? (t = t || 90, e = e || [0.32, 0.61, 0.89, 1]) : (t = t || 115, e = e || [0.24, 0.48, 0.7, 0.8, 1]), e[e.length - 1] <= 1 && (e = e.map((c) => Math.floor(t * c)));
|
|
6744
6745
|
let a;
|
|
6745
6746
|
if (!i) {
|
|
@@ -6750,16 +6751,16 @@ function Cf(s, t, e, n = console.log) {
|
|
|
6750
6751
|
n("_".repeat(t)), gn(r, e, n), n("=".repeat(t));
|
|
6751
6752
|
const o = s.layers;
|
|
6752
6753
|
for (let c = 0; c < o.length; ++c)
|
|
6753
|
-
i ?
|
|
6754
|
+
i ? Tf(o[c], e, n) : $f(o[c], e, a, n), n((c === o.length - 1 ? "=" : "_").repeat(t));
|
|
6754
6755
|
s.checkTrainableWeightsConsistency();
|
|
6755
|
-
const l =
|
|
6756
|
+
const l = Df(s), u = un(s.nonTrainableWeights);
|
|
6756
6757
|
n(`Total params: ${l + u}`), n(`Trainable params: ${l}`), n(`Non-trainable params: ${u}`), n("_".repeat(t));
|
|
6757
6758
|
}
|
|
6758
|
-
function
|
|
6759
|
+
function Df(s) {
|
|
6759
6760
|
let t;
|
|
6760
6761
|
return s.collectedTrainableWeights != null ? t = un(s.collectedTrainableWeights) : t = un(s.trainableWeights), t;
|
|
6761
6762
|
}
|
|
6762
|
-
function
|
|
6763
|
+
function zf(s) {
|
|
6763
6764
|
let t = !0;
|
|
6764
6765
|
const e = [], n = [];
|
|
6765
6766
|
for (const i in s.nodesByDepth)
|
|
@@ -6792,7 +6793,7 @@ function gn(s, t, e = console.log) {
|
|
|
6792
6793
|
i > 0 && (n = n.slice(0, n.length - 1) + " "), n += s[i], n = n.slice(0, t[i]), n += " ".repeat(t[i] - n.length);
|
|
6793
6794
|
e(n);
|
|
6794
6795
|
}
|
|
6795
|
-
function
|
|
6796
|
+
function Tf(s, t, e) {
|
|
6796
6797
|
let n, i;
|
|
6797
6798
|
try {
|
|
6798
6799
|
i = s.inboundNodes.map((l) => JSON.stringify(l.inputShapes)).join(",");
|
|
@@ -6812,7 +6813,7 @@ function zf(s, t, e) {
|
|
|
6812
6813
|
];
|
|
6813
6814
|
gn(o, t, e);
|
|
6814
6815
|
}
|
|
6815
|
-
function
|
|
6816
|
+
function $f(s, t, e, n) {
|
|
6816
6817
|
let i, r;
|
|
6817
6818
|
try {
|
|
6818
6819
|
r = s.inboundNodes.map((h) => JSON.stringify(h.inputShapes)).join(",");
|
|
@@ -6851,7 +6852,7 @@ function Tf(s, t, e, n) {
|
|
|
6851
6852
|
* https://opensource.org/licenses/MIT.
|
|
6852
6853
|
* =============================================================================
|
|
6853
6854
|
*/
|
|
6854
|
-
function
|
|
6855
|
+
function Dr(s, t, e) {
|
|
6855
6856
|
return (s === "inboundNodes" || s === "outputLayers" || s === "inputLayers") && t === 0 && typeof e == "string";
|
|
6856
6857
|
}
|
|
6857
6858
|
function Un(s, t) {
|
|
@@ -6865,7 +6866,7 @@ function Un(s, t) {
|
|
|
6865
6866
|
const e = [], n = s.length;
|
|
6866
6867
|
for (let i = 0; i < n; ++i) {
|
|
6867
6868
|
const r = s[i];
|
|
6868
|
-
|
|
6869
|
+
Dr(t, i, r) ? e.push(r) : e.push(Un(r, t));
|
|
6869
6870
|
}
|
|
6870
6871
|
return e;
|
|
6871
6872
|
} else {
|
|
@@ -6893,7 +6894,7 @@ function Vn(s, t) {
|
|
|
6893
6894
|
const e = [], n = s.length;
|
|
6894
6895
|
for (let i = 0; i < n; ++i) {
|
|
6895
6896
|
const r = s[i];
|
|
6896
|
-
|
|
6897
|
+
Dr(t, i, r) ? e.push(r) : e.push(Vn(r, t));
|
|
6897
6898
|
}
|
|
6898
6899
|
return e;
|
|
6899
6900
|
} else {
|
|
@@ -6906,7 +6907,7 @@ function Vn(s, t) {
|
|
|
6906
6907
|
}
|
|
6907
6908
|
}
|
|
6908
6909
|
/** @license See the LICENSE file. */
|
|
6909
|
-
const
|
|
6910
|
+
const zr = "4.22.0";
|
|
6910
6911
|
/**
|
|
6911
6912
|
* @license
|
|
6912
6913
|
* Copyright 2018 Google LLC
|
|
@@ -6916,7 +6917,7 @@ const Ir = "4.22.0";
|
|
|
6916
6917
|
* https://opensource.org/licenses/MIT.
|
|
6917
6918
|
* =============================================================================
|
|
6918
6919
|
*/
|
|
6919
|
-
const
|
|
6920
|
+
const Ef = (s) => {
|
|
6920
6921
|
const t = Object.keys(s);
|
|
6921
6922
|
if (t.length === 0)
|
|
6922
6923
|
return !1;
|
|
@@ -6954,7 +6955,7 @@ class vt extends W {
|
|
|
6954
6955
|
(I == null || z == null || _ == null) && (I = y.sourceLayer, z = y.nodeIndex, _ = y.tensorIndex);
|
|
6955
6956
|
const T = I.inboundNodes[z];
|
|
6956
6957
|
if (N.indexOf(T) !== -1)
|
|
6957
|
-
throw new
|
|
6958
|
+
throw new Et(`The tensor ${y.name} at layer "${I.name}" is part of a cycle.`);
|
|
6958
6959
|
if (C.indexOf(T) !== -1)
|
|
6959
6960
|
return;
|
|
6960
6961
|
this.containerNodes.add(vt.nodeKey(I, z)), I.id in a || (a[I.id] = Object.keys(a).length), N.indexOf(T) === -1 && N.push(T);
|
|
@@ -7009,7 +7010,7 @@ class vt extends W {
|
|
|
7009
7010
|
if (N != null) {
|
|
7010
7011
|
for (const I of C.inputTensors)
|
|
7011
7012
|
if (b.indexOf(I) === -1)
|
|
7012
|
-
throw new
|
|
7013
|
+
throw new Et(`Graph disconnected: cannot obtain value for tensor ${I} at layer "${N.name}". The following previous layers were accessed without issue: ${m}`);
|
|
7013
7014
|
for (const I of C.outputTensors)
|
|
7014
7015
|
b.push(I);
|
|
7015
7016
|
m.push(N.name);
|
|
@@ -7020,7 +7021,7 @@ class vt extends W {
|
|
|
7020
7021
|
for (const y of v) {
|
|
7021
7022
|
const C = v.filter((N) => N === y).length;
|
|
7022
7023
|
if (C !== 1)
|
|
7023
|
-
throw new
|
|
7024
|
+
throw new Et(`The name "${y}" is used ${C} times in the model. All layer names should be unique. Layer names: ` + JSON.stringify(v));
|
|
7024
7025
|
}
|
|
7025
7026
|
this.outboundNodes = [], this.inboundNodes = [], new Nn({
|
|
7026
7027
|
outboundLayer: this,
|
|
@@ -7127,7 +7128,7 @@ class vt extends W {
|
|
|
7127
7128
|
loadWeights(t, e = !0) {
|
|
7128
7129
|
const n = {};
|
|
7129
7130
|
let i = 0;
|
|
7130
|
-
const r =
|
|
7131
|
+
const r = Ef(t);
|
|
7131
7132
|
r && this.parseWeights(t);
|
|
7132
7133
|
for (const o of this.layers)
|
|
7133
7134
|
for (const [l, u] of o.weights.entries()) {
|
|
@@ -7170,7 +7171,7 @@ class vt extends W {
|
|
|
7170
7171
|
*/
|
|
7171
7172
|
updatedConfig() {
|
|
7172
7173
|
const t = this.getConfig(), e = {};
|
|
7173
|
-
return e.className = this.getClassName(), e.config = t, e.kerasVersion = `tfjs-layers ${
|
|
7174
|
+
return e.className = this.getClassName(), e.config = t, e.kerasVersion = `tfjs-layers ${zr}`, e.backend = "TensorFlow.js", e;
|
|
7174
7175
|
}
|
|
7175
7176
|
/**
|
|
7176
7177
|
* Returns a JSON string containing the network configuration.
|
|
@@ -7476,7 +7477,7 @@ class vt extends W {
|
|
|
7476
7477
|
const c = e.name, h = e.layers;
|
|
7477
7478
|
for (const m of h)
|
|
7478
7479
|
u(m);
|
|
7479
|
-
for (; !
|
|
7480
|
+
for (; !zu(a); )
|
|
7480
7481
|
for (const m of h) {
|
|
7481
7482
|
const v = r[m.name];
|
|
7482
7483
|
if (v.name in a) {
|
|
@@ -7539,7 +7540,7 @@ class vt extends W {
|
|
|
7539
7540
|
* https://opensource.org/licenses/MIT.
|
|
7540
7541
|
* =============================================================================
|
|
7541
7542
|
*/
|
|
7542
|
-
function
|
|
7543
|
+
function Lf(s, t, e) {
|
|
7543
7544
|
const n = t.length;
|
|
7544
7545
|
if (s == null || Array.isArray(s) && s.length === 0)
|
|
7545
7546
|
return t.map((i) => null);
|
|
@@ -7557,14 +7558,14 @@ function Ef(s, t, e) {
|
|
|
7557
7558
|
} else
|
|
7558
7559
|
throw new Error(`The model has multiple (${n}) outputs, so ${e} must be either an array with ${n} elements or an object with ${t} keys. Provided ${e} not understood: ${JSON.stringify(s)}`);
|
|
7559
7560
|
}
|
|
7560
|
-
function
|
|
7561
|
-
return
|
|
7561
|
+
function Tr(s, t) {
|
|
7562
|
+
return Lf(s, t, "classWeight");
|
|
7562
7563
|
}
|
|
7563
|
-
async function
|
|
7564
|
+
async function $r(s, t, e, n) {
|
|
7564
7565
|
if (e != null) {
|
|
7565
7566
|
const i = x(() => {
|
|
7566
7567
|
if (s.shape.length === 1)
|
|
7567
|
-
return
|
|
7568
|
+
return mu(s);
|
|
7568
7569
|
if (s.shape.length === 2) {
|
|
7569
7570
|
if (s.shape[1] > 1)
|
|
7570
7571
|
return sn(s, 1);
|
|
@@ -7584,7 +7585,7 @@ async function zr(s, t, e, n) {
|
|
|
7584
7585
|
} else
|
|
7585
7586
|
return null;
|
|
7586
7587
|
}
|
|
7587
|
-
function
|
|
7588
|
+
function Ff(s, t) {
|
|
7588
7589
|
return w(s, t);
|
|
7589
7590
|
}
|
|
7590
7591
|
/**
|
|
@@ -7596,8 +7597,8 @@ function Lf(s, t) {
|
|
|
7596
7597
|
* https://opensource.org/licenses/MIT.
|
|
7597
7598
|
* =============================================================================
|
|
7598
7599
|
*/
|
|
7599
|
-
const
|
|
7600
|
-
function
|
|
7600
|
+
const Mf = 32;
|
|
7601
|
+
function Er(s, t) {
|
|
7601
7602
|
let e, n;
|
|
7602
7603
|
const i = t;
|
|
7603
7604
|
e = i.xs, n = i.ys, k(e != null && n != null, () => `A Dataset iterator for fitDataset() is expected to generate objects of the form \`{xs: xVal, ys: yVal}\`, where the two values may be \`tf.Tensor\`, an array of Tensors, or a map of string to Tensor. The provided Dataset instead generates ${t}`);
|
|
@@ -7624,12 +7625,12 @@ function ri(s, t, e) {
|
|
|
7624
7625
|
return n;
|
|
7625
7626
|
}
|
|
7626
7627
|
}
|
|
7627
|
-
function
|
|
7628
|
+
function Of(s) {
|
|
7628
7629
|
if (s.length === 3)
|
|
7629
7630
|
throw new G("Validation with sample weights is not implemented yet.");
|
|
7630
7631
|
return { xs: s[0], ys: s[1] };
|
|
7631
7632
|
}
|
|
7632
|
-
async function
|
|
7633
|
+
async function Rf(s, t, e) {
|
|
7633
7634
|
const n = e.batchesPerEpoch != null;
|
|
7634
7635
|
if (k(s.optimizer != null, () => "You must compile a model before training/testing. Use LayersModel.compile(modelCompileConfig)."), k(e != null, () => "For fitDataset(), the 2nd argument (config) is required, but it is not provided in this call."), k(e.epochs != null && e.epochs > 0 && Number.isInteger(e.epochs), () => `For fitDataset(), config.epochs is expected to be a positive integer, but got ${e.epochs}`), k(!n || e.batchesPerEpoch > 0 && Number.isInteger(e.batchesPerEpoch), () => `For fitDataset(), config.batchesPerEpoch is expected to be a positive integer if specified, but got ${e.batchesPerEpoch}`), k(
|
|
7635
7636
|
// tslint:disable-next-line:no-any
|
|
@@ -7645,19 +7646,19 @@ async function Of(s, t, e) {
|
|
|
7645
7646
|
if (ai(e.validationData))
|
|
7646
7647
|
k(e.validationBatches == null || e.validationBatches > 0 && Number.isInteger(e.validationBatches), () => `For fitDataset() with dataset-based validation, config.validationBatches is expected not to be provided, or to be a positive integer, but got ${e.validationBatches}`);
|
|
7647
7648
|
else {
|
|
7648
|
-
const m =
|
|
7649
|
+
const m = Of(e.validationData);
|
|
7649
7650
|
r = m.xs, a = m.ys;
|
|
7650
7651
|
}
|
|
7651
7652
|
const o = s.makeTrainFunction(), l = s.getDedupedMetricsNames();
|
|
7652
7653
|
let u;
|
|
7653
7654
|
i ? u = l.slice().concat(l.map((m) => "val_" + m)) : u = l.slice();
|
|
7654
|
-
const c =
|
|
7655
|
+
const c = xr(e.callbacks, e.yieldEvery), h = e.verbose == null ? 1 : e.verbose, { callbackList: p, history: f } = Nr(
|
|
7655
7656
|
c,
|
|
7656
7657
|
h,
|
|
7657
7658
|
e.epochs,
|
|
7658
7659
|
null,
|
|
7659
7660
|
null,
|
|
7660
|
-
|
|
7661
|
+
_f(t, e),
|
|
7661
7662
|
null,
|
|
7662
7663
|
// Batch size determined by the dataset itself.
|
|
7663
7664
|
i,
|
|
@@ -7676,13 +7677,13 @@ async function Of(s, t, e) {
|
|
|
7676
7677
|
break;
|
|
7677
7678
|
}
|
|
7678
7679
|
if (C.value != null) {
|
|
7679
|
-
const { xs: N, ys: I } =
|
|
7680
|
+
const { xs: N, ys: I } = Er(s, C.value), z = {};
|
|
7680
7681
|
z.batch = y, z.size = N[0].shape[0], await p.onBatchBegin(y, z);
|
|
7681
7682
|
const _ = [];
|
|
7682
7683
|
if (e.classWeight != null) {
|
|
7683
|
-
const R =
|
|
7684
|
+
const R = Tr(e.classWeight, s.outputNames);
|
|
7684
7685
|
for (let q = 0; q < R.length; ++q)
|
|
7685
|
-
_.push(await
|
|
7686
|
+
_.push(await $r(I[q], null, R[q]));
|
|
7686
7687
|
}
|
|
7687
7688
|
const T = N.concat(I).concat(_), E = o(T);
|
|
7688
7689
|
Z(T);
|
|
@@ -7690,13 +7691,13 @@ async function Of(s, t, e) {
|
|
|
7690
7691
|
const q = l[R], bt = E[R];
|
|
7691
7692
|
z[q] = bt, Bt(bt);
|
|
7692
7693
|
}
|
|
7693
|
-
await p.onBatchEnd(y, z),
|
|
7694
|
+
await p.onBatchEnd(y, z), kr(z), y++, v++;
|
|
7694
7695
|
}
|
|
7695
7696
|
if (n ? v >= e.batchesPerEpoch : C.done) {
|
|
7696
7697
|
if (i) {
|
|
7697
7698
|
let N;
|
|
7698
7699
|
ai(e.validationData) ? N = K(await s.evaluateDataset(e.validationData, { batches: e.validationBatches })) : N = K(s.evaluate(r, a, {
|
|
7699
|
-
batchSize: e.validationBatchSize == null ?
|
|
7700
|
+
batchSize: e.validationBatchSize == null ? Mf : e.validationBatchSize,
|
|
7700
7701
|
verbose: 0
|
|
7701
7702
|
}));
|
|
7702
7703
|
for (let I = 0; I < s.metricsNames.length; ++I)
|
|
@@ -7715,30 +7716,30 @@ async function Of(s, t, e) {
|
|
|
7715
7716
|
s.isTraining = !1;
|
|
7716
7717
|
}
|
|
7717
7718
|
}
|
|
7718
|
-
function
|
|
7719
|
+
function _f(s, t) {
|
|
7719
7720
|
let e = null;
|
|
7720
7721
|
return t.batchesPerEpoch != null ? e = t.batchesPerEpoch : Number.isFinite(s.size) && (e = s.size), e;
|
|
7721
7722
|
}
|
|
7722
7723
|
function ai(s) {
|
|
7723
7724
|
return typeof s.iterator == "function";
|
|
7724
7725
|
}
|
|
7725
|
-
function
|
|
7726
|
+
function Bf(s) {
|
|
7726
7727
|
return typeof s.next == "function";
|
|
7727
7728
|
}
|
|
7728
|
-
async function
|
|
7729
|
+
async function Wf(s, t, e) {
|
|
7729
7730
|
e = e || {};
|
|
7730
7731
|
const n = e.batches != null, i = s.testFunction;
|
|
7731
7732
|
let r = [];
|
|
7732
7733
|
if (e.verbose > 0)
|
|
7733
7734
|
throw new G("Verbose mode is not implemented yet.");
|
|
7734
7735
|
k(!n || e.batches > 0 && Number.isInteger(e.batches), () => `Test loop expects \`batches\` to be a positive integer, but received ${JSON.stringify(e.batches)}`);
|
|
7735
|
-
const a =
|
|
7736
|
+
const a = Bf(t) ? t : await t.iterator();
|
|
7736
7737
|
let o = 0, l = 0;
|
|
7737
7738
|
for (; !n || l < e.batches; ) {
|
|
7738
7739
|
const u = await a.next();
|
|
7739
7740
|
if (r = x(() => {
|
|
7740
7741
|
if (u.value) {
|
|
7741
|
-
const { xs: c, ys: h } =
|
|
7742
|
+
const { xs: c, ys: h } = Er(s, u.value), p = c.concat(h), f = x(() => i(p));
|
|
7742
7743
|
if (Z(p), l === 0)
|
|
7743
7744
|
for (let b = 0; b < f.length; ++b)
|
|
7744
7745
|
r.push(tt(0));
|
|
@@ -7777,7 +7778,7 @@ function Te(s, t, e) {
|
|
|
7777
7778
|
return s == null ? [null] : Array.isArray(s) ? s.map((n) => nn(n, t, e - t)) : nn(s, t, e - t);
|
|
7778
7779
|
}
|
|
7779
7780
|
function jn(s, t) {
|
|
7780
|
-
return x(() => s == null ? null : Array.isArray(s) ? s.map((e) => jn(e, t)) :
|
|
7781
|
+
return x(() => s == null ? null : Array.isArray(s) ? s.map((e) => jn(e, t)) : Vi(s, t.dtype === "int32" ? t : L(t, "int32")));
|
|
7781
7782
|
}
|
|
7782
7783
|
function Ln(s, t) {
|
|
7783
7784
|
const e = [];
|
|
@@ -7786,13 +7787,13 @@ function Ln(s, t) {
|
|
|
7786
7787
|
i = n + t, i >= s && (i = s), e.push([n, i]), n = i;
|
|
7787
7788
|
return e;
|
|
7788
7789
|
}
|
|
7789
|
-
function
|
|
7790
|
+
function Lr(s) {
|
|
7790
7791
|
const t = [];
|
|
7791
7792
|
s instanceof xe && (s = [s]);
|
|
7792
7793
|
for (let e = 0; e < s.length; ++e) {
|
|
7793
7794
|
const n = s[e];
|
|
7794
7795
|
if (n.rank === 1)
|
|
7795
|
-
t.push(
|
|
7796
|
+
t.push(yn(n, 1));
|
|
7796
7797
|
else {
|
|
7797
7798
|
if (n.rank === 0)
|
|
7798
7799
|
throw new Error("Expected tensor to be at least 1D, but received a 0D tensor (scalar).");
|
|
@@ -7839,14 +7840,14 @@ function Nt(s, t) {
|
|
|
7839
7840
|
* https://opensource.org/licenses/MIT.
|
|
7840
7841
|
* =============================================================================
|
|
7841
7842
|
*/
|
|
7842
|
-
function
|
|
7843
|
+
function Gf(s) {
|
|
7843
7844
|
return s instanceof xe;
|
|
7844
7845
|
}
|
|
7845
7846
|
function Kn(s) {
|
|
7846
7847
|
return Array.isArray(s);
|
|
7847
7848
|
}
|
|
7848
7849
|
function oi(s) {
|
|
7849
|
-
return !
|
|
7850
|
+
return !Gf(s) && !Kn(s);
|
|
7850
7851
|
}
|
|
7851
7852
|
function li(s, t, e, n = !0, i = "") {
|
|
7852
7853
|
if (t == null || t.length === 0) {
|
|
@@ -7886,7 +7887,7 @@ function li(s, t, e, n = !0, i = "") {
|
|
|
7886
7887
|
throw new d(`The model ${i} expects ${t.length} Tensor(s), but only received one Tensor. Found: Tensor with shape ${s.shape}`);
|
|
7887
7888
|
r = [s];
|
|
7888
7889
|
}
|
|
7889
|
-
if (r =
|
|
7890
|
+
if (r = Lr(r), e != null)
|
|
7890
7891
|
for (let a = 0; a < t.length; ++a) {
|
|
7891
7892
|
if (e[a] == null)
|
|
7892
7893
|
continue;
|
|
@@ -7903,7 +7904,7 @@ function li(s, t, e, n = !0, i = "") {
|
|
|
7903
7904
|
}
|
|
7904
7905
|
return r;
|
|
7905
7906
|
}
|
|
7906
|
-
function
|
|
7907
|
+
function Pf(s, t, e) {
|
|
7907
7908
|
const n = jt(s.map((r) => r.shape[0]));
|
|
7908
7909
|
n.sort();
|
|
7909
7910
|
const i = jt(t.map((r) => r.shape[0]));
|
|
@@ -7914,7 +7915,7 @@ function Gf(s, t, e) {
|
|
|
7914
7915
|
if (n.length > 0 && i.length > 0 && !Ft(n, i))
|
|
7915
7916
|
throw new d(`Input Tensors should have the same number of samples as target Tensors. Found ${n[0]} input sample(s) and ${i[0]} target sample(s).`);
|
|
7916
7917
|
}
|
|
7917
|
-
function
|
|
7918
|
+
function Uf(s, t, e) {
|
|
7918
7919
|
const n = [
|
|
7919
7920
|
vn,
|
|
7920
7921
|
Sn,
|
|
@@ -7963,7 +7964,7 @@ function ui(s, t, e, n = !0, i = "") {
|
|
|
7963
7964
|
}
|
|
7964
7965
|
}
|
|
7965
7966
|
}
|
|
7966
|
-
function
|
|
7967
|
+
function Vf(s, t) {
|
|
7967
7968
|
if (s == null || Array.isArray(s) && s.length === 0)
|
|
7968
7969
|
return t.map((n) => []);
|
|
7969
7970
|
let e;
|
|
@@ -7984,7 +7985,7 @@ function Uf(s, t) {
|
|
|
7984
7985
|
return n;
|
|
7985
7986
|
}
|
|
7986
7987
|
}
|
|
7987
|
-
const
|
|
7988
|
+
const jf = "layers-model";
|
|
7988
7989
|
class we extends vt {
|
|
7989
7990
|
constructor(t) {
|
|
7990
7991
|
super(t), this.isTraining = !1;
|
|
@@ -8027,7 +8028,7 @@ class we extends vt {
|
|
|
8027
8028
|
summary(t, e, n = console.log) {
|
|
8028
8029
|
if (!this.built)
|
|
8029
8030
|
throw new d("This model has never been called, thus its weights have not been created yet. So no summary can be displayed. Build the model first (e.g., by calling it on some test data).");
|
|
8030
|
-
|
|
8031
|
+
If(this, t, e, n);
|
|
8031
8032
|
}
|
|
8032
8033
|
/**
|
|
8033
8034
|
* Configures and prepares the model for training and evaluation. Compiling
|
|
@@ -8041,9 +8042,9 @@ class we extends vt {
|
|
|
8041
8042
|
*/
|
|
8042
8043
|
compile(t) {
|
|
8043
8044
|
if (t.loss == null && (t.loss = []), this.loss = t.loss, typeof t.optimizer == "string")
|
|
8044
|
-
this.optimizer_ =
|
|
8045
|
+
this.optimizer_ = Cf(t.optimizer), this.isOptimizerOwned = !0;
|
|
8045
8046
|
else {
|
|
8046
|
-
if (!(t.optimizer instanceof
|
|
8047
|
+
if (!(t.optimizer instanceof gu))
|
|
8047
8048
|
throw new d("User-defined optimizer must be an instance of tf.Optimizer.");
|
|
8048
8049
|
this.optimizer_ = t.optimizer, this.isOptimizerOwned = !1;
|
|
8049
8050
|
}
|
|
@@ -8079,7 +8080,7 @@ class we extends vt {
|
|
|
8079
8080
|
this.outputs.length > 1 && (this.metricsTensors.push([o, a]), this.metricsNames.push(this.outputNames[a] + "_loss"));
|
|
8080
8081
|
}
|
|
8081
8082
|
});
|
|
8082
|
-
const i =
|
|
8083
|
+
const i = Vf(t.metrics, this.outputNames), r = (a, o, l) => {
|
|
8083
8084
|
this.outputNames.length > 1 && (o = this.outputNames[a] + "_" + o), this.metricsNames.push(o), this.metricsTensors.push([l, a]);
|
|
8084
8085
|
};
|
|
8085
8086
|
le("metric", () => {
|
|
@@ -8092,11 +8093,11 @@ class we extends vt {
|
|
|
8092
8093
|
for (const g of u) {
|
|
8093
8094
|
if (typeof g == "string" && ["accuracy", "acc", "crossentropy", "ce"].indexOf(g) !== -1) {
|
|
8094
8095
|
const m = this.internalOutputShapes[a];
|
|
8095
|
-
m[m.length - 1] === 1 || this.lossFunctions[a] === Sn ? ["accuracy", "acc"].indexOf(g) !== -1 ? p =
|
|
8096
|
+
m[m.length - 1] === 1 || this.lossFunctions[a] === Sn ? ["accuracy", "acc"].indexOf(g) !== -1 ? p = Sr : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = gf) : this.lossFunctions[a] === dn ? ["accuracy", "acc"].indexOf(g) !== -1 ? p = bf : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = Ir) : ["accuracy", "acc"].indexOf(g) !== -1 ? p = Ar : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = Cr);
|
|
8096
8097
|
let v;
|
|
8097
8098
|
["accuracy", "acc"].indexOf(g) !== -1 ? v = "acc" : ["crossentropy", "ce"].indexOf(g) !== -1 && (v = "ce"), f = p, h = "" + v;
|
|
8098
8099
|
} else
|
|
8099
|
-
f =
|
|
8100
|
+
f = Af(g), h = "" + tn(g);
|
|
8100
8101
|
let b;
|
|
8101
8102
|
le(h, () => {
|
|
8102
8103
|
b = f;
|
|
@@ -8185,7 +8186,7 @@ class we extends vt {
|
|
|
8185
8186
|
* @doc {heading: 'Models', subheading: 'Classes'}
|
|
8186
8187
|
*/
|
|
8187
8188
|
async evaluateDataset(t, e) {
|
|
8188
|
-
return this.makeTestFunction(),
|
|
8189
|
+
return this.makeTestFunction(), Wf(this, t, e);
|
|
8189
8190
|
}
|
|
8190
8191
|
/**
|
|
8191
8192
|
* Get number of samples provided for training, evaluation or prediction.
|
|
@@ -8319,7 +8320,7 @@ class we extends vt {
|
|
|
8319
8320
|
* @doc {heading: 'Models', subheading: 'Classes'}
|
|
8320
8321
|
*/
|
|
8321
8322
|
predict(t, e = {}) {
|
|
8322
|
-
const n =
|
|
8323
|
+
const n = Lr(t);
|
|
8323
8324
|
ui(n, this.inputNames, this.feedInputShapes, !1);
|
|
8324
8325
|
try {
|
|
8325
8326
|
const i = e.batchSize == null ? 32 : e.batchSize;
|
|
@@ -8350,13 +8351,13 @@ class we extends vt {
|
|
|
8350
8351
|
}
|
|
8351
8352
|
standardizeUserDataXY(t, e, n = !0, i) {
|
|
8352
8353
|
if (this.optimizer_ == null)
|
|
8353
|
-
throw new
|
|
8354
|
+
throw new Et("You must compile a model before training/testing. Use LayersModel.compile(modelCompileArgs).");
|
|
8354
8355
|
const r = [];
|
|
8355
8356
|
for (let a = 0; a < this.feedOutputShapes.length; ++a) {
|
|
8356
8357
|
const o = this.feedOutputShapes[a];
|
|
8357
8358
|
this.feedLossFns[a] === dn ? r.push(o.slice(0, o.length - 1).concat([1])) : r.push(o);
|
|
8358
8359
|
}
|
|
8359
|
-
if (t = li(t, this.feedInputNames, this.feedInputShapes, !1, "input"), e = li(e, this.feedOutputNames, r, !1, "target"),
|
|
8360
|
+
if (t = li(t, this.feedInputNames, this.feedInputShapes, !1, "input"), e = li(e, this.feedOutputNames, r, !1, "target"), Pf(t, e), Uf(e, this.feedLossFns, this.feedOutputShapes), this.stateful && i != null && i > 0 && t[0].shape[0] % i !== 0)
|
|
8360
8361
|
throw new d(`In a stateful network, you should only pass inputs with a number of samples that is divisible by the batch size ${i}. Found: ${t[0].shape[0]} sample(s).`);
|
|
8361
8362
|
return [t, e];
|
|
8362
8363
|
}
|
|
@@ -8366,10 +8367,10 @@ class we extends vt {
|
|
|
8366
8367
|
throw new Error("sample weight is not supported yet.");
|
|
8367
8368
|
let u = null;
|
|
8368
8369
|
if (i != null) {
|
|
8369
|
-
const c =
|
|
8370
|
+
const c = Tr(i, this.outputNames);
|
|
8370
8371
|
u = [];
|
|
8371
8372
|
for (let h = 0; h < c.length; ++h)
|
|
8372
|
-
u.push(await
|
|
8373
|
+
u.push(await $r(l[h], null, c[h]));
|
|
8373
8374
|
}
|
|
8374
8375
|
return [o, l, u];
|
|
8375
8376
|
}
|
|
@@ -8392,7 +8393,7 @@ class we extends vt {
|
|
|
8392
8393
|
if (r != null)
|
|
8393
8394
|
throw new G("steps mode in testLoop() is not implemented yet");
|
|
8394
8395
|
{
|
|
8395
|
-
const l = Ln(a, n), u = Rn(
|
|
8396
|
+
const l = Ln(a, n), u = Rn(It(0, a));
|
|
8396
8397
|
for (let c = 0; c < l.length; ++c) {
|
|
8397
8398
|
const h = l[c][0], p = l[c][1], f = nn(u, h, p - h), g = jn(e, f), b = t(g);
|
|
8398
8399
|
if (c === 0)
|
|
@@ -8443,7 +8444,7 @@ class we extends vt {
|
|
|
8443
8444
|
for (let b = 0; b < this.lossFunctions.length; ++b) {
|
|
8444
8445
|
const m = this.lossFunctions[b];
|
|
8445
8446
|
let v = m(i[b], f[b]);
|
|
8446
|
-
r[b] != null && (v =
|
|
8447
|
+
r[b] != null && (v = Ff(v, r[b]));
|
|
8447
8448
|
const y = at(v);
|
|
8448
8449
|
e.push(y), b === 0 ? g = v : g = $(g, v);
|
|
8449
8450
|
}
|
|
@@ -8557,7 +8558,7 @@ class we extends vt {
|
|
|
8557
8558
|
const C = this.makeTrainFunction(), N = this.getDedupedMetricsNames();
|
|
8558
8559
|
let I, z;
|
|
8559
8560
|
m ? (this.makeTestFunction(), I = this.testFunction, z = N.slice().concat(N.map((E) => "val_" + E))) : (I = null, v = [], z = N.slice());
|
|
8560
|
-
const _ =
|
|
8561
|
+
const _ = xr(n.callbacks, n.yieldEvery);
|
|
8561
8562
|
return await this.fitLoop(C, y, N, f, n.epochs, n.verbose, _, I, v, n.shuffle, z, n.initialEpoch, null, null);
|
|
8562
8563
|
} finally {
|
|
8563
8564
|
this.isTraining = !1, Nt(i, t), Nt(r, e), Nt(a, t), Nt(o, e), Nt(c, l), Nt(h, u), p != null && Z(p);
|
|
@@ -8597,8 +8598,8 @@ class we extends vt {
|
|
|
8597
8598
|
throw new d("Can only use `validationSteps` when doing step-wise training, i.e., `stepsPerEpoch` must be set.");
|
|
8598
8599
|
const m = this.checkNumSamples(e, i, f, "steps_per_epoch");
|
|
8599
8600
|
let v;
|
|
8600
|
-
m != null && (v =
|
|
8601
|
-
const { callbackList: y, history: C } =
|
|
8601
|
+
m != null && (v = It(0, m)), a == null && (a = 1);
|
|
8602
|
+
const { callbackList: y, history: C } = Nr(o, a, r, p, m, f, i, b, h);
|
|
8602
8603
|
y.setModel(this), this.history = C, await y.onTrainBegin(), this.stopTraining_ = !1;
|
|
8603
8604
|
for (let N = p; N < r; ++N) {
|
|
8604
8605
|
await y.onEpochBegin(N);
|
|
@@ -8608,7 +8609,7 @@ class we extends vt {
|
|
|
8608
8609
|
{
|
|
8609
8610
|
if (c === "batch")
|
|
8610
8611
|
throw new G("batch shuffling is not implemneted yet");
|
|
8611
|
-
c &&
|
|
8612
|
+
c && bu(v);
|
|
8612
8613
|
const z = Rn(v), _ = Ln(m, i);
|
|
8613
8614
|
for (let T = 0; T < _.length; ++T) {
|
|
8614
8615
|
const E = {};
|
|
@@ -8617,17 +8618,17 @@ class we extends vt {
|
|
|
8617
8618
|
E.batch = T, E.size = q - R;
|
|
8618
8619
|
const ie = jn(e, bt), re = t(ie);
|
|
8619
8620
|
for (let xt = 0; xt < n.length; ++xt) {
|
|
8620
|
-
const
|
|
8621
|
-
E[
|
|
8621
|
+
const Tt = n[xt], me = re[xt];
|
|
8622
|
+
E[Tt] = me, Bt(me);
|
|
8622
8623
|
}
|
|
8623
8624
|
if (T === _.length - 1 && b) {
|
|
8624
8625
|
const xt = this.testLoop(l, u, i);
|
|
8625
|
-
for (let
|
|
8626
|
-
const me = n[
|
|
8626
|
+
for (let Tt = 0; Tt < n.length; ++Tt) {
|
|
8627
|
+
const me = n[Tt], ze = xt[Tt];
|
|
8627
8628
|
Bt(ze), I["val_" + me] = ze;
|
|
8628
8629
|
}
|
|
8629
8630
|
}
|
|
8630
|
-
}), await y.onBatchEnd(T, E),
|
|
8631
|
+
}), await y.onBatchEnd(T, E), kr(E), this.stopTraining_)
|
|
8631
8632
|
break;
|
|
8632
8633
|
}
|
|
8633
8634
|
z.dispose();
|
|
@@ -8661,7 +8662,7 @@ class we extends vt {
|
|
|
8661
8662
|
* @doc {heading: 'Models', subheading: 'Classes'}
|
|
8662
8663
|
*/
|
|
8663
8664
|
async fitDataset(t, e) {
|
|
8664
|
-
return
|
|
8665
|
+
return Rf(this, t, e);
|
|
8665
8666
|
}
|
|
8666
8667
|
/**
|
|
8667
8668
|
* Runs a single gradient update on a single batch of data.
|
|
@@ -8913,7 +8914,7 @@ class we extends vt {
|
|
|
8913
8914
|
*/
|
|
8914
8915
|
async save(t, e) {
|
|
8915
8916
|
if (typeof t == "string") {
|
|
8916
|
-
const u =
|
|
8917
|
+
const u = yu(t);
|
|
8917
8918
|
if (u.length === 0)
|
|
8918
8919
|
throw new d(`Cannot find any save handlers for URL '${t}'`);
|
|
8919
8920
|
if (u.length > 1)
|
|
@@ -8924,14 +8925,14 @@ class we extends vt {
|
|
|
8924
8925
|
throw new d("LayersModel.save() cannot proceed because the IOHandler provided does not have the `save` attribute defined.");
|
|
8925
8926
|
const n = await Os(this.getNamedWeights(e)), o = {
|
|
8926
8927
|
modelTopology: this.toJSON(null, !1),
|
|
8927
|
-
format:
|
|
8928
|
-
generatedBy: `TensorFlow.js tfjs-layers v${
|
|
8928
|
+
format: jf,
|
|
8929
|
+
generatedBy: `TensorFlow.js tfjs-layers v${zr}`,
|
|
8929
8930
|
convertedBy: null
|
|
8930
8931
|
};
|
|
8931
8932
|
if ((e == null ? !1 : e.includeOptimizer) && this.optimizer != null) {
|
|
8932
8933
|
o.trainingConfig = this.getTrainingConfig();
|
|
8933
8934
|
const u = "optimizer", { data: c, specs: h } = await Os(await this.optimizer.getWeights(), u);
|
|
8934
|
-
n.specs.push(...h), n.data =
|
|
8935
|
+
n.specs.push(...h), n.data = wu([n.data, c]);
|
|
8935
8936
|
}
|
|
8936
8937
|
return this.userDefinedMetadata != null && (ii(this.userDefinedMetadata, this.name, !0), o.userDefinedMetadata = this.userDefinedMetadata), o.weightData = n.data, o.weightSpecs = n.specs, t.save(o);
|
|
8937
8938
|
}
|
|
@@ -8963,10 +8964,10 @@ class we extends vt {
|
|
|
8963
8964
|
}
|
|
8964
8965
|
we.className = "Model";
|
|
8965
8966
|
S(we);
|
|
8966
|
-
class
|
|
8967
|
+
class Fr extends we {
|
|
8967
8968
|
}
|
|
8968
|
-
|
|
8969
|
-
S(
|
|
8969
|
+
Fr.className = "Functional";
|
|
8970
|
+
S(Fr);
|
|
8970
8971
|
/**
|
|
8971
8972
|
* @license
|
|
8972
8973
|
* Copyright 2018 Google LLC
|
|
@@ -9036,7 +9037,7 @@ class Re extends we {
|
|
|
9036
9037
|
throw new d(`A layer added to a Sequential model must not already be connected somewhere else. LayersModel received layer ${t.name} which has ${t.inboundNodes.length} pre-existing inbound connections.`);
|
|
9037
9038
|
if (t.inboundNodes[0].outputTensors.length !== 1)
|
|
9038
9039
|
throw new d("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");
|
|
9039
|
-
this.checkShape(t), this.outputs = [t.inboundNodes[0].outputTensors[0]], this.inputs =
|
|
9040
|
+
this.checkShape(t), this.outputs = [t.inboundNodes[0].outputTensors[0]], this.inputs = dr(this.outputs[0]);
|
|
9040
9041
|
}
|
|
9041
9042
|
this.inboundNodes = [], new Nn({
|
|
9042
9043
|
outboundLayer: this,
|
|
@@ -9165,7 +9166,7 @@ class Re extends we {
|
|
|
9165
9166
|
*/
|
|
9166
9167
|
evaluate(t, e, n = {}) {
|
|
9167
9168
|
if (!this.built)
|
|
9168
|
-
throw new
|
|
9169
|
+
throw new Et("The model needs to be compiled before being used.");
|
|
9169
9170
|
return this.model.evaluate(t, e, n);
|
|
9170
9171
|
}
|
|
9171
9172
|
// TODO(cais): Add code snippet below once real dataset objects are
|
|
@@ -9192,7 +9193,7 @@ class Re extends we {
|
|
|
9192
9193
|
*/
|
|
9193
9194
|
async evaluateDataset(t, e) {
|
|
9194
9195
|
if (!this.built)
|
|
9195
|
-
throw new
|
|
9196
|
+
throw new Et("The model needs to be compiled before being used.");
|
|
9196
9197
|
return this.model.evaluateDataset(t, e);
|
|
9197
9198
|
}
|
|
9198
9199
|
/**
|
|
@@ -9282,7 +9283,7 @@ class Re extends we {
|
|
|
9282
9283
|
*/
|
|
9283
9284
|
async fit(t, e, n = {}) {
|
|
9284
9285
|
if (!this.built)
|
|
9285
|
-
throw new
|
|
9286
|
+
throw new Et("The model needs to be compiled before being used.");
|
|
9286
9287
|
return this.model.fit(t, e, n);
|
|
9287
9288
|
}
|
|
9288
9289
|
/**
|
|
@@ -9372,7 +9373,7 @@ class Re extends we {
|
|
|
9372
9373
|
*/
|
|
9373
9374
|
async fitDataset(t, e) {
|
|
9374
9375
|
if (!this.built)
|
|
9375
|
-
throw new
|
|
9376
|
+
throw new Et("The model needs to be compiled before being used.");
|
|
9376
9377
|
return this.model.fitDataset(t, e);
|
|
9377
9378
|
}
|
|
9378
9379
|
/**
|
|
@@ -9485,7 +9486,7 @@ let ut = class extends Be {
|
|
|
9485
9486
|
return {};
|
|
9486
9487
|
}
|
|
9487
9488
|
};
|
|
9488
|
-
class
|
|
9489
|
+
class Mr extends ut {
|
|
9489
9490
|
/**
|
|
9490
9491
|
* Calculate the activation function.
|
|
9491
9492
|
*
|
|
@@ -9494,74 +9495,74 @@ class Lr extends ut {
|
|
|
9494
9495
|
* @return Output of the ELU activation.
|
|
9495
9496
|
*/
|
|
9496
9497
|
apply(t, e = 1) {
|
|
9497
|
-
return
|
|
9498
|
+
return Tu(t, e);
|
|
9498
9499
|
}
|
|
9499
9500
|
}
|
|
9500
|
-
|
|
9501
|
-
S(Lr);
|
|
9502
|
-
class Fr extends ut {
|
|
9503
|
-
apply(t) {
|
|
9504
|
-
return oh(t);
|
|
9505
|
-
}
|
|
9506
|
-
}
|
|
9507
|
-
Fr.className = "selu";
|
|
9508
|
-
S(Fr);
|
|
9509
|
-
class Mr extends ut {
|
|
9510
|
-
apply(t) {
|
|
9511
|
-
return Ve(t);
|
|
9512
|
-
}
|
|
9513
|
-
}
|
|
9514
|
-
Mr.className = "relu";
|
|
9501
|
+
Mr.className = "elu";
|
|
9515
9502
|
S(Mr);
|
|
9516
9503
|
class Or extends ut {
|
|
9517
9504
|
apply(t) {
|
|
9518
|
-
return
|
|
9505
|
+
return oh(t);
|
|
9519
9506
|
}
|
|
9520
9507
|
}
|
|
9521
|
-
Or.className = "
|
|
9508
|
+
Or.className = "selu";
|
|
9522
9509
|
S(Or);
|
|
9523
9510
|
class Rr extends ut {
|
|
9524
9511
|
apply(t) {
|
|
9525
|
-
return t;
|
|
9512
|
+
return Ve(t);
|
|
9526
9513
|
}
|
|
9527
9514
|
}
|
|
9528
|
-
Rr.className = "
|
|
9515
|
+
Rr.className = "relu";
|
|
9529
9516
|
S(Rr);
|
|
9530
9517
|
class _r extends ut {
|
|
9531
9518
|
apply(t) {
|
|
9532
|
-
return
|
|
9519
|
+
return x(() => ji(6, Ve(t)));
|
|
9533
9520
|
}
|
|
9534
9521
|
}
|
|
9535
|
-
_r.className = "
|
|
9522
|
+
_r.className = "relu6";
|
|
9536
9523
|
S(_r);
|
|
9537
9524
|
class Br extends ut {
|
|
9538
9525
|
apply(t) {
|
|
9539
|
-
return
|
|
9526
|
+
return t;
|
|
9540
9527
|
}
|
|
9541
9528
|
}
|
|
9542
|
-
Br.className = "
|
|
9529
|
+
Br.className = "linear";
|
|
9543
9530
|
S(Br);
|
|
9544
9531
|
class Wr extends ut {
|
|
9545
9532
|
apply(t) {
|
|
9546
|
-
return
|
|
9533
|
+
return Qn(t);
|
|
9547
9534
|
}
|
|
9548
9535
|
}
|
|
9549
|
-
Wr.className = "
|
|
9536
|
+
Wr.className = "sigmoid";
|
|
9550
9537
|
S(Wr);
|
|
9551
9538
|
class Gr extends ut {
|
|
9552
9539
|
apply(t) {
|
|
9553
9540
|
return $u(t);
|
|
9554
9541
|
}
|
|
9555
9542
|
}
|
|
9556
|
-
Gr.className = "
|
|
9543
|
+
Gr.className = "hardSigmoid";
|
|
9557
9544
|
S(Gr);
|
|
9558
9545
|
class Pr extends ut {
|
|
9559
9546
|
apply(t) {
|
|
9560
|
-
return
|
|
9547
|
+
return us(t);
|
|
9561
9548
|
}
|
|
9562
9549
|
}
|
|
9563
|
-
Pr.className = "
|
|
9550
|
+
Pr.className = "softplus";
|
|
9564
9551
|
S(Pr);
|
|
9552
|
+
class Ur extends ut {
|
|
9553
|
+
apply(t) {
|
|
9554
|
+
return Eu(t);
|
|
9555
|
+
}
|
|
9556
|
+
}
|
|
9557
|
+
Ur.className = "softsign";
|
|
9558
|
+
S(Ur);
|
|
9559
|
+
class Vr extends ut {
|
|
9560
|
+
apply(t) {
|
|
9561
|
+
return as(t);
|
|
9562
|
+
}
|
|
9563
|
+
}
|
|
9564
|
+
Vr.className = "tanh";
|
|
9565
|
+
S(Vr);
|
|
9565
9566
|
let vs = class extends ut {
|
|
9566
9567
|
/**
|
|
9567
9568
|
* Calculate the activation function.
|
|
@@ -9576,12 +9577,12 @@ let vs = class extends ut {
|
|
|
9576
9577
|
* @throws ValueError: In case `dim(x) < 2`.
|
|
9577
9578
|
*/
|
|
9578
9579
|
apply(t, e = -1) {
|
|
9579
|
-
return
|
|
9580
|
+
return tr(t, e);
|
|
9580
9581
|
}
|
|
9581
9582
|
};
|
|
9582
9583
|
vs.className = "softmax";
|
|
9583
9584
|
S(vs);
|
|
9584
|
-
class
|
|
9585
|
+
class jr extends ut {
|
|
9585
9586
|
/**
|
|
9586
9587
|
* Calculate the activation function of log softmax:
|
|
9587
9588
|
* log( exp(x_i) / sum(exp(x)) )
|
|
@@ -9599,9 +9600,9 @@ class Ur extends ut {
|
|
|
9599
9600
|
return Uc(t, e);
|
|
9600
9601
|
}
|
|
9601
9602
|
}
|
|
9602
|
-
|
|
9603
|
-
S(
|
|
9604
|
-
class
|
|
9603
|
+
jr.className = "logSoftmax";
|
|
9604
|
+
S(jr);
|
|
9605
|
+
class Kr extends ut {
|
|
9605
9606
|
/**
|
|
9606
9607
|
* Calculate the activation function.
|
|
9607
9608
|
*
|
|
@@ -9615,9 +9616,9 @@ class Vr extends ut {
|
|
|
9615
9616
|
}));
|
|
9616
9617
|
}
|
|
9617
9618
|
}
|
|
9618
|
-
|
|
9619
|
-
S(
|
|
9620
|
-
class
|
|
9619
|
+
Kr.className = "gelu";
|
|
9620
|
+
S(Kr);
|
|
9621
|
+
class Hr extends ut {
|
|
9621
9622
|
/**
|
|
9622
9623
|
* Calculate the activation function.
|
|
9623
9624
|
*
|
|
@@ -9628,9 +9629,9 @@ class jr extends ut {
|
|
|
9628
9629
|
return x(() => w(0.5, w(t, $(1, as(w(ee(P(2, Math.PI)), $(t, w(0.044715, qn(t, 3)))))))));
|
|
9629
9630
|
}
|
|
9630
9631
|
}
|
|
9631
|
-
|
|
9632
|
-
S(
|
|
9633
|
-
class
|
|
9632
|
+
Hr.className = "gelu_new";
|
|
9633
|
+
S(Hr);
|
|
9634
|
+
class qr extends ut {
|
|
9634
9635
|
/**
|
|
9635
9636
|
* Calculate the activation function.
|
|
9636
9637
|
*
|
|
@@ -9641,9 +9642,9 @@ class Kr extends ut {
|
|
|
9641
9642
|
return x(() => w(t, as(us(t))));
|
|
9642
9643
|
}
|
|
9643
9644
|
}
|
|
9644
|
-
|
|
9645
|
-
S(
|
|
9646
|
-
class
|
|
9645
|
+
qr.className = "mish";
|
|
9646
|
+
S(qr);
|
|
9647
|
+
class Zr extends ut {
|
|
9647
9648
|
/**
|
|
9648
9649
|
* Calculate the activation function.
|
|
9649
9650
|
*
|
|
@@ -9655,8 +9656,8 @@ class Hr extends ut {
|
|
|
9655
9656
|
return x(() => w(Qn(w(t, e)), t));
|
|
9656
9657
|
}
|
|
9657
9658
|
}
|
|
9658
|
-
|
|
9659
|
-
S(
|
|
9659
|
+
Zr.className = "swish";
|
|
9660
|
+
S(Zr);
|
|
9660
9661
|
function Yt(s) {
|
|
9661
9662
|
return s.getClassName();
|
|
9662
9663
|
}
|
|
@@ -9682,15 +9683,15 @@ function Qt(s) {
|
|
|
9682
9683
|
* https://opensource.org/licenses/MIT.
|
|
9683
9684
|
* =============================================================================
|
|
9684
9685
|
*/
|
|
9685
|
-
function
|
|
9686
|
+
function Kf(s) {
|
|
9686
9687
|
if (s != null && typeof s != "object")
|
|
9687
9688
|
throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an object, but received: ${s}`);
|
|
9688
9689
|
}
|
|
9689
|
-
class
|
|
9690
|
+
class Jr extends Be {
|
|
9690
9691
|
}
|
|
9691
|
-
class
|
|
9692
|
+
class Xr extends Jr {
|
|
9692
9693
|
constructor(t) {
|
|
9693
|
-
super(),
|
|
9694
|
+
super(), Kf(t), this.l1 = t == null || t.l1 == null ? 0.01 : t.l1, this.l2 = t == null || t.l2 == null ? 0.01 : t.l2, this.hasL1 = this.l1 !== 0, this.hasL2 = this.l2 !== 0;
|
|
9694
9695
|
}
|
|
9695
9696
|
/**
|
|
9696
9697
|
* Porting note: Renamed from __call__.
|
|
@@ -9710,8 +9711,8 @@ class Zr extends qr {
|
|
|
9710
9711
|
return new t({ l1: e.l1, l2: e.l2 });
|
|
9711
9712
|
}
|
|
9712
9713
|
}
|
|
9713
|
-
|
|
9714
|
-
S(
|
|
9714
|
+
Xr.className = "L1L2";
|
|
9715
|
+
S(Xr);
|
|
9715
9716
|
const ci = {
|
|
9716
9717
|
l1l2: "L1L2"
|
|
9717
9718
|
};
|
|
@@ -9727,7 +9728,7 @@ function X(s) {
|
|
|
9727
9728
|
if (typeof s == "string") {
|
|
9728
9729
|
const e = { className: s in ci ? ci[s] : s, config: {} };
|
|
9729
9730
|
return hi(e);
|
|
9730
|
-
} else return s instanceof
|
|
9731
|
+
} else return s instanceof Jr ? s : hi(s);
|
|
9731
9732
|
}
|
|
9732
9733
|
/**
|
|
9733
9734
|
* @license
|
|
@@ -9738,14 +9739,14 @@ function X(s) {
|
|
|
9738
9739
|
* https://opensource.org/licenses/MIT.
|
|
9739
9740
|
* =============================================================================
|
|
9740
9741
|
*/
|
|
9741
|
-
class
|
|
9742
|
+
class Yr extends W {
|
|
9742
9743
|
constructor(t) {
|
|
9743
9744
|
super(t ?? {}), this.supportsMasking = !0, t != null && (this.maxValue = t.maxValue);
|
|
9744
9745
|
}
|
|
9745
9746
|
call(t, e) {
|
|
9746
9747
|
t = O(t);
|
|
9747
9748
|
let n = Ve(t);
|
|
9748
|
-
return this.maxValue != null && (n =
|
|
9749
|
+
return this.maxValue != null && (n = Ct(n, 0, this.maxValue)), n;
|
|
9749
9750
|
}
|
|
9750
9751
|
computeOutputShape(t) {
|
|
9751
9752
|
return t;
|
|
@@ -9755,15 +9756,15 @@ class Jr extends W {
|
|
|
9755
9756
|
return Object.assign(t, e), t;
|
|
9756
9757
|
}
|
|
9757
9758
|
}
|
|
9758
|
-
|
|
9759
|
-
S(
|
|
9760
|
-
class
|
|
9759
|
+
Yr.className = "ReLU";
|
|
9760
|
+
S(Yr);
|
|
9761
|
+
class Qr extends W {
|
|
9761
9762
|
constructor(t) {
|
|
9762
9763
|
super(t ?? {}), this.DEFAULT_ALPHA = 0.3, t == null && (t = {}), this.alpha = t.alpha == null ? this.DEFAULT_ALPHA : t.alpha;
|
|
9763
9764
|
}
|
|
9764
9765
|
call(t, e) {
|
|
9765
9766
|
const n = O(t);
|
|
9766
|
-
return
|
|
9767
|
+
return Lu(n, this.alpha);
|
|
9767
9768
|
}
|
|
9768
9769
|
computeOutputShape(t) {
|
|
9769
9770
|
return t;
|
|
@@ -9773,9 +9774,9 @@ class Xr extends W {
|
|
|
9773
9774
|
return Object.assign(t, e), t;
|
|
9774
9775
|
}
|
|
9775
9776
|
}
|
|
9776
|
-
|
|
9777
|
-
S(
|
|
9778
|
-
class
|
|
9777
|
+
Qr.className = "LeakyReLU";
|
|
9778
|
+
S(Qr);
|
|
9779
|
+
class ta extends W {
|
|
9779
9780
|
constructor(t) {
|
|
9780
9781
|
if (super(t ?? {}), this.DEFAULT_ALPHA_INITIALIZER = "zeros", t == null && (t = {}), this.supportsMasking = !0, this.alphaInitializer = J(t.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER), this.alphaRegularizer = X(t.alphaRegularizer), this.alphaConstraint = rt(t.alphaConstraint), t.sharedAxes == null)
|
|
9781
9782
|
this.sharedAxes = null;
|
|
@@ -9803,7 +9804,7 @@ class Yr extends W {
|
|
|
9803
9804
|
})], this.built = !0;
|
|
9804
9805
|
}
|
|
9805
9806
|
call(t, e) {
|
|
9806
|
-
return t = O(t),
|
|
9807
|
+
return t = O(t), Fu(t, this.alpha.read());
|
|
9807
9808
|
}
|
|
9808
9809
|
getConfig() {
|
|
9809
9810
|
const t = {
|
|
@@ -9815,9 +9816,9 @@ class Yr extends W {
|
|
|
9815
9816
|
return Object.assign(t, e), t;
|
|
9816
9817
|
}
|
|
9817
9818
|
}
|
|
9818
|
-
|
|
9819
|
-
S(
|
|
9820
|
-
class
|
|
9819
|
+
ta.className = "PReLU";
|
|
9820
|
+
S(ta);
|
|
9821
|
+
class ea extends W {
|
|
9821
9822
|
constructor(t) {
|
|
9822
9823
|
if (super(t ?? {}), this.DEFAULT_ALPHA = 1, t == null && (t = {}), t.alpha != null && t.alpha !== this.DEFAULT_ALPHA)
|
|
9823
9824
|
throw new G(`Non-default alpha value (${t.alpha}) is not supported by the ELU layer yet.`);
|
|
@@ -9825,7 +9826,7 @@ class Qr extends W {
|
|
|
9825
9826
|
}
|
|
9826
9827
|
call(t, e) {
|
|
9827
9828
|
const n = O(t);
|
|
9828
|
-
return
|
|
9829
|
+
return Mu(n);
|
|
9829
9830
|
}
|
|
9830
9831
|
computeOutputShape(t) {
|
|
9831
9832
|
return t;
|
|
@@ -9835,9 +9836,9 @@ class Qr extends W {
|
|
|
9835
9836
|
return Object.assign(t, e), t;
|
|
9836
9837
|
}
|
|
9837
9838
|
}
|
|
9838
|
-
|
|
9839
|
-
S(
|
|
9840
|
-
class
|
|
9839
|
+
ea.className = "ELU";
|
|
9840
|
+
S(ea);
|
|
9841
|
+
class na extends W {
|
|
9841
9842
|
constructor(t) {
|
|
9842
9843
|
super(t ?? {}), this.DEFAULT_THETA = 1, t == null && (t = {}), this.theta = t.theta == null ? this.DEFAULT_THETA : t.theta;
|
|
9843
9844
|
}
|
|
@@ -9853,9 +9854,9 @@ class ta extends W {
|
|
|
9853
9854
|
return Object.assign(t, e), t;
|
|
9854
9855
|
}
|
|
9855
9856
|
}
|
|
9856
|
-
|
|
9857
|
-
S(
|
|
9858
|
-
class
|
|
9857
|
+
na.className = "ThresholdedReLU";
|
|
9858
|
+
S(na);
|
|
9859
|
+
class sa extends W {
|
|
9859
9860
|
constructor(t) {
|
|
9860
9861
|
super(t ?? {}), this.DEFAULT_AXIS = 1, t == null && (t = {}), this.softmax = new vs().apply, this.axis = t.axis == null ? this.DEFAULT_AXIS : t.axis;
|
|
9861
9862
|
}
|
|
@@ -9867,7 +9868,7 @@ class ea extends W {
|
|
|
9867
9868
|
const r = w(V(pe(n.shape), L(i, n.dtype)), tt(-1e9));
|
|
9868
9869
|
n = $(n, r);
|
|
9869
9870
|
}
|
|
9870
|
-
return this.axis instanceof Array ? this.axis.length > 1 ? Jt(V(n,
|
|
9871
|
+
return this.axis instanceof Array ? this.axis.length > 1 ? Jt(V(n, Vu(n, this.axis, !0))) : this.softmax(n, this.axis[0]) : this.softmax(n, this.axis);
|
|
9871
9872
|
});
|
|
9872
9873
|
}
|
|
9873
9874
|
computeOutputShape(t) {
|
|
@@ -9878,8 +9879,8 @@ class ea extends W {
|
|
|
9878
9879
|
return Object.assign(t, e), t;
|
|
9879
9880
|
}
|
|
9880
9881
|
}
|
|
9881
|
-
|
|
9882
|
-
S(
|
|
9882
|
+
sa.className = "Softmax";
|
|
9883
|
+
S(sa);
|
|
9883
9884
|
/**
|
|
9884
9885
|
* @license
|
|
9885
9886
|
* Copyright 2018 Google LLC
|
|
@@ -9896,19 +9897,19 @@ function ke(s, t, e) {
|
|
|
9896
9897
|
throw new d(`The ${e} argument must be an integer or tuple of ${t} integers. Received: ${s.length} elements.`);
|
|
9897
9898
|
for (let n = 0; n < t; ++n) {
|
|
9898
9899
|
const i = s[n];
|
|
9899
|
-
if (!
|
|
9900
|
+
if (!Ou(i))
|
|
9900
9901
|
throw new d(`The ${e} argument must be an integer or tuple of ${t} integers. Received: ${JSON.stringify(s)} including a non-integer number ${i}`);
|
|
9901
9902
|
}
|
|
9902
9903
|
return s;
|
|
9903
9904
|
}
|
|
9904
|
-
function
|
|
9905
|
+
function At(s, t, e, n, i = 1) {
|
|
9905
9906
|
if (s == null)
|
|
9906
9907
|
return s;
|
|
9907
9908
|
const r = t + (t - 1) * (i - 1);
|
|
9908
9909
|
let a;
|
|
9909
9910
|
return e === "same" ? a = s : a = s - r + 1, Math.floor((a + n - 1) / n);
|
|
9910
9911
|
}
|
|
9911
|
-
function
|
|
9912
|
+
function $t(s, t, e, n) {
|
|
9912
9913
|
if (s == null)
|
|
9913
9914
|
return null;
|
|
9914
9915
|
if (n === "valid")
|
|
@@ -9931,10 +9932,10 @@ function Tt(s, t, e, n) {
|
|
|
9931
9932
|
function Ss(s, t) {
|
|
9932
9933
|
return x(() => (et(t), t === "channelsFirst" ? j(s, [0, 2, 3, 1]) : s));
|
|
9933
9934
|
}
|
|
9934
|
-
function
|
|
9935
|
+
function ia(s, t) {
|
|
9935
9936
|
return x(() => (et(t), t === "channelsFirst" ? j(s, [0, 2, 3, 4, 1]) : s));
|
|
9936
9937
|
}
|
|
9937
|
-
function
|
|
9938
|
+
function Hf(s, t, e, n = 1, i = "valid", r, a = 1) {
|
|
9938
9939
|
return x(() => {
|
|
9939
9940
|
if (r == null && (r = ne()), et(r), s.shape.length !== 3)
|
|
9940
9941
|
throw new d(`The input of a conv1dWithBias operation should be 3, but is ${s.shape.length} instead.`);
|
|
@@ -9945,7 +9946,7 @@ function Kf(s, t, e, n = 1, i = "valid", r, a = 1) {
|
|
|
9945
9946
|
if (r === "channelsFirst" && (s = j(s, [0, 2, 1])), i === "causal")
|
|
9946
9947
|
throw new G("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");
|
|
9947
9948
|
let o = kc(s, t, n, i === "same" ? "same" : "valid", "NWC", a);
|
|
9948
|
-
return e != null && (o =
|
|
9949
|
+
return e != null && (o = zt(o, e)), o;
|
|
9949
9950
|
});
|
|
9950
9951
|
}
|
|
9951
9952
|
function pi(s, t, e, n = [1, 1], i = "valid", r, a, o = null) {
|
|
@@ -9969,16 +9970,16 @@ function pi(s, t, e, n = [1, 1], i = "valid", r, a, o = null) {
|
|
|
9969
9970
|
}), r === "channelsFirst" && (l = j(l, [0, 3, 1, 2])), l;
|
|
9970
9971
|
});
|
|
9971
9972
|
}
|
|
9972
|
-
function
|
|
9973
|
+
function qf(s, t, e, n = [1, 1, 1], i = "valid", r, a) {
|
|
9973
9974
|
return x(() => {
|
|
9974
9975
|
if (r == null && (r = ne()), et(r), s.rank !== 4 && s.rank !== 5)
|
|
9975
9976
|
throw new d(`conv3dWithBias expects input to be of rank 4 or 5, but received ${s.rank}.`);
|
|
9976
9977
|
if (t.rank !== 4 && t.rank !== 5)
|
|
9977
9978
|
throw new d(`conv3dWithBias expects kernel to be of rank 4 or 5, but received ${s.rank}.`);
|
|
9978
|
-
let o =
|
|
9979
|
+
let o = ia(s, r);
|
|
9979
9980
|
if (i === "causal")
|
|
9980
9981
|
throw new G("The support for CAUSAL padding mode in conv3dWithBias is not implemented yet.");
|
|
9981
|
-
return o = Ac(o, t, n, i === "same" ? "same" : "valid", "NDHWC", a), e != null && (o =
|
|
9982
|
+
return o = Ac(o, t, n, i === "same" ? "same" : "valid", "NDHWC", a), e != null && (o = zt(o, e)), r === "channelsFirst" && (o = j(o, [0, 4, 1, 2, 3])), o;
|
|
9982
9983
|
});
|
|
9983
9984
|
}
|
|
9984
9985
|
class An extends W {
|
|
@@ -10036,16 +10037,16 @@ class De extends An {
|
|
|
10036
10037
|
return x(() => {
|
|
10037
10038
|
t = O(t);
|
|
10038
10039
|
let n;
|
|
10039
|
-
const i = this.bias == null ? null : this.bias.read(), r =
|
|
10040
|
+
const i = this.bias == null ? null : this.bias.read(), r = Ki(this.activation.getClassName());
|
|
10040
10041
|
if (r != null && this.rank === 2)
|
|
10041
10042
|
n = pi(t, this.kernel.read(), i, this.strides, this.padding, this.dataFormat, this.dilationRate, r);
|
|
10042
10043
|
else {
|
|
10043
10044
|
if (this.rank === 1)
|
|
10044
|
-
n =
|
|
10045
|
+
n = Hf(t, this.kernel.read(), i, this.strides[0], this.padding, this.dataFormat, this.dilationRate[0]);
|
|
10045
10046
|
else if (this.rank === 2)
|
|
10046
10047
|
n = pi(t, this.kernel.read(), i, this.strides, this.padding, this.dataFormat, this.dilationRate);
|
|
10047
10048
|
else if (this.rank === 3)
|
|
10048
|
-
n =
|
|
10049
|
+
n = qf(t, this.kernel.read(), i, this.strides, this.padding, this.dataFormat, this.dilationRate);
|
|
10049
10050
|
else
|
|
10050
10051
|
throw new G("convolutions greater than 3D are not implemented yet.");
|
|
10051
10052
|
this.activation != null && (n = this.activation.apply(n));
|
|
@@ -10057,7 +10058,7 @@ class De extends An {
|
|
|
10057
10058
|
t = U(t);
|
|
10058
10059
|
const e = [], n = this.dataFormat === "channelsLast" ? t.slice(1, t.length - 1) : t.slice(2);
|
|
10059
10060
|
for (let r = 0; r < n.length; ++r) {
|
|
10060
|
-
const a =
|
|
10061
|
+
const a = At(n[r], this.kernelSize[r], this.padding, this.strides[r], typeof this.dilationRate == "number" ? this.dilationRate : this.dilationRate[r]);
|
|
10061
10062
|
e.push(a);
|
|
10062
10063
|
}
|
|
10063
10064
|
let i = [t[0]];
|
|
@@ -10107,7 +10108,7 @@ class Je extends De {
|
|
|
10107
10108
|
}
|
|
10108
10109
|
Je.className = "Conv3D";
|
|
10109
10110
|
S(Je);
|
|
10110
|
-
class
|
|
10111
|
+
class ra extends Ze {
|
|
10111
10112
|
constructor(t) {
|
|
10112
10113
|
if (super(t), this.inputSpec = [new st({ ndim: 4 })], this.padding !== "same" && this.padding !== "valid")
|
|
10113
10114
|
throw new d(`Conv2DTranspose currently supports only padding modes 'same' and 'valid', but received padding mode ${this.padding}`);
|
|
@@ -10129,10 +10130,10 @@ class sa extends Ze {
|
|
|
10129
10130
|
const i = n.shape, r = i[0];
|
|
10130
10131
|
let a, o;
|
|
10131
10132
|
this.dataFormat === "channelsFirst" ? (a = 2, o = 3) : (a = 1, o = 2);
|
|
10132
|
-
const l = i[a], u = i[o], c = this.kernelSize[0], h = this.kernelSize[1], p = this.strides[0], f = this.strides[1], g =
|
|
10133
|
+
const l = i[a], u = i[o], c = this.kernelSize[0], h = this.kernelSize[1], p = this.strides[0], f = this.strides[1], g = $t(l, p, c, this.padding), b = $t(u, f, h, this.padding), m = [r, g, b, this.filters];
|
|
10133
10134
|
this.dataFormat !== "channelsLast" && (n = j(n, [0, 2, 3, 1]));
|
|
10134
10135
|
let v = vc(n, this.kernel.read(), m, this.strides, this.padding);
|
|
10135
|
-
return this.dataFormat !== "channelsLast" && (v = j(v, [0, 3, 1, 2])), this.bias != null && (v =
|
|
10136
|
+
return this.dataFormat !== "channelsLast" && (v = j(v, [0, 3, 1, 2])), this.bias != null && (v = zt(v, this.bias.read(), this.dataFormat)), this.activation != null && (v = this.activation.apply(v)), v;
|
|
10136
10137
|
});
|
|
10137
10138
|
}
|
|
10138
10139
|
computeOutputShape(t) {
|
|
@@ -10141,16 +10142,16 @@ class sa extends Ze {
|
|
|
10141
10142
|
let n, i, r;
|
|
10142
10143
|
this.dataFormat === "channelsFirst" ? (n = 1, i = 2, r = 3) : (n = 3, i = 1, r = 2);
|
|
10143
10144
|
const a = this.kernelSize[0], o = this.kernelSize[1], l = this.strides[0], u = this.strides[1];
|
|
10144
|
-
return e[n] = this.filters, e[i] =
|
|
10145
|
+
return e[n] = this.filters, e[i] = $t(e[i], l, a, this.padding), e[r] = $t(e[r], u, o, this.padding), e;
|
|
10145
10146
|
}
|
|
10146
10147
|
getConfig() {
|
|
10147
10148
|
const t = super.getConfig();
|
|
10148
10149
|
return delete t.dilationRate, t;
|
|
10149
10150
|
}
|
|
10150
10151
|
}
|
|
10151
|
-
|
|
10152
|
-
S(
|
|
10153
|
-
class
|
|
10152
|
+
ra.className = "Conv2DTranspose";
|
|
10153
|
+
S(ra);
|
|
10154
|
+
class aa extends Je {
|
|
10154
10155
|
constructor(t) {
|
|
10155
10156
|
if (super(t), this.inputSpec = [new st({ ndim: 5 })], this.padding !== "same" && this.padding !== "valid")
|
|
10156
10157
|
throw new d(`Conv3DTranspose currently supports only padding modes 'same' and 'valid', but received padding mode ${this.padding}`);
|
|
@@ -10172,10 +10173,10 @@ class ia extends Je {
|
|
|
10172
10173
|
const i = n.shape, r = i[0];
|
|
10173
10174
|
let a, o, l;
|
|
10174
10175
|
this.dataFormat === "channelsFirst" ? (l = 2, a = 3, o = 4) : (l = 1, a = 2, o = 3);
|
|
10175
|
-
const u = i[l], c = i[a], h = i[o], p = this.kernelSize[0], f = this.kernelSize[1], g = this.kernelSize[2], b = this.strides[0], m = this.strides[1], v = this.strides[2], y =
|
|
10176
|
+
const u = i[l], c = i[a], h = i[o], p = this.kernelSize[0], f = this.kernelSize[1], g = this.kernelSize[2], b = this.strides[0], m = this.strides[1], v = this.strides[2], y = $t(u, b, p, this.padding), C = $t(c, m, f, this.padding), N = $t(h, v, g, this.padding), I = [r, y, C, N, this.filters];
|
|
10176
10177
|
this.dataFormat !== "channelsLast" && (n = j(n, [0, 2, 3, 4, 1]));
|
|
10177
10178
|
let z = Dc(n, this.kernel.read(), I, this.strides, this.padding);
|
|
10178
|
-
return this.dataFormat !== "channelsLast" && (z = j(z, [0, 4, 1, 2, 3])), this.bias !== null && (z =
|
|
10179
|
+
return this.dataFormat !== "channelsLast" && (z = j(z, [0, 4, 1, 2, 3])), this.bias !== null && (z = zt(z, this.bias.read(), this.dataFormat)), this.activation !== null && (z = this.activation.apply(z)), z;
|
|
10179
10180
|
});
|
|
10180
10181
|
}
|
|
10181
10182
|
computeOutputShape(t) {
|
|
@@ -10184,16 +10185,16 @@ class ia extends Je {
|
|
|
10184
10185
|
let n, i, r, a;
|
|
10185
10186
|
this.dataFormat === "channelsFirst" ? (n = 1, i = 2, r = 3, a = 4) : (n = 4, i = 1, r = 2, a = 3);
|
|
10186
10187
|
const o = this.kernelSize[0], l = this.kernelSize[1], u = this.kernelSize[2], c = this.strides[0], h = this.strides[1], p = this.strides[2];
|
|
10187
|
-
return e[n] = this.filters, e[i] =
|
|
10188
|
+
return e[n] = this.filters, e[i] = $t(e[i], c, o, this.padding), e[r] = $t(e[r], h, l, this.padding), e[a] = $t(e[a], p, u, this.padding), e;
|
|
10188
10189
|
}
|
|
10189
10190
|
getConfig() {
|
|
10190
10191
|
const t = super.getConfig();
|
|
10191
10192
|
return delete t.dilationRate, t;
|
|
10192
10193
|
}
|
|
10193
10194
|
}
|
|
10194
|
-
|
|
10195
|
-
S(
|
|
10196
|
-
class
|
|
10195
|
+
aa.className = "Conv3DTranspose";
|
|
10196
|
+
S(aa);
|
|
10197
|
+
class oa extends De {
|
|
10197
10198
|
constructor(t, e) {
|
|
10198
10199
|
if (super(t, e), this.DEFAULT_DEPTHWISE_INITIALIZER = "glorotUniform", this.DEFAULT_POINTWISE_INITIALIZER = "glorotUniform", this.depthwiseKernel = null, this.pointwiseKernel = null, e.filters == null)
|
|
10199
10200
|
throw new d("The `filters` configuration field is required by SeparableConv, but is unspecified.");
|
|
@@ -10222,7 +10223,7 @@ class ra extends De {
|
|
|
10222
10223
|
let n;
|
|
10223
10224
|
if (this.rank === 1)
|
|
10224
10225
|
throw new G("1D separable convolution is not implemented yet.");
|
|
10225
|
-
return this.rank === 2 && (this.dataFormat === "channelsFirst" && (t = j(t, [0, 2, 3, 1])), n = uh(t, this.depthwiseKernel.read(), this.pointwiseKernel.read(), this.strides, this.padding, this.dilationRate, "NHWC")), this.useBias && (n =
|
|
10226
|
+
return this.rank === 2 && (this.dataFormat === "channelsFirst" && (t = j(t, [0, 2, 3, 1])), n = uh(t, this.depthwiseKernel.read(), this.pointwiseKernel.read(), this.strides, this.padding, this.dilationRate, "NHWC")), this.useBias && (n = zt(n, this.bias.read(), this.dataFormat)), this.activation != null && (n = this.activation.apply(n)), this.dataFormat === "channelsFirst" && (n = j(n, [0, 3, 1, 2])), n;
|
|
10226
10227
|
});
|
|
10227
10228
|
}
|
|
10228
10229
|
getConfig() {
|
|
@@ -10230,14 +10231,14 @@ class ra extends De {
|
|
|
10230
10231
|
return delete t.rank, delete t.kernelInitializer, delete t.kernelRegularizer, delete t.kernelConstraint, t.depthwiseInitializer = Y(this.depthwiseInitializer), t.pointwiseInitializer = Y(this.pointwiseInitializer), t.depthwiseRegularizer = H(this.depthwiseRegularizer), t.pointwiseRegularizer = H(this.pointwiseRegularizer), t.depthwiseConstraint = it(this.depthwiseConstraint), t.pointwiseConstraint = it(this.pointwiseConstraint), t;
|
|
10231
10232
|
}
|
|
10232
10233
|
}
|
|
10233
|
-
|
|
10234
|
-
class
|
|
10234
|
+
oa.className = "SeparableConv";
|
|
10235
|
+
class la extends oa {
|
|
10235
10236
|
constructor(t) {
|
|
10236
10237
|
super(2, t);
|
|
10237
10238
|
}
|
|
10238
10239
|
}
|
|
10239
|
-
|
|
10240
|
-
S(
|
|
10240
|
+
la.className = "SeparableConv2D";
|
|
10241
|
+
S(la);
|
|
10241
10242
|
class Cn extends De {
|
|
10242
10243
|
constructor(t) {
|
|
10243
10244
|
super(1, t), Cn.verifyArgs(t), this.inputSpec = [{ ndim: 3 }];
|
|
@@ -10253,7 +10254,7 @@ class Cn extends De {
|
|
|
10253
10254
|
}
|
|
10254
10255
|
Cn.className = "Conv1D";
|
|
10255
10256
|
S(Cn);
|
|
10256
|
-
class
|
|
10257
|
+
class ua extends W {
|
|
10257
10258
|
constructor(t) {
|
|
10258
10259
|
super(t), typeof t.cropping == "number" ? this.cropping = [[t.cropping, t.cropping], [t.cropping, t.cropping]] : typeof t.cropping[0] == "number" ? this.cropping = [
|
|
10259
10260
|
[t.cropping[0], t.cropping[0]],
|
|
@@ -10289,11 +10290,11 @@ class oa extends W {
|
|
|
10289
10290
|
return Object.assign(t, e), t;
|
|
10290
10291
|
}
|
|
10291
10292
|
}
|
|
10292
|
-
|
|
10293
|
-
S(
|
|
10294
|
-
class
|
|
10293
|
+
ua.className = "Cropping2D";
|
|
10294
|
+
S(ua);
|
|
10295
|
+
class ca extends W {
|
|
10295
10296
|
constructor(t) {
|
|
10296
|
-
super(t), this.DEFAULT_SIZE = [2, 2], this.inputSpec = [{ ndim: 4 }], this.size = t.size == null ? this.DEFAULT_SIZE : t.size, this.dataFormat = t.dataFormat == null ? "channelsLast" : t.dataFormat, et(this.dataFormat), this.interpolation = t.interpolation == null ? "nearest" : t.interpolation,
|
|
10297
|
+
super(t), this.DEFAULT_SIZE = [2, 2], this.inputSpec = [{ ndim: 4 }], this.size = t.size == null ? this.DEFAULT_SIZE : t.size, this.dataFormat = t.dataFormat == null ? "channelsLast" : t.dataFormat, et(this.dataFormat), this.interpolation = t.interpolation == null ? "nearest" : t.interpolation, Ru(this.interpolation);
|
|
10297
10298
|
}
|
|
10298
10299
|
computeOutputShape(t) {
|
|
10299
10300
|
if (this.dataFormat === "channelsFirst") {
|
|
@@ -10327,8 +10328,8 @@ class la extends W {
|
|
|
10327
10328
|
return Object.assign(t, e), t;
|
|
10328
10329
|
}
|
|
10329
10330
|
}
|
|
10330
|
-
|
|
10331
|
-
S(
|
|
10331
|
+
ca.className = "UpSampling2D";
|
|
10332
|
+
S(ca);
|
|
10332
10333
|
/**
|
|
10333
10334
|
* @license
|
|
10334
10335
|
* Copyright 2018 Google LLC
|
|
@@ -10338,7 +10339,7 @@ S(la);
|
|
|
10338
10339
|
* https://opensource.org/licenses/MIT.
|
|
10339
10340
|
* =============================================================================
|
|
10340
10341
|
*/
|
|
10341
|
-
function
|
|
10342
|
+
function Zf(s, t, e = [1, 1], n = "valid", i, r) {
|
|
10342
10343
|
return x(() => {
|
|
10343
10344
|
i == null && (i = ne()), et(i);
|
|
10344
10345
|
let a = Ss(s, i);
|
|
@@ -10346,10 +10347,10 @@ function qf(s, t, e = [1, 1], n = "valid", i, r) {
|
|
|
10346
10347
|
throw new d(`Input for depthwiseConv2d is required to be 4-D, but is instead ${s.rank}-D`);
|
|
10347
10348
|
if (t.rank !== 4)
|
|
10348
10349
|
throw new d(`depthwiseKernel is required to be 4-D, but is instead ${t.rank}-D`);
|
|
10349
|
-
return a =
|
|
10350
|
+
return a = nr(a, t, e, n === "same" ? "same" : "valid", "NHWC", r), i === "channelsFirst" && (a = j(a, [0, 3, 1, 2])), a;
|
|
10350
10351
|
});
|
|
10351
10352
|
}
|
|
10352
|
-
class
|
|
10353
|
+
class ha extends An {
|
|
10353
10354
|
constructor(t) {
|
|
10354
10355
|
super(2, t), this.depthwiseKernel = null, this.depthMultiplier = t.depthMultiplier == null ? 1 : t.depthMultiplier, this.depthwiseInitializer = J(t.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.depthwiseConstraint = rt(t.depthwiseConstraint), this.depthwiseRegularizer = X(t.depthwiseRegularizer);
|
|
10355
10356
|
}
|
|
@@ -10370,13 +10371,13 @@ class ua extends An {
|
|
|
10370
10371
|
call(t, e) {
|
|
10371
10372
|
return x(() => {
|
|
10372
10373
|
t = O(t);
|
|
10373
|
-
let n =
|
|
10374
|
-
return this.useBias && (n =
|
|
10374
|
+
let n = Zf(t, this.depthwiseKernel.read(), this.strides, this.padding, this.dataFormat, null);
|
|
10375
|
+
return this.useBias && (n = zt(n, this.bias.read(), this.dataFormat)), this.activation != null && (n = this.activation.apply(n)), n;
|
|
10375
10376
|
});
|
|
10376
10377
|
}
|
|
10377
10378
|
computeOutputShape(t) {
|
|
10378
10379
|
t = U(t);
|
|
10379
|
-
const e = this.dataFormat === "channelsFirst" ? t[2] : t[1], n = this.dataFormat === "channelsFirst" ? t[3] : t[2], i = this.dataFormat === "channelsFirst" ? t[1] * this.depthMultiplier : t[3] * this.depthMultiplier, r =
|
|
10380
|
+
const e = this.dataFormat === "channelsFirst" ? t[2] : t[1], n = this.dataFormat === "channelsFirst" ? t[3] : t[2], i = this.dataFormat === "channelsFirst" ? t[1] * this.depthMultiplier : t[3] * this.depthMultiplier, r = At(e, this.kernelSize[0], this.padding, this.strides[0]), a = At(n, this.kernelSize[1], this.padding, this.strides[1]);
|
|
10380
10381
|
return this.dataFormat === "channelsFirst" ? [t[0], i, r, a] : [t[0], r, a, i];
|
|
10381
10382
|
}
|
|
10382
10383
|
getConfig() {
|
|
@@ -10384,8 +10385,8 @@ class ua extends An {
|
|
|
10384
10385
|
return t.depthMultiplier = this.depthMultiplier, t.depthwiseInitializer = Y(this.depthwiseInitializer), t.depthwiseRegularizer = H(this.depthwiseRegularizer), t.depthwiseConstraint = it(this.depthwiseRegularizer), t;
|
|
10385
10386
|
}
|
|
10386
10387
|
}
|
|
10387
|
-
|
|
10388
|
-
S(
|
|
10388
|
+
ha.className = "DepthwiseConv2D";
|
|
10389
|
+
S(ha);
|
|
10389
10390
|
/**
|
|
10390
10391
|
* @license
|
|
10391
10392
|
* Copyright 2018 Google LLC
|
|
@@ -10395,7 +10396,7 @@ S(ua);
|
|
|
10395
10396
|
* https://opensource.org/licenses/MIT.
|
|
10396
10397
|
* =============================================================================
|
|
10397
10398
|
*/
|
|
10398
|
-
function
|
|
10399
|
+
function pa(s, t, e, n) {
|
|
10399
10400
|
if (Array.isArray(s)) {
|
|
10400
10401
|
if (t != null || e != null)
|
|
10401
10402
|
throw new d("When inputs is an array, neither initialState or constants should be provided");
|
|
@@ -10406,12 +10407,12 @@ function ca(s, t, e, n) {
|
|
|
10406
10407
|
}
|
|
10407
10408
|
return t = i(t), e = i(e), { inputs: s, initialState: t, constants: e };
|
|
10408
10409
|
}
|
|
10409
|
-
function
|
|
10410
|
+
function da(s, t, e, n = !1, i, r, a = !1, o = !1) {
|
|
10410
10411
|
return x(() => {
|
|
10411
10412
|
const l = t.shape.length;
|
|
10412
10413
|
if (l < 3)
|
|
10413
10414
|
throw new d(`Input should be at least 3D, but is ${l}D.`);
|
|
10414
|
-
const u = [1, 0].concat(
|
|
10415
|
+
const u = [1, 0].concat(It(2, l));
|
|
10415
10416
|
t = j(t, u), a && console.warn("Backend rnn(): the unroll = true option is not applicable to the imperative deeplearn.js backend."), i != null && (i = L(L(i, "bool"), "float32"), i.rank === l - 1 && (i = ue(i, -1)), i = j(i, u)), n && (t = on(t, 0), i != null && (i = on(i, 0)));
|
|
10416
10417
|
const c = [];
|
|
10417
10418
|
let h, p = e;
|
|
@@ -10424,7 +10425,7 @@ function ha(s, t, e, n = !1, i, r, a = !1, o = !1) {
|
|
|
10424
10425
|
h = C[0], p = C[1];
|
|
10425
10426
|
else {
|
|
10426
10427
|
const N = x(() => {
|
|
10427
|
-
const I = b[v], z = V(
|
|
10428
|
+
const I = b[v], z = V(Dt(I), I), _ = $(w(C[0], I), w(p[0], z)), T = p.map((E, R) => $(w(C[1][R], I), w(E, z)));
|
|
10428
10429
|
return { output: _, newStates: T };
|
|
10429
10430
|
});
|
|
10430
10431
|
h = N.output, p = N.newStates;
|
|
@@ -10450,7 +10451,7 @@ class se extends W {
|
|
|
10450
10451
|
getStates() {
|
|
10451
10452
|
if (this.states_ == null) {
|
|
10452
10453
|
const t = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
|
|
10453
|
-
return
|
|
10454
|
+
return It(0, t).map((e) => null);
|
|
10454
10455
|
} else
|
|
10455
10456
|
return this.states_;
|
|
10456
10457
|
}
|
|
@@ -10563,7 +10564,7 @@ class se extends W {
|
|
|
10563
10564
|
apply(t, e) {
|
|
10564
10565
|
let n = e == null ? null : e.initialState, i = e == null ? null : e.constants;
|
|
10565
10566
|
e == null && (e = {});
|
|
10566
|
-
const r =
|
|
10567
|
+
const r = pa(t, n, i, this.numConstants);
|
|
10567
10568
|
t = r.inputs, n = r.initialState, i = r.constants;
|
|
10568
10569
|
let a = [], o = [];
|
|
10569
10570
|
if (n != null) {
|
|
@@ -10590,7 +10591,7 @@ class se extends W {
|
|
|
10590
10591
|
if (r.length !== a)
|
|
10591
10592
|
throw new d(`RNN Layer has ${a} state(s) but was passed ${r.length} initial state(s).`);
|
|
10592
10593
|
this.unroll && console.warn("Ignoring unroll = true for RNN layer, due to imperative backend.");
|
|
10593
|
-
const o = { training: i }, u =
|
|
10594
|
+
const o = { training: i }, u = da((g, b) => {
|
|
10594
10595
|
const m = this.cell.call([g].concat(b), o);
|
|
10595
10596
|
return [m[0], m.slice(1)];
|
|
10596
10597
|
}, t, r, this.goBackwards, n, null, this.unroll, this.returnSequences), c = u[0], h = u[1], p = u[2];
|
|
@@ -10602,7 +10603,7 @@ class se extends W {
|
|
|
10602
10603
|
getInitialState(t) {
|
|
10603
10604
|
return x(() => {
|
|
10604
10605
|
let e = mt(t.shape);
|
|
10605
|
-
return e = B(e, [1, 2]), e =
|
|
10606
|
+
return e = B(e, [1, 2]), e = yn(e), Array.isArray(this.cell.stateSize) ? this.cell.stateSize.map((n) => n > 1 ? _s(e, [1, n]) : e) : this.cell.stateSize > 1 ? [_s(e, [1, this.cell.stateSize])] : [e];
|
|
10606
10607
|
});
|
|
10607
10608
|
}
|
|
10608
10609
|
get trainableWeights() {
|
|
@@ -10663,20 +10664,20 @@ class As extends In {
|
|
|
10663
10664
|
t = t[0];
|
|
10664
10665
|
const i = e.training == null ? !1 : e.training;
|
|
10665
10666
|
0 < this.dropout && this.dropout < 1 && this.dropoutMask == null && (this.dropoutMask = te({
|
|
10666
|
-
ones: () =>
|
|
10667
|
+
ones: () => Dt(t),
|
|
10667
10668
|
rate: this.dropout,
|
|
10668
10669
|
training: i,
|
|
10669
10670
|
dropoutFunc: this.dropoutFunc
|
|
10670
10671
|
})), 0 < this.recurrentDropout && this.recurrentDropout < 1 && this.recurrentDropoutMask == null && (this.recurrentDropoutMask = te({
|
|
10671
|
-
ones: () =>
|
|
10672
|
+
ones: () => Dt(n),
|
|
10672
10673
|
rate: this.recurrentDropout,
|
|
10673
10674
|
training: i,
|
|
10674
10675
|
dropoutFunc: this.dropoutFunc
|
|
10675
10676
|
}));
|
|
10676
10677
|
let r;
|
|
10677
10678
|
const a = this.dropoutMask, o = this.recurrentDropoutMask;
|
|
10678
|
-
a != null ? r =
|
|
10679
|
-
let l = $(r,
|
|
10679
|
+
a != null ? r = St(w(t, a), this.kernel.read()) : r = St(t, this.kernel.read()), this.bias != null && (r = zt(r, this.bias.read())), o != null && (n = w(n, o));
|
|
10680
|
+
let l = $(r, St(n, this.recurrentKernel.read()));
|
|
10680
10681
|
return this.activation != null && (l = this.activation.apply(l)), [l, l];
|
|
10681
10682
|
});
|
|
10682
10683
|
}
|
|
@@ -10703,7 +10704,7 @@ class As extends In {
|
|
|
10703
10704
|
}
|
|
10704
10705
|
As.className = "SimpleRNNCell";
|
|
10705
10706
|
S(As);
|
|
10706
|
-
class
|
|
10707
|
+
class fa extends se {
|
|
10707
10708
|
constructor(t) {
|
|
10708
10709
|
t.cell = new As(t), super(t);
|
|
10709
10710
|
}
|
|
@@ -10719,8 +10720,8 @@ class pa extends se {
|
|
|
10719
10720
|
return new t(e);
|
|
10720
10721
|
}
|
|
10721
10722
|
}
|
|
10722
|
-
|
|
10723
|
-
S(
|
|
10723
|
+
fa.className = "SimpleRNN";
|
|
10724
|
+
S(fa);
|
|
10724
10725
|
class Cs extends In {
|
|
10725
10726
|
constructor(t) {
|
|
10726
10727
|
if (super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", t.resetAfter)
|
|
@@ -10742,13 +10743,13 @@ class Cs extends In {
|
|
|
10742
10743
|
const n = e.training == null ? !1 : e.training;
|
|
10743
10744
|
let i = t[1];
|
|
10744
10745
|
t = t[0], 0 < this.dropout && this.dropout < 1 && this.dropoutMask == null && (this.dropoutMask = te({
|
|
10745
|
-
ones: () =>
|
|
10746
|
+
ones: () => Dt(t),
|
|
10746
10747
|
rate: this.dropout,
|
|
10747
10748
|
training: n,
|
|
10748
10749
|
count: 3,
|
|
10749
10750
|
dropoutFunc: this.dropoutFunc
|
|
10750
10751
|
})), 0 < this.recurrentDropout && this.recurrentDropout < 1 && this.recurrentDropoutMask == null && (this.recurrentDropoutMask = te({
|
|
10751
|
-
ones: () =>
|
|
10752
|
+
ones: () => Dt(i),
|
|
10752
10753
|
rate: this.recurrentDropout,
|
|
10753
10754
|
training: n,
|
|
10754
10755
|
count: 3,
|
|
@@ -10757,11 +10758,11 @@ class Cs extends In {
|
|
|
10757
10758
|
const r = this.dropoutMask, a = this.recurrentDropoutMask;
|
|
10758
10759
|
let o, l, u;
|
|
10759
10760
|
0 < this.dropout && this.dropout < 1 && (t = w(t, r[0]));
|
|
10760
|
-
let c =
|
|
10761
|
-
this.useBias && (c =
|
|
10762
|
-
const h = this.recurrentKernel.read(), [p, f] = Kt(h, [2 * this.units, this.units], h.rank - 1), g =
|
|
10761
|
+
let c = St(t, this.kernel.read());
|
|
10762
|
+
this.useBias && (c = zt(c, this.bias.read())), 0 < this.recurrentDropout && this.recurrentDropout < 1 && (i = w(i, a[0]));
|
|
10763
|
+
const h = this.recurrentKernel.read(), [p, f] = Kt(h, [2 * this.units, this.units], h.rank - 1), g = St(i, p), [b, m, v] = Kt(c, 3, c.rank - 1), [y, C] = Kt(g, 2, g.rank - 1);
|
|
10763
10764
|
o = this.recurrentActivation.apply($(b, y)), l = this.recurrentActivation.apply($(m, C));
|
|
10764
|
-
const N =
|
|
10765
|
+
const N = St(w(l, i), f);
|
|
10765
10766
|
u = this.activation.apply($(v, N));
|
|
10766
10767
|
const I = $(w(o, i), w($(1, pt(o)), u));
|
|
10767
10768
|
return [I, I];
|
|
@@ -10793,7 +10794,7 @@ class Cs extends In {
|
|
|
10793
10794
|
}
|
|
10794
10795
|
Cs.className = "GRUCell";
|
|
10795
10796
|
S(Cs);
|
|
10796
|
-
class
|
|
10797
|
+
class ma extends se {
|
|
10797
10798
|
constructor(t) {
|
|
10798
10799
|
t.implementation === 0 && console.warn("`implementation=0` has been deprecated, and now defaults to `implementation=1`. Please update your layer call."), t.cell = new Cs(t), super(t);
|
|
10799
10800
|
}
|
|
@@ -10809,8 +10810,8 @@ class da extends se {
|
|
|
10809
10810
|
return e.implmentation === 0 && (e.implementation = 1), new t(e);
|
|
10810
10811
|
}
|
|
10811
10812
|
}
|
|
10812
|
-
|
|
10813
|
-
S(
|
|
10813
|
+
ma.className = "GRU";
|
|
10814
|
+
S(ma);
|
|
10814
10815
|
class Dn extends In {
|
|
10815
10816
|
constructor(t) {
|
|
10816
10817
|
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout = Ne([1, qt([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Ne([
|
|
@@ -10849,13 +10850,13 @@ class Dn extends In {
|
|
|
10849
10850
|
let i = t[1];
|
|
10850
10851
|
const r = t[2];
|
|
10851
10852
|
t = t[0], 0 < this.dropout && this.dropout < 1 && this.dropoutMask == null && (this.dropoutMask = te({
|
|
10852
|
-
ones: () =>
|
|
10853
|
+
ones: () => Dt(t),
|
|
10853
10854
|
rate: this.dropout,
|
|
10854
10855
|
training: n,
|
|
10855
10856
|
count: 4,
|
|
10856
10857
|
dropoutFunc: this.dropoutFunc
|
|
10857
10858
|
})), 0 < this.recurrentDropout && this.recurrentDropout < 1 && this.recurrentDropoutMask == null && (this.recurrentDropoutMask = te({
|
|
10858
|
-
ones: () =>
|
|
10859
|
+
ones: () => Dt(i),
|
|
10859
10860
|
rate: this.recurrentDropout,
|
|
10860
10861
|
training: n,
|
|
10861
10862
|
count: 4,
|
|
@@ -10864,8 +10865,8 @@ class Dn extends In {
|
|
|
10864
10865
|
const a = this.dropoutMask, o = this.recurrentDropoutMask;
|
|
10865
10866
|
let l, u, c, h;
|
|
10866
10867
|
0 < this.dropout && this.dropout < 1 && (t = w(t, a[0]));
|
|
10867
|
-
let p =
|
|
10868
|
-
0 < this.recurrentDropout && this.recurrentDropout < 1 && (i = w(i, o[0])), p = $(p,
|
|
10868
|
+
let p = St(t, this.kernel.read());
|
|
10869
|
+
0 < this.recurrentDropout && this.recurrentDropout < 1 && (i = w(i, o[0])), p = $(p, St(i, this.recurrentKernel.read())), this.useBias && (p = zt(p, this.bias.read()));
|
|
10869
10870
|
const [f, g, b, m] = Kt(p, 4, p.rank - 1);
|
|
10870
10871
|
l = this.recurrentActivation.apply(f), u = this.recurrentActivation.apply(g), c = $(w(u, r), w(l, this.activation.apply(b))), h = this.recurrentActivation.apply(m);
|
|
10871
10872
|
const v = w(h, this.activation.apply(c));
|
|
@@ -10898,7 +10899,7 @@ class Dn extends In {
|
|
|
10898
10899
|
}
|
|
10899
10900
|
Dn.className = "LSTMCell";
|
|
10900
10901
|
S(Dn);
|
|
10901
|
-
class
|
|
10902
|
+
class ga extends se {
|
|
10902
10903
|
constructor(t) {
|
|
10903
10904
|
t.implementation === 0 && console.warn("`implementation=0` has been deprecated, and now defaults to `implementation=1`. Please update your layer call."), t.cell = new Dn(t), super(t);
|
|
10904
10905
|
}
|
|
@@ -10914,8 +10915,8 @@ class fa extends se {
|
|
|
10914
10915
|
return e.implmentation === 0 && (e.implementation = 1), new t(e);
|
|
10915
10916
|
}
|
|
10916
10917
|
}
|
|
10917
|
-
|
|
10918
|
-
S(
|
|
10918
|
+
ga.className = "LSTM";
|
|
10919
|
+
S(ga);
|
|
10919
10920
|
class Is extends In {
|
|
10920
10921
|
constructor(t) {
|
|
10921
10922
|
super(t), this.cells = t.cells;
|
|
@@ -11019,7 +11020,7 @@ class Is extends In {
|
|
|
11019
11020
|
Is.className = "StackedRNNCells";
|
|
11020
11021
|
S(Is);
|
|
11021
11022
|
function te(s) {
|
|
11022
|
-
const { ones: t, rate: e, training: n = !1, count: i = 1, dropoutFunc: r } = s, a = () => r != null ? r(t(), e) :
|
|
11023
|
+
const { ones: t, rate: e, training: n = !1, count: i = 1, dropoutFunc: r } = s, a = () => r != null ? r(t(), e) : Hi(t(), e), o = () => Ke(a, t, n);
|
|
11023
11024
|
return !i || i <= 1 ? Bt(o().clone()) : Array(i).fill(void 0).map(o).map((u) => Bt(u.clone()));
|
|
11024
11025
|
}
|
|
11025
11026
|
/**
|
|
@@ -11031,7 +11032,7 @@ function te(s) {
|
|
|
11031
11032
|
* https://opensource.org/licenses/MIT.
|
|
11032
11033
|
* =============================================================================
|
|
11033
11034
|
*/
|
|
11034
|
-
var
|
|
11035
|
+
var Jf = function(s, t) {
|
|
11035
11036
|
var e = {};
|
|
11036
11037
|
for (var n in s) Object.prototype.hasOwnProperty.call(s, n) && t.indexOf(n) < 0 && (e[n] = s[n]);
|
|
11037
11038
|
if (s != null && typeof Object.getOwnPropertySymbols == "function")
|
|
@@ -11039,7 +11040,7 @@ var Zf = function(s, t) {
|
|
|
11039
11040
|
t.indexOf(n[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, n[i]) && (e[n[i]] = s[n[i]]);
|
|
11040
11041
|
return e;
|
|
11041
11042
|
};
|
|
11042
|
-
class
|
|
11043
|
+
class ba extends se {
|
|
11043
11044
|
constructor(t) {
|
|
11044
11045
|
if (t.unroll)
|
|
11045
11046
|
throw new G("Unrolling is not possible with convolutional RNNs.");
|
|
@@ -11091,14 +11092,14 @@ class ma extends se {
|
|
|
11091
11092
|
});
|
|
11092
11093
|
}
|
|
11093
11094
|
computeSingleOutputShape(t) {
|
|
11094
|
-
const { dataFormat: e, filters: n, kernelSize: i, padding: r, strides: a, dilationRate: o } = this.cell, l = e === "channelsFirst", u = t[l ? 3 : 2], c = t[l ? 4 : 3], h =
|
|
11095
|
+
const { dataFormat: e, filters: n, kernelSize: i, padding: r, strides: a, dilationRate: o } = this.cell, l = e === "channelsFirst", u = t[l ? 3 : 2], c = t[l ? 4 : 3], h = At(u, i[0], r, a[0], o[0]), p = At(c, i[1], r, a[1], o[1]);
|
|
11095
11096
|
return [
|
|
11096
11097
|
...t.slice(0, 2),
|
|
11097
11098
|
...l ? [n, h, p] : [h, p, n]
|
|
11098
11099
|
];
|
|
11099
11100
|
}
|
|
11100
11101
|
}
|
|
11101
|
-
|
|
11102
|
+
ba.className = "ConvRNN2D";
|
|
11102
11103
|
class Ds extends Dn {
|
|
11103
11104
|
constructor(t) {
|
|
11104
11105
|
const { filters: e, kernelSize: n, strides: i, padding: r, dataFormat: a, dilationRate: o } = t;
|
|
@@ -11136,7 +11137,7 @@ class Ds extends Dn {
|
|
|
11136
11137
|
throw new d(`ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got ${t.length}.`);
|
|
11137
11138
|
const n = e.training || !1, i = t[0], r = t[1], a = t[2], o = 4;
|
|
11138
11139
|
0 < this.dropout && this.dropout < 1 && this.dropoutMask == null && (this.dropoutMask = te({
|
|
11139
|
-
ones: () =>
|
|
11140
|
+
ones: () => Dt(i),
|
|
11140
11141
|
rate: this.dropout,
|
|
11141
11142
|
training: n,
|
|
11142
11143
|
count: o,
|
|
@@ -11145,7 +11146,7 @@ class Ds extends Dn {
|
|
|
11145
11146
|
const l = this.dropoutMask, u = ($s, Tn, Es) => !Tn || !Tn[Es] ? $s : w(Tn[Es], $s);
|
|
11146
11147
|
let c = u(i, l, 0), h = u(i, l, 1), p = u(i, l, 2), f = u(i, l, 3);
|
|
11147
11148
|
0 < this.recurrentDropout && this.recurrentDropout < 1 && this.recurrentDropoutMask == null && (this.recurrentDropoutMask = te({
|
|
11148
|
-
ones: () =>
|
|
11149
|
+
ones: () => Dt(r),
|
|
11149
11150
|
rate: this.recurrentDropout,
|
|
11150
11151
|
training: n,
|
|
11151
11152
|
count: o,
|
|
@@ -11157,12 +11158,12 @@ class Ds extends Dn {
|
|
|
11157
11158
|
c = this.inputConv(c, N, T, this.padding), h = this.inputConv(h, I, E, this.padding), p = this.inputConv(p, z, R, this.padding), f = this.inputConv(f, _, q, this.padding);
|
|
11158
11159
|
const [bt, ie, re, xt] = Kt(this.recurrentKernel.read(), o, C);
|
|
11159
11160
|
b = this.recurrentConv(b, bt), m = this.recurrentConv(m, ie), v = this.recurrentConv(v, re), y = this.recurrentConv(y, xt);
|
|
11160
|
-
const
|
|
11161
|
+
const Tt = this.recurrentActivation.apply($(c, b)), me = this.recurrentActivation.apply($(h, m)), ze = $(w(me, a), w(Tt, this.activation.apply($(p, v)))), Ts = w(this.recurrentActivation.apply($(f, y)), this.activation.apply(ze));
|
|
11161
11162
|
return [Ts, Ts, ze];
|
|
11162
11163
|
});
|
|
11163
11164
|
}
|
|
11164
11165
|
getConfig() {
|
|
11165
|
-
const t = super.getConfig(), { units: e } = t, n =
|
|
11166
|
+
const t = super.getConfig(), { units: e } = t, n = Jf(t, ["units"]), i = {
|
|
11166
11167
|
filters: this.filters,
|
|
11167
11168
|
kernelSize: this.kernelSize,
|
|
11168
11169
|
padding: this.padding,
|
|
@@ -11174,7 +11175,7 @@ class Ds extends Dn {
|
|
|
11174
11175
|
}
|
|
11175
11176
|
inputConv(t, e, n, i) {
|
|
11176
11177
|
const r = Ce(t, e, this.strides, i || "valid", this.dataFormat === "channelsFirst" ? "NCHW" : "NHWC", this.dilationRate);
|
|
11177
|
-
return n ?
|
|
11178
|
+
return n ? zt(r, n, this.dataFormat) : r;
|
|
11178
11179
|
}
|
|
11179
11180
|
recurrentConv(t, e) {
|
|
11180
11181
|
return Ce(t, e, 1, "same", this.dataFormat === "channelsFirst" ? "NCHW" : "NHWC");
|
|
@@ -11182,7 +11183,7 @@ class Ds extends Dn {
|
|
|
11182
11183
|
}
|
|
11183
11184
|
Ds.className = "ConvLSTM2DCell";
|
|
11184
11185
|
S(Ds);
|
|
11185
|
-
class
|
|
11186
|
+
class ya extends ba {
|
|
11186
11187
|
constructor(t) {
|
|
11187
11188
|
const e = new Ds(t);
|
|
11188
11189
|
super(Object.assign(Object.assign({}, t), { cell: e }));
|
|
@@ -11192,8 +11193,8 @@ class ga extends ma {
|
|
|
11192
11193
|
return new t(e);
|
|
11193
11194
|
}
|
|
11194
11195
|
}
|
|
11195
|
-
|
|
11196
|
-
S(
|
|
11196
|
+
ya.className = "ConvLSTM2D";
|
|
11197
|
+
S(ya);
|
|
11197
11198
|
/**
|
|
11198
11199
|
* @license
|
|
11199
11200
|
* Copyright 2018 Google LLC
|
|
@@ -11221,7 +11222,7 @@ class zs extends W {
|
|
|
11221
11222
|
const n = O(t);
|
|
11222
11223
|
if (0 < this.rate && this.rate < 1) {
|
|
11223
11224
|
const i = e.training == null ? !1 : e.training, r = this.getNoiseShape(n);
|
|
11224
|
-
return Ke(() =>
|
|
11225
|
+
return Ke(() => Hi(n, this.rate, r, this.seed), () => n, i);
|
|
11225
11226
|
}
|
|
11226
11227
|
return t;
|
|
11227
11228
|
});
|
|
@@ -11240,7 +11241,7 @@ class zs extends W {
|
|
|
11240
11241
|
}
|
|
11241
11242
|
zs.className = "Dropout";
|
|
11242
11243
|
S(zs);
|
|
11243
|
-
class
|
|
11244
|
+
class wa extends zs {
|
|
11244
11245
|
constructor(t) {
|
|
11245
11246
|
super(t), this.inputSpec = [{ ndim: 3 }];
|
|
11246
11247
|
}
|
|
@@ -11249,9 +11250,9 @@ class ba extends zs {
|
|
|
11249
11250
|
return [e[0], 1, e[2]];
|
|
11250
11251
|
}
|
|
11251
11252
|
}
|
|
11252
|
-
|
|
11253
|
-
S(
|
|
11254
|
-
class
|
|
11253
|
+
wa.className = "SpatialDropout1D";
|
|
11254
|
+
S(wa);
|
|
11255
|
+
class ka extends W {
|
|
11255
11256
|
constructor(t) {
|
|
11256
11257
|
if (super(t), this.activation = null, this.useBias = !0, this.kernel = null, this.bias = null, this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_BIAS_INITIALIZER = "zeros", t.batchInputShape == null && t.inputShape == null && t.inputDim != null) {
|
|
11257
11258
|
let e = null;
|
|
@@ -11272,9 +11273,9 @@ class ya extends W {
|
|
|
11272
11273
|
call(t, e) {
|
|
11273
11274
|
return x(() => {
|
|
11274
11275
|
this.invokeCallHook(t, e);
|
|
11275
|
-
const n = O(t), i =
|
|
11276
|
+
const n = O(t), i = Ki(this.activation.getClassName());
|
|
11276
11277
|
let r;
|
|
11277
|
-
return i != null ? r =
|
|
11278
|
+
return i != null ? r = St(n, this.kernel.read(), i, this.bias ? this.bias.read() : null) : (r = St(n, this.kernel.read()), this.bias != null && (r = zt(r, this.bias.read())), this.activation != null && (r = this.activation.apply(r))), r;
|
|
11278
11279
|
});
|
|
11279
11280
|
}
|
|
11280
11281
|
getConfig() {
|
|
@@ -11293,9 +11294,9 @@ class ya extends W {
|
|
|
11293
11294
|
return Object.assign(t, e), t;
|
|
11294
11295
|
}
|
|
11295
11296
|
}
|
|
11296
|
-
|
|
11297
|
-
S(
|
|
11298
|
-
class
|
|
11297
|
+
ka.className = "Dense";
|
|
11298
|
+
S(ka);
|
|
11299
|
+
class xa extends W {
|
|
11299
11300
|
constructor(t) {
|
|
11300
11301
|
t = t || {}, super(t), this.inputSpec = [{ minNDim: 3 }], this.dataFormat = t.dataFormat;
|
|
11301
11302
|
}
|
|
@@ -11316,7 +11317,7 @@ class wa extends W {
|
|
|
11316
11317
|
i.push(r);
|
|
11317
11318
|
i.push(1), n = j(n, i);
|
|
11318
11319
|
}
|
|
11319
|
-
return
|
|
11320
|
+
return _u(n);
|
|
11320
11321
|
});
|
|
11321
11322
|
}
|
|
11322
11323
|
getConfig() {
|
|
@@ -11326,9 +11327,9 @@ class wa extends W {
|
|
|
11326
11327
|
return Object.assign(t, e), t;
|
|
11327
11328
|
}
|
|
11328
11329
|
}
|
|
11329
|
-
|
|
11330
|
-
S(
|
|
11331
|
-
class
|
|
11330
|
+
xa.className = "Flatten";
|
|
11331
|
+
S(xa);
|
|
11332
|
+
class Na extends W {
|
|
11332
11333
|
constructor(t) {
|
|
11333
11334
|
super(t), this.supportsMasking = !0, this.activation = Qt(t.activation);
|
|
11334
11335
|
}
|
|
@@ -11344,9 +11345,9 @@ class ka extends W {
|
|
|
11344
11345
|
return Object.assign(t, e), t;
|
|
11345
11346
|
}
|
|
11346
11347
|
}
|
|
11347
|
-
|
|
11348
|
-
S(
|
|
11349
|
-
class
|
|
11348
|
+
Na.className = "Activation";
|
|
11349
|
+
S(Na);
|
|
11350
|
+
class va extends W {
|
|
11350
11351
|
constructor(t) {
|
|
11351
11352
|
super(t), this.n = t.n, this.inputSpec = [{ ndim: 2 }];
|
|
11352
11353
|
}
|
|
@@ -11354,7 +11355,7 @@ class xa extends W {
|
|
|
11354
11355
|
return [t[0], this.n, t[1]];
|
|
11355
11356
|
}
|
|
11356
11357
|
call(t, e) {
|
|
11357
|
-
return x(() => (t = O(t),
|
|
11358
|
+
return x(() => (t = O(t), Bu(t, this.n)));
|
|
11358
11359
|
}
|
|
11359
11360
|
getConfig() {
|
|
11360
11361
|
const t = {
|
|
@@ -11363,9 +11364,9 @@ class xa extends W {
|
|
|
11363
11364
|
return Object.assign(t, e), t;
|
|
11364
11365
|
}
|
|
11365
11366
|
}
|
|
11366
|
-
|
|
11367
|
-
S(
|
|
11368
|
-
class
|
|
11367
|
+
va.className = "RepeatVector";
|
|
11368
|
+
S(va);
|
|
11369
|
+
class Sa extends W {
|
|
11369
11370
|
constructor(t) {
|
|
11370
11371
|
super(t), this.targetShape = t.targetShape;
|
|
11371
11372
|
for (let e = 0; e < this.targetShape.length; ++e)
|
|
@@ -11433,15 +11434,15 @@ class Na extends W {
|
|
|
11433
11434
|
return Object.assign(t, e), t;
|
|
11434
11435
|
}
|
|
11435
11436
|
}
|
|
11436
|
-
|
|
11437
|
-
S(
|
|
11438
|
-
class
|
|
11437
|
+
Sa.className = "Reshape";
|
|
11438
|
+
S(Sa);
|
|
11439
|
+
class Aa extends W {
|
|
11439
11440
|
constructor(t) {
|
|
11440
11441
|
if (super(t), t.dims == null)
|
|
11441
11442
|
throw new Error("Required configuration field `dims` is missing during Permute constructor call.");
|
|
11442
11443
|
if (!Array.isArray(t.dims))
|
|
11443
11444
|
throw new Error(`Permute constructor requires \`dims\` to be an Array, but received ${t.dims} instead.`);
|
|
11444
|
-
const e =
|
|
11445
|
+
const e = It(1, t.dims.length + 1);
|
|
11445
11446
|
if (!Ft(t.dims.slice().sort(), e))
|
|
11446
11447
|
throw new Error("Invalid permutation `dims`: " + JSON.stringify(t.dims) + " `dims` must contain consecutive integers starting from 1.");
|
|
11447
11448
|
this.dims = t.dims, this.dimsIncludingBatch = [0].concat(this.dims), this.inputSpec = [new st({ ndim: this.dims.length + 1 })];
|
|
@@ -11463,9 +11464,9 @@ class va extends W {
|
|
|
11463
11464
|
return Object.assign(t, e), t;
|
|
11464
11465
|
}
|
|
11465
11466
|
}
|
|
11466
|
-
|
|
11467
|
-
S(
|
|
11468
|
-
class
|
|
11467
|
+
Aa.className = "Permute";
|
|
11468
|
+
S(Aa);
|
|
11469
|
+
class Ca extends W {
|
|
11469
11470
|
constructor(t) {
|
|
11470
11471
|
super(t ?? {}), this.supportsMasking = !0, t != null ? this.maskValue = t.maskValue == null ? 0 : t.maskValue : this.maskValue = 0;
|
|
11471
11472
|
}
|
|
@@ -11488,8 +11489,8 @@ class Sa extends W {
|
|
|
11488
11489
|
});
|
|
11489
11490
|
}
|
|
11490
11491
|
}
|
|
11491
|
-
|
|
11492
|
-
S(
|
|
11492
|
+
Ca.className = "Masking";
|
|
11493
|
+
S(Ca);
|
|
11493
11494
|
/**
|
|
11494
11495
|
* @license
|
|
11495
11496
|
* Copyright 2018 Google LLC
|
|
@@ -11499,7 +11500,7 @@ S(Sa);
|
|
|
11499
11500
|
* https://opensource.org/licenses/MIT.
|
|
11500
11501
|
* =============================================================================
|
|
11501
11502
|
*/
|
|
11502
|
-
class
|
|
11503
|
+
class Ia extends W {
|
|
11503
11504
|
constructor(t) {
|
|
11504
11505
|
if (super(t), this.embeddings = null, this.DEFAULT_EMBEDDINGS_INITIALIZER = "randomUniform", t.batchInputShape == null && t.inputShape == null) {
|
|
11505
11506
|
let e = null;
|
|
@@ -11538,8 +11539,8 @@ class Aa extends W {
|
|
|
11538
11539
|
return x(() => {
|
|
11539
11540
|
this.invokeCallHook(t, e);
|
|
11540
11541
|
let n = O(t);
|
|
11541
|
-
n.dtype !== "int32" && (n =
|
|
11542
|
-
const i =
|
|
11542
|
+
n.dtype !== "int32" && (n = Lt(n, "int32"));
|
|
11543
|
+
const i = Vi(this.embeddings.read(), A(n, [n.size]));
|
|
11543
11544
|
return A(i, U(this.computeOutputShape(n.shape)));
|
|
11544
11545
|
});
|
|
11545
11546
|
}
|
|
@@ -11557,8 +11558,8 @@ class Aa extends W {
|
|
|
11557
11558
|
return Object.assign(t, e), t;
|
|
11558
11559
|
}
|
|
11559
11560
|
}
|
|
11560
|
-
|
|
11561
|
-
S(
|
|
11561
|
+
Ia.className = "Embedding";
|
|
11562
|
+
S(Ia);
|
|
11562
11563
|
/**
|
|
11563
11564
|
* @license
|
|
11564
11565
|
* Copyright 2018 Google LLC
|
|
@@ -11638,7 +11639,7 @@ class fe extends W {
|
|
|
11638
11639
|
for (let a of t) {
|
|
11639
11640
|
const o = a.rank;
|
|
11640
11641
|
for (let l = 0; l < r - o; ++l)
|
|
11641
|
-
a =
|
|
11642
|
+
a = yn(a, 1);
|
|
11642
11643
|
n.push(a);
|
|
11643
11644
|
}
|
|
11644
11645
|
return this.mergeFunction(n);
|
|
@@ -11651,7 +11652,7 @@ class fe extends W {
|
|
|
11651
11652
|
let f = A(l, [h].concat(ye(c.slice(1))));
|
|
11652
11653
|
f = j(f, [1, 0]), f = A(f, p), n.push(f), r = !0;
|
|
11653
11654
|
} else if (u > 1) {
|
|
11654
|
-
const c =
|
|
11655
|
+
const c = It(1, u).concat([0]);
|
|
11655
11656
|
n.push(j(l, c)), r = !0;
|
|
11656
11657
|
} else
|
|
11657
11658
|
n.push(l);
|
|
@@ -11663,7 +11664,7 @@ class fe extends W {
|
|
|
11663
11664
|
const l = a.shape, u = l.length, c = l[u - 1], h = [c].concat(l.slice(0, l.length - 1));
|
|
11664
11665
|
a = A(j(A(a, [-1, c]), [1, 0]), h);
|
|
11665
11666
|
} else if (o > 1) {
|
|
11666
|
-
const l = [o - 1].concat(
|
|
11667
|
+
const l = [o - 1].concat(It(0, o - 1));
|
|
11667
11668
|
a = j(a, l);
|
|
11668
11669
|
}
|
|
11669
11670
|
}
|
|
@@ -11706,7 +11707,7 @@ class fe extends W {
|
|
|
11706
11707
|
});
|
|
11707
11708
|
}
|
|
11708
11709
|
}
|
|
11709
|
-
class
|
|
11710
|
+
class Da extends fe {
|
|
11710
11711
|
constructor(t) {
|
|
11711
11712
|
super(t);
|
|
11712
11713
|
}
|
|
@@ -11719,9 +11720,9 @@ class Ca extends fe {
|
|
|
11719
11720
|
});
|
|
11720
11721
|
}
|
|
11721
11722
|
}
|
|
11722
|
-
|
|
11723
|
-
S(
|
|
11724
|
-
class
|
|
11723
|
+
Da.className = "Add";
|
|
11724
|
+
S(Da);
|
|
11725
|
+
class za extends fe {
|
|
11725
11726
|
constructor(t) {
|
|
11726
11727
|
super(t);
|
|
11727
11728
|
}
|
|
@@ -11734,9 +11735,9 @@ class Ia extends fe {
|
|
|
11734
11735
|
});
|
|
11735
11736
|
}
|
|
11736
11737
|
}
|
|
11737
|
-
|
|
11738
|
-
S(
|
|
11739
|
-
class
|
|
11738
|
+
za.className = "Multiply";
|
|
11739
|
+
S(za);
|
|
11740
|
+
class Ta extends fe {
|
|
11740
11741
|
constructor(t) {
|
|
11741
11742
|
super(t);
|
|
11742
11743
|
}
|
|
@@ -11749,9 +11750,9 @@ class Da extends fe {
|
|
|
11749
11750
|
});
|
|
11750
11751
|
}
|
|
11751
11752
|
}
|
|
11752
|
-
|
|
11753
|
-
S(
|
|
11754
|
-
class
|
|
11753
|
+
Ta.className = "Average";
|
|
11754
|
+
S(Ta);
|
|
11755
|
+
class $a extends fe {
|
|
11755
11756
|
constructor(t) {
|
|
11756
11757
|
super(t);
|
|
11757
11758
|
}
|
|
@@ -11764,9 +11765,9 @@ class za extends fe {
|
|
|
11764
11765
|
});
|
|
11765
11766
|
}
|
|
11766
11767
|
}
|
|
11767
|
-
|
|
11768
|
-
S(
|
|
11769
|
-
class
|
|
11768
|
+
$a.className = "Maximum";
|
|
11769
|
+
S($a);
|
|
11770
|
+
class Ea extends fe {
|
|
11770
11771
|
constructor(t) {
|
|
11771
11772
|
super(t);
|
|
11772
11773
|
}
|
|
@@ -11774,14 +11775,14 @@ class Ta extends fe {
|
|
|
11774
11775
|
return x(() => {
|
|
11775
11776
|
let e = t[0];
|
|
11776
11777
|
for (let n = 1; n < t.length; ++n)
|
|
11777
|
-
e =
|
|
11778
|
+
e = ji(e, t[n]);
|
|
11778
11779
|
return e;
|
|
11779
11780
|
});
|
|
11780
11781
|
}
|
|
11781
11782
|
}
|
|
11782
|
-
|
|
11783
|
-
S(
|
|
11784
|
-
class
|
|
11783
|
+
Ea.className = "Minimum";
|
|
11784
|
+
S(Ea);
|
|
11785
|
+
class La extends fe {
|
|
11785
11786
|
constructor(t) {
|
|
11786
11787
|
super(t), this.DEFAULT_AXIS = -1, t == null && (t = {}), this.axis = t.axis == null ? this.DEFAULT_AXIS : t.axis, this.supportsMasking = !0, this.reshapeRequired = !1;
|
|
11787
11788
|
}
|
|
@@ -11848,7 +11849,7 @@ class $a extends fe {
|
|
|
11848
11849
|
return null;
|
|
11849
11850
|
const i = [];
|
|
11850
11851
|
for (let a = 0; a < t.length; ++a)
|
|
11851
|
-
e[a] == null ? i.push(L(
|
|
11852
|
+
e[a] == null ? i.push(L(Dt(t[a]), "bool")) : e[a].rank < t[a].rank ? i.push(ue(e[a], -1)) : i.push(e[a]);
|
|
11852
11853
|
const r = is(i, this.axis);
|
|
11853
11854
|
return Ju(r, -1, !1);
|
|
11854
11855
|
});
|
|
@@ -11860,14 +11861,14 @@ class $a extends fe {
|
|
|
11860
11861
|
return Object.assign(t, e), t;
|
|
11861
11862
|
}
|
|
11862
11863
|
}
|
|
11863
|
-
|
|
11864
|
-
S(
|
|
11864
|
+
La.className = "Concatenate";
|
|
11865
|
+
S(La);
|
|
11865
11866
|
function $e(s, t) {
|
|
11866
11867
|
for (; s < 0; )
|
|
11867
11868
|
s += t;
|
|
11868
11869
|
return s;
|
|
11869
11870
|
}
|
|
11870
|
-
function
|
|
11871
|
+
function Xf(s, t, e) {
|
|
11871
11872
|
if (s.shape.length > 3 || t.shape.length > 3)
|
|
11872
11873
|
throw new G("batchDot is not implemented for tensors of 4D or higher rank yet");
|
|
11873
11874
|
if (k(s.shape.length >= 2, () => `batchDot requires the rank of x to be >= 2, but got ${s.shape.length}`), k(s.shape.length >= 2, () => `batchDot requires the rank of y to be >= 2, but got ${t.shape.length}`), typeof e == "number" && (e = [e, e]), s.dtype === "complex64" || t.dtype === "complex64")
|
|
@@ -11909,7 +11910,7 @@ function Jf(s, t, e) {
|
|
|
11909
11910
|
return o.shape.length === 1 && (o = ue(o, 1)), o;
|
|
11910
11911
|
});
|
|
11911
11912
|
}
|
|
11912
|
-
class
|
|
11913
|
+
class Fa extends fe {
|
|
11913
11914
|
constructor(t) {
|
|
11914
11915
|
super(t), this.axes = t.axes, this.normalize = t.normalize == null ? !1 : t.normalize, this.supportsMasking = !0, this.reshapeRequired = !1;
|
|
11915
11916
|
}
|
|
@@ -11929,7 +11930,7 @@ class Ea extends fe {
|
|
|
11929
11930
|
return Array.isArray(this.axes) ? i = this.axes.map((r, a) => $e(r, t[a].shape.length)) : i = [
|
|
11930
11931
|
$e(this.axes, e.shape.length),
|
|
11931
11932
|
$e(this.axes, n.shape.length)
|
|
11932
|
-
], this.normalize && (e = pn(e, i[0]), n = pn(n, i[1])),
|
|
11933
|
+
], this.normalize && (e = pn(e, i[0]), n = pn(n, i[1])), Xf(e, n, i);
|
|
11933
11934
|
}
|
|
11934
11935
|
interpretAxes(t, e) {
|
|
11935
11936
|
let n;
|
|
@@ -11959,8 +11960,8 @@ class Ea extends fe {
|
|
|
11959
11960
|
return Object.assign(t, e), t;
|
|
11960
11961
|
}
|
|
11961
11962
|
}
|
|
11962
|
-
|
|
11963
|
-
S(
|
|
11963
|
+
Fa.className = "Dot";
|
|
11964
|
+
S(Fa);
|
|
11964
11965
|
/**
|
|
11965
11966
|
* @license
|
|
11966
11967
|
* Copyright 2018 Google LLC
|
|
@@ -11970,7 +11971,7 @@ S(Ea);
|
|
|
11970
11971
|
* https://opensource.org/licenses/MIT.
|
|
11971
11972
|
* =============================================================================
|
|
11972
11973
|
*/
|
|
11973
|
-
class
|
|
11974
|
+
class Ma extends W {
|
|
11974
11975
|
constructor(t) {
|
|
11975
11976
|
super(t), this.supportsMasking = !0, this.stddev = t.stddev;
|
|
11976
11977
|
}
|
|
@@ -11985,13 +11986,13 @@ class La extends W {
|
|
|
11985
11986
|
return x(() => {
|
|
11986
11987
|
this.invokeCallHook(t, e);
|
|
11987
11988
|
const n = O(t);
|
|
11988
|
-
return Ke(() => $(
|
|
11989
|
+
return Ke(() => $(bn(n.shape, 0, this.stddev), n), () => n, e.training || !1);
|
|
11989
11990
|
});
|
|
11990
11991
|
}
|
|
11991
11992
|
}
|
|
11992
|
-
|
|
11993
|
-
S(
|
|
11994
|
-
class
|
|
11993
|
+
Ma.className = "GaussianNoise";
|
|
11994
|
+
S(Ma);
|
|
11995
|
+
class Oa extends W {
|
|
11995
11996
|
constructor(t) {
|
|
11996
11997
|
super(t), this.supportsMasking = !0, this.rate = t.rate;
|
|
11997
11998
|
}
|
|
@@ -12008,14 +12009,14 @@ class Fa extends W {
|
|
|
12008
12009
|
const n = O(t);
|
|
12009
12010
|
return this.rate > 0 && this.rate < 1 ? Ke(() => {
|
|
12010
12011
|
const r = Math.sqrt(this.rate / (1 - this.rate));
|
|
12011
|
-
return w(n,
|
|
12012
|
+
return w(n, bn(n.shape, 1, r));
|
|
12012
12013
|
}, () => n, e.training || !1) : n;
|
|
12013
12014
|
});
|
|
12014
12015
|
}
|
|
12015
12016
|
}
|
|
12016
|
-
|
|
12017
|
-
S(
|
|
12018
|
-
class
|
|
12017
|
+
Oa.className = "GaussianDropout";
|
|
12018
|
+
S(Oa);
|
|
12019
|
+
class Ra extends W {
|
|
12019
12020
|
constructor(t) {
|
|
12020
12021
|
super(t), this.supportsMasking = !0, this.rate = t.rate, this.noiseShape = t.noiseShape;
|
|
12021
12022
|
}
|
|
@@ -12035,8 +12036,8 @@ class Ma extends W {
|
|
|
12035
12036
|
const n = this._getNoiseShape(t);
|
|
12036
12037
|
return Ke(() => {
|
|
12037
12038
|
const r = O(t), o = -1.6732632423543772 * 1.0507009873554805;
|
|
12038
|
-
let l = Ue(
|
|
12039
|
-
l =
|
|
12039
|
+
let l = Ue(wn(n), this.rate);
|
|
12040
|
+
l = Lt(l, "float32");
|
|
12040
12041
|
const u = ((1 - this.rate) * (1 + this.rate * o ** 2)) ** -0.5, c = -u * o * this.rate, h = $(w(r, l), w($(l, -1), o));
|
|
12041
12042
|
return $(w(h, u), c);
|
|
12042
12043
|
}, () => O(t), e.training || !1);
|
|
@@ -12045,8 +12046,8 @@ class Ma extends W {
|
|
|
12045
12046
|
});
|
|
12046
12047
|
}
|
|
12047
12048
|
}
|
|
12048
|
-
|
|
12049
|
-
S(
|
|
12049
|
+
Ra.className = "AlphaDropout";
|
|
12050
|
+
S(Ra);
|
|
12050
12051
|
/**
|
|
12051
12052
|
* @license
|
|
12052
12053
|
* Copyright 2018 Google LLC
|
|
@@ -12068,25 +12069,25 @@ function _e(s, t, e, n, i, r = 1e-3) {
|
|
|
12068
12069
|
throw new G(`batchNormalization is not implemented for array of rank ${s.rank} yet`);
|
|
12069
12070
|
return a;
|
|
12070
12071
|
}
|
|
12071
|
-
function
|
|
12072
|
+
function Yf(s, t, e, n, i = 1e-3) {
|
|
12072
12073
|
return x(() => {
|
|
12073
12074
|
const r = rs(s, n), a = r.mean, o = r.variance;
|
|
12074
12075
|
return [_e(s, a, o, e, t, i), a, o];
|
|
12075
12076
|
});
|
|
12076
12077
|
}
|
|
12077
|
-
function
|
|
12078
|
+
function Qf(s, t, e, n, i = 1e-3) {
|
|
12078
12079
|
return x(() => {
|
|
12079
12080
|
const r = rs(s, n), a = r.mean, o = r.variance, l = [];
|
|
12080
|
-
for (const g of
|
|
12081
|
+
for (const g of It(0, s.rank))
|
|
12081
12082
|
n.indexOf(g) !== -1 ? l.push(1) : l.push(s.shape[g]);
|
|
12082
12083
|
const u = A(a, l), c = A(o, l), h = t == null ? null : A(t, l), p = e == null ? null : A(e, l);
|
|
12083
12084
|
return [_e(s, u, c, p, h, i), a, o];
|
|
12084
12085
|
});
|
|
12085
12086
|
}
|
|
12086
|
-
function
|
|
12087
|
-
return Ft(n.slice().sort(),
|
|
12087
|
+
function tm(s, t, e, n, i = 1e-3) {
|
|
12088
|
+
return Ft(n.slice().sort(), It(0, s.rank - 1)) ? Yf(s, t, e, n, i) : Qf(s, t, e, n, i);
|
|
12088
12089
|
}
|
|
12089
|
-
class
|
|
12090
|
+
class _a extends W {
|
|
12090
12091
|
constructor(t) {
|
|
12091
12092
|
t == null && (t = {}), super(t), this.supportsMasking = !0, this.axis = t.axis == null ? -1 : t.axis, this.momentum = t.momentum == null ? 0.99 : t.momentum, this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = J(t.betaInitializer || "zeros"), this.gammaInitializer = J(t.gammaInitializer || "ones"), this.movingMeanInitializer = J(t.movingMeanInitializer || "zeros"), this.movingVarianceInitializer = J(t.movingVarianceInitializer || "ones"), this.betaConstraint = rt(t.betaConstraint), this.gammaConstraint = rt(t.gammaConstraint), this.betaRegularizer = X(t.betaRegularizer), this.gammaRegularizer = X(t.gammaRegularizer);
|
|
12092
12093
|
}
|
|
@@ -12101,13 +12102,13 @@ class Oa extends W {
|
|
|
12101
12102
|
}
|
|
12102
12103
|
call(t, e) {
|
|
12103
12104
|
return x(() => {
|
|
12104
|
-
const n = e.training == null ? !1 : e.training, i = O(t), r = i.shape, a = r.length, o =
|
|
12105
|
+
const n = e.training == null ? !1 : e.training, i = O(t), r = i.shape, a = r.length, o = It(0, a), l = this.axis >= 0 ? this.axis : this.axis + a;
|
|
12105
12106
|
o.splice(l, 1);
|
|
12106
12107
|
const u = ce(1, a);
|
|
12107
12108
|
u[l] = r[l];
|
|
12108
12109
|
const c = o.slice();
|
|
12109
12110
|
c.sort();
|
|
12110
|
-
const h = !Ft(c,
|
|
12111
|
+
const h = !Ft(c, It(0, a).slice(0, a - 1)), p = () => {
|
|
12111
12112
|
if (h) {
|
|
12112
12113
|
const y = A(this.movingMean.read(), u), C = A(this.movingVariance.read(), u), N = this.center ? A(this.beta.read(), u) : null, I = this.scale ? A(this.gamma.read(), u) : null;
|
|
12113
12114
|
return _e(i, y, C, N, I, this.epsilon);
|
|
@@ -12116,7 +12117,7 @@ class Oa extends W {
|
|
|
12116
12117
|
};
|
|
12117
12118
|
if (!n)
|
|
12118
12119
|
return p();
|
|
12119
|
-
const [f, g, b] =
|
|
12120
|
+
const [f, g, b] = tm(i, this.gamma.read(), this.beta.read(), o, this.epsilon), m = (y, C, N) => {
|
|
12120
12121
|
x(() => {
|
|
12121
12122
|
const I = 1 - N, z = y.read(), _ = w(V(z, C), I);
|
|
12122
12123
|
y.write(V(z, _));
|
|
@@ -12146,9 +12147,9 @@ class Oa extends W {
|
|
|
12146
12147
|
return Object.assign(t, e), t;
|
|
12147
12148
|
}
|
|
12148
12149
|
}
|
|
12149
|
-
|
|
12150
|
-
S(
|
|
12151
|
-
class
|
|
12150
|
+
_a.className = "BatchNormalization";
|
|
12151
|
+
S(_a);
|
|
12152
|
+
class Ba extends W {
|
|
12152
12153
|
constructor(t) {
|
|
12153
12154
|
if (t == null && (t = {}), super(t), this.axis = t.axis == null ? -1 : t.axis, typeof this.axis == "number") {
|
|
12154
12155
|
if (!Number.isInteger(this.axis))
|
|
@@ -12204,8 +12205,8 @@ class Ra extends W {
|
|
|
12204
12205
|
return Object.assign(t, e), t;
|
|
12205
12206
|
}
|
|
12206
12207
|
}
|
|
12207
|
-
|
|
12208
|
-
S(
|
|
12208
|
+
Ba.className = "LayerNormalization";
|
|
12209
|
+
S(Ba);
|
|
12209
12210
|
/**
|
|
12210
12211
|
* @license
|
|
12211
12212
|
* Copyright 2018 Google LLC
|
|
@@ -12215,7 +12216,7 @@ S(Ra);
|
|
|
12215
12216
|
* https://opensource.org/licenses/MIT.
|
|
12216
12217
|
* =============================================================================
|
|
12217
12218
|
*/
|
|
12218
|
-
function
|
|
12219
|
+
function em(s, t, e) {
|
|
12219
12220
|
return x(() => {
|
|
12220
12221
|
if (s.rank !== 4)
|
|
12221
12222
|
throw new d(`temporalPadding expects input tensor to be 4-D, but received a ${s.rank}-D tensor.`);
|
|
@@ -12224,10 +12225,10 @@ function tm(s, t, e) {
|
|
|
12224
12225
|
if (e == null && (e = ne()), e !== "channelsLast" && e !== "channelsFirst")
|
|
12225
12226
|
throw new d(`Unknown data format: ${e}. Supported data formats are 'channelsLast' and 'channelsFirst.`);
|
|
12226
12227
|
let n;
|
|
12227
|
-
return e === "channelsFirst" ? n = [[0, 0], [0, 0], t[0], t[1]] : n = [[0, 0], t[0], t[1], [0, 0]],
|
|
12228
|
+
return e === "channelsFirst" ? n = [[0, 0], [0, 0], t[0], t[1]] : n = [[0, 0], t[0], t[1], [0, 0]], sr(s, n);
|
|
12228
12229
|
});
|
|
12229
12230
|
}
|
|
12230
|
-
class
|
|
12231
|
+
class Wa extends W {
|
|
12231
12232
|
constructor(t) {
|
|
12232
12233
|
if (t == null && (t = {}), super(t), this.dataFormat = t.dataFormat == null ? ne() : t.dataFormat, t.padding == null)
|
|
12233
12234
|
this.padding = [[1, 1], [1, 1]];
|
|
@@ -12256,7 +12257,7 @@ class _a extends W {
|
|
|
12256
12257
|
return this.dataFormat === "channelsFirst" ? (t[2] != null && t[2] >= 0 ? e = t[2] + this.padding[0][0] + this.padding[0][1] : e = null, t[3] != null && t[3] >= 0 ? n = t[3] + this.padding[1][0] + this.padding[1][1] : n = null, [t[0], t[1], e, n]) : (t[1] != null && t[1] >= 0 ? e = t[1] + this.padding[0][0] + this.padding[0][1] : e = null, t[2] != null && t[2] >= 0 ? n = t[2] + this.padding[1][0] + this.padding[1][1] : n = null, [t[0], e, n, t[3]]);
|
|
12257
12258
|
}
|
|
12258
12259
|
call(t, e) {
|
|
12259
|
-
return x(() =>
|
|
12260
|
+
return x(() => em(O(t), this.padding, this.dataFormat));
|
|
12260
12261
|
}
|
|
12261
12262
|
getConfig() {
|
|
12262
12263
|
const t = {
|
|
@@ -12266,8 +12267,8 @@ class _a extends W {
|
|
|
12266
12267
|
return Object.assign(t, e), t;
|
|
12267
12268
|
}
|
|
12268
12269
|
}
|
|
12269
|
-
|
|
12270
|
-
S(
|
|
12270
|
+
Wa.className = "ZeroPadding2D";
|
|
12271
|
+
S(Wa);
|
|
12271
12272
|
/**
|
|
12272
12273
|
* @license
|
|
12273
12274
|
* Copyright 2018 Google LLC
|
|
@@ -12279,7 +12280,7 @@ S(_a);
|
|
|
12279
12280
|
*/
|
|
12280
12281
|
function zn(s, t, e, n, i, r) {
|
|
12281
12282
|
return x(() => {
|
|
12282
|
-
et(i),
|
|
12283
|
+
et(i), qi(r), gt(n), e == null && (e = [1, 1]), n == null && (n = "valid"), i == null && (i = ne()), r == null && (r = "max"), s = Ss(s, i);
|
|
12283
12284
|
let a;
|
|
12284
12285
|
const o = n === "same" ? "same" : "valid";
|
|
12285
12286
|
return r === "max" ? a = Hc(s, t, e, o) : a = ic(
|
|
@@ -12291,15 +12292,15 @@ function zn(s, t, e, n, i, r) {
|
|
|
12291
12292
|
), i === "channelsFirst" && (a = j(a, [0, 3, 1, 2])), a;
|
|
12292
12293
|
});
|
|
12293
12294
|
}
|
|
12294
|
-
function
|
|
12295
|
+
function Ga(s, t, e, n, i, r) {
|
|
12295
12296
|
return x(() => {
|
|
12296
|
-
et(i),
|
|
12297
|
+
et(i), qi(r), gt(n), e == null && (e = [1, 1, 1]), n == null && (n = "valid"), i == null && (i = ne()), r == null && (r = "max"), s = ia(s, i);
|
|
12297
12298
|
let a;
|
|
12298
12299
|
const o = n === "same" ? "same" : "valid";
|
|
12299
12300
|
return r === "max" ? a = Zc(s, t, e, o) : a = ac(s, t, e, o), i === "channelsFirst" && (a = j(a, [0, 4, 1, 2, 3])), a;
|
|
12300
12301
|
});
|
|
12301
12302
|
}
|
|
12302
|
-
class
|
|
12303
|
+
class Pa extends W {
|
|
12303
12304
|
/**
|
|
12304
12305
|
*
|
|
12305
12306
|
* @param args Parameters for the Pooling layer.
|
|
@@ -12325,12 +12326,12 @@ class Wa extends W {
|
|
|
12325
12326
|
}
|
|
12326
12327
|
computeOutputShape(t) {
|
|
12327
12328
|
t = U(t);
|
|
12328
|
-
const e =
|
|
12329
|
+
const e = At(t[1], this.poolSize[0], this.padding, this.strides[0]);
|
|
12329
12330
|
return [t[0], e, t[2]];
|
|
12330
12331
|
}
|
|
12331
12332
|
call(t, e) {
|
|
12332
12333
|
return x(() => {
|
|
12333
|
-
this.invokeCallHook(t, e), t =
|
|
12334
|
+
this.invokeCallHook(t, e), t = yn(O(t), 2);
|
|
12334
12335
|
const n = this.poolingFunction(O(t), [this.poolSize[0], 1], [this.strides[0], 1], this.padding, "channelsLast");
|
|
12335
12336
|
return ts(n, [2]);
|
|
12336
12337
|
});
|
|
@@ -12344,7 +12345,7 @@ class Wa extends W {
|
|
|
12344
12345
|
return Object.assign(t, e), t;
|
|
12345
12346
|
}
|
|
12346
12347
|
}
|
|
12347
|
-
class
|
|
12348
|
+
class Ua extends Pa {
|
|
12348
12349
|
constructor(t) {
|
|
12349
12350
|
super(t);
|
|
12350
12351
|
}
|
|
@@ -12352,9 +12353,9 @@ class Ga extends Wa {
|
|
|
12352
12353
|
return et(r), gt(i), zn(t, e, n, i, r, "max");
|
|
12353
12354
|
}
|
|
12354
12355
|
}
|
|
12355
|
-
|
|
12356
|
-
S(
|
|
12357
|
-
class
|
|
12356
|
+
Ua.className = "MaxPooling1D";
|
|
12357
|
+
S(Ua);
|
|
12358
|
+
class Va extends Pa {
|
|
12358
12359
|
constructor(t) {
|
|
12359
12360
|
super(t);
|
|
12360
12361
|
}
|
|
@@ -12362,9 +12363,9 @@ class Pa extends Wa {
|
|
|
12362
12363
|
return et(r), gt(i), zn(t, e, n, i, r, "avg");
|
|
12363
12364
|
}
|
|
12364
12365
|
}
|
|
12365
|
-
|
|
12366
|
-
S(
|
|
12367
|
-
class
|
|
12366
|
+
Va.className = "AveragePooling1D";
|
|
12367
|
+
S(Va);
|
|
12368
|
+
class ja extends W {
|
|
12368
12369
|
constructor(t) {
|
|
12369
12370
|
if (t.poolSize == null && (t.poolSize = [2, 2]), super(t), this.poolSize = Array.isArray(t.poolSize) ? t.poolSize : [t.poolSize, t.poolSize], t.strides == null)
|
|
12370
12371
|
this.strides = this.poolSize;
|
|
@@ -12379,7 +12380,7 @@ class Ua extends W {
|
|
|
12379
12380
|
computeOutputShape(t) {
|
|
12380
12381
|
t = U(t);
|
|
12381
12382
|
let e = this.dataFormat === "channelsFirst" ? t[2] : t[1], n = this.dataFormat === "channelsFirst" ? t[3] : t[2];
|
|
12382
|
-
return e =
|
|
12383
|
+
return e = At(e, this.poolSize[0], this.padding, this.strides[0]), n = At(n, this.poolSize[1], this.padding, this.strides[1]), this.dataFormat === "channelsFirst" ? [t[0], t[1], e, n] : [t[0], e, n, t[3]];
|
|
12383
12384
|
}
|
|
12384
12385
|
call(t, e) {
|
|
12385
12386
|
return x(() => (this.invokeCallHook(t, e), this.poolingFunction(O(t), this.poolSize, this.strides, this.padding, this.dataFormat)));
|
|
@@ -12394,7 +12395,7 @@ class Ua extends W {
|
|
|
12394
12395
|
return Object.assign(t, e), t;
|
|
12395
12396
|
}
|
|
12396
12397
|
}
|
|
12397
|
-
class
|
|
12398
|
+
class Ka extends ja {
|
|
12398
12399
|
constructor(t) {
|
|
12399
12400
|
super(t);
|
|
12400
12401
|
}
|
|
@@ -12402,9 +12403,9 @@ class Va extends Ua {
|
|
|
12402
12403
|
return et(r), gt(i), zn(t, e, n, i, r, "max");
|
|
12403
12404
|
}
|
|
12404
12405
|
}
|
|
12405
|
-
|
|
12406
|
-
S(
|
|
12407
|
-
class
|
|
12406
|
+
Ka.className = "MaxPooling2D";
|
|
12407
|
+
S(Ka);
|
|
12408
|
+
class Ha extends ja {
|
|
12408
12409
|
constructor(t) {
|
|
12409
12410
|
super(t);
|
|
12410
12411
|
}
|
|
@@ -12412,9 +12413,9 @@ class ja extends Ua {
|
|
|
12412
12413
|
return et(r), gt(i), zn(t, e, n, i, r, "avg");
|
|
12413
12414
|
}
|
|
12414
12415
|
}
|
|
12415
|
-
|
|
12416
|
-
S(
|
|
12417
|
-
class
|
|
12416
|
+
Ha.className = "AveragePooling2D";
|
|
12417
|
+
S(Ha);
|
|
12418
|
+
class qa extends W {
|
|
12418
12419
|
constructor(t) {
|
|
12419
12420
|
if (t.poolSize == null && (t.poolSize = [2, 2, 2]), super(t), this.poolSize = Array.isArray(t.poolSize) ? t.poolSize : [t.poolSize, t.poolSize, t.poolSize], t.strides == null)
|
|
12420
12421
|
this.strides = this.poolSize;
|
|
@@ -12429,7 +12430,7 @@ class Ka extends W {
|
|
|
12429
12430
|
computeOutputShape(t) {
|
|
12430
12431
|
t = U(t);
|
|
12431
12432
|
let e = this.dataFormat === "channelsFirst" ? t[2] : t[1], n = this.dataFormat === "channelsFirst" ? t[3] : t[2], i = this.dataFormat === "channelsFirst" ? t[4] : t[3];
|
|
12432
|
-
return e =
|
|
12433
|
+
return e = At(e, this.poolSize[0], this.padding, this.strides[0]), n = At(n, this.poolSize[1], this.padding, this.strides[1]), i = At(i, this.poolSize[2], this.padding, this.strides[2]), this.dataFormat === "channelsFirst" ? [t[0], t[1], e, n, i] : [t[0], e, n, i, t[4]];
|
|
12433
12434
|
}
|
|
12434
12435
|
call(t, e) {
|
|
12435
12436
|
return x(() => (this.invokeCallHook(t, e), this.poolingFunction(O(t), this.poolSize, this.strides, this.padding, this.dataFormat)));
|
|
@@ -12444,27 +12445,27 @@ class Ka extends W {
|
|
|
12444
12445
|
return Object.assign(t, e), t;
|
|
12445
12446
|
}
|
|
12446
12447
|
}
|
|
12447
|
-
class
|
|
12448
|
+
class Za extends qa {
|
|
12448
12449
|
constructor(t) {
|
|
12449
12450
|
super(t);
|
|
12450
12451
|
}
|
|
12451
12452
|
poolingFunction(t, e, n, i, r) {
|
|
12452
|
-
return et(r), gt(i),
|
|
12453
|
+
return et(r), gt(i), Ga(t, e, n, i, r, "max");
|
|
12453
12454
|
}
|
|
12454
12455
|
}
|
|
12455
|
-
|
|
12456
|
-
S(
|
|
12457
|
-
class
|
|
12456
|
+
Za.className = "MaxPooling3D";
|
|
12457
|
+
S(Za);
|
|
12458
|
+
class Ja extends qa {
|
|
12458
12459
|
constructor(t) {
|
|
12459
12460
|
super(t);
|
|
12460
12461
|
}
|
|
12461
12462
|
poolingFunction(t, e, n, i, r) {
|
|
12462
|
-
return et(r), gt(i),
|
|
12463
|
+
return et(r), gt(i), Ga(t, e, n, i, r, "avg");
|
|
12463
12464
|
}
|
|
12464
12465
|
}
|
|
12465
|
-
|
|
12466
|
-
S(
|
|
12467
|
-
class
|
|
12466
|
+
Ja.className = "AveragePooling3D";
|
|
12467
|
+
S(Ja);
|
|
12468
|
+
class Xa extends W {
|
|
12468
12469
|
constructor(t) {
|
|
12469
12470
|
super(t), this.inputSpec = [new st({ ndim: 3 })];
|
|
12470
12471
|
}
|
|
@@ -12475,7 +12476,7 @@ class Za extends W {
|
|
|
12475
12476
|
throw new G();
|
|
12476
12477
|
}
|
|
12477
12478
|
}
|
|
12478
|
-
class
|
|
12479
|
+
class Ya extends Xa {
|
|
12479
12480
|
constructor(t) {
|
|
12480
12481
|
super(t || {});
|
|
12481
12482
|
}
|
|
@@ -12486,9 +12487,9 @@ class Ja extends Za {
|
|
|
12486
12487
|
});
|
|
12487
12488
|
}
|
|
12488
12489
|
}
|
|
12489
|
-
|
|
12490
|
-
S(
|
|
12491
|
-
class
|
|
12490
|
+
Ya.className = "GlobalAveragePooling1D";
|
|
12491
|
+
S(Ya);
|
|
12492
|
+
class Qa extends Xa {
|
|
12492
12493
|
constructor(t) {
|
|
12493
12494
|
super(t || {});
|
|
12494
12495
|
}
|
|
@@ -12499,9 +12500,9 @@ class Xa extends Za {
|
|
|
12499
12500
|
});
|
|
12500
12501
|
}
|
|
12501
12502
|
}
|
|
12502
|
-
|
|
12503
|
-
S(
|
|
12504
|
-
class
|
|
12503
|
+
Qa.className = "GlobalMaxPooling1D";
|
|
12504
|
+
S(Qa);
|
|
12505
|
+
class to extends W {
|
|
12505
12506
|
constructor(t) {
|
|
12506
12507
|
super(t), this.dataFormat = t.dataFormat == null ? "channelsLast" : t.dataFormat, et(this.dataFormat), this.inputSpec = [new st({ ndim: 4 })];
|
|
12507
12508
|
}
|
|
@@ -12516,7 +12517,7 @@ class Ya extends W {
|
|
|
12516
12517
|
return Object.assign(t, e), t;
|
|
12517
12518
|
}
|
|
12518
12519
|
}
|
|
12519
|
-
class
|
|
12520
|
+
class eo extends to {
|
|
12520
12521
|
call(t, e) {
|
|
12521
12522
|
return x(() => {
|
|
12522
12523
|
const n = O(t);
|
|
@@ -12524,9 +12525,9 @@ class Qa extends Ya {
|
|
|
12524
12525
|
});
|
|
12525
12526
|
}
|
|
12526
12527
|
}
|
|
12527
|
-
|
|
12528
|
-
S(
|
|
12529
|
-
class
|
|
12528
|
+
eo.className = "GlobalAveragePooling2D";
|
|
12529
|
+
S(eo);
|
|
12530
|
+
class no extends to {
|
|
12530
12531
|
call(t, e) {
|
|
12531
12532
|
return x(() => {
|
|
12532
12533
|
const n = O(t);
|
|
@@ -12534,8 +12535,8 @@ class to extends Ya {
|
|
|
12534
12535
|
});
|
|
12535
12536
|
}
|
|
12536
12537
|
}
|
|
12537
|
-
|
|
12538
|
-
S(
|
|
12538
|
+
no.className = "GlobalMaxPooling2D";
|
|
12539
|
+
S(no);
|
|
12539
12540
|
/**
|
|
12540
12541
|
* @license
|
|
12541
12542
|
* Copyright 2018 Google LLC
|
|
@@ -12545,7 +12546,7 @@ S(to);
|
|
|
12545
12546
|
* https://opensource.org/licenses/MIT.
|
|
12546
12547
|
* =============================================================================
|
|
12547
12548
|
*/
|
|
12548
|
-
class
|
|
12549
|
+
class so extends W {
|
|
12549
12550
|
constructor(t) {
|
|
12550
12551
|
super(t), this.layer = t.layer;
|
|
12551
12552
|
}
|
|
@@ -12601,7 +12602,7 @@ class eo extends W {
|
|
|
12601
12602
|
return Object.assign(a, e), new t(a);
|
|
12602
12603
|
}
|
|
12603
12604
|
}
|
|
12604
|
-
class
|
|
12605
|
+
class io extends so {
|
|
12605
12606
|
constructor(t) {
|
|
12606
12607
|
super(t), this.supportsMasking = !0;
|
|
12607
12608
|
}
|
|
@@ -12618,7 +12619,7 @@ class no extends eo {
|
|
|
12618
12619
|
return [n[0], i].concat(n.slice(1));
|
|
12619
12620
|
}
|
|
12620
12621
|
call(t, e) {
|
|
12621
|
-
return x(() => (t = O(t),
|
|
12622
|
+
return x(() => (t = O(t), da(
|
|
12622
12623
|
(a, o) => [O(this.layer.call(a, e)), []],
|
|
12623
12624
|
t,
|
|
12624
12625
|
[],
|
|
@@ -12631,19 +12632,19 @@ class no extends eo {
|
|
|
12631
12632
|
)[1]));
|
|
12632
12633
|
}
|
|
12633
12634
|
}
|
|
12634
|
-
|
|
12635
|
-
S(
|
|
12636
|
-
function
|
|
12637
|
-
Zn(
|
|
12635
|
+
io.className = "TimeDistributed";
|
|
12636
|
+
S(io);
|
|
12637
|
+
function nm(s) {
|
|
12638
|
+
Zn(Wu, "BidirectionalMergeMode", s);
|
|
12638
12639
|
}
|
|
12639
|
-
const
|
|
12640
|
-
class
|
|
12640
|
+
const sm = "concat";
|
|
12641
|
+
class ro extends so {
|
|
12641
12642
|
constructor(t) {
|
|
12642
12643
|
super(t);
|
|
12643
12644
|
const e = t.layer.getConfig(), n = {};
|
|
12644
12645
|
n.className = t.layer.getClassName(), n.config = e, this.forwardLayer = Wt(n), e.goBackwards = e.goBackwards !== !0;
|
|
12645
12646
|
const i = {};
|
|
12646
|
-
if (i.className = t.layer.getClassName(), i.config = e, this.backwardLayer = Wt(i), this.forwardLayer.name = "forward_" + this.forwardLayer.name, this.backwardLayer.name = "backward_" + this.backwardLayer.name, this.mergeMode = t.mergeMode === void 0 ?
|
|
12647
|
+
if (i.className = t.layer.getClassName(), i.config = e, this.backwardLayer = Wt(i), this.forwardLayer.name = "forward_" + this.forwardLayer.name, this.backwardLayer.name = "backward_" + this.backwardLayer.name, this.mergeMode = t.mergeMode === void 0 ? sm : t.mergeMode, nm(this.mergeMode), t.weights)
|
|
12647
12648
|
throw new G("weights support is not implemented for Bidirectional layer yet.");
|
|
12648
12649
|
this._stateful = t.layer.stateful, this.returnSequences = t.layer.returnSequences, this.returnState = t.layer.returnState, this.supportsMasking = !0, this._trainable = !0, this.inputSpec = t.layer.inputSpec, this.numConstants = null;
|
|
12649
12650
|
}
|
|
@@ -12669,7 +12670,7 @@ class so extends eo {
|
|
|
12669
12670
|
apply(t, e) {
|
|
12670
12671
|
let n = e == null ? null : e.initialState, i = e == null ? null : e.constants;
|
|
12671
12672
|
e == null && (e = {});
|
|
12672
|
-
const r =
|
|
12673
|
+
const r = pa(t, n, i, this.numConstants);
|
|
12673
12674
|
if (t = r.inputs, n = r.initialState, i = r.constants, Array.isArray(t) && (n = t.slice(1), t = t[0]), (n == null || n.length === 0) && i == null)
|
|
12674
12675
|
return super.apply(t, e);
|
|
12675
12676
|
const a = [], o = [];
|
|
@@ -12755,8 +12756,8 @@ class so extends eo {
|
|
|
12755
12756
|
return i.layer = n, new t(i);
|
|
12756
12757
|
}
|
|
12757
12758
|
}
|
|
12758
|
-
|
|
12759
|
-
S(
|
|
12759
|
+
ro.className = "Bidirectional";
|
|
12760
|
+
S(ro);
|
|
12760
12761
|
/**
|
|
12761
12762
|
* @license
|
|
12762
12763
|
* Copyright 2022 CodeSmith LLC
|
|
@@ -12766,7 +12767,7 @@ S(so);
|
|
|
12766
12767
|
* https://opensource.org/licenses/MIT.
|
|
12767
12768
|
* =============================================================================
|
|
12768
12769
|
*/
|
|
12769
|
-
class
|
|
12770
|
+
class ao extends W {
|
|
12770
12771
|
constructor(t) {
|
|
12771
12772
|
super(t), this.scale = t.scale, t.offset ? this.offset = t.offset : this.offset = 0;
|
|
12772
12773
|
}
|
|
@@ -12778,11 +12779,11 @@ class io extends W {
|
|
|
12778
12779
|
return Object.assign(t, e), t;
|
|
12779
12780
|
}
|
|
12780
12781
|
call(t, e) {
|
|
12781
|
-
return x(() => (t = O(t), t.dtype !== "float32" && (t =
|
|
12782
|
+
return x(() => (t = O(t), t.dtype !== "float32" && (t = Lt(t, "float32")), $(w(t, this.scale), this.offset)));
|
|
12782
12783
|
}
|
|
12783
12784
|
}
|
|
12784
|
-
|
|
12785
|
-
S(
|
|
12785
|
+
ao.className = "Rescaling";
|
|
12786
|
+
S(ao);
|
|
12786
12787
|
/**
|
|
12787
12788
|
* @license
|
|
12788
12789
|
* Copyright 2022 CodeSmith LLC
|
|
@@ -12792,8 +12793,8 @@ S(io);
|
|
|
12792
12793
|
* https://opensource.org/licenses/MIT.
|
|
12793
12794
|
* =============================================================================
|
|
12794
12795
|
*/
|
|
12795
|
-
const { resizeBilinear:
|
|
12796
|
-
class
|
|
12796
|
+
const { resizeBilinear: im, cropAndResize: rm } = _t;
|
|
12797
|
+
class oo extends W {
|
|
12797
12798
|
constructor(t) {
|
|
12798
12799
|
super(t), this.height = t.height, this.width = t.width;
|
|
12799
12800
|
}
|
|
@@ -12804,14 +12805,14 @@ class ro extends W {
|
|
|
12804
12805
|
t.rank === 3 ? (c = !0, u = kn([t])) : u = t;
|
|
12805
12806
|
for (let I = 0; I < u.shape[0]; I++)
|
|
12806
12807
|
m.push(b);
|
|
12807
|
-
const v =
|
|
12808
|
-
return c ?
|
|
12808
|
+
const v = Ku(m, [m.length, 4]), y = Hu(0, m.length, 1, "int32"), N = rm(u, v, y, [i, r], "nearest");
|
|
12809
|
+
return c ? Lt(O(en(N)), l) : Lt(N, l);
|
|
12809
12810
|
});
|
|
12810
12811
|
}
|
|
12811
12812
|
upsize(t, e, n, i) {
|
|
12812
12813
|
return x(() => {
|
|
12813
|
-
const r =
|
|
12814
|
-
return
|
|
12814
|
+
const r = im(t, [e, n]);
|
|
12815
|
+
return Lt(r, i);
|
|
12815
12816
|
});
|
|
12816
12817
|
}
|
|
12817
12818
|
call(t, e) {
|
|
@@ -12836,8 +12837,8 @@ class ro extends W {
|
|
|
12836
12837
|
return t[e] = this.height, t[n] = this.width, t;
|
|
12837
12838
|
}
|
|
12838
12839
|
}
|
|
12839
|
-
|
|
12840
|
-
S(
|
|
12840
|
+
oo.className = "CenterCrop";
|
|
12841
|
+
S(oo);
|
|
12841
12842
|
/**
|
|
12842
12843
|
* @license
|
|
12843
12844
|
* Copyright 2022 CodeSmith LLC
|
|
@@ -12847,9 +12848,9 @@ S(ro);
|
|
|
12847
12848
|
* https://opensource.org/licenses/MIT.
|
|
12848
12849
|
* =============================================================================
|
|
12849
12850
|
*/
|
|
12850
|
-
function
|
|
12851
|
+
function am(s, t, e, n) {
|
|
12851
12852
|
let i = O(s);
|
|
12852
|
-
if (i.dtype !== "int32" && (i =
|
|
12853
|
+
if (i.dtype !== "int32" && (i = Lt(i, "int32")), t === "int")
|
|
12853
12854
|
return i;
|
|
12854
12855
|
const r = i.shape;
|
|
12855
12856
|
if (i.rank === 0 && (i = ue(i, -1)), t === "oneHot" && i.shape[i.shape.length - 1] !== 1 && (i = ue(i, -1)), i.rank > 2)
|
|
@@ -12871,7 +12872,7 @@ function rm(s, t, e, n) {
|
|
|
12871
12872
|
* https://opensource.org/licenses/MIT.
|
|
12872
12873
|
* =============================================================================
|
|
12873
12874
|
*/
|
|
12874
|
-
class
|
|
12875
|
+
class lo extends W {
|
|
12875
12876
|
constructor(t) {
|
|
12876
12877
|
super(t), this.numTokens = t.numTokens, t.outputMode ? this.outputMode = t.outputMode : this.outputMode = "multiHot";
|
|
12877
12878
|
}
|
|
@@ -12887,7 +12888,7 @@ class ao extends W {
|
|
|
12887
12888
|
}
|
|
12888
12889
|
call(t, e) {
|
|
12889
12890
|
return x(() => {
|
|
12890
|
-
t = O(t), t.dtype !== "int32" && (t =
|
|
12891
|
+
t = O(t), t.dtype !== "int32" && (t = Lt(t, "int32"));
|
|
12891
12892
|
let n;
|
|
12892
12893
|
if (typeof e.countWeights < "u") {
|
|
12893
12894
|
if (this.outputMode !== "count")
|
|
@@ -12895,15 +12896,15 @@ class ao extends W {
|
|
|
12895
12896
|
Received countWeights=${e.countWeights}`);
|
|
12896
12897
|
n = O(e.countWeights);
|
|
12897
12898
|
}
|
|
12898
|
-
const i = ve(t), r =
|
|
12899
|
+
const i = ve(t), r = qu(t), a = Gt(this.numTokens, i).bufferSync().get(0), o = Ue(r, 0).bufferSync().get(0);
|
|
12899
12900
|
if (!(a && o))
|
|
12900
12901
|
throw new d(`Input values must be between 0 < values <= numTokens with numTokens=${this.numTokens}`);
|
|
12901
|
-
return
|
|
12902
|
+
return am(t, this.outputMode, this.numTokens, n);
|
|
12902
12903
|
});
|
|
12903
12904
|
}
|
|
12904
12905
|
}
|
|
12905
|
-
|
|
12906
|
-
S(
|
|
12906
|
+
lo.className = "CategoryEncoding";
|
|
12907
|
+
S(lo);
|
|
12907
12908
|
/**
|
|
12908
12909
|
* @license
|
|
12909
12910
|
* Copyright 2022 CodeSmith LLC
|
|
@@ -12913,8 +12914,8 @@ S(ao);
|
|
|
12913
12914
|
* https://opensource.org/licenses/MIT.
|
|
12914
12915
|
* =============================================================================
|
|
12915
12916
|
*/
|
|
12916
|
-
const
|
|
12917
|
-
class
|
|
12917
|
+
const om = ["bilinear", "nearest"], di = new Set(om);
|
|
12918
|
+
class uo extends W {
|
|
12918
12919
|
constructor(t) {
|
|
12919
12920
|
if (super(t), this.height = t.height, this.width = t.width, t.interpolation)
|
|
12920
12921
|
if (di.has(t.interpolation))
|
|
@@ -12950,8 +12951,8 @@ class oo extends W {
|
|
|
12950
12951
|
});
|
|
12951
12952
|
}
|
|
12952
12953
|
}
|
|
12953
|
-
|
|
12954
|
-
S(
|
|
12954
|
+
uo.className = "Resizing";
|
|
12955
|
+
S(uo);
|
|
12955
12956
|
/**
|
|
12956
12957
|
* @license
|
|
12957
12958
|
* Copyright 2023 CodeSmith LLC
|
|
@@ -12961,7 +12962,7 @@ S(oo);
|
|
|
12961
12962
|
* https://opensource.org/licenses/MIT.
|
|
12962
12963
|
* =============================================================================
|
|
12963
12964
|
*/
|
|
12964
|
-
class
|
|
12965
|
+
class co {
|
|
12965
12966
|
constructor(t) {
|
|
12966
12967
|
this.seed = t;
|
|
12967
12968
|
}
|
|
@@ -12970,7 +12971,7 @@ class lo {
|
|
|
12970
12971
|
return this.seed++;
|
|
12971
12972
|
}
|
|
12972
12973
|
}
|
|
12973
|
-
|
|
12974
|
+
co.className = "RandomSeed";
|
|
12974
12975
|
/**
|
|
12975
12976
|
* @license
|
|
12976
12977
|
* Copyright 2023 CodeSmith LLC
|
|
@@ -12980,9 +12981,9 @@ lo.className = "RandomSeed";
|
|
|
12980
12981
|
* https://opensource.org/licenses/MIT.
|
|
12981
12982
|
* =============================================================================
|
|
12982
12983
|
*/
|
|
12983
|
-
class
|
|
12984
|
+
class ho extends W {
|
|
12984
12985
|
constructor(t) {
|
|
12985
|
-
super(t), this.randomGenerator = new
|
|
12986
|
+
super(t), this.randomGenerator = new co(t.seed);
|
|
12986
12987
|
}
|
|
12987
12988
|
getConfig() {
|
|
12988
12989
|
const t = {
|
|
@@ -12991,7 +12992,7 @@ class uo extends W {
|
|
|
12991
12992
|
return Object.assign(t, e), t;
|
|
12992
12993
|
}
|
|
12993
12994
|
}
|
|
12994
|
-
|
|
12995
|
+
ho.className = "BaseRandomLayer";
|
|
12995
12996
|
/**
|
|
12996
12997
|
* @license
|
|
12997
12998
|
* Copyright 2023 CodeSmith LLC
|
|
@@ -13001,8 +13002,8 @@ uo.className = "BaseRandomLayer";
|
|
|
13001
13002
|
* https://opensource.org/licenses/MIT.
|
|
13002
13003
|
* =============================================================================
|
|
13003
13004
|
*/
|
|
13004
|
-
const
|
|
13005
|
-
class
|
|
13005
|
+
const lm = ["bilinear", "nearest"], fi = new Set(lm);
|
|
13006
|
+
class po extends ho {
|
|
13006
13007
|
constructor(t) {
|
|
13007
13008
|
super(t);
|
|
13008
13009
|
const { factor: e, interpolation: n = "bilinear" } = t;
|
|
@@ -13042,7 +13043,7 @@ class co extends uo {
|
|
|
13042
13043
|
const n = O(t);
|
|
13043
13044
|
this.imgHeight = n.shape[n.shape.length - 3];
|
|
13044
13045
|
const i = n.shape[n.shape.length - 2];
|
|
13045
|
-
this.widthFactor =
|
|
13046
|
+
this.widthFactor = wn([1], 1 + this.widthLower, 1 + this.widthUpper, "float32", this.randomGenerator.next());
|
|
13046
13047
|
let r = this.widthFactor.dataSync()[0] * i;
|
|
13047
13048
|
r = Math.round(r);
|
|
13048
13049
|
const a = [this.imgHeight, r];
|
|
@@ -13058,12 +13059,52 @@ class co extends uo {
|
|
|
13058
13059
|
});
|
|
13059
13060
|
}
|
|
13060
13061
|
}
|
|
13061
|
-
|
|
13062
|
-
S(
|
|
13062
|
+
po.className = "RandomWidth";
|
|
13063
|
+
S(po);
|
|
13064
|
+
class Om {
|
|
13065
|
+
vocabSize;
|
|
13066
|
+
embedDim;
|
|
13067
|
+
tiedWeights;
|
|
13068
|
+
initializer;
|
|
13069
|
+
constructor(t, e) {
|
|
13070
|
+
this.vocabSize = t.vocabSize, this.embedDim = t.embedDim, this.initializer = Yd({
|
|
13071
|
+
mean: 0,
|
|
13072
|
+
stddev: 0.02
|
|
13073
|
+
}), this.tiedWeights = Ji(
|
|
13074
|
+
this.initializer.apply([this.vocabSize, this.embedDim]),
|
|
13075
|
+
!0,
|
|
13076
|
+
e || "tied_embedding"
|
|
13077
|
+
);
|
|
13078
|
+
}
|
|
13079
|
+
get variables() {
|
|
13080
|
+
return [this.tiedWeights];
|
|
13081
|
+
}
|
|
13082
|
+
embed(t) {
|
|
13083
|
+
return Qi(this.tiedWeights, t, 0);
|
|
13084
|
+
}
|
|
13085
|
+
project(t) {
|
|
13086
|
+
return St(t, this.tiedWeights.transpose());
|
|
13087
|
+
}
|
|
13088
|
+
getWeights() {
|
|
13089
|
+
return [this.tiedWeights];
|
|
13090
|
+
}
|
|
13091
|
+
setWeights(t) {
|
|
13092
|
+
this.tiedWeights.assign(t[0]);
|
|
13093
|
+
}
|
|
13094
|
+
getConfig() {
|
|
13095
|
+
return {
|
|
13096
|
+
vocabSize: this.vocabSize,
|
|
13097
|
+
embedDim: this.embedDim
|
|
13098
|
+
};
|
|
13099
|
+
}
|
|
13100
|
+
dispose() {
|
|
13101
|
+
this.tiedWeights.dispose();
|
|
13102
|
+
}
|
|
13103
|
+
}
|
|
13063
13104
|
export {
|
|
13064
13105
|
zs as D,
|
|
13065
|
-
|
|
13066
|
-
|
|
13067
|
-
|
|
13068
|
-
|
|
13106
|
+
Ia as E,
|
|
13107
|
+
Om as T,
|
|
13108
|
+
sr as p,
|
|
13109
|
+
Yd as r
|
|
13069
13110
|
};
|