@genai-fi/nanogpt 0.4.4 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseLayer-BhrMN8JO.js +135 -0
- package/dist/Generator.js +44 -41
- package/dist/NanoGPTModel.d.ts +12 -16
- package/dist/NanoGPTModel.js +128 -138
- package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
- package/dist/TeachableLLM.js +8 -5
- package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
- package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
- package/dist/broadcast_to-CMlkG8NS.js +44 -0
- package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
- package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
- package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
- package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
- package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
- package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
- package/dist/layers/BaseLayer.d.ts +28 -4
- package/dist/layers/BaseLayer.js +3 -16
- package/dist/layers/CausalSelfAttention.d.ts +22 -24
- package/dist/layers/CausalSelfAttention.js +73 -127
- package/dist/layers/MLP.d.ts +8 -15
- package/dist/layers/MLP.js +43 -81
- package/dist/layers/RMSNorm.d.ts +5 -11
- package/dist/layers/RMSNorm.js +13 -29
- package/dist/layers/RoPECache.js +14 -12
- package/dist/layers/TiedEmbedding.d.ts +6 -16
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.d.ts +12 -16
- package/dist/layers/TransformerBlock.js +20 -41
- package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
- package/dist/main.js +22 -19
- package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
- package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
- package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
- package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
- package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
- package/dist/ops/appendCache.js +4 -4
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +14 -15
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.d.ts +1 -0
- package/dist/ops/cpu/matMulMul.js +17 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.d.ts +1 -0
- package/dist/ops/cpu/normRMS.js +39 -0
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +8 -8
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +13 -9
- package/dist/ops/grads/fusedSoftmax.js +12 -9
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.d.ts +2 -0
- package/dist/ops/grads/normRMS.js +20 -0
- package/dist/ops/grads/qkv.js +19 -9
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.d.ts +2 -0
- package/dist/ops/matMulMul.js +9 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.d.ts +2 -0
- package/dist/ops/normRMS.js +10 -0
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +13 -12
- package/dist/ops/webgl/fusedSoftmax.js +43 -40
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.d.ts +3 -2
- package/dist/ops/webgl/matMulGelu.js +77 -75
- package/dist/ops/webgl/matMulMul.d.ts +14 -0
- package/dist/ops/webgl/matMulMul.js +28 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.d.ts +1 -0
- package/dist/ops/webgl/normRMS.js +86 -0
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops-ObfXLHYQ.js +1269 -0
- package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
- package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
- package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
- package/dist/slice_util-D-kaD4ZV.js +49 -0
- package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
- package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
- package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
- package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
- package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
- package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
- package/dist/tfjs_backend-NucKez4s.js +1010 -0
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +44 -44
- package/dist/training/Evaluator.js +6 -6
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +7 -7
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +10 -10
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/save.js +10 -8
- package/dist/utilities/weights.js +2 -2
- package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
- package/package.json +1 -1
- package/dist/slice_util-BdhYwFY_.js +0 -90
- package/dist/tfjs_backend-DuKis_xG.js +0 -2271
- package/dist/variable-BJTZ3jOy.js +0 -23
|
@@ -1,26 +1,27 @@
|
|
|
1
|
-
import { o as F,
|
|
2
|
-
import { s as ku, a as xu, g as
|
|
3
|
-
import { M as
|
|
4
|
-
import {
|
|
5
|
-
import {
|
|
6
|
-
import {
|
|
7
|
-
import {
|
|
8
|
-
import {
|
|
9
|
-
import {
|
|
10
|
-
import {
|
|
11
|
-
import { s as
|
|
12
|
-
import {
|
|
13
|
-
import {
|
|
14
|
-
import {
|
|
15
|
-
import {
|
|
16
|
-
import {
|
|
17
|
-
import {
|
|
18
|
-
import {
|
|
19
|
-
import {
|
|
20
|
-
import {
|
|
21
|
-
import {
|
|
22
|
-
import {
|
|
23
|
-
import {
|
|
1
|
+
import { o as F, i as D, E as M, bb as fo, bc as mo, bd as mi, k, b9 as Mn, be as gi, y as L, bf as bi, bg as yi, bh as wi, bi as ki, bj as xi, bk as Ni, bl as vi, bm as go, bn as Si, bo, bp as Ai, bq as yo, br as Ci, q as Hn, N as wt, bs as wo, bt as Ii, bu as Di, bv as zi, c as On, s as V, b as w, bw as ko, bx as Ti, by as $i, bz as xo, bA as Ei, bB as Li, bC as Fi, bD as Mi, bE as Oi, bF as Ri, bG as _i, bH as Bi, l as No, ah as vo, bI as Wi, bJ as So, a9 as $, bK as Ls, bL as Ao, bM as Co, bN as Io, bO as Do, bP as zo, A as To, bQ as $o, bR as Eo, bS as Lo, bT as S, t as x, f as tt, p as Gi, bU as Be, bV as We, ae as Ft, a as Z, K as Fo, bW as Mo, bX as Oo, a2 as ct, a5 as ee, ai as P, bY as Ro, bZ as _o, ax as lt, b_ as Bo, z as Q, b$ as Wo, c0 as Go, c1 as Po, c2 as Uo, c3 as Vo, c4 as jo, c5 as Ko, c6 as Ho, B as qo, c7 as Zo, c8 as Jo, c9 as Xo, aq as Yo, ca as Qo, C as tl, a1 as he, cb as el, _ as nl, cc as sl, cd as il, ce as rl, ar as al, cf as ol, a7 as ll, aE as ul, cg as cl, ag as hl, ch as pl, G as dl, aG as fl, ci as ml, cj as gl, ck as bl, cl as yl, as as wl, a8 as kl, cm as xl, cn as Nl, co as vl, M as Sl, cp as Al, cq as Cl, cr as Il, a0 as Dl, a3 as zl, aM as Tl, cs as $l, aa as El, ct as Ll, aK as Fl, P as Ml, cu as Ol, U as qn, at as Rl, cv as _l, cw as Bl, Q as Wl, av as Gl, au as Pl, u as Ul, aY as Vl, cx as jl, aZ as Kl, cy as Hl, aO as ql, aC as Zl, ap as Jl, cz as Xl, $ as Yl, aA as Ql, S as tu, w as eu, cA as nu, cB as su, cC as iu, aw as ru, cD as au, D as ou, cE as lu, T as uu, aR as cu, aQ as hu, cF as Ie, cG as pu, h as du, cH as Fs, F as Bt, a4 as Fe, g as fu, x as mu, V as xe, cI as gu, cJ as bu, m as Ms, cK as yu, cL as Os, cM as wu } from "./index-iNhkcAEQ.js";
|
|
2
|
+
import { s as ku, a as xu, g as Nu, b as vu, V as d, N as G, r as bn, c as Zn, e as et, f as ye, h as Ge, i as Jn, j as Pi, k as Ui, t as Rt, R as Et, l as ht, A as Pt, m as K, n as le, o as Xn, p as Yn, q as Ct, u as nt, v as Pe, w as Su, x as Au, y as Ue, z as Ut, B as Lt, C as oe, D as jt, E as Xe, F as ue, G as Cu, H as en, I as yn, J as Vi, K as It, L as Rs, M as Iu, O as Du, P as zu, Q as Tu, S as $u, T as Eu, U as Ht, W as Lu, X as Qn, Y as zt, Z as Ye, _ as Fu, $ as ot, a0 as ji, a1 as gt, a2 as ne, a3 as _s, a4 as Ne, d as St, a5 as Bs, a6 as Ve, a7 as Ki, a8 as ts, a9 as Mu, aa as Ou, ab as Hi, ac as Ru } from "./tfjs_backend-NucKez4s.js";
|
|
3
|
+
import { M as _u, a as wn, f as qi } from "./dropout-kbDY39Ci.js";
|
|
4
|
+
import { e as Bu, l as Wu, n as pt, w as qt, c as je, g as Ke, d as es, b as j, a as Ee, f as Gt, h as Gu, s as be, u as nn, i as ce, j as ns, t as Rn, m as Zi, k as _t } from "./ops-ObfXLHYQ.js";
|
|
5
|
+
import { z as mt } from "./zeros-NMYTayy7.js";
|
|
6
|
+
import { o as pe } from "./ones-BIeFnPHR.js";
|
|
7
|
+
import { v as Ji, B as Pu } from "./BaseLayer-BhrMN8JO.js";
|
|
8
|
+
import { r as A } from "./reshape-DxTPgnwL.js";
|
|
9
|
+
import { s as B } from "./sum-B_92TaHD.js";
|
|
10
|
+
import { m as Ot } from "./mat_mul-D0SifYfJ.js";
|
|
11
|
+
import { s as Kt } from "./split-BCbrzthj.js";
|
|
12
|
+
import { s as Uu, c as Xi } from "./sin-BOX-JVAj.js";
|
|
13
|
+
import { g as Yi, d as ss, e as Ws, c as Vu } from "./axis_util-97KkkyRQ.js";
|
|
14
|
+
import { a as Zt, e as Jt, l as ju } from "./log_sum_exp-CkumwesB.js";
|
|
15
|
+
import { s as kn } from "./stack--cqr9Dgc.js";
|
|
16
|
+
import { p as Ku } from "./slice_util-D-kaD4ZV.js";
|
|
17
|
+
import { c as is } from "./concat-Cxbo2sOz.js";
|
|
18
|
+
import { g as Qi } from "./gather-Bxe1Qip8.js";
|
|
19
|
+
import { m as at, a as rs } from "./moments-B06NlR_V.js";
|
|
20
|
+
import { s as tr } from "./softmax-BjsptB07.js";
|
|
21
|
+
import { m as ve } from "./max-CYaAjEEp.js";
|
|
22
|
+
import { t as Hu } from "./tensor-CfiPXsW4.js";
|
|
23
|
+
import { r as qu } from "./range-BsFU-SNG.js";
|
|
24
|
+
import { m as Zu } from "./norm-D3676xIo.js";
|
|
24
25
|
/**
|
|
25
26
|
* @license
|
|
26
27
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -37,11 +38,11 @@ import { m as qu } from "./norm-DSva3hI3.js";
|
|
|
37
38
|
* limitations under the License.
|
|
38
39
|
* =============================================================================
|
|
39
40
|
*/
|
|
40
|
-
function
|
|
41
|
+
function Ju(s, t = null, e = !1) {
|
|
41
42
|
const i = { x: D(s, "x", "all", "bool") }, r = { axis: t, keepDims: e };
|
|
42
43
|
return M.runKernel(fo, i, r);
|
|
43
44
|
}
|
|
44
|
-
const
|
|
45
|
+
const Xu = /* @__PURE__ */ F({ all_: Ju });
|
|
45
46
|
/**
|
|
46
47
|
* @license
|
|
47
48
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -58,11 +59,11 @@ const Ju = /* @__PURE__ */ F({ all_: Zu });
|
|
|
58
59
|
* limitations under the License.
|
|
59
60
|
* =============================================================================
|
|
60
61
|
*/
|
|
61
|
-
function
|
|
62
|
+
function Yu(s, t = null, e = !1) {
|
|
62
63
|
const i = { x: D(s, "x", "any", "bool") }, r = { axis: t, keepDims: e };
|
|
63
64
|
return M.runKernel(mo, i, r);
|
|
64
65
|
}
|
|
65
|
-
const Gs = /* @__PURE__ */ F({ any_:
|
|
66
|
+
const Gs = /* @__PURE__ */ F({ any_: Yu });
|
|
66
67
|
/**
|
|
67
68
|
* @license
|
|
68
69
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
|
@@ -79,11 +80,11 @@ const Gs = /* @__PURE__ */ F({ any_: Xu });
|
|
|
79
80
|
* limitations under the License.
|
|
80
81
|
* =============================================================================
|
|
81
82
|
*/
|
|
82
|
-
function
|
|
83
|
+
function Qu(s, t = 0) {
|
|
83
84
|
const n = { x: D(s, "x", "argMax") }, i = { axis: t };
|
|
84
85
|
return M.runKernel(mi, n, i);
|
|
85
86
|
}
|
|
86
|
-
const sn = /* @__PURE__ */ F({ argMax_:
|
|
87
|
+
const sn = /* @__PURE__ */ F({ argMax_: Qu });
|
|
87
88
|
/**
|
|
88
89
|
* @license
|
|
89
90
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -100,7 +101,7 @@ const sn = /* @__PURE__ */ F({ argMax_: Yu });
|
|
|
100
101
|
* limitations under the License.
|
|
101
102
|
* =============================================================================
|
|
102
103
|
*/
|
|
103
|
-
function
|
|
104
|
+
function tc(s, t, e, n, i, r, a = !1, o = "channelsLast") {
|
|
104
105
|
let [l, u, c, h] = [-1, -1, -1, -1];
|
|
105
106
|
if (o === "channelsLast")
|
|
106
107
|
[l, u, c, h] = s;
|
|
@@ -108,7 +109,7 @@ function Qu(s, t, e, n, i, r, a = !1, o = "channelsLast") {
|
|
|
108
109
|
[l, h, u, c] = s;
|
|
109
110
|
else
|
|
110
111
|
throw new Error(`Unknown dataFormat ${o}`);
|
|
111
|
-
const [p, f, , g] = t, [b, m] = rn(e), [
|
|
112
|
+
const [p, f, , g] = t, [b, m] = rn(e), [v, y] = rn(n), C = _n(p, v), N = _n(f, y), { padInfo: I, outHeight: z, outWidth: _ } = sc(i, u, c, b, m, C, N, r, o), T = a ? g * h : g;
|
|
112
113
|
let E;
|
|
113
114
|
return o === "channelsFirst" ? E = [l, T, z, _] : o === "channelsLast" && (E = [l, z, _, T]), {
|
|
114
115
|
batchSize: l,
|
|
@@ -125,20 +126,20 @@ function Qu(s, t, e, n, i, r, a = !1, o = "channelsLast") {
|
|
|
125
126
|
filterHeight: p,
|
|
126
127
|
filterWidth: f,
|
|
127
128
|
effectiveFilterHeight: C,
|
|
128
|
-
effectiveFilterWidth:
|
|
129
|
-
dilationHeight:
|
|
129
|
+
effectiveFilterWidth: N,
|
|
130
|
+
dilationHeight: v,
|
|
130
131
|
dilationWidth: y,
|
|
131
132
|
inShape: s,
|
|
132
133
|
outShape: E,
|
|
133
134
|
filterShape: t
|
|
134
135
|
};
|
|
135
136
|
}
|
|
136
|
-
function
|
|
137
|
-
n == null && (n =
|
|
137
|
+
function ec(s, t, e, n, i) {
|
|
138
|
+
n == null && (n = nc(s, t, e));
|
|
138
139
|
const r = s[0], a = s[1], o = an((r - t + 2 * n) / e + 1, i), l = an((a - t + 2 * n) / e + 1, i);
|
|
139
140
|
return [o, l];
|
|
140
141
|
}
|
|
141
|
-
function
|
|
142
|
+
function nc(s, t, e, n = 1) {
|
|
142
143
|
const i = _n(t, n);
|
|
143
144
|
return Math.floor((s[0] * (e - 1) - e + i) / 2);
|
|
144
145
|
}
|
|
@@ -148,16 +149,16 @@ function rn(s) {
|
|
|
148
149
|
function _n(s, t) {
|
|
149
150
|
return t <= 1 ? s : s + (s - 1) * (t - 1);
|
|
150
151
|
}
|
|
151
|
-
function
|
|
152
|
+
function sc(s, t, e, n, i, r, a, o, l) {
|
|
152
153
|
let u, c, h;
|
|
153
154
|
if (typeof s == "number") {
|
|
154
155
|
u = { top: s, bottom: s, left: s, right: s, type: s === 0 ? "VALID" : "NUMBER" };
|
|
155
|
-
const f =
|
|
156
|
+
const f = ec([t, e], r, n, s, o);
|
|
156
157
|
c = f[0], h = f[1];
|
|
157
158
|
} else if (s === "same") {
|
|
158
159
|
c = Math.ceil(t / n), h = Math.ceil(e / i);
|
|
159
|
-
const p = Math.max(0, (c - 1) * n + r - t), f = Math.max(0, (h - 1) * i + a - e), g = Math.floor(p / 2), b = p - g, m = Math.floor(f / 2),
|
|
160
|
-
u = { top: g, bottom: b, left: m, right:
|
|
160
|
+
const p = Math.max(0, (c - 1) * n + r - t), f = Math.max(0, (h - 1) * i + a - e), g = Math.floor(p / 2), b = p - g, m = Math.floor(f / 2), v = f - m;
|
|
161
|
+
u = { top: g, bottom: b, left: m, right: v, type: "SAME" };
|
|
161
162
|
} else if (s === "valid")
|
|
162
163
|
u = { top: 0, bottom: 0, left: 0, right: 0, type: "VALID" }, c = Math.ceil((t - r + 1) / n), h = Math.ceil((e - a + 1) / i);
|
|
163
164
|
else if (typeof s == "object") {
|
|
@@ -223,7 +224,7 @@ function ft(s, t, e) {
|
|
|
223
224
|
* limitations under the License.
|
|
224
225
|
* =============================================================================
|
|
225
226
|
*/
|
|
226
|
-
function
|
|
227
|
+
function ic(s, t, e, n, i) {
|
|
227
228
|
const r = D(s, "x", "avgPool", "float32"), a = 1;
|
|
228
229
|
k(de(e, a), () => `Error in avgPool: Either strides or dilations must be 1. Got strides ${e} and dilations '${a}'`);
|
|
229
230
|
let o = r, l = !1;
|
|
@@ -232,7 +233,7 @@ function sc(s, t, e, n, i) {
|
|
|
232
233
|
let h = M.runKernel(gi, u, c);
|
|
233
234
|
return h = L(h, r.dtype), l ? A(h, [h.shape[1], h.shape[2], h.shape[3]]) : h;
|
|
234
235
|
}
|
|
235
|
-
const
|
|
236
|
+
const rc = /* @__PURE__ */ F({ avgPool_: ic });
|
|
236
237
|
/**
|
|
237
238
|
* @license
|
|
238
239
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -249,7 +250,7 @@ const ic = /* @__PURE__ */ F({ avgPool_: sc });
|
|
|
249
250
|
* limitations under the License.
|
|
250
251
|
* =============================================================================
|
|
251
252
|
*/
|
|
252
|
-
function
|
|
253
|
+
function ac(s, t, e, n, i, r = "NDHWC") {
|
|
253
254
|
const a = D(s, "x", "avgPool3d", "float32");
|
|
254
255
|
let o = a, l = !1;
|
|
255
256
|
a.rank === 4 && (l = !0, o = A(a, [1, a.shape[0], a.shape[1], a.shape[2], a.shape[3]])), k(o.rank === 5, () => `Error in avgPool3d: x must be rank 5 but got rank ${o.rank}.`), k(r === "NDHWC", () => `Error in avgPool3d: Only NDHWC is currently supported, but got dataFormat of ${r}`), k(typeof e == "number" && e > 0 || Array.isArray(e) && e[0] > 0 && e[1] > 0 && e[2] > 0, () => `Error in avgPool3d: Stride must be > 0, but got '${e}'`), ft("avgPool3d", n, i);
|
|
@@ -257,7 +258,7 @@ function rc(s, t, e, n, i, r = "NDHWC") {
|
|
|
257
258
|
let h = M.runKernel(bi, u, c);
|
|
258
259
|
return h = L(h, o.dtype), l ? A(h, [h.shape[1], h.shape[2], h.shape[3], h.shape[4]]) : h;
|
|
259
260
|
}
|
|
260
|
-
const
|
|
261
|
+
const oc = /* @__PURE__ */ F({ avgPool3d_: ac });
|
|
261
262
|
/**
|
|
262
263
|
* @license
|
|
263
264
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -274,11 +275,11 @@ const ac = /* @__PURE__ */ F({ avgPool3d_: rc });
|
|
|
274
275
|
* limitations under the License.
|
|
275
276
|
* =============================================================================
|
|
276
277
|
*/
|
|
277
|
-
function
|
|
278
|
+
function lc(s) {
|
|
278
279
|
const e = { x: D(s, "x", "tanh", "float32") };
|
|
279
280
|
return M.runKernel(yi, e);
|
|
280
281
|
}
|
|
281
|
-
const as = /* @__PURE__ */ F({ tanh_:
|
|
282
|
+
const as = /* @__PURE__ */ F({ tanh_: lc });
|
|
282
283
|
/**
|
|
283
284
|
* @license
|
|
284
285
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -295,14 +296,14 @@ const as = /* @__PURE__ */ F({ tanh_: oc });
|
|
|
295
296
|
* limitations under the License.
|
|
296
297
|
* =============================================================================
|
|
297
298
|
*/
|
|
298
|
-
function
|
|
299
|
+
function uc(s, t, e) {
|
|
299
300
|
const n = D(s, "x", "batchToSpaceND"), i = t.reduce((o, l) => o * l);
|
|
300
301
|
k(n.rank >= 1 + t.length, () => `input rank is ${n.rank} but should be > than blockShape.length ${t.length}`), k(e.length === t.length, () => `crops.length is ${e.length} but should be equal to blockShape.length ${t.length}`), k(n.shape[0] % i === 0, () => `input tensor batch is ${n.shape[0]} but is not divisible by the product of the elements of blockShape ${t.join(" * ")} === ${i}`);
|
|
301
302
|
const r = { x: n }, a = { blockShape: t, crops: e };
|
|
302
303
|
return M.runKernel(wi, r, a);
|
|
303
304
|
}
|
|
304
|
-
const
|
|
305
|
-
function
|
|
305
|
+
const cc = /* @__PURE__ */ F({ batchToSpaceND_: uc });
|
|
306
|
+
function hc(s) {
|
|
306
307
|
let t;
|
|
307
308
|
return s.rank === 0 || s.rank === 1 ? t = A(s, [1, 1, 1, s.size]) : s.rank === 2 ? t = A(s, [1, 1, s.shape[0], s.shape[1]]) : s.rank === 3 ? t = A(s, [1, s.shape[0], s.shape[1], s.shape[2]]) : t = s, t;
|
|
308
309
|
}
|
|
@@ -322,7 +323,7 @@ function cc(s) {
|
|
|
322
323
|
* limitations under the License.
|
|
323
324
|
* =============================================================================
|
|
324
325
|
*/
|
|
325
|
-
function
|
|
326
|
+
function pc(s, t, e, n, i, r) {
|
|
326
327
|
r == null && (r = 1e-3);
|
|
327
328
|
const a = D(s, "x", "batchNorm"), o = D(t, "mean", "batchNorm"), l = D(e, "variance", "batchNorm");
|
|
328
329
|
let u;
|
|
@@ -330,7 +331,7 @@ function hc(s, t, e, n, i, r) {
|
|
|
330
331
|
let c;
|
|
331
332
|
n != null && (c = D(n, "offset", "batchNorm")), k(o.rank === l.rank, () => "Batch normalization gradient requires mean and variance to have equal ranks."), k(c == null || o.rank === c.rank, () => "Batch normalization gradient requires mean and offset to have equal ranks."), k(u == null || o.rank === u.rank, () => "Batch normalization gradient requires mean and scale to have equal ranks.");
|
|
332
333
|
const p = {
|
|
333
|
-
x:
|
|
334
|
+
x: hc(a),
|
|
334
335
|
scale: u,
|
|
335
336
|
offset: c,
|
|
336
337
|
mean: o,
|
|
@@ -338,31 +339,31 @@ function hc(s, t, e, n, i, r) {
|
|
|
338
339
|
}, f = { varianceEpsilon: r }, g = M.runKernel(ki, p, f);
|
|
339
340
|
return A(g, a.shape);
|
|
340
341
|
}
|
|
341
|
-
const os = /* @__PURE__ */ F({ batchNorm_:
|
|
342
|
-
function
|
|
342
|
+
const os = /* @__PURE__ */ F({ batchNorm_: pc });
|
|
343
|
+
function dc(s, t, e, n, i, r) {
|
|
343
344
|
const a = D(s, "x", "batchNorm"), o = D(t, "mean", "batchNorm"), l = D(e, "variance", "batchNorm");
|
|
344
345
|
let u;
|
|
345
346
|
i != null && (u = D(i, "scale", "batchNorm"));
|
|
346
347
|
let c;
|
|
347
348
|
return n != null && (c = D(n, "offset", "batchNorm")), k(a.rank === 2, () => `Error in batchNorm2D: x must be rank 2 but got rank ${a.rank}.`), k(o.rank === 2 || o.rank === 1, () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but got rank ${o.rank}.`), k(l.rank === 2 || l.rank === 1, () => `Error in batchNorm2D: variance must be rank 2 or rank 1 but got rank ${l.rank}.`), u != null && k(u.rank === 2 || u.rank === 1, () => `Error in batchNorm2D: scale must be rank 2 or rank 1 but got rank ${u.rank}.`), c != null && k(c.rank === 2 || c.rank === 1, () => `Error in batchNorm2D: offset must be rank 2 or rank 1 but got rank ${c.rank}.`), os(a, o, l, c, u, r);
|
|
348
349
|
}
|
|
349
|
-
const
|
|
350
|
-
function
|
|
350
|
+
const fc = /* @__PURE__ */ F({ batchNorm2d_: dc });
|
|
351
|
+
function mc(s, t, e, n, i, r) {
|
|
351
352
|
const a = D(s, "x", "batchNorm"), o = D(t, "mean", "batchNorm"), l = D(e, "variance", "batchNorm");
|
|
352
353
|
let u;
|
|
353
354
|
i != null && (u = D(i, "scale", "batchNorm"));
|
|
354
355
|
let c;
|
|
355
356
|
return n != null && (c = D(n, "offset", "batchNorm")), k(a.rank === 3, () => `Error in batchNorm3D: x must be rank 3 but got rank ${a.rank}.`), k(o.rank === 3 || o.rank === 1, () => `Error in batchNorm3D: mean must be rank 3 or rank 1 but got rank ${o.rank}.`), k(l.rank === 3 || l.rank === 1, () => `Error in batchNorm3D: variance must be rank 3 or rank 1 but got rank ${l.rank}.`), u != null && k(u.rank === 3 || u.rank === 1, () => `Error in batchNorm3D: scale must be rank 3 or rank 1 but got rank ${u.rank}.`), c != null && k(c.rank === 3 || c.rank === 1, () => `Error in batchNorm3D: offset must be rank 3 or rank 1 but got rank ${c.rank}.`), os(a, o, l, c, u, r);
|
|
356
357
|
}
|
|
357
|
-
const
|
|
358
|
-
function
|
|
358
|
+
const gc = /* @__PURE__ */ F({ batchNorm3d_: mc });
|
|
359
|
+
function bc(s, t, e, n, i, r) {
|
|
359
360
|
const a = D(s, "x", "batchNorm"), o = D(t, "mean", "batchNorm"), l = D(e, "variance", "batchNorm");
|
|
360
361
|
let u;
|
|
361
362
|
i != null && (u = D(i, "scale", "batchNorm"));
|
|
362
363
|
let c;
|
|
363
364
|
return n != null && (c = D(n, "offset", "batchNorm")), k(a.rank === 4, () => `Error in batchNorm4D: x must be rank 4 but got rank ${a.rank}.`), k(o.rank === 4 || o.rank === 1, () => `Error in batchNorm4D: mean must be rank 4 or rank 1 but got rank ${o.rank}.`), k(l.rank === 4 || l.rank === 1, () => `Error in batchNorm4D: variance must be rank 4 or rank 1 but got rank ${l.rank}.`), u != null && k(u.rank === 4 || u.rank === 1, () => `Error in batchNorm4D: scale must be rank 4 or rank 1 but got rank ${u.rank}.`), c != null && k(c.rank === 4 || c.rank === 1, () => `Error in batchNorm4D: offset must be rank 4 or rank 1 but got rank ${c.rank}.`), os(a, o, l, c, u, r);
|
|
364
365
|
}
|
|
365
|
-
const
|
|
366
|
+
const yc = /* @__PURE__ */ F({ batchNorm4d_: bc });
|
|
366
367
|
/**
|
|
367
368
|
* @license
|
|
368
369
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -379,7 +380,7 @@ const bc = /* @__PURE__ */ F({ batchNorm4d_: gc });
|
|
|
379
380
|
* limitations under the License.
|
|
380
381
|
* =============================================================================
|
|
381
382
|
*/
|
|
382
|
-
function
|
|
383
|
+
function wc(s, t, e, n, i = "NHWC", r = [1, 1], a) {
|
|
383
384
|
const o = D(s, "x", "conv2d", "float32"), l = D(t, "filter", "conv2d", "float32");
|
|
384
385
|
let u = o, c = !1;
|
|
385
386
|
o.rank === 3 && (c = !0, u = A(o, [1, o.shape[0], o.shape[1], o.shape[2]])), k(u.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${u.rank}.`), k(l.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ${l.rank}.`), ft("conv2d", n, a);
|
|
@@ -388,15 +389,15 @@ function yc(s, t, e, n, i = "NHWC", r = [1, 1], a) {
|
|
|
388
389
|
const p = { x: u, filter: l }, f = { strides: e, pad: n, dataFormat: i, dilations: r, dimRoundingMode: a }, g = M.runKernel(xi, p, f);
|
|
389
390
|
return c ? A(g, [g.shape[1], g.shape[2], g.shape[3]]) : g;
|
|
390
391
|
}
|
|
391
|
-
const Ce = /* @__PURE__ */ F({ conv2d_:
|
|
392
|
-
function
|
|
392
|
+
const Ce = /* @__PURE__ */ F({ conv2d_: wc });
|
|
393
|
+
function kc(s, t, e, n, i = "NWC", r = 1, a) {
|
|
393
394
|
const o = D(s, "x", "conv1d"), l = D(t, "filter", "conv1d");
|
|
394
395
|
let u = o, c = !1;
|
|
395
396
|
o.rank === 2 && (c = !0, u = A(o, [1, o.shape[0], o.shape[1]])), k(u.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${u.rank}.`), k(l.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ${l.rank}.`), ft("conv1d", n, a), k(u.shape[2] === l.shape[1], () => `Error in conv1d: depth of input (${u.shape[2]}) must match input depth for filter ${l.shape[1]}.`), k(de(e, r), () => `Error in conv1D: Either stride or dilation must be 1. Got stride ${e} and dilation '${r}'`), k(Ae(r), () => "Error in conv1D: Dilated rates should be larger than 0."), k(Ae(e), () => "Error in conv1D: Stride should be larger than 0."), k(i === "NWC", () => `Error in conv1d: got dataFormat of ${i} but only NWC is currently supported.`);
|
|
396
397
|
const h = A(l, [1, l.shape[0], l.shape[1], l.shape[2]]), p = A(u, [u.shape[0], 1, u.shape[1], u.shape[2]]), m = Ce(p, h, [1, e], n, "NHWC", [1, r], a);
|
|
397
398
|
return c ? A(m, [m.shape[2], m.shape[3]]) : A(m, [m.shape[0], m.shape[2], m.shape[3]]);
|
|
398
399
|
}
|
|
399
|
-
const
|
|
400
|
+
const xc = /* @__PURE__ */ F({ conv1d_: kc });
|
|
400
401
|
/**
|
|
401
402
|
* @license
|
|
402
403
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -413,21 +414,21 @@ const kc = /* @__PURE__ */ F({ conv1d_: wc });
|
|
|
413
414
|
* limitations under the License.
|
|
414
415
|
* =============================================================================
|
|
415
416
|
*/
|
|
416
|
-
function
|
|
417
|
+
function Nc(s, t, e, n, i, r = "NHWC", a) {
|
|
417
418
|
k(s.length === t.rank, () => `Length of inShape (${s.length}) and rank of dy (${t.rank}) must match`);
|
|
418
419
|
let o = s, l = t, u = !1;
|
|
419
420
|
t.rank === 3 && (u = !0, l = A(t, [1, t.shape[0], t.shape[1], t.shape[2]]), o = [1, s[0], s[1], s[2]]), k(o.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ${o.length}.`), k(l.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got rank ${l.rank}`), k(e.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got rank ${e.rank}`);
|
|
420
421
|
const c = r === "NHWC" ? o[3] : o[1], h = r === "NHWC" ? l.shape[3] : l.shape[1];
|
|
421
422
|
k(c === e.shape[2], () => `Error in conv2dDerInput: depth of input (${c}) must match input depth for filter ${e.shape[2]}.`), k(h === e.shape[3], () => `Error in conv2dDerInput: depth of output (${h}) must match output depth for filter ${e.shape[3]}.`), ft("conv2dDerInput", i, a);
|
|
422
|
-
const p = { dy: l, filter: e }, f = { strides: n, pad: i, dataFormat: r, dimRoundingMode: a, inputShape: o }, g = M.runKernel(
|
|
423
|
+
const p = { dy: l, filter: e }, f = { strides: n, pad: i, dataFormat: r, dimRoundingMode: a, inputShape: o }, g = M.runKernel(Ni, p, f);
|
|
423
424
|
return u ? A(g, [g.shape[1], g.shape[2], g.shape[3]]) : g;
|
|
424
425
|
}
|
|
425
|
-
const ls = /* @__PURE__ */ F({ conv2DBackpropInput_:
|
|
426
|
+
const ls = /* @__PURE__ */ F({ conv2DBackpropInput_: Nc });
|
|
426
427
|
function vc(s, t, e, n, i, r) {
|
|
427
428
|
const a = D(s, "x", "conv2dTranspose"), o = D(t, "filter", "conv2dTranspose");
|
|
428
429
|
return ls(e, a, o, n, i, "NHWC", r);
|
|
429
430
|
}
|
|
430
|
-
const
|
|
431
|
+
const Sc = /* @__PURE__ */ F({ conv2dTranspose_: vc });
|
|
431
432
|
/**
|
|
432
433
|
* @license
|
|
433
434
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -444,14 +445,14 @@ const Nc = /* @__PURE__ */ F({ conv2dTranspose_: vc });
|
|
|
444
445
|
* limitations under the License.
|
|
445
446
|
* =============================================================================
|
|
446
447
|
*/
|
|
447
|
-
function
|
|
448
|
+
function Ac(s, t, e, n, i = "NDHWC", r = [1, 1, 1]) {
|
|
448
449
|
const a = D(s, "x", "conv3d"), o = D(t, "filter", "conv3d");
|
|
449
450
|
let l = a, u = !1;
|
|
450
451
|
a.rank === 4 && (u = !0, l = A(a, [1, a.shape[0], a.shape[1], a.shape[2], a.shape[3]])), k(l.rank === 5, () => `Error in conv3d: input must be rank 5, but got rank ${l.rank}.`), k(o.rank === 5, () => `Error in conv3d: filter must be rank 5, but got rank ${o.rank}.`), k(l.shape[4] === o.shape[3], () => `Error in conv3d: depth of input (${l.shape[4]}) must match input depth for filter ${o.shape[3]}.`), k(de(e, r), () => `Error in conv3D: Either strides or dilations must be 1. Got strides ${e} and dilations '${r}'`), k(i === "NDHWC", () => `Error in conv3d: got dataFormat of ${i} but only NDHWC is currently supported.`), k(Ae(r), () => "Error in conv3D: Dilated rates should be larger than 0."), k(Ae(e), () => "Error in conv3D: Strides should be larger than 0.");
|
|
451
|
-
const c = { x: l, filter: o }, h = { strides: e, pad: n, dataFormat: i, dilations: r }, p = M.runKernel(
|
|
452
|
+
const c = { x: l, filter: o }, h = { strides: e, pad: n, dataFormat: i, dilations: r }, p = M.runKernel(vi, c, h);
|
|
452
453
|
return u ? A(p, [p.shape[1], p.shape[2], p.shape[3], p.shape[4]]) : p;
|
|
453
454
|
}
|
|
454
|
-
const
|
|
455
|
+
const Cc = /* @__PURE__ */ F({ conv3d_: Ac });
|
|
455
456
|
/**
|
|
456
457
|
* @license
|
|
457
458
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -468,7 +469,7 @@ const Ac = /* @__PURE__ */ F({ conv3d_: Sc });
|
|
|
468
469
|
* limitations under the License.
|
|
469
470
|
* =============================================================================
|
|
470
471
|
*/
|
|
471
|
-
function
|
|
472
|
+
function Ic(s, t, e, n, i) {
|
|
472
473
|
k(s.length === t.rank, () => `Length of inShape (${s.length}) and rank of dy (${t.rank}) must match`);
|
|
473
474
|
let r = s, a = t, o = !1;
|
|
474
475
|
t.rank === 4 && (o = !0, a = A(t, [1, t.shape[0], t.shape[1], t.shape[2], t.shape[3]]), r = [1, s[0], s[1], s[2], s[3]]);
|
|
@@ -477,12 +478,12 @@ function Cc(s, t, e, n, i) {
|
|
|
477
478
|
const c = { dy: a, filter: e }, h = { pad: i, strides: n, inputShape: r }, p = M.runKernel(go, c, h);
|
|
478
479
|
return o ? A(p, [p.shape[1], p.shape[2], p.shape[3], p.shape[4]]) : p;
|
|
479
480
|
}
|
|
480
|
-
const er = /* @__PURE__ */ F({ conv3DBackpropInput_:
|
|
481
|
-
function
|
|
481
|
+
const er = /* @__PURE__ */ F({ conv3DBackpropInput_: Ic });
|
|
482
|
+
function Dc(s, t, e, n, i) {
|
|
482
483
|
const r = D(s, "x", "conv3dTranspose"), a = D(t, "filter", "conv3dTranspose");
|
|
483
484
|
return er(e, r, a, n, i);
|
|
484
485
|
}
|
|
485
|
-
const
|
|
486
|
+
const zc = /* @__PURE__ */ F({ conv3dTranspose_: Dc });
|
|
486
487
|
/**
|
|
487
488
|
* @license
|
|
488
489
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -499,11 +500,11 @@ const Dc = /* @__PURE__ */ F({ conv3dTranspose_: Ic });
|
|
|
499
500
|
* limitations under the License.
|
|
500
501
|
* =============================================================================
|
|
501
502
|
*/
|
|
502
|
-
function
|
|
503
|
+
function Tc(s) {
|
|
503
504
|
const e = { x: D(s, "x", "cosh", "float32") };
|
|
504
505
|
return M.runKernel(Si, e);
|
|
505
506
|
}
|
|
506
|
-
const
|
|
507
|
+
const $c = /* @__PURE__ */ F({ cosh_: Tc });
|
|
507
508
|
/**
|
|
508
509
|
* @license
|
|
509
510
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
|
@@ -520,11 +521,11 @@ const Tc = /* @__PURE__ */ F({ cosh_: zc });
|
|
|
520
521
|
* limitations under the License.
|
|
521
522
|
* =============================================================================
|
|
522
523
|
*/
|
|
523
|
-
function
|
|
524
|
+
function Ec(s, t = 0, e = !1, n = !1) {
|
|
524
525
|
const r = { x: D(s, "x", "cumprod") }, a = { axis: t, exclusive: e, reverse: n };
|
|
525
526
|
return M.runKernel(bo, r, a);
|
|
526
527
|
}
|
|
527
|
-
const Ps = /* @__PURE__ */ F({ cumprod_:
|
|
528
|
+
const Ps = /* @__PURE__ */ F({ cumprod_: Ec });
|
|
528
529
|
/**
|
|
529
530
|
* @license
|
|
530
531
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -541,11 +542,11 @@ const Ps = /* @__PURE__ */ F({ cumprod_: $c });
|
|
|
541
542
|
* limitations under the License.
|
|
542
543
|
* =============================================================================
|
|
543
544
|
*/
|
|
544
|
-
function
|
|
545
|
+
function Lc(s, t = 0, e = !1, n = !1) {
|
|
545
546
|
const r = { x: D(s, "x", "cumsum") }, a = { axis: t, exclusive: e, reverse: n };
|
|
546
547
|
return M.runKernel(Ai, r, a);
|
|
547
548
|
}
|
|
548
|
-
const
|
|
549
|
+
const Fc = /* @__PURE__ */ F({ cumsum_: Lc });
|
|
549
550
|
/**
|
|
550
551
|
* @license
|
|
551
552
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -562,13 +563,13 @@ const Lc = /* @__PURE__ */ F({ cumsum_: Ec });
|
|
|
562
563
|
* limitations under the License.
|
|
563
564
|
* =============================================================================
|
|
564
565
|
*/
|
|
565
|
-
function
|
|
566
|
+
function Mc(s, t, e, n = !1) {
|
|
566
567
|
const i = D(s, "x", "denseBincount"), r = D(t, "weights", "denseBincount");
|
|
567
568
|
k(i.dtype === "int32", () => `Error in denseBincount: input dtype must be int32, but got ${i.dtype}`), k(i.rank <= 2, () => `Error in denseBincount: input must be at most rank 2, but got rank ${i.rank}.`), k(e >= 0, () => `size must be non-negative, but got ${e}.`), k(r.size === i.size || r.size === 0, () => `Error in denseBincount: weights must have the same shape as x or 0-length, but got x shape: ${i.shape}, weights shape: ${r.shape}.`);
|
|
568
569
|
const a = { x: i, weights: r }, o = { size: e, binaryOutput: n };
|
|
569
570
|
return M.runKernel(yo, a, o);
|
|
570
571
|
}
|
|
571
|
-
const Us = /* @__PURE__ */ F({ denseBincount_:
|
|
572
|
+
const Us = /* @__PURE__ */ F({ denseBincount_: Mc });
|
|
572
573
|
/**
|
|
573
574
|
* @license
|
|
574
575
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -585,7 +586,7 @@ const Us = /* @__PURE__ */ F({ denseBincount_: Fc });
|
|
|
585
586
|
* limitations under the License.
|
|
586
587
|
* =============================================================================
|
|
587
588
|
*/
|
|
588
|
-
function
|
|
589
|
+
function Oc(s, t, e, n, i = "NHWC", r = [1, 1], a) {
|
|
589
590
|
const o = D(s, "x", "depthwiseConv2d", "float32"), l = D(t, "filter", "depthwiseConv2d", "float32");
|
|
590
591
|
let u = o, c = !1;
|
|
591
592
|
o.rank === 3 && (c = !0, u = A(o, [1, o.shape[0], o.shape[1], o.shape[2]])), k(u.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got rank ${u.rank}.`), k(l.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ${l.rank}.`);
|
|
@@ -594,7 +595,7 @@ function Mc(s, t, e, n, i = "NHWC", r = [1, 1], a) {
|
|
|
594
595
|
const p = { x: u, filter: l }, f = { strides: e, pad: n, dataFormat: i, dilations: r, dimRoundingMode: a }, g = M.runKernel(Ci, p, f);
|
|
595
596
|
return c ? A(g, [g.shape[1], g.shape[2], g.shape[3]]) : g;
|
|
596
597
|
}
|
|
597
|
-
const nr = /* @__PURE__ */ F({ depthwiseConv2d_:
|
|
598
|
+
const nr = /* @__PURE__ */ F({ depthwiseConv2d_: Oc });
|
|
598
599
|
/**
|
|
599
600
|
* @license
|
|
600
601
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -611,13 +612,13 @@ const nr = /* @__PURE__ */ F({ depthwiseConv2d_: Mc });
|
|
|
611
612
|
* limitations under the License.
|
|
612
613
|
* =============================================================================
|
|
613
614
|
*/
|
|
614
|
-
function
|
|
615
|
+
function Rc(s, t) {
|
|
615
616
|
let e = D(s, "a", "equal", "string_or_numeric"), n = D(t, "b", "equal", "string_or_numeric");
|
|
616
617
|
[e, n] = Hn(e, n), wt(e.shape, n.shape);
|
|
617
618
|
const i = { a: e, b: n };
|
|
618
619
|
return M.runKernel(wo, i);
|
|
619
620
|
}
|
|
620
|
-
const Xt = /* @__PURE__ */ F({ equal_:
|
|
621
|
+
const Xt = /* @__PURE__ */ F({ equal_: Rc });
|
|
621
622
|
/**
|
|
622
623
|
* @license
|
|
623
624
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -634,13 +635,13 @@ const Xt = /* @__PURE__ */ F({ equal_: Oc });
|
|
|
634
635
|
* limitations under the License.
|
|
635
636
|
* =============================================================================
|
|
636
637
|
*/
|
|
637
|
-
function
|
|
638
|
+
function _c(s) {
|
|
638
639
|
let t = D(s, "x", "erf");
|
|
639
640
|
k(t.dtype === "int32" || t.dtype === "float32", () => "Input dtype must be `int32` or `float32`."), t.dtype === "int32" && (t = L(t, "float32"));
|
|
640
641
|
const e = { x: t };
|
|
641
642
|
return M.runKernel(Ii, e);
|
|
642
643
|
}
|
|
643
|
-
const
|
|
644
|
+
const Bc = /* @__PURE__ */ F({ erf_: _c });
|
|
644
645
|
/**
|
|
645
646
|
* @license
|
|
646
647
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -657,11 +658,11 @@ const _c = /* @__PURE__ */ F({ erf_: Rc });
|
|
|
657
658
|
* limitations under the License.
|
|
658
659
|
* =============================================================================
|
|
659
660
|
*/
|
|
660
|
-
function
|
|
661
|
+
function Wc(s) {
|
|
661
662
|
const e = { x: D(s, "x", "log1p") };
|
|
662
663
|
return M.runKernel(Di, e);
|
|
663
664
|
}
|
|
664
|
-
const
|
|
665
|
+
const Gc = /* @__PURE__ */ F({ log1p_: Wc });
|
|
665
666
|
/**
|
|
666
667
|
* @license
|
|
667
668
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -678,11 +679,11 @@ const Wc = /* @__PURE__ */ F({ log1p_: Bc });
|
|
|
678
679
|
* limitations under the License.
|
|
679
680
|
* =============================================================================
|
|
680
681
|
*/
|
|
681
|
-
function
|
|
682
|
+
function Pc(s) {
|
|
682
683
|
const e = { x: D(s, "x", "softplus") };
|
|
683
684
|
return M.runKernel(zi, e);
|
|
684
685
|
}
|
|
685
|
-
const us = /* @__PURE__ */ F({ softplus_:
|
|
686
|
+
const us = /* @__PURE__ */ F({ softplus_: Pc });
|
|
686
687
|
/**
|
|
687
688
|
* @license
|
|
688
689
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
|
@@ -699,19 +700,19 @@ const us = /* @__PURE__ */ F({ softplus_: Gc });
|
|
|
699
700
|
* limitations under the License.
|
|
700
701
|
* =============================================================================
|
|
701
702
|
*/
|
|
702
|
-
function
|
|
703
|
+
function Uc(s, t = -1) {
|
|
703
704
|
const e = D(s, "logits", "logSoftmax");
|
|
704
705
|
if (t === -1 && (t = e.rank - 1), t !== e.rank - 1)
|
|
705
706
|
throw Error(`Log Softmax along a non-last dimension is not yet supported. Logits was rank ${e.rank} and axis was ${t}`);
|
|
706
707
|
return On((i, r) => {
|
|
707
|
-
const o =
|
|
708
|
+
const o = ve(i, t, !0), l = V(i, o), u = V(L(l, "float32"), Zt(B(Jt(l), t, !0)));
|
|
708
709
|
return r([u]), { value: u, gradFunc: (h, p) => {
|
|
709
710
|
const [f] = p, g = !0, b = Jt(f);
|
|
710
711
|
return V(h, w(B(h, t, g), b));
|
|
711
712
|
} };
|
|
712
713
|
})(e);
|
|
713
714
|
}
|
|
714
|
-
const
|
|
715
|
+
const Vc = /* @__PURE__ */ F({ logSoftmax_: Uc });
|
|
715
716
|
/**
|
|
716
717
|
* @license
|
|
717
718
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -728,11 +729,11 @@ const Uc = /* @__PURE__ */ F({ logSoftmax_: Pc });
|
|
|
728
729
|
* limitations under the License.
|
|
729
730
|
* =============================================================================
|
|
730
731
|
*/
|
|
731
|
-
function
|
|
732
|
+
function jc(s) {
|
|
732
733
|
const e = { x: D(s, "x", "logicalNot", "bool") };
|
|
733
734
|
return M.runKernel(ko, e);
|
|
734
735
|
}
|
|
735
|
-
const
|
|
736
|
+
const Kc = /* @__PURE__ */ F({ logicalNot_: jc });
|
|
736
737
|
/**
|
|
737
738
|
* @license
|
|
738
739
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -749,14 +750,14 @@ const jc = /* @__PURE__ */ F({ logicalNot_: Vc });
|
|
|
749
750
|
* limitations under the License.
|
|
750
751
|
* =============================================================================
|
|
751
752
|
*/
|
|
752
|
-
function
|
|
753
|
+
function Hc(s, t, e, n, i) {
|
|
753
754
|
const r = D(s, "x", "maxPool"), a = 1;
|
|
754
755
|
let o = r, l = !1;
|
|
755
756
|
r.rank === 3 && (l = !0, o = A(r, [1, r.shape[0], r.shape[1], r.shape[2]])), k(o.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${o.rank}.`), k(de(e, a), () => `Error in maxPool: Either strides or dilations must be 1. Got strides ${e} and dilations '${a}'`), ft("maxPool", n, i);
|
|
756
757
|
const u = { x: o }, c = { filterSize: t, strides: e, pad: n, dimRoundingMode: i }, h = M.runKernel(Ti, u, c);
|
|
757
758
|
return l ? A(h, [h.shape[1], h.shape[2], h.shape[3]]) : h;
|
|
758
759
|
}
|
|
759
|
-
const
|
|
760
|
+
const qc = /* @__PURE__ */ F({ maxPool_: Hc });
|
|
760
761
|
/**
|
|
761
762
|
* @license
|
|
762
763
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -773,14 +774,14 @@ const Hc = /* @__PURE__ */ F({ maxPool_: Kc });
|
|
|
773
774
|
* limitations under the License.
|
|
774
775
|
* =============================================================================
|
|
775
776
|
*/
|
|
776
|
-
function
|
|
777
|
+
function Zc(s, t = [1, 1, 1], e, n, i, r = "NDHWC") {
|
|
777
778
|
const a = D(s, "x", "maxPool3d");
|
|
778
779
|
let o = a, l = !1;
|
|
779
780
|
a.rank === 4 && (l = !0, o = A(a, [1, a.shape[0], a.shape[1], a.shape[2], a.shape[3]])), k(o.rank === 5, () => `Error in maxPool3d: x must be rank 5 but got rank ${o.rank}.`), k(r === "NDHWC", () => `Error in maxPool3d: Only NDHWC is currently supported, but got dataFormat of ${r}`), ft("maxPool3d", n, i);
|
|
780
781
|
const u = { x: o }, c = { filterSize: t, strides: e, pad: n, dimRoundingMode: i, dataFormat: r }, h = M.runKernel($i, u, c);
|
|
781
782
|
return l ? A(h, [h.shape[1], h.shape[2], h.shape[3], h.shape[4]]) : h;
|
|
782
783
|
}
|
|
783
|
-
const
|
|
784
|
+
const Jc = /* @__PURE__ */ F({ maxPool3d_: Zc });
|
|
784
785
|
/**
|
|
785
786
|
* @license
|
|
786
787
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -797,13 +798,13 @@ const Zc = /* @__PURE__ */ F({ maxPool3d_: qc });
|
|
|
797
798
|
* limitations under the License.
|
|
798
799
|
* =============================================================================
|
|
799
800
|
*/
|
|
800
|
-
function
|
|
801
|
+
function Xc(s, t) {
|
|
801
802
|
let e = D(s, "a", "notEqual", "string_or_numeric"), n = D(t, "b", "notEqual", "string_or_numeric");
|
|
802
803
|
[e, n] = Hn(e, n), wt(e.shape, n.shape);
|
|
803
804
|
const i = { a: e, b: n };
|
|
804
805
|
return M.runKernel(xo, i);
|
|
805
806
|
}
|
|
806
|
-
const Bn = /* @__PURE__ */ F({ notEqual_:
|
|
807
|
+
const Bn = /* @__PURE__ */ F({ notEqual_: Xc });
|
|
807
808
|
/**
|
|
808
809
|
* @license
|
|
809
810
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -820,13 +821,13 @@ const Bn = /* @__PURE__ */ F({ notEqual_: Jc });
|
|
|
820
821
|
* limitations under the License.
|
|
821
822
|
* =============================================================================
|
|
822
823
|
*/
|
|
823
|
-
function
|
|
824
|
+
function Yc(s, t, e = 1, n = 0, i = "int32") {
|
|
824
825
|
if (t < 2)
|
|
825
826
|
throw new Error(`Error in oneHot: depth must be >=2, but it is ${t}`);
|
|
826
827
|
const a = { indices: D(s, "indices", "oneHot", "int32") }, o = { dtype: i, depth: t, onValue: e, offValue: n };
|
|
827
828
|
return M.runKernel(Ei, a, o);
|
|
828
829
|
}
|
|
829
|
-
const
|
|
830
|
+
const Qc = /* @__PURE__ */ F({ oneHot_: Yc });
|
|
830
831
|
/**
|
|
831
832
|
* @license
|
|
832
833
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -843,11 +844,11 @@ const Yc = /* @__PURE__ */ F({ oneHot_: Xc });
|
|
|
843
844
|
* limitations under the License.
|
|
844
845
|
* =============================================================================
|
|
845
846
|
*/
|
|
846
|
-
function
|
|
847
|
+
function th(s) {
|
|
847
848
|
const e = { x: D(s, "x", "onesLike") };
|
|
848
849
|
return M.runKernel(Li, e);
|
|
849
850
|
}
|
|
850
|
-
const Dt = /* @__PURE__ */ F({ onesLike_:
|
|
851
|
+
const Dt = /* @__PURE__ */ F({ onesLike_: th });
|
|
851
852
|
/**
|
|
852
853
|
* @license
|
|
853
854
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -864,14 +865,14 @@ const Dt = /* @__PURE__ */ F({ onesLike_: Qc });
|
|
|
864
865
|
* limitations under the License.
|
|
865
866
|
* =============================================================================
|
|
866
867
|
*/
|
|
867
|
-
function
|
|
868
|
+
function eh(s, t, e = 0) {
|
|
868
869
|
const n = D(s, "x", "pad");
|
|
869
870
|
if (n.rank === 0)
|
|
870
871
|
throw new Error("pad(scalar) is not defined. Pass non-scalar to pad");
|
|
871
872
|
const i = { paddings: t, constantValue: e }, r = { x: n };
|
|
872
873
|
return M.runKernel(Fi, r, i);
|
|
873
874
|
}
|
|
874
|
-
const sr = /* @__PURE__ */ F({ pad_:
|
|
875
|
+
const sr = /* @__PURE__ */ F({ pad_: eh });
|
|
875
876
|
/**
|
|
876
877
|
* @license
|
|
877
878
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -888,13 +889,13 @@ const sr = /* @__PURE__ */ F({ pad_: th });
|
|
|
888
889
|
* limitations under the License.
|
|
889
890
|
* =============================================================================
|
|
890
891
|
*/
|
|
891
|
-
function
|
|
892
|
+
function nh(s, t, e) {
|
|
892
893
|
const n = D(s, "x", "spaceToBatchND");
|
|
893
894
|
k(n.rank >= 1 + t.length, () => `input rank ${n.rank} should be > than [blockShape] ${t.length}`), k(e.length === t.length, () => `paddings.shape[0] ${e.length} must be equal to [blockShape] ${t.length}`), k(n.shape.reduce((a, o, l) => l > 0 && l <= t.length ? a && (o + e[l - 1][0] + e[l - 1][1]) % t[l - 1] === 0 : a, !0), () => `input spatial dimensions ${n.shape.slice(1)} with paddings ${e.toString()} must be divisible by blockShapes ${t.toString()}`);
|
|
894
895
|
const i = { x: n }, r = { blockShape: t, paddings: e };
|
|
895
896
|
return M.runKernel(Mi, i, r);
|
|
896
897
|
}
|
|
897
|
-
const
|
|
898
|
+
const sh = /* @__PURE__ */ F({ spaceToBatchND_: nh });
|
|
898
899
|
/**
|
|
899
900
|
* @license
|
|
900
901
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -911,11 +912,11 @@ const nh = /* @__PURE__ */ F({ spaceToBatchND_: eh });
|
|
|
911
912
|
* limitations under the License.
|
|
912
913
|
* =============================================================================
|
|
913
914
|
*/
|
|
914
|
-
function
|
|
915
|
+
function ih(s, t) {
|
|
915
916
|
const n = { x: D(s, "x", "reverse") }, i = { dims: t };
|
|
916
917
|
return M.runKernel(Oi, n, i);
|
|
917
918
|
}
|
|
918
|
-
const on = /* @__PURE__ */ F({ reverse_:
|
|
919
|
+
const on = /* @__PURE__ */ F({ reverse_: ih });
|
|
919
920
|
/**
|
|
920
921
|
* @license
|
|
921
922
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -932,11 +933,11 @@ const on = /* @__PURE__ */ F({ reverse_: sh });
|
|
|
932
933
|
* limitations under the License.
|
|
933
934
|
* =============================================================================
|
|
934
935
|
*/
|
|
935
|
-
function
|
|
936
|
+
function rh(s) {
|
|
936
937
|
const e = { x: D(s, "x", "rsqrt", "float32") };
|
|
937
938
|
return M.runKernel(Ri, e);
|
|
938
939
|
}
|
|
939
|
-
const
|
|
940
|
+
const ah = /* @__PURE__ */ F({ rsqrt_: rh });
|
|
940
941
|
/**
|
|
941
942
|
* @license
|
|
942
943
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -953,12 +954,12 @@ const rh = /* @__PURE__ */ F({ rsqrt_: ih });
|
|
|
953
954
|
* limitations under the License.
|
|
954
955
|
* =============================================================================
|
|
955
956
|
*/
|
|
956
|
-
function
|
|
957
|
+
function oh(s) {
|
|
957
958
|
const e = { x: D(s, "x", "selu") };
|
|
958
959
|
return M.runKernel(_i, e);
|
|
959
960
|
}
|
|
960
|
-
const
|
|
961
|
-
function
|
|
961
|
+
const lh = /* @__PURE__ */ F({ selu_: oh });
|
|
962
|
+
function uh(s, t, e, n, i, r = [1, 1], a = "NHWC") {
|
|
962
963
|
const o = D(s, "x", "separableConv2d"), l = D(t, "depthwiseFilter", "separableConv2d"), u = D(e, "pointwiseFilter", "separableConv2d");
|
|
963
964
|
let c = o, h = !1;
|
|
964
965
|
if (o.rank === 3 && (h = !0, c = A(o, [1, o.shape[0], o.shape[1], o.shape[2]])), a === "NCHW")
|
|
@@ -969,7 +970,7 @@ function lh(s, t, e, n, i, r = [1, 1], a = "NHWC") {
|
|
|
969
970
|
const g = nr(c, l, n, i, a, r), m = Ce(g, u, 1, "valid", a);
|
|
970
971
|
return h ? A(m, [m.shape[1], m.shape[2], m.shape[3]]) : m;
|
|
971
972
|
}
|
|
972
|
-
const
|
|
973
|
+
const ch = /* @__PURE__ */ F({ separableConv2d_: uh });
|
|
973
974
|
/**
|
|
974
975
|
* @license
|
|
975
976
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -986,11 +987,11 @@ const uh = /* @__PURE__ */ F({ separableConv2d_: lh });
|
|
|
986
987
|
* limitations under the License.
|
|
987
988
|
* =============================================================================
|
|
988
989
|
*/
|
|
989
|
-
function
|
|
990
|
+
function hh(s) {
|
|
990
991
|
const e = { x: D(s, "x", "sinh") };
|
|
991
992
|
return M.runKernel(Bi, e);
|
|
992
993
|
}
|
|
993
|
-
const
|
|
994
|
+
const ph = /* @__PURE__ */ F({ sinh_: hh });
|
|
994
995
|
/**
|
|
995
996
|
* @license
|
|
996
997
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1007,15 +1008,15 @@ const hh = /* @__PURE__ */ F({ sinh_: ch });
|
|
|
1007
1008
|
* limitations under the License.
|
|
1008
1009
|
* =============================================================================
|
|
1009
1010
|
*/
|
|
1010
|
-
function
|
|
1011
|
-
if (
|
|
1011
|
+
function dh(s, t = 0, e = 1, n, i) {
|
|
1012
|
+
if (No(s), n != null && n === "bool")
|
|
1012
1013
|
throw new Error("Unsupported data type $ { dtype }");
|
|
1013
|
-
const r = new
|
|
1014
|
+
const r = new _u(t, e, n, !0, i), a = vo(s, n);
|
|
1014
1015
|
for (let o = 0; o < a.values.length; o++)
|
|
1015
1016
|
a.values[o] = r.nextValue();
|
|
1016
1017
|
return a.toTensor();
|
|
1017
1018
|
}
|
|
1018
|
-
const ir = /* @__PURE__ */ F({ truncatedNormal_:
|
|
1019
|
+
const ir = /* @__PURE__ */ F({ truncatedNormal_: dh });
|
|
1019
1020
|
/**
|
|
1020
1021
|
* @license
|
|
1021
1022
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1032,13 +1033,13 @@ const ir = /* @__PURE__ */ F({ truncatedNormal_: ph });
|
|
|
1032
1033
|
* limitations under the License.
|
|
1033
1034
|
* =============================================================================
|
|
1034
1035
|
*/
|
|
1035
|
-
function
|
|
1036
|
+
function fh(s, t, e) {
|
|
1036
1037
|
const n = D(s, "x", "unsortedSegmentSum"), i = D(t, "segmentIds", "unsortedSegmentSum", "int32");
|
|
1037
1038
|
k(Mn(e), () => "numSegments must be of dtype int");
|
|
1038
1039
|
const r = { x: n, segmentIds: i }, a = { numSegments: e };
|
|
1039
1040
|
return M.runKernel(Wi, r, a);
|
|
1040
1041
|
}
|
|
1041
|
-
const
|
|
1042
|
+
const mh = /* @__PURE__ */ F({ unsortedSegmentSum_: fh });
|
|
1042
1043
|
/**
|
|
1043
1044
|
* @license
|
|
1044
1045
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1055,7 +1056,7 @@ const fh = /* @__PURE__ */ F({ unsortedSegmentSum_: dh });
|
|
|
1055
1056
|
* limitations under the License.
|
|
1056
1057
|
* =============================================================================
|
|
1057
1058
|
*/
|
|
1058
|
-
function
|
|
1059
|
+
function gh(s, t, e, n, i, r = "NHWC", a) {
|
|
1059
1060
|
let o = s;
|
|
1060
1061
|
s.rank === 3 && (o = A(s, [1, s.shape[0], s.shape[1], s.shape[2]]));
|
|
1061
1062
|
let l = t;
|
|
@@ -1065,7 +1066,7 @@ function mh(s, t, e, n, i, r = "NHWC", a) {
|
|
|
1065
1066
|
const h = { x: o, dy: l }, p = { strides: n, pad: i, dataFormat: r, dimRoundingMode: a, filterShape: e };
|
|
1066
1067
|
return M.runKernel(So, h, p);
|
|
1067
1068
|
}
|
|
1068
|
-
const cs = /* @__PURE__ */ F({ conv2DBackpropFilter_:
|
|
1069
|
+
const cs = /* @__PURE__ */ F({ conv2DBackpropFilter_: gh });
|
|
1069
1070
|
/**
|
|
1070
1071
|
* @license
|
|
1071
1072
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
@@ -1082,7 +1083,7 @@ const cs = /* @__PURE__ */ F({ conv2DBackpropFilter_: mh });
|
|
|
1082
1083
|
* limitations under the License.
|
|
1083
1084
|
* =============================================================================
|
|
1084
1085
|
*/
|
|
1085
|
-
function
|
|
1086
|
+
function bh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilations: r = [1, 1], dimRoundingMode: a, bias: o, activation: l = "linear", preluActivationWeights: u, leakyreluAlpha: c }) {
|
|
1086
1087
|
if (l = l || "linear", ku(M.state.gradientDepth, l) === !1) {
|
|
1087
1088
|
k(i === "NHWC", () => `Error in fused conv2d: got dataFormat of ${i} but only NHWC is currently supported for the case of gradient depth is 0 and the activation is not linear.`);
|
|
1088
1089
|
let z = Ce(s, t, e, n, i, r, a);
|
|
@@ -1093,9 +1094,9 @@ function gh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilat
|
|
|
1093
1094
|
h.rank === 3 && (g = !0, f = A(h, [1, h.shape[0], h.shape[1], h.shape[2]])), k(f.rank === 4, () => `Error in fused conv2d: input must be rank 4, but got rank ${f.rank}.`), k(p.rank === 4, () => `Error in fused conv2d: filter must be rank 4, but got rank ${p.rank}.`), ft("fused conv2d", n, a);
|
|
1094
1095
|
const b = i === "NHWC" ? f.shape[3] : f.shape[1];
|
|
1095
1096
|
k(p.shape[2] === b, () => `Error in conv2d: depth of input (${b}) must match input depth for filter ${p.shape[2]}.`), k(de(e, r), () => `Error in conv2D: Either strides or dilations must be 1. Got strides ${e} and dilations '${r}'`);
|
|
1096
|
-
const m =
|
|
1097
|
-
let
|
|
1098
|
-
o != null && (
|
|
1097
|
+
const m = tc(f.shape, p.shape, e, r, n, a);
|
|
1098
|
+
let v;
|
|
1099
|
+
o != null && (v = D(o, "bias", "fused conv2d"), [v] = Hn(v, h), i === "NHWC" ? wt(m.outShape, v.shape) : (k(v.shape.length <= 1, () => `Error in fused conv2d: only supports scalar or 1-D Tensor bias for NCHW format but got the bias of rank-${v.shape.length}.`), k(v.shape.length === 0 || v.shape[0] === m.outChannels || v.shape[0] === 1, () => `Error in fused conv2d: bias shape (${v.shape}) is not compatible with the number of output channels (${m.outChannels})`)));
|
|
1099
1100
|
let y;
|
|
1100
1101
|
if (u != null) {
|
|
1101
1102
|
const z = u.shape;
|
|
@@ -1112,18 +1113,18 @@ function gh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilat
|
|
|
1112
1113
|
}
|
|
1113
1114
|
const C = (z, _) => {
|
|
1114
1115
|
k(i === "NHWC", () => `Error in gradient of fused conv2D: got dataFormat of ${i} but only NHWC is currently supported.`);
|
|
1115
|
-
const [T, E, R, q] = _, bt =
|
|
1116
|
+
const [T, E, R, q] = _, bt = Nu(z, R, l);
|
|
1116
1117
|
k(Se(r), () => `Error in gradient of fused conv2D: dilation rates greater than 1 are not yet supported in gradients. Got dilations '${r}'`);
|
|
1117
1118
|
const ie = ls(E.shape, bt, T, e, n), re = cs(E, bt, T.shape, e, n), xt = [ie, re];
|
|
1118
1119
|
if (q != null) {
|
|
1119
|
-
const Tt =
|
|
1120
|
+
const Tt = vu(q, bt);
|
|
1120
1121
|
xt.push(Tt);
|
|
1121
1122
|
}
|
|
1122
1123
|
return xt;
|
|
1123
|
-
},
|
|
1124
|
+
}, N = {
|
|
1124
1125
|
x: f,
|
|
1125
1126
|
filter: p,
|
|
1126
|
-
bias:
|
|
1127
|
+
bias: v,
|
|
1127
1128
|
preluActivationWeights: y
|
|
1128
1129
|
}, I = {
|
|
1129
1130
|
strides: e,
|
|
@@ -1137,15 +1138,15 @@ function gh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilat
|
|
|
1137
1138
|
return o == null ? On((_, T, E) => {
|
|
1138
1139
|
let R = (
|
|
1139
1140
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
1140
|
-
M.runKernel(Ls,
|
|
1141
|
+
M.runKernel(Ls, N, I)
|
|
1141
1142
|
);
|
|
1142
1143
|
return E([T, _, R]), g && (R = A(R, [R.shape[1], R.shape[2], R.shape[3]])), { value: R, gradFunc: C };
|
|
1143
1144
|
})(f, p) : On((_, T, E, R) => {
|
|
1144
|
-
let q = M.runKernel(Ls,
|
|
1145
|
+
let q = M.runKernel(Ls, N, I);
|
|
1145
1146
|
return R([T, _, q, E]), g && (q = A(q, [q.shape[1], q.shape[2], q.shape[3]])), { value: q, gradFunc: C };
|
|
1146
|
-
})(f, p,
|
|
1147
|
+
})(f, p, v);
|
|
1147
1148
|
}
|
|
1148
|
-
const
|
|
1149
|
+
const yh = /* @__PURE__ */ F({ fusedConv2d_: bh });
|
|
1149
1150
|
/**
|
|
1150
1151
|
* @license
|
|
1151
1152
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1162,7 +1163,7 @@ const bh = /* @__PURE__ */ F({ fusedConv2d_: gh });
|
|
|
1162
1163
|
* limitations under the License.
|
|
1163
1164
|
* =============================================================================
|
|
1164
1165
|
*/
|
|
1165
|
-
function
|
|
1166
|
+
function wh(s, t, e, n, i, r = [1, 1], a) {
|
|
1166
1167
|
let o = s;
|
|
1167
1168
|
s.rank === 3 && (o = A(s, [1, s.shape[0], s.shape[1], s.shape[2]]));
|
|
1168
1169
|
let l = t;
|
|
@@ -1170,7 +1171,7 @@ function yh(s, t, e, n, i, r = [1, 1], a) {
|
|
|
1170
1171
|
const u = { x: o, dy: l }, c = { strides: n, pad: i, dimRoundingMode: a, dilations: r, filterShape: e };
|
|
1171
1172
|
return M.runKernel(Ao, u, c);
|
|
1172
1173
|
}
|
|
1173
|
-
const
|
|
1174
|
+
const kh = F({ depthwiseConv2dNativeBackpropFilter_: wh });
|
|
1174
1175
|
/**
|
|
1175
1176
|
* @license
|
|
1176
1177
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1187,7 +1188,7 @@ const wh = F({ depthwiseConv2dNativeBackpropFilter_: yh });
|
|
|
1187
1188
|
* limitations under the License.
|
|
1188
1189
|
* =============================================================================
|
|
1189
1190
|
*/
|
|
1190
|
-
function
|
|
1191
|
+
function xh(s, t, e, n, i, r = [1, 1], a) {
|
|
1191
1192
|
let o = t, l = !1;
|
|
1192
1193
|
t.rank === 3 && (l = !0, o = A(t, [1, t.shape[0], t.shape[1], t.shape[2]]));
|
|
1193
1194
|
const u = { dy: o, filter: e }, c = { strides: n, pad: i, dimRoundingMode: a, dilations: r, inputShape: s }, h = (
|
|
@@ -1196,7 +1197,7 @@ function kh(s, t, e, n, i, r = [1, 1], a) {
|
|
|
1196
1197
|
);
|
|
1197
1198
|
return l ? A(h, [h.shape[1], h.shape[2], h.shape[3]]) : h;
|
|
1198
1199
|
}
|
|
1199
|
-
const
|
|
1200
|
+
const Nh = F({ depthwiseConv2dNativeBackpropInput_: xh });
|
|
1200
1201
|
/**
|
|
1201
1202
|
* @license
|
|
1202
1203
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -1394,9 +1395,9 @@ const ge = vh;
|
|
|
1394
1395
|
* limitations under the License.
|
|
1395
1396
|
* =============================================================================
|
|
1396
1397
|
*/
|
|
1397
|
-
const
|
|
1398
|
-
function
|
|
1399
|
-
return new Promise((s) =>
|
|
1398
|
+
const Sh = typeof requestAnimationFrame < "u" ? requestAnimationFrame : typeof setImmediate < "u" ? setImmediate : (s) => s();
|
|
1399
|
+
function Ah() {
|
|
1400
|
+
return new Promise((s) => Sh(() => s()));
|
|
1400
1401
|
}
|
|
1401
1402
|
/**
|
|
1402
1403
|
* @license
|
|
@@ -1414,7 +1415,7 @@ function Sh() {
|
|
|
1414
1415
|
* limitations under the License.
|
|
1415
1416
|
* =============================================================================
|
|
1416
1417
|
*/
|
|
1417
|
-
const
|
|
1418
|
+
const Ch = 1.7580993408473768, Ih = 1.0507009873554805;
|
|
1418
1419
|
/**
|
|
1419
1420
|
* @license
|
|
1420
1421
|
* Copyright 2022 Google LLC
|
|
@@ -1478,9 +1479,9 @@ class rr {
|
|
|
1478
1479
|
* https://opensource.org/licenses/MIT.
|
|
1479
1480
|
* =============================================================================
|
|
1480
1481
|
*/
|
|
1481
|
-
let
|
|
1482
|
+
let Dh = 0;
|
|
1482
1483
|
function ar() {
|
|
1483
|
-
return
|
|
1484
|
+
return Dh++;
|
|
1484
1485
|
}
|
|
1485
1486
|
const Qe = {};
|
|
1486
1487
|
function xn(s = "") {
|
|
@@ -1495,7 +1496,7 @@ function xn(s = "") {
|
|
|
1495
1496
|
* https://opensource.org/licenses/MIT.
|
|
1496
1497
|
* =============================================================================
|
|
1497
1498
|
*/
|
|
1498
|
-
const
|
|
1499
|
+
const zh = ["fanIn", "fanOut", "fanAvg"], Th = ["normal", "uniform", "truncatedNormal"];
|
|
1499
1500
|
/**
|
|
1500
1501
|
* @license
|
|
1501
1502
|
* Copyright 2018 Google LLC
|
|
@@ -1505,11 +1506,11 @@ const Dh = ["fanIn", "fanOut", "fanAvg"], zh = ["normal", "uniform", "truncatedN
|
|
|
1505
1506
|
* https://opensource.org/licenses/MIT.
|
|
1506
1507
|
* =============================================================================
|
|
1507
1508
|
*/
|
|
1508
|
-
function Th(s) {
|
|
1509
|
-
Zn(Dh, "FanMode", s);
|
|
1510
|
-
}
|
|
1511
1509
|
function $h(s) {
|
|
1512
|
-
Zn(zh, "
|
|
1510
|
+
Zn(zh, "FanMode", s);
|
|
1511
|
+
}
|
|
1512
|
+
function Eh(s) {
|
|
1513
|
+
Zn(Th, "Distribution", s);
|
|
1513
1514
|
}
|
|
1514
1515
|
class kt extends Be {
|
|
1515
1516
|
fromConfigUsesCustomObjects() {
|
|
@@ -1603,7 +1604,7 @@ class hr extends kt {
|
|
|
1603
1604
|
return x(() => {
|
|
1604
1605
|
if (t.length !== 2 || t[0] !== t[1])
|
|
1605
1606
|
throw new d("Identity matrix initializer can only be used for 2D square matrices.");
|
|
1606
|
-
return w(this.gain,
|
|
1607
|
+
return w(this.gain, Bu(t[0]));
|
|
1607
1608
|
});
|
|
1608
1609
|
}
|
|
1609
1610
|
getConfig() {
|
|
@@ -1612,7 +1613,7 @@ class hr extends kt {
|
|
|
1612
1613
|
}
|
|
1613
1614
|
hr.className = "Identity";
|
|
1614
1615
|
S(hr);
|
|
1615
|
-
function
|
|
1616
|
+
function Lh(s, t = "channelsLast") {
|
|
1616
1617
|
let e, n;
|
|
1617
1618
|
if (et(t), s.length === 2)
|
|
1618
1619
|
e = s[0], n = s[1];
|
|
@@ -1638,10 +1639,10 @@ class dt extends kt {
|
|
|
1638
1639
|
constructor(t) {
|
|
1639
1640
|
if (super(), t.scale < 0)
|
|
1640
1641
|
throw new d(`scale must be a positive float. Got: ${t.scale}`);
|
|
1641
|
-
this.scale = t.scale == null ? 1 : t.scale, this.mode = t.mode == null ? "fanIn" : t.mode,
|
|
1642
|
+
this.scale = t.scale == null ? 1 : t.scale, this.mode = t.mode == null ? "fanIn" : t.mode, $h(this.mode), this.distribution = t.distribution == null ? "normal" : t.distribution, Eh(this.distribution), this.seed = t.seed;
|
|
1642
1643
|
}
|
|
1643
1644
|
apply(t, e) {
|
|
1644
|
-
const n =
|
|
1645
|
+
const n = Lh(t), i = n[0], r = n[1];
|
|
1645
1646
|
let a = this.scale;
|
|
1646
1647
|
if (this.mode === "fanIn" ? a /= Math.max(1, i) : this.mode === "fanOut" ? a /= Math.max(1, r) : a /= Math.max(1, (i + r) / 2), this.distribution === "normal") {
|
|
1647
1648
|
const o = Math.sqrt(a);
|
|
@@ -1781,7 +1782,7 @@ class pr extends kt {
|
|
|
1781
1782
|
e = e;
|
|
1782
1783
|
const n = Gi(t.slice(0, -1)), i = t[t.length - 1], r = n * i;
|
|
1783
1784
|
r > this.ELEMENTS_WARN_SLOW && console.warn(`Orthogonal initializer is being called on a matrix with more than ${this.ELEMENTS_WARN_SLOW} (${r}) elements: Slowness may result.`);
|
|
1784
|
-
const a = [Math.max(i, n), Math.min(i, n)], o = bn(a, 0, 1, e, this.seed), l =
|
|
1785
|
+
const a = [Math.max(i, n), Math.min(i, n)], o = bn(a, 0, 1, e, this.seed), l = Wu.qr(o, !1);
|
|
1785
1786
|
let u = l[0];
|
|
1786
1787
|
const h = l[1].flatten().stridedSlice([0], [Math.min(i, n) * Math.min(i, n)], [Math.min(i, n) + 1]);
|
|
1787
1788
|
return u = w(u, h.sign()), n < i && (u = u.transpose()), w(tt(this.gain), u.reshape(t));
|
|
@@ -1898,7 +1899,7 @@ function un(s) {
|
|
|
1898
1899
|
* =============================================================================
|
|
1899
1900
|
*/
|
|
1900
1901
|
const Ks = "Variable";
|
|
1901
|
-
class
|
|
1902
|
+
class Fh {
|
|
1902
1903
|
/**
|
|
1903
1904
|
* Construct Variable from a `tf.Tensor`.
|
|
1904
1905
|
*
|
|
@@ -1934,7 +1935,7 @@ class Lh {
|
|
|
1934
1935
|
* @return This Variable.
|
|
1935
1936
|
*/
|
|
1936
1937
|
write(t) {
|
|
1937
|
-
return this.assertNotDisposed(),
|
|
1938
|
+
return this.assertNotDisposed(), Mh(this.val, t), this.val.id !== t.id && (this.val.assign(t), this.constraint != null && this.val.assign(this.constraint.apply(this.val))), this;
|
|
1938
1939
|
}
|
|
1939
1940
|
/**
|
|
1940
1941
|
* Dispose this LayersVariable instance from memory.
|
|
@@ -1953,7 +1954,7 @@ class Lh {
|
|
|
1953
1954
|
this.trainable_ = t, this.val.trainable = t;
|
|
1954
1955
|
}
|
|
1955
1956
|
}
|
|
1956
|
-
function
|
|
1957
|
+
function Mh(s, t) {
|
|
1957
1958
|
if (s.shape.toString() !== t.shape.toString())
|
|
1958
1959
|
throw new Error("Shape mismatch: " + JSON.stringify(s.shape) + " vs. " + JSON.stringify(t.shape));
|
|
1959
1960
|
}
|
|
@@ -1997,10 +1998,10 @@ class Mt {
|
|
|
1997
1998
|
this.dtype = t, this.shape = e, this.sourceLayer = n, this.inputs = i, this.callArgs = r, this.outputTensorIndex = o, this.id = ar(), a != null && (this.originalName = Pi(a), this.name = Ui(this.originalName)), this.rank = e.length;
|
|
1998
1999
|
}
|
|
1999
2000
|
}
|
|
2000
|
-
let
|
|
2001
|
-
class
|
|
2001
|
+
let Oh = 0;
|
|
2002
|
+
class Nn {
|
|
2002
2003
|
constructor(t, e) {
|
|
2003
|
-
this.callArgs = e, this.id =
|
|
2004
|
+
this.callArgs = e, this.id = Oh++, this.outboundLayer = t.outboundLayer, this.inboundLayers = t.inboundLayers, this.nodeIndices = t.nodeIndices, this.tensorIndices = t.tensorIndices, this.inputTensors = t.inputTensors, this.outputTensors = t.outputTensors, this.inputMasks = t.inputMasks, this.outputMasks = t.outputMasks, this.inputShapes = t.inputShapes, this.outputShapes = t.outputShapes;
|
|
2004
2005
|
for (const n of t.inboundLayers)
|
|
2005
2006
|
n?.outboundNodes.push(this);
|
|
2006
2007
|
t.outboundLayer.inboundNodes.push(this);
|
|
@@ -2017,10 +2018,10 @@ class vn {
|
|
|
2017
2018
|
};
|
|
2018
2019
|
}
|
|
2019
2020
|
}
|
|
2020
|
-
let
|
|
2021
|
+
let Rh = 0;
|
|
2021
2022
|
class W extends Be {
|
|
2022
2023
|
constructor(t = {}) {
|
|
2023
|
-
super(), this._callHook = null, this._addedWeightNames = [], this._stateful = !1, this.id =
|
|
2024
|
+
super(), this._callHook = null, this._addedWeightNames = [], this._stateful = !1, this.id = Rh++, this.activityRegularizer = null, this.inputSpec = null, this.supportsMasking = !1, this._trainableWeights = [], this._nonTrainableWeights = [], this._losses = [], this._updates = [], this._built = !1, this.inboundNodes = [], this.outboundNodes = [];
|
|
2024
2025
|
let e = t.name;
|
|
2025
2026
|
if (!e) {
|
|
2026
2027
|
const n = this.getClassName();
|
|
@@ -2334,7 +2335,7 @@ class W extends Be {
|
|
|
2334
2335
|
// Porting Note: This is a replacement for __call__() in Python.
|
|
2335
2336
|
apply(t, e) {
|
|
2336
2337
|
e = e || {}, this.assertNotDisposed();
|
|
2337
|
-
const n = K(t), i =
|
|
2338
|
+
const n = K(t), i = Wh(t), r = Gh(t);
|
|
2338
2339
|
if (i === r)
|
|
2339
2340
|
throw new d("Arguments to apply() must be all SymbolicTensors or all Tensors");
|
|
2340
2341
|
return le(this.name, () => {
|
|
@@ -2355,9 +2356,9 @@ class W extends Be {
|
|
|
2355
2356
|
throw new G("Layer invocation in the presence of activity regularizer(s) is not supported yet.");
|
|
2356
2357
|
return a;
|
|
2357
2358
|
} else {
|
|
2358
|
-
const a =
|
|
2359
|
+
const a = _h(t), o = this.computeOutputShape(a);
|
|
2359
2360
|
let l;
|
|
2360
|
-
const u =
|
|
2361
|
+
const u = Bh(t);
|
|
2361
2362
|
if (this.warnOnIncompatibleInputShape(Array.isArray(t) ? a[0] : a), o != null && o.length > 0 && Array.isArray(o[0]) ? l = o.map((c, h) => new Mt(u, c, this, K(t), e, this.name, h)) : l = new Mt(u, o, this, K(t), e, this.name), this.addInboundNode(t, l, null, null, a, o, e), this._refCount++, this.activityRegularizer != null)
|
|
2362
2363
|
throw new G("Layer invocation in the presence of activity regularizer(s) is not supported yet.");
|
|
2363
2364
|
return l;
|
|
@@ -2496,7 +2497,7 @@ class W extends Be {
|
|
|
2496
2497
|
if (this._addedWeightNames.indexOf(t) !== -1)
|
|
2497
2498
|
throw new d(`Duplicate weight name ${t} for layer ${this.name}`);
|
|
2498
2499
|
this._addedWeightNames.push(t), n == null && (n = "float32"), this.fastWeightInitDuringBuild && (i = l != null ? l() : J("zeros"));
|
|
2499
|
-
const u = i.apply(e, n), c = new
|
|
2500
|
+
const u = i.apply(e, n), c = new Fh(u, n, t, a, o);
|
|
2500
2501
|
return u.dispose(), r != null && this.addLoss(() => r.apply(c.read())), a == null && (a = !0), a ? this._trainableWeights.push(c) : this._nonTrainableWeights.push(c), c;
|
|
2501
2502
|
}
|
|
2502
2503
|
/**
|
|
@@ -2587,7 +2588,7 @@ class W extends Be {
|
|
|
2587
2588
|
const u = [], c = [], h = [];
|
|
2588
2589
|
for (const p of l)
|
|
2589
2590
|
u.push(p.sourceLayer), c.push(p.nodeIndex), h.push(p.tensorIndex);
|
|
2590
|
-
new
|
|
2591
|
+
new Nn({
|
|
2591
2592
|
outboundLayer: this,
|
|
2592
2593
|
inboundLayers: u,
|
|
2593
2594
|
nodeIndices: c,
|
|
@@ -2679,14 +2680,14 @@ class W extends Be {
|
|
|
2679
2680
|
return --this._refCount === 0 && (t = this.disposeWeights()), { refCountAfterDispose: this._refCount, numDisposedVariables: t };
|
|
2680
2681
|
}
|
|
2681
2682
|
}
|
|
2682
|
-
function
|
|
2683
|
+
function _h(s) {
|
|
2683
2684
|
s = K(s);
|
|
2684
2685
|
const t = [];
|
|
2685
2686
|
for (const e of s)
|
|
2686
2687
|
t.push(e.shape);
|
|
2687
2688
|
return ht(t);
|
|
2688
2689
|
}
|
|
2689
|
-
function
|
|
2690
|
+
function Bh(s) {
|
|
2690
2691
|
return "float32";
|
|
2691
2692
|
}
|
|
2692
2693
|
function dr(s, t, e) {
|
|
@@ -2707,7 +2708,7 @@ function dr(s, t, e) {
|
|
|
2707
2708
|
}
|
|
2708
2709
|
}
|
|
2709
2710
|
}
|
|
2710
|
-
function
|
|
2711
|
+
function Wh(s) {
|
|
2711
2712
|
let t = !0;
|
|
2712
2713
|
for (const e of K(s))
|
|
2713
2714
|
if (!(e instanceof Mt)) {
|
|
@@ -2716,7 +2717,7 @@ function Bh(s) {
|
|
|
2716
2717
|
}
|
|
2717
2718
|
return t;
|
|
2718
2719
|
}
|
|
2719
|
-
function
|
|
2720
|
+
function Gh(s) {
|
|
2720
2721
|
let t = !0;
|
|
2721
2722
|
for (const e of K(s))
|
|
2722
2723
|
if (e instanceof Mt) {
|
|
@@ -2751,7 +2752,7 @@ class He extends W {
|
|
|
2751
2752
|
const n = t.dtype || "float32";
|
|
2752
2753
|
this.batchInputShape = e, this.dtype = n, this.inputSpec = [{ shape: e }];
|
|
2753
2754
|
const i = new Mt(this.dtype, this.batchInputShape, this, [], {}, this.name);
|
|
2754
|
-
i.nodeIndex = 0, i.tensorIndex = 0, new
|
|
2755
|
+
i.nodeIndex = 0, i.tensorIndex = 0, new Nn({
|
|
2755
2756
|
outboundLayer: this,
|
|
2756
2757
|
inboundLayers: [],
|
|
2757
2758
|
nodeIndices: [],
|
|
@@ -2781,7 +2782,7 @@ class He extends W {
|
|
|
2781
2782
|
}
|
|
2782
2783
|
He.className = "InputLayer";
|
|
2783
2784
|
S(He);
|
|
2784
|
-
function
|
|
2785
|
+
function Ph(s) {
|
|
2785
2786
|
if (s.batchShape == null && s.shape == null)
|
|
2786
2787
|
throw new Error("Please provide to Input either a `shape` or a `batchShape` argument. Note that `shape` does not include the batch dimension.");
|
|
2787
2788
|
if (s.batchShape != null && s.shape != null)
|
|
@@ -2805,7 +2806,7 @@ function Gh(s) {
|
|
|
2805
2806
|
* https://opensource.org/licenses/MIT.
|
|
2806
2807
|
* =============================================================================
|
|
2807
2808
|
*/
|
|
2808
|
-
function
|
|
2809
|
+
function Uh(s, t) {
|
|
2809
2810
|
if (s.dtype == null || s.dtype === t.dtype)
|
|
2810
2811
|
return t;
|
|
2811
2812
|
try {
|
|
@@ -2843,7 +2844,7 @@ class Vt {
|
|
|
2843
2844
|
*/
|
|
2844
2845
|
add(t, e, n) {
|
|
2845
2846
|
if (this.id2Value[t.id] == null)
|
|
2846
|
-
this.id2Value[t.id] =
|
|
2847
|
+
this.id2Value[t.id] = Uh(t, e), this.name2Id[t.name] = t.id, n != null && (this.id2Mask[t.id] = n);
|
|
2847
2848
|
else
|
|
2848
2849
|
throw new d(`Duplicate key: name=${t.name}, id=${t.id}`);
|
|
2849
2850
|
return this;
|
|
@@ -2913,7 +2914,7 @@ class Vt {
|
|
|
2913
2914
|
}
|
|
2914
2915
|
}
|
|
2915
2916
|
const cn = new rr(), hn = new rr();
|
|
2916
|
-
function
|
|
2917
|
+
function Vh(s) {
|
|
2917
2918
|
cn?.setMaxEntries(s), hn?.setMaxEntries(s);
|
|
2918
2919
|
}
|
|
2919
2920
|
function Le(s, t, e, n) {
|
|
@@ -2923,7 +2924,7 @@ function Le(s, t, e, n) {
|
|
|
2923
2924
|
const c = o.join(",") + "|" + t.names().sort().join(",");
|
|
2924
2925
|
let h = cn.get(c), p;
|
|
2925
2926
|
if (h == null) {
|
|
2926
|
-
const g =
|
|
2927
|
+
const g = jh(a, t);
|
|
2927
2928
|
h = g.sorted, p = g.recipientCounts, cn.put(c, h), hn.put(c, p);
|
|
2928
2929
|
}
|
|
2929
2930
|
p = {}, i || Object.assign(p, hn.get(c));
|
|
@@ -2932,17 +2933,17 @@ function Le(s, t, e, n) {
|
|
|
2932
2933
|
const b = h[g], m = b.sourceLayer;
|
|
2933
2934
|
if (m instanceof He)
|
|
2934
2935
|
continue;
|
|
2935
|
-
const
|
|
2936
|
-
let
|
|
2936
|
+
const v = [], y = [], C = [];
|
|
2937
|
+
let N = !1;
|
|
2937
2938
|
for (const E of b.inputs) {
|
|
2938
2939
|
const R = f.getValue(E), q = f.getMask(E);
|
|
2939
|
-
|
|
2940
|
+
v.push(R), y.push(q), q != null && (N = !0), i || (p[E.name]--, p[E.name] === 0 && !t.hasKey(E) && o.indexOf(E.name) === -1 && !R.isDisposed && E.sourceLayer.stateful !== !0 && C.push(R));
|
|
2940
2941
|
}
|
|
2941
|
-
|
|
2942
|
-
const I = K(m.apply(
|
|
2942
|
+
N && (e = e || {}, e.mask = y[0]);
|
|
2943
|
+
const I = K(m.apply(v, e));
|
|
2943
2944
|
let z = null;
|
|
2944
|
-
m.supportsMasking && (z = m.computeMask(
|
|
2945
|
-
const _ =
|
|
2945
|
+
m.supportsMasking && (z = m.computeMask(v, y));
|
|
2946
|
+
const _ = Hh(b), T = Array.isArray(_) ? _ : [_];
|
|
2946
2947
|
for (let E = 0; E < T.length; ++E) {
|
|
2947
2948
|
f.hasKey(T[E]) || f.add(T[E], I[E], Array.isArray(z) ? z[0] : z);
|
|
2948
2949
|
const R = o.indexOf(T[E].name);
|
|
@@ -2952,7 +2953,7 @@ function Le(s, t, e, n) {
|
|
|
2952
2953
|
}
|
|
2953
2954
|
return f.disposeMasks(), r ? l : l[0];
|
|
2954
2955
|
}
|
|
2955
|
-
function
|
|
2956
|
+
function jh(s, t) {
|
|
2956
2957
|
k(s != null && s.length > 0, () => "Expected at least one fetch, got none");
|
|
2957
2958
|
let e = [], n = {};
|
|
2958
2959
|
if (s.length === 1) {
|
|
@@ -2970,10 +2971,10 @@ function Vh(s, t) {
|
|
|
2970
2971
|
}
|
|
2971
2972
|
return {
|
|
2972
2973
|
sorted: e,
|
|
2973
|
-
recipientCounts:
|
|
2974
|
+
recipientCounts: Kh(n)
|
|
2974
2975
|
};
|
|
2975
2976
|
}
|
|
2976
|
-
function
|
|
2977
|
+
function Kh(s) {
|
|
2977
2978
|
const t = {};
|
|
2978
2979
|
for (const e in s)
|
|
2979
2980
|
t[e] = s[e].size;
|
|
@@ -3001,7 +3002,7 @@ function Hs(s, t) {
|
|
|
3001
3002
|
}
|
|
3002
3003
|
return { sorted: n, recipientMap: i };
|
|
3003
3004
|
}
|
|
3004
|
-
function
|
|
3005
|
+
function Hh(s) {
|
|
3005
3006
|
let t;
|
|
3006
3007
|
if (s.sourceLayer.inboundNodes.length === 1)
|
|
3007
3008
|
t = s.sourceLayer.output;
|
|
@@ -3033,8 +3034,8 @@ function Kh(s) {
|
|
|
3033
3034
|
* limitations under the License.
|
|
3034
3035
|
* =============================================================================
|
|
3035
3036
|
*/
|
|
3036
|
-
const
|
|
3037
|
-
|
|
3037
|
+
const qh = Fo();
|
|
3038
|
+
qh.registerFlag("TOPOLOGICAL_SORT_CACHE_MAX_ENTRIES", () => 100, Vh);
|
|
3038
3039
|
/**
|
|
3039
3040
|
* @license
|
|
3040
3041
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3075,7 +3076,7 @@ const fr = {
|
|
|
3075
3076
|
* limitations under the License.
|
|
3076
3077
|
* =============================================================================
|
|
3077
3078
|
*/
|
|
3078
|
-
const
|
|
3079
|
+
const Zh = {
|
|
3079
3080
|
kernelName: Oo,
|
|
3080
3081
|
inputsToSave: ["x"],
|
|
3081
3082
|
gradFunc: (s, t) => {
|
|
@@ -3104,7 +3105,7 @@ const qh = {
|
|
|
3104
3105
|
* limitations under the License.
|
|
3105
3106
|
* =============================================================================
|
|
3106
3107
|
*/
|
|
3107
|
-
const
|
|
3108
|
+
const Jh = {
|
|
3108
3109
|
kernelName: Ro,
|
|
3109
3110
|
inputsToSave: ["x"],
|
|
3110
3111
|
gradFunc: (s, t) => {
|
|
@@ -3133,7 +3134,7 @@ const Zh = {
|
|
|
3133
3134
|
* limitations under the License.
|
|
3134
3135
|
* =============================================================================
|
|
3135
3136
|
*/
|
|
3136
|
-
const
|
|
3137
|
+
const Xh = {
|
|
3137
3138
|
kernelName: _o,
|
|
3138
3139
|
inputsToSave: ["a", "b"],
|
|
3139
3140
|
gradFunc: (s, t) => {
|
|
@@ -3165,7 +3166,7 @@ const Jh = {
|
|
|
3165
3166
|
* limitations under the License.
|
|
3166
3167
|
* =============================================================================
|
|
3167
3168
|
*/
|
|
3168
|
-
const
|
|
3169
|
+
const Yh = {
|
|
3169
3170
|
kernelName: Bo,
|
|
3170
3171
|
saveAllInputs: !0,
|
|
3171
3172
|
gradFunc: (s, t) => {
|
|
@@ -3191,7 +3192,7 @@ const Xh = {
|
|
|
3191
3192
|
* limitations under the License.
|
|
3192
3193
|
* =============================================================================
|
|
3193
3194
|
*/
|
|
3194
|
-
const
|
|
3195
|
+
const Qh = {
|
|
3195
3196
|
kernelName: mi,
|
|
3196
3197
|
inputsToSave: ["x"],
|
|
3197
3198
|
gradFunc: (s, t) => {
|
|
@@ -3215,7 +3216,7 @@ const Yh = {
|
|
|
3215
3216
|
* limitations under the License.
|
|
3216
3217
|
* =============================================================================
|
|
3217
3218
|
*/
|
|
3218
|
-
const
|
|
3219
|
+
const tp = {
|
|
3219
3220
|
kernelName: Wo,
|
|
3220
3221
|
inputsToSave: ["x"],
|
|
3221
3222
|
gradFunc: (s, t) => {
|
|
@@ -3239,7 +3240,7 @@ const Qh = {
|
|
|
3239
3240
|
* limitations under the License.
|
|
3240
3241
|
* =============================================================================
|
|
3241
3242
|
*/
|
|
3242
|
-
const
|
|
3243
|
+
const ep = {
|
|
3243
3244
|
kernelName: Go,
|
|
3244
3245
|
inputsToSave: ["x"],
|
|
3245
3246
|
gradFunc: (s, t) => {
|
|
@@ -3263,7 +3264,7 @@ const tp = {
|
|
|
3263
3264
|
* limitations under the License.
|
|
3264
3265
|
* =============================================================================
|
|
3265
3266
|
*/
|
|
3266
|
-
const
|
|
3267
|
+
const np = {
|
|
3267
3268
|
kernelName: Po,
|
|
3268
3269
|
inputsToSave: ["x"],
|
|
3269
3270
|
gradFunc: (s, t) => {
|
|
@@ -3292,7 +3293,7 @@ const ep = {
|
|
|
3292
3293
|
* limitations under the License.
|
|
3293
3294
|
* =============================================================================
|
|
3294
3295
|
*/
|
|
3295
|
-
const
|
|
3296
|
+
const sp = {
|
|
3296
3297
|
kernelName: Uo,
|
|
3297
3298
|
inputsToSave: ["a", "b"],
|
|
3298
3299
|
gradFunc: (s, t) => {
|
|
@@ -3326,7 +3327,7 @@ const np = {
|
|
|
3326
3327
|
* limitations under the License.
|
|
3327
3328
|
* =============================================================================
|
|
3328
3329
|
*/
|
|
3329
|
-
const
|
|
3330
|
+
const ip = {
|
|
3330
3331
|
kernelName: Vo,
|
|
3331
3332
|
inputsToSave: ["x"],
|
|
3332
3333
|
gradFunc: (s, t) => {
|
|
@@ -3350,7 +3351,7 @@ const sp = {
|
|
|
3350
3351
|
* limitations under the License.
|
|
3351
3352
|
* =============================================================================
|
|
3352
3353
|
*/
|
|
3353
|
-
const
|
|
3354
|
+
const rp = {
|
|
3354
3355
|
kernelName: jo,
|
|
3355
3356
|
inputsToSave: ["x"],
|
|
3356
3357
|
gradFunc: (s, t) => {
|
|
@@ -3374,7 +3375,7 @@ const ip = {
|
|
|
3374
3375
|
* limitations under the License.
|
|
3375
3376
|
* =============================================================================
|
|
3376
3377
|
*/
|
|
3377
|
-
function
|
|
3378
|
+
function ap(s, t, e, n, i, r) {
|
|
3378
3379
|
const a = D(s, "dy", "avgPool3dGrad"), o = D(t, "input", "avgPool3dGrad");
|
|
3379
3380
|
let l = a, u = o, c = !1;
|
|
3380
3381
|
o.rank === 4 && (c = !0, l = A(a, [1, a.shape[0], a.shape[1], a.shape[2], a.shape[3]]), u = A(o, [
|
|
@@ -3387,7 +3388,7 @@ function rp(s, t, e, n, i, r) {
|
|
|
3387
3388
|
const h = { dy: l, input: u }, p = { filterSize: e, strides: n, pad: i, dimRoundingMode: r }, f = M.runKernel(Ko, h, p);
|
|
3388
3389
|
return c ? A(f, [f.shape[1], f.shape[2], f.shape[3], f.shape[4]]) : f;
|
|
3389
3390
|
}
|
|
3390
|
-
const
|
|
3391
|
+
const op = /* @__PURE__ */ F({ avgPool3dGrad_: ap });
|
|
3391
3392
|
/**
|
|
3392
3393
|
* @license
|
|
3393
3394
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3404,13 +3405,13 @@ const ap = /* @__PURE__ */ F({ avgPool3dGrad_: rp });
|
|
|
3404
3405
|
* limitations under the License.
|
|
3405
3406
|
* =============================================================================
|
|
3406
3407
|
*/
|
|
3407
|
-
const
|
|
3408
|
+
const lp = {
|
|
3408
3409
|
kernelName: bi,
|
|
3409
3410
|
inputsToSave: ["x"],
|
|
3410
3411
|
gradFunc: (s, t, e) => {
|
|
3411
3412
|
const [n] = t, { filterSize: i, strides: r, pad: a, dimRoundingMode: o } = e;
|
|
3412
3413
|
return {
|
|
3413
|
-
x: () =>
|
|
3414
|
+
x: () => op(s, n, i, r, a, o)
|
|
3414
3415
|
};
|
|
3415
3416
|
}
|
|
3416
3417
|
};
|
|
@@ -3430,7 +3431,7 @@ const op = {
|
|
|
3430
3431
|
* limitations under the License.
|
|
3431
3432
|
* =============================================================================
|
|
3432
3433
|
*/
|
|
3433
|
-
function
|
|
3434
|
+
function up(s, t, e, n, i) {
|
|
3434
3435
|
const r = D(s, "dy", "avgPoolGrad"), a = D(t, "input", "avgPoolGrad");
|
|
3435
3436
|
k(a.rank === r.rank, () => `Rank of input (${a.rank}) does not match rank of dy (${r.rank})`);
|
|
3436
3437
|
let o = a, l = r, u = !1;
|
|
@@ -3438,7 +3439,7 @@ function lp(s, t, e, n, i) {
|
|
|
3438
3439
|
const c = { dy: l, input: o }, h = { filterSize: e, strides: n, pad: i }, p = M.runKernel(Ho, c, h);
|
|
3439
3440
|
return u ? A(p, [p.shape[1], p.shape[2], p.shape[3]]) : p;
|
|
3440
3441
|
}
|
|
3441
|
-
const
|
|
3442
|
+
const cp = /* @__PURE__ */ F({ avgPoolGrad_: up });
|
|
3442
3443
|
/**
|
|
3443
3444
|
* @license
|
|
3444
3445
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3455,12 +3456,12 @@ const up = /* @__PURE__ */ F({ avgPoolGrad_: lp });
|
|
|
3455
3456
|
* limitations under the License.
|
|
3456
3457
|
* =============================================================================
|
|
3457
3458
|
*/
|
|
3458
|
-
const
|
|
3459
|
+
const hp = {
|
|
3459
3460
|
kernelName: gi,
|
|
3460
3461
|
inputsToSave: ["x"],
|
|
3461
3462
|
gradFunc: (s, t, e) => {
|
|
3462
3463
|
const [n] = t, { filterSize: i, strides: r, pad: a } = e;
|
|
3463
|
-
return { x: () =>
|
|
3464
|
+
return { x: () => cp(s, n, i, r, a) };
|
|
3464
3465
|
}
|
|
3465
3466
|
};
|
|
3466
3467
|
/**
|
|
@@ -3479,7 +3480,7 @@ const cp = {
|
|
|
3479
3480
|
* limitations under the License.
|
|
3480
3481
|
* =============================================================================
|
|
3481
3482
|
*/
|
|
3482
|
-
const
|
|
3483
|
+
const pp = {
|
|
3483
3484
|
kernelName: qo,
|
|
3484
3485
|
inputsToSave: ["a", "b"],
|
|
3485
3486
|
gradFunc: (s, t, e) => {
|
|
@@ -3515,11 +3516,11 @@ const hp = {
|
|
|
3515
3516
|
* limitations under the License.
|
|
3516
3517
|
* =============================================================================
|
|
3517
3518
|
*/
|
|
3518
|
-
const
|
|
3519
|
+
const dp = {
|
|
3519
3520
|
kernelName: wi,
|
|
3520
3521
|
gradFunc: (s, t, e) => {
|
|
3521
3522
|
const { blockShape: n, crops: i } = e;
|
|
3522
|
-
return { x: () =>
|
|
3523
|
+
return { x: () => sh(s, n, i) };
|
|
3523
3524
|
}
|
|
3524
3525
|
};
|
|
3525
3526
|
/**
|
|
@@ -3538,7 +3539,7 @@ const pp = {
|
|
|
3538
3539
|
* limitations under the License.
|
|
3539
3540
|
* =============================================================================
|
|
3540
3541
|
*/
|
|
3541
|
-
const
|
|
3542
|
+
const fp = {
|
|
3542
3543
|
kernelName: Zo,
|
|
3543
3544
|
gradFunc: (s, t, e) => {
|
|
3544
3545
|
const n = e, i = n.inputShape, r = n.shape, a = Array.from(r);
|
|
@@ -3574,7 +3575,7 @@ const dp = {
|
|
|
3574
3575
|
* limitations under the License.
|
|
3575
3576
|
* =============================================================================
|
|
3576
3577
|
*/
|
|
3577
|
-
const
|
|
3578
|
+
const mp = {
|
|
3578
3579
|
kernelName: Jo,
|
|
3579
3580
|
gradFunc: (s) => ({ x: () => s.clone() })
|
|
3580
3581
|
};
|
|
@@ -3594,7 +3595,7 @@ const fp = {
|
|
|
3594
3595
|
* limitations under the License.
|
|
3595
3596
|
* =============================================================================
|
|
3596
3597
|
*/
|
|
3597
|
-
const
|
|
3598
|
+
const gp = {
|
|
3598
3599
|
kernelName: Xo,
|
|
3599
3600
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
3600
3601
|
};
|
|
@@ -3614,13 +3615,13 @@ const mp = {
|
|
|
3614
3615
|
* limitations under the License.
|
|
3615
3616
|
* =============================================================================
|
|
3616
3617
|
*/
|
|
3617
|
-
const
|
|
3618
|
+
const bp = {
|
|
3618
3619
|
kernelName: Yo,
|
|
3619
3620
|
inputsToSave: ["x"],
|
|
3620
3621
|
gradFunc: (s, t, e) => {
|
|
3621
3622
|
const [n] = t, { clipValueMin: i, clipValueMax: r } = e;
|
|
3622
3623
|
return {
|
|
3623
|
-
x: () =>
|
|
3624
|
+
x: () => qt(je(Ke(n, i), es(n, r)), s, Q(s))
|
|
3624
3625
|
};
|
|
3625
3626
|
}
|
|
3626
3627
|
};
|
|
@@ -3640,7 +3641,7 @@ const gp = {
|
|
|
3640
3641
|
* limitations under the License.
|
|
3641
3642
|
* =============================================================================
|
|
3642
3643
|
*/
|
|
3643
|
-
const
|
|
3644
|
+
const yp = {
|
|
3644
3645
|
kernelName: Qo,
|
|
3645
3646
|
inputsToSave: ["x"],
|
|
3646
3647
|
gradFunc: fr.gradFunc
|
|
@@ -3661,7 +3662,7 @@ const bp = {
|
|
|
3661
3662
|
* limitations under the License.
|
|
3662
3663
|
* =============================================================================
|
|
3663
3664
|
*/
|
|
3664
|
-
const
|
|
3665
|
+
const wp = {
|
|
3665
3666
|
kernelName: tl,
|
|
3666
3667
|
saveAllInputs: !0,
|
|
3667
3668
|
gradFunc: (s, t, e) => {
|
|
@@ -3685,7 +3686,7 @@ const yp = {
|
|
|
3685
3686
|
* limitations under the License.
|
|
3686
3687
|
* =============================================================================
|
|
3687
3688
|
*/
|
|
3688
|
-
const
|
|
3689
|
+
const kp = {
|
|
3689
3690
|
kernelName: xi,
|
|
3690
3691
|
inputsToSave: ["x", "filter"],
|
|
3691
3692
|
gradFunc: (s, t, e) => {
|
|
@@ -3712,8 +3713,8 @@ const wp = {
|
|
|
3712
3713
|
* limitations under the License.
|
|
3713
3714
|
* =============================================================================
|
|
3714
3715
|
*/
|
|
3715
|
-
const
|
|
3716
|
-
kernelName:
|
|
3716
|
+
const xp = {
|
|
3717
|
+
kernelName: Ni,
|
|
3717
3718
|
inputsToSave: ["dy", "filter"],
|
|
3718
3719
|
gradFunc: (s, t, e) => {
|
|
3719
3720
|
const [n, i] = t, { strides: r, pad: a, dataFormat: o, dimRoundingMode: l } = e;
|
|
@@ -3739,7 +3740,7 @@ const kp = {
|
|
|
3739
3740
|
* limitations under the License.
|
|
3740
3741
|
* =============================================================================
|
|
3741
3742
|
*/
|
|
3742
|
-
function
|
|
3743
|
+
function Np(s, t, e, n, i) {
|
|
3743
3744
|
let r = s;
|
|
3744
3745
|
s.rank === 4 && (r = A(s, [1, s.shape[0], s.shape[1], s.shape[2], s.shape[3]]));
|
|
3745
3746
|
let a = t;
|
|
@@ -3747,7 +3748,7 @@ function xp(s, t, e, n, i) {
|
|
|
3747
3748
|
const o = { x: r, dy: a }, l = { strides: n, pad: i, filterShape: e };
|
|
3748
3749
|
return M.runKernel(el, o, l);
|
|
3749
3750
|
}
|
|
3750
|
-
const vp = /* @__PURE__ */ F({ conv3DBackpropFilter_:
|
|
3751
|
+
const vp = /* @__PURE__ */ F({ conv3DBackpropFilter_: Np });
|
|
3751
3752
|
/**
|
|
3752
3753
|
* @license
|
|
3753
3754
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3764,8 +3765,8 @@ const vp = /* @__PURE__ */ F({ conv3DBackpropFilter_: xp });
|
|
|
3764
3765
|
* limitations under the License.
|
|
3765
3766
|
* =============================================================================
|
|
3766
3767
|
*/
|
|
3767
|
-
const
|
|
3768
|
-
kernelName:
|
|
3768
|
+
const Sp = {
|
|
3769
|
+
kernelName: vi,
|
|
3769
3770
|
inputsToSave: ["x", "filter"],
|
|
3770
3771
|
gradFunc: (s, t, e) => {
|
|
3771
3772
|
const { dilations: n, strides: i, pad: r } = e;
|
|
@@ -3793,12 +3794,12 @@ const Np = {
|
|
|
3793
3794
|
* limitations under the License.
|
|
3794
3795
|
* =============================================================================
|
|
3795
3796
|
*/
|
|
3796
|
-
const
|
|
3797
|
+
const Ap = {
|
|
3797
3798
|
kernelName: nl,
|
|
3798
3799
|
inputsToSave: ["x"],
|
|
3799
3800
|
gradFunc: (s, t) => {
|
|
3800
3801
|
const [e] = t;
|
|
3801
|
-
return { x: () => w(pt(
|
|
3802
|
+
return { x: () => w(pt(Uu(L(e, "float32"))), s) };
|
|
3802
3803
|
}
|
|
3803
3804
|
};
|
|
3804
3805
|
/**
|
|
@@ -3817,12 +3818,12 @@ const Sp = {
|
|
|
3817
3818
|
* limitations under the License.
|
|
3818
3819
|
* =============================================================================
|
|
3819
3820
|
*/
|
|
3820
|
-
const
|
|
3821
|
+
const Cp = {
|
|
3821
3822
|
kernelName: Si,
|
|
3822
3823
|
inputsToSave: ["x"],
|
|
3823
3824
|
gradFunc: (s, t) => {
|
|
3824
3825
|
const [e] = t;
|
|
3825
|
-
return { x: () => w(
|
|
3826
|
+
return { x: () => w(ph(L(e, "float32")), s) };
|
|
3826
3827
|
}
|
|
3827
3828
|
};
|
|
3828
3829
|
/**
|
|
@@ -3841,7 +3842,7 @@ const Ap = {
|
|
|
3841
3842
|
* limitations under the License.
|
|
3842
3843
|
* =============================================================================
|
|
3843
3844
|
*/
|
|
3844
|
-
const
|
|
3845
|
+
const Ip = {
|
|
3845
3846
|
kernelName: Ai,
|
|
3846
3847
|
inputsToSave: ["x"],
|
|
3847
3848
|
gradFunc: (s, t, e) => {
|
|
@@ -3849,7 +3850,7 @@ const Cp = {
|
|
|
3849
3850
|
return {
|
|
3850
3851
|
x: () => {
|
|
3851
3852
|
const o = Yi([i], n.rank);
|
|
3852
|
-
let l =
|
|
3853
|
+
let l = Fc(s, i, r, !a);
|
|
3853
3854
|
return o != null && (l = j(l, o)), l;
|
|
3854
3855
|
}
|
|
3855
3856
|
};
|
|
@@ -3871,7 +3872,7 @@ const Cp = {
|
|
|
3871
3872
|
* limitations under the License.
|
|
3872
3873
|
* =============================================================================
|
|
3873
3874
|
*/
|
|
3874
|
-
const
|
|
3875
|
+
const Dp = {
|
|
3875
3876
|
kernelName: Ci,
|
|
3876
3877
|
inputsToSave: ["x", "filter"],
|
|
3877
3878
|
gradFunc: (s, t, e) => {
|
|
@@ -3879,8 +3880,8 @@ const Ip = {
|
|
|
3879
3880
|
k(Se(o), () => `Error in gradient of depthwiseConv2dNative: dilation rates greater than 1 are not yet supported. Got dilations '${o}'`);
|
|
3880
3881
|
const [l, u] = t;
|
|
3881
3882
|
return k(l.rank === 4, () => `Error in gradient of depthwiseConv2dNative: input must be rank 4, but got rank ${l.rank}.`), k(u.rank === 4, () => `Error in gradient of depthwiseConv2dNative: filter must be rank 4, but got rank ${u.rank}.`), k(l.shape[3] === u.shape[2], () => `Error in gradient of depthwiseConv2d: number of input channels (${l.shape[3]}) must match the inChannels dimension in filter ${u.shape[2]}.`), k(de(i, o), () => `Error in gradient of depthwiseConv2d: Either strides or dilations must be 1. Got strides ${i} and dilations '${o}'.`), ft("depthwiseConv2d", r, a), {
|
|
3882
|
-
x: () =>
|
|
3883
|
-
filter: () =>
|
|
3883
|
+
x: () => Nh(l.shape, s, u, i, r, o, a),
|
|
3884
|
+
filter: () => kh(l, s, u.shape, i, r, o, a)
|
|
3884
3885
|
};
|
|
3885
3886
|
}
|
|
3886
3887
|
};
|
|
@@ -3900,7 +3901,7 @@ const Ip = {
|
|
|
3900
3901
|
* limitations under the License.
|
|
3901
3902
|
* =============================================================================
|
|
3902
3903
|
*/
|
|
3903
|
-
const
|
|
3904
|
+
const zp = {
|
|
3904
3905
|
kernelName: sl,
|
|
3905
3906
|
inputsToSave: ["x", "filter"],
|
|
3906
3907
|
gradFunc: (s, t, e) => {
|
|
@@ -3927,7 +3928,7 @@ const Dp = {
|
|
|
3927
3928
|
* limitations under the License.
|
|
3928
3929
|
* =============================================================================
|
|
3929
3930
|
*/
|
|
3930
|
-
const
|
|
3931
|
+
const Tp = {
|
|
3931
3932
|
kernelName: al,
|
|
3932
3933
|
outputsToSave: [!0],
|
|
3933
3934
|
gradFunc: (s, t) => {
|
|
@@ -3951,7 +3952,7 @@ const zp = {
|
|
|
3951
3952
|
* limitations under the License.
|
|
3952
3953
|
* =============================================================================
|
|
3953
3954
|
*/
|
|
3954
|
-
const
|
|
3955
|
+
const $p = {
|
|
3955
3956
|
kernelName: Ii,
|
|
3956
3957
|
inputsToSave: ["x"],
|
|
3957
3958
|
gradFunc: (s, t) => {
|
|
@@ -3975,7 +3976,7 @@ const Tp = {
|
|
|
3975
3976
|
* limitations under the License.
|
|
3976
3977
|
* =============================================================================
|
|
3977
3978
|
*/
|
|
3978
|
-
const
|
|
3979
|
+
const Ep = {
|
|
3979
3980
|
kernelName: ll,
|
|
3980
3981
|
outputsToSave: [!0],
|
|
3981
3982
|
gradFunc: (s, t) => {
|
|
@@ -3999,7 +4000,7 @@ const $p = {
|
|
|
3999
4000
|
* limitations under the License.
|
|
4000
4001
|
* =============================================================================
|
|
4001
4002
|
*/
|
|
4002
|
-
const
|
|
4003
|
+
const Lp = {
|
|
4003
4004
|
kernelName: ul,
|
|
4004
4005
|
inputsToSave: ["input"],
|
|
4005
4006
|
gradFunc: (s, t) => {
|
|
@@ -4023,7 +4024,7 @@ const Ep = {
|
|
|
4023
4024
|
* limitations under the License.
|
|
4024
4025
|
* =============================================================================
|
|
4025
4026
|
*/
|
|
4026
|
-
const
|
|
4027
|
+
const Fp = {
|
|
4027
4028
|
kernelName: cl,
|
|
4028
4029
|
inputsToSave: ["x"],
|
|
4029
4030
|
gradFunc: (s, t) => {
|
|
@@ -4047,7 +4048,7 @@ const Lp = {
|
|
|
4047
4048
|
* limitations under the License.
|
|
4048
4049
|
* =============================================================================
|
|
4049
4050
|
*/
|
|
4050
|
-
const
|
|
4051
|
+
const Mp = {
|
|
4051
4052
|
kernelName: hl,
|
|
4052
4053
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4053
4054
|
};
|
|
@@ -4067,7 +4068,7 @@ const Fp = {
|
|
|
4067
4068
|
* limitations under the License.
|
|
4068
4069
|
* =============================================================================
|
|
4069
4070
|
*/
|
|
4070
|
-
const
|
|
4071
|
+
const Op = {
|
|
4071
4072
|
kernelName: pl,
|
|
4072
4073
|
inputsToSave: ["a", "b"],
|
|
4073
4074
|
gradFunc: (s, t) => {
|
|
@@ -4100,35 +4101,35 @@ const Mp = {
|
|
|
4100
4101
|
* limitations under the License.
|
|
4101
4102
|
* =============================================================================
|
|
4102
4103
|
*/
|
|
4103
|
-
const
|
|
4104
|
+
const Rp = {
|
|
4104
4105
|
kernelName: ki,
|
|
4105
4106
|
inputsToSave: ["x", "mean", "variance", "scale"],
|
|
4106
4107
|
gradFunc: (s, t, e) => {
|
|
4107
4108
|
const { varianceEpsilon: n } = e, [i, r, a, o] = t, l = o ?? tt(1), u = lt(r.shape, i.shape), c = [];
|
|
4108
4109
|
if (r.rank === 1) {
|
|
4109
|
-
for (let
|
|
4110
|
-
c.push(i.shape[
|
|
4110
|
+
for (let N = 0; N < i.shape.length - 1; ++N)
|
|
4111
|
+
c.push(i.shape[N]);
|
|
4111
4112
|
c.push(1);
|
|
4112
4113
|
}
|
|
4113
|
-
const h = V(i, r), p = w(s, l), f =
|
|
4114
|
+
const h = V(i, r), p = w(s, l), f = ah($(a, tt(n))), g = w(w(w(f, f), f), tt(-0.5));
|
|
4114
4115
|
return {
|
|
4115
4116
|
x: () => r.rank === 1 ? A(w(w(s, Ee(A(f, [1, 1, 1, r.shape[0]]), c)), l), i.shape) : A(w(w(s, f), l), i.shape),
|
|
4116
4117
|
mean: () => {
|
|
4117
|
-
let
|
|
4118
|
-
return r.rank === 1 && (
|
|
4118
|
+
let N = w(w(f, tt(-1)), p);
|
|
4119
|
+
return r.rank === 1 && (N = B(N, u)), A(N, r.shape);
|
|
4119
4120
|
},
|
|
4120
4121
|
variance: () => {
|
|
4121
|
-
let
|
|
4122
|
-
return r.rank === 1 && (
|
|
4122
|
+
let N = w(w(g, h), p);
|
|
4123
|
+
return r.rank === 1 && (N = B(N, u)), A(N, r.shape);
|
|
4123
4124
|
},
|
|
4124
4125
|
scale: () => {
|
|
4125
|
-
const
|
|
4126
|
-
let I = w(s,
|
|
4126
|
+
const N = w(h, f);
|
|
4127
|
+
let I = w(s, N);
|
|
4127
4128
|
return r.rank === 1 && (I = B(I, u)), A(I, r.shape);
|
|
4128
4129
|
},
|
|
4129
4130
|
offset: () => {
|
|
4130
|
-
let
|
|
4131
|
-
return r.rank === 1 && (
|
|
4131
|
+
let N = s;
|
|
4132
|
+
return r.rank === 1 && (N = B(N, u)), A(N, r.shape);
|
|
4132
4133
|
}
|
|
4133
4134
|
};
|
|
4134
4135
|
}
|
|
@@ -4149,17 +4150,17 @@ const Op = {
|
|
|
4149
4150
|
* limitations under the License.
|
|
4150
4151
|
* =============================================================================
|
|
4151
4152
|
*/
|
|
4152
|
-
const
|
|
4153
|
+
const _p = {
|
|
4153
4154
|
kernelName: dl,
|
|
4154
4155
|
inputsToSave: ["x", "indices"],
|
|
4155
4156
|
gradFunc: (s, t, e) => {
|
|
4156
4157
|
const [n, i] = t, { axis: r, batchDims: a } = e, o = he(r, n.shape)[0], l = (u, c, h) => () => {
|
|
4157
|
-
const p = u.shape, f = c.size, g = p.slice(0, o), b = g.length, m = p.slice(r, p.length).slice(1),
|
|
4158
|
+
const p = u.shape, f = c.size, g = p.slice(0, o), b = g.length, m = p.slice(r, p.length).slice(1), v = m.length, y = qs(0, b), C = qs(b + 1, b + 1 + v), N = Zs([
|
|
4158
4159
|
g,
|
|
4159
4160
|
[f],
|
|
4160
4161
|
m
|
|
4161
|
-
]), I = A(h,
|
|
4162
|
-
let E =
|
|
4162
|
+
]), I = A(h, N), z = A(c, [f]), _ = Zs([[b], y, C]), T = j(I, _);
|
|
4163
|
+
let E = mh(T, z, u.shape[o]);
|
|
4163
4164
|
const R = ss(_);
|
|
4164
4165
|
return E = j(E, R), E;
|
|
4165
4166
|
};
|
|
@@ -4199,7 +4200,7 @@ function Zs(s) {
|
|
|
4199
4200
|
* limitations under the License.
|
|
4200
4201
|
* =============================================================================
|
|
4201
4202
|
*/
|
|
4202
|
-
const
|
|
4203
|
+
const Bp = {
|
|
4203
4204
|
kernelName: fl,
|
|
4204
4205
|
inputsToSave: ["a", "b"],
|
|
4205
4206
|
gradFunc: (s, t) => {
|
|
@@ -4223,7 +4224,7 @@ const _p = {
|
|
|
4223
4224
|
* limitations under the License.
|
|
4224
4225
|
* =============================================================================
|
|
4225
4226
|
*/
|
|
4226
|
-
const
|
|
4227
|
+
const Wp = {
|
|
4227
4228
|
kernelName: ml,
|
|
4228
4229
|
gradFunc: (s) => ({ x: () => L(s, "float32") })
|
|
4229
4230
|
};
|
|
@@ -4243,7 +4244,7 @@ const Bp = {
|
|
|
4243
4244
|
* limitations under the License.
|
|
4244
4245
|
* =============================================================================
|
|
4245
4246
|
*/
|
|
4246
|
-
const
|
|
4247
|
+
const Gp = {
|
|
4247
4248
|
kernelName: gl,
|
|
4248
4249
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4249
4250
|
};
|
|
@@ -4263,7 +4264,7 @@ const Wp = {
|
|
|
4263
4264
|
* limitations under the License.
|
|
4264
4265
|
* =============================================================================
|
|
4265
4266
|
*/
|
|
4266
|
-
const
|
|
4267
|
+
const Pp = {
|
|
4267
4268
|
kernelName: bl,
|
|
4268
4269
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4269
4270
|
};
|
|
@@ -4283,7 +4284,7 @@ const Gp = {
|
|
|
4283
4284
|
* limitations under the License.
|
|
4284
4285
|
* =============================================================================
|
|
4285
4286
|
*/
|
|
4286
|
-
const
|
|
4287
|
+
const Up = {
|
|
4287
4288
|
kernelName: yl,
|
|
4288
4289
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4289
4290
|
};
|
|
@@ -4303,12 +4304,12 @@ const Pp = {
|
|
|
4303
4304
|
* limitations under the License.
|
|
4304
4305
|
* =============================================================================
|
|
4305
4306
|
*/
|
|
4306
|
-
const
|
|
4307
|
+
const Vp = {
|
|
4307
4308
|
kernelName: wl,
|
|
4308
4309
|
inputsToSave: ["x"],
|
|
4309
4310
|
gradFunc: (s, t, e) => {
|
|
4310
4311
|
const [n] = t, { alpha: i } = e, r = Gt(n, 0);
|
|
4311
|
-
return { x: () =>
|
|
4312
|
+
return { x: () => qt(r, s, w(s, i)) };
|
|
4312
4313
|
}
|
|
4313
4314
|
};
|
|
4314
4315
|
/**
|
|
@@ -4327,7 +4328,7 @@ const Up = {
|
|
|
4327
4328
|
* limitations under the License.
|
|
4328
4329
|
* =============================================================================
|
|
4329
4330
|
*/
|
|
4330
|
-
const
|
|
4331
|
+
const jp = {
|
|
4331
4332
|
kernelName: Di,
|
|
4332
4333
|
inputsToSave: ["x"],
|
|
4333
4334
|
gradFunc: (s, t) => {
|
|
@@ -4351,7 +4352,7 @@ const Vp = {
|
|
|
4351
4352
|
* limitations under the License.
|
|
4352
4353
|
* =============================================================================
|
|
4353
4354
|
*/
|
|
4354
|
-
const
|
|
4355
|
+
const Kp = {
|
|
4355
4356
|
kernelName: kl,
|
|
4356
4357
|
inputsToSave: ["x"],
|
|
4357
4358
|
gradFunc: (s, t) => {
|
|
@@ -4375,7 +4376,7 @@ const jp = {
|
|
|
4375
4376
|
* limitations under the License.
|
|
4376
4377
|
* =============================================================================
|
|
4377
4378
|
*/
|
|
4378
|
-
const
|
|
4379
|
+
const Hp = {
|
|
4379
4380
|
kernelName: xl,
|
|
4380
4381
|
inputsToSave: [],
|
|
4381
4382
|
outputsToSave: [!0],
|
|
@@ -4405,11 +4406,11 @@ const Kp = {
|
|
|
4405
4406
|
* limitations under the License.
|
|
4406
4407
|
* =============================================================================
|
|
4407
4408
|
*/
|
|
4408
|
-
function
|
|
4409
|
+
function qp(s, t, e, n = 5, i = 1, r = 1, a = 0.5) {
|
|
4409
4410
|
const o = { x: s, y: t, dy: e }, l = { depthRadius: n, bias: i, alpha: r, beta: a };
|
|
4410
|
-
return M.runKernel(
|
|
4411
|
+
return M.runKernel(Nl, o, l);
|
|
4411
4412
|
}
|
|
4412
|
-
const
|
|
4413
|
+
const Zp = F({ localResponseNormalizationBackprop_: qp });
|
|
4413
4414
|
/**
|
|
4414
4415
|
* @license
|
|
4415
4416
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -4426,14 +4427,14 @@ const qp = F({ localResponseNormalizationBackprop_: Hp });
|
|
|
4426
4427
|
* limitations under the License.
|
|
4427
4428
|
* =============================================================================
|
|
4428
4429
|
*/
|
|
4429
|
-
const
|
|
4430
|
-
kernelName:
|
|
4430
|
+
const Jp = {
|
|
4431
|
+
kernelName: vl,
|
|
4431
4432
|
inputsToSave: ["x"],
|
|
4432
4433
|
outputsToSave: [!0],
|
|
4433
4434
|
gradFunc: (s, t, e) => {
|
|
4434
4435
|
const [n, i] = t, { depthRadius: r, bias: a, alpha: o, beta: l } = e;
|
|
4435
4436
|
return {
|
|
4436
|
-
x: () =>
|
|
4437
|
+
x: () => Zp(n, i, s, r, a, o, l)
|
|
4437
4438
|
};
|
|
4438
4439
|
}
|
|
4439
4440
|
};
|
|
@@ -4501,12 +4502,12 @@ const Js = {
|
|
|
4501
4502
|
* limitations under the License.
|
|
4502
4503
|
* =============================================================================
|
|
4503
4504
|
*/
|
|
4504
|
-
const
|
|
4505
|
+
const Xp = {
|
|
4505
4506
|
kernelName: Al,
|
|
4506
4507
|
inputsToSave: ["a", "b"],
|
|
4507
4508
|
gradFunc: (s, t) => {
|
|
4508
4509
|
const [e, n] = t;
|
|
4509
|
-
return { a: () => w(s, L(
|
|
4510
|
+
return { a: () => w(s, L(Ke(e, n), "float32")), b: () => w(s, L(Gu(e, n), "float32")) };
|
|
4510
4511
|
}
|
|
4511
4512
|
};
|
|
4512
4513
|
/**
|
|
@@ -4525,7 +4526,7 @@ const Jp = {
|
|
|
4525
4526
|
* limitations under the License.
|
|
4526
4527
|
* =============================================================================
|
|
4527
4528
|
*/
|
|
4528
|
-
function
|
|
4529
|
+
function Yp(s, t, e, n, i, r, a) {
|
|
4529
4530
|
const o = D(s, "dy", "maxPool3dGrad"), l = D(t, "input", "maxPool3dGrad"), u = D(e, "output", "maxPool3dGrad");
|
|
4530
4531
|
let c = o, h = l, p = u, f = !1;
|
|
4531
4532
|
l.rank === 4 && (f = !0, c = A(o, [1, o.shape[0], o.shape[1], o.shape[2], o.shape[3]]), h = A(l, [
|
|
@@ -4544,7 +4545,7 @@ function Xp(s, t, e, n, i, r, a) {
|
|
|
4544
4545
|
const g = { dy: c, input: h, output: p }, b = { filterSize: n, strides: i, pad: r, dimRoundingMode: a }, m = M.runKernel(Cl, g, b);
|
|
4545
4546
|
return f ? A(m, [m.shape[1], m.shape[2], m.shape[3], m.shape[4]]) : m;
|
|
4546
4547
|
}
|
|
4547
|
-
const
|
|
4548
|
+
const Qp = /* @__PURE__ */ F({ maxPool3dGrad_: Yp });
|
|
4548
4549
|
/**
|
|
4549
4550
|
* @license
|
|
4550
4551
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -4561,14 +4562,14 @@ const Yp = /* @__PURE__ */ F({ maxPool3dGrad_: Xp });
|
|
|
4561
4562
|
* limitations under the License.
|
|
4562
4563
|
* =============================================================================
|
|
4563
4564
|
*/
|
|
4564
|
-
const
|
|
4565
|
+
const td = {
|
|
4565
4566
|
kernelName: $i,
|
|
4566
4567
|
inputsToSave: ["x"],
|
|
4567
4568
|
outputsToSave: [!0],
|
|
4568
4569
|
gradFunc: (s, t, e) => {
|
|
4569
4570
|
const [n, i] = t, { filterSize: r, strides: a, pad: o, dimRoundingMode: l } = e;
|
|
4570
4571
|
return {
|
|
4571
|
-
x: () =>
|
|
4572
|
+
x: () => Qp(s, n, i, r, a, o, l)
|
|
4572
4573
|
};
|
|
4573
4574
|
}
|
|
4574
4575
|
};
|
|
@@ -4588,13 +4589,13 @@ const Qp = {
|
|
|
4588
4589
|
* limitations under the License.
|
|
4589
4590
|
* =============================================================================
|
|
4590
4591
|
*/
|
|
4591
|
-
function
|
|
4592
|
+
function ed(s, t, e, n, i, r, a) {
|
|
4592
4593
|
const o = D(s, "dy", "maxPoolGrad"), l = D(t, "input", "maxPoolGrad"), u = D(e, "output", "maxPoolGrad");
|
|
4593
4594
|
k(l.rank === o.rank, () => `Rank of input (${l.rank}) does not match rank of dy (${o.rank})`), k(o.rank === 4, () => `Error in maxPoolGrad: dy must be rank 4 but got rank ${o.rank}.`), k(l.rank === 4, () => `Error in maxPoolGrad: input must be rank 4 but got rank ${l.rank}.`), ft("maxPoolGrad", r, a);
|
|
4594
4595
|
const c = { dy: o, input: l, output: u }, h = { filterSize: n, strides: i, pad: r, dimRoundingMode: a };
|
|
4595
4596
|
return M.runKernel(Il, c, h);
|
|
4596
4597
|
}
|
|
4597
|
-
const
|
|
4598
|
+
const nd = /* @__PURE__ */ F({ maxPoolGrad_: ed });
|
|
4598
4599
|
/**
|
|
4599
4600
|
* @license
|
|
4600
4601
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -4611,14 +4612,14 @@ const ed = /* @__PURE__ */ F({ maxPoolGrad_: td });
|
|
|
4611
4612
|
* limitations under the License.
|
|
4612
4613
|
* =============================================================================
|
|
4613
4614
|
*/
|
|
4614
|
-
const
|
|
4615
|
+
const sd = {
|
|
4615
4616
|
kernelName: Ti,
|
|
4616
4617
|
inputsToSave: ["x"],
|
|
4617
4618
|
outputsToSave: [!0],
|
|
4618
4619
|
gradFunc: (s, t, e) => {
|
|
4619
4620
|
const [n, i] = t, { filterSize: r, strides: a, pad: o } = e;
|
|
4620
4621
|
return {
|
|
4621
|
-
x: () =>
|
|
4622
|
+
x: () => nd(s, n, i, r, a, o)
|
|
4622
4623
|
};
|
|
4623
4624
|
}
|
|
4624
4625
|
};
|
|
@@ -4638,11 +4639,11 @@ const nd = {
|
|
|
4638
4639
|
* limitations under the License.
|
|
4639
4640
|
* =============================================================================
|
|
4640
4641
|
*/
|
|
4641
|
-
const
|
|
4642
|
+
const id = {
|
|
4642
4643
|
kernelName: Dl,
|
|
4643
4644
|
inputsToSave: ["x"],
|
|
4644
4645
|
gradFunc: (s, t, e) => {
|
|
4645
|
-
const [n] = t, { axis: i } = e, r = he(i, n.shape), o =
|
|
4646
|
+
const [n] = t, { axis: i } = e, r = he(i, n.shape), o = Vu(n.shape, r)[1], l = Gi(o);
|
|
4646
4647
|
return { x: () => {
|
|
4647
4648
|
const c = n.shape.slice();
|
|
4648
4649
|
r.forEach((f) => {
|
|
@@ -4669,7 +4670,7 @@ const sd = {
|
|
|
4669
4670
|
* limitations under the License.
|
|
4670
4671
|
* =============================================================================
|
|
4671
4672
|
*/
|
|
4672
|
-
const
|
|
4673
|
+
const rd = {
|
|
4673
4674
|
kernelName: zl,
|
|
4674
4675
|
inputsToSave: ["x"],
|
|
4675
4676
|
outputsToSave: [!0],
|
|
@@ -4696,12 +4697,12 @@ const id = {
|
|
|
4696
4697
|
* limitations under the License.
|
|
4697
4698
|
* =============================================================================
|
|
4698
4699
|
*/
|
|
4699
|
-
const
|
|
4700
|
+
const ad = {
|
|
4700
4701
|
kernelName: Tl,
|
|
4701
4702
|
inputsToSave: ["a", "b"],
|
|
4702
4703
|
gradFunc: (s, t) => {
|
|
4703
4704
|
const [e, n] = t;
|
|
4704
|
-
return { a: () => w(s, L(
|
|
4705
|
+
return { a: () => w(s, L(es(e, n), "float32")), b: () => w(s, L(Gt(e, n), "float32")) };
|
|
4705
4706
|
}
|
|
4706
4707
|
};
|
|
4707
4708
|
/**
|
|
@@ -4720,7 +4721,7 @@ const rd = {
|
|
|
4720
4721
|
* limitations under the License.
|
|
4721
4722
|
* =============================================================================
|
|
4722
4723
|
*/
|
|
4723
|
-
const
|
|
4724
|
+
const od = {
|
|
4724
4725
|
kernelName: $l,
|
|
4725
4726
|
inputsToSave: ["x"],
|
|
4726
4727
|
gradFunc: (s, t, e) => {
|
|
@@ -4744,7 +4745,7 @@ const ad = {
|
|
|
4744
4745
|
* limitations under the License.
|
|
4745
4746
|
* =============================================================================
|
|
4746
4747
|
*/
|
|
4747
|
-
const
|
|
4748
|
+
const ld = {
|
|
4748
4749
|
kernelName: El,
|
|
4749
4750
|
inputsToSave: ["a", "b"],
|
|
4750
4751
|
gradFunc: (s, t) => {
|
|
@@ -4753,7 +4754,7 @@ const od = {
|
|
|
4753
4754
|
const o = lt(e.shape, i);
|
|
4754
4755
|
return o.length > 0 ? A(B(s, o), e.shape) : s;
|
|
4755
4756
|
}, b: () => {
|
|
4756
|
-
const o = w(s, pt(
|
|
4757
|
+
const o = w(s, pt(qi(P(e, n)))), l = lt(n.shape, i);
|
|
4757
4758
|
return l.length > 0 ? A(B(o, l), n.shape) : o;
|
|
4758
4759
|
} };
|
|
4759
4760
|
}
|
|
@@ -4774,7 +4775,7 @@ const od = {
|
|
|
4774
4775
|
* limitations under the License.
|
|
4775
4776
|
* =============================================================================
|
|
4776
4777
|
*/
|
|
4777
|
-
const
|
|
4778
|
+
const ud = {
|
|
4778
4779
|
kernelName: Ll,
|
|
4779
4780
|
inputsToSave: ["a", "b"],
|
|
4780
4781
|
gradFunc: (s, t) => {
|
|
@@ -4804,7 +4805,7 @@ const ld = {
|
|
|
4804
4805
|
* limitations under the License.
|
|
4805
4806
|
* =============================================================================
|
|
4806
4807
|
*/
|
|
4807
|
-
const
|
|
4808
|
+
const cd = {
|
|
4808
4809
|
kernelName: Fl,
|
|
4809
4810
|
gradFunc: (s) => ({ x: () => pt(s) })
|
|
4810
4811
|
};
|
|
@@ -4824,7 +4825,7 @@ const ud = {
|
|
|
4824
4825
|
* limitations under the License.
|
|
4825
4826
|
* =============================================================================
|
|
4826
4827
|
*/
|
|
4827
|
-
const
|
|
4828
|
+
const hd = {
|
|
4828
4829
|
kernelName: Ei,
|
|
4829
4830
|
inputsToSave: ["indices"],
|
|
4830
4831
|
gradFunc: (s, t) => {
|
|
@@ -4848,7 +4849,7 @@ const cd = {
|
|
|
4848
4849
|
* limitations under the License.
|
|
4849
4850
|
* =============================================================================
|
|
4850
4851
|
*/
|
|
4851
|
-
const
|
|
4852
|
+
const pd = {
|
|
4852
4853
|
kernelName: Li,
|
|
4853
4854
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
4854
4855
|
};
|
|
@@ -4868,12 +4869,12 @@ const hd = {
|
|
|
4868
4869
|
* limitations under the License.
|
|
4869
4870
|
* =============================================================================
|
|
4870
4871
|
*/
|
|
4871
|
-
const
|
|
4872
|
+
const dd = {
|
|
4872
4873
|
kernelName: Ml,
|
|
4873
4874
|
saveAllInputs: !0,
|
|
4874
4875
|
gradFunc: (s, t, e) => {
|
|
4875
4876
|
const { axis: n } = e;
|
|
4876
|
-
return
|
|
4877
|
+
return nn(s, n).map((r) => () => r);
|
|
4877
4878
|
}
|
|
4878
4879
|
};
|
|
4879
4880
|
/**
|
|
@@ -4916,7 +4917,7 @@ const Xs = {
|
|
|
4916
4917
|
* limitations under the License.
|
|
4917
4918
|
* =============================================================================
|
|
4918
4919
|
*/
|
|
4919
|
-
const
|
|
4920
|
+
const fd = {
|
|
4920
4921
|
kernelName: Ol,
|
|
4921
4922
|
inputsToSave: ["a", "b"],
|
|
4922
4923
|
outputsToSave: [!0],
|
|
@@ -4928,7 +4929,7 @@ const dd = {
|
|
|
4928
4929
|
const p = lt(r.shape, o);
|
|
4929
4930
|
return p.length > 0 && (h = B(h, p)), A(h, r.shape);
|
|
4930
4931
|
}, b: () => {
|
|
4931
|
-
const c = Gt(r, 0), h =
|
|
4932
|
+
const c = Gt(r, 0), h = qt(c, Zt(r), Q(r));
|
|
4932
4933
|
let p = w(s, w(i, h));
|
|
4933
4934
|
const f = lt(a.shape, o);
|
|
4934
4935
|
return f.length > 0 && (p = B(p, f)), A(p, a.shape);
|
|
@@ -4951,15 +4952,15 @@ const dd = {
|
|
|
4951
4952
|
* limitations under the License.
|
|
4952
4953
|
* =============================================================================
|
|
4953
4954
|
*/
|
|
4954
|
-
const
|
|
4955
|
+
const md = {
|
|
4955
4956
|
kernelName: Rl,
|
|
4956
4957
|
inputsToSave: ["x", "alpha"],
|
|
4957
4958
|
gradFunc: (s, t) => {
|
|
4958
4959
|
const [e, n] = t, i = Gt(e, 0);
|
|
4959
4960
|
return {
|
|
4960
|
-
x: () =>
|
|
4961
|
+
x: () => qt(i, s, w(s, n)),
|
|
4961
4962
|
alpha: () => {
|
|
4962
|
-
let r =
|
|
4963
|
+
let r = qt(i, Q(s), w(s, e));
|
|
4963
4964
|
const a = lt(n.shape, s.shape);
|
|
4964
4965
|
return a.length > 0 && (r = B(r, a)), A(r, n.shape);
|
|
4965
4966
|
}
|
|
@@ -4982,33 +4983,33 @@ const fd = {
|
|
|
4982
4983
|
* limitations under the License.
|
|
4983
4984
|
* =============================================================================
|
|
4984
4985
|
*/
|
|
4985
|
-
function
|
|
4986
|
+
function gd(s, t, e) {
|
|
4986
4987
|
const n = s.shape.slice();
|
|
4987
4988
|
n[e] = 1;
|
|
4988
4989
|
const i = A(t, n), r = Ps(s, e, !0, !1), a = Ps(s, e, !0, !0), o = w(r, a);
|
|
4989
4990
|
return w(i, o);
|
|
4990
4991
|
}
|
|
4991
|
-
function
|
|
4992
|
+
function bd(s, t, e) {
|
|
4992
4993
|
const n = s.shape.length, i = n - e.length, r = Yi(e, n);
|
|
4993
4994
|
let a = s;
|
|
4994
4995
|
r != null && (a = j(s, r));
|
|
4995
4996
|
const o = a.shape.slice(), u = o.splice(n - e.length, e.length).reduce((p, f) => p * f, 1);
|
|
4996
4997
|
o.push(u);
|
|
4997
4998
|
const c = a.reshape(o);
|
|
4998
|
-
let h =
|
|
4999
|
+
let h = gd(c, t, i);
|
|
4999
5000
|
if (h = h.reshape(a.shape), r != null) {
|
|
5000
5001
|
const p = ss(r);
|
|
5001
5002
|
h = j(h, p);
|
|
5002
5003
|
}
|
|
5003
5004
|
return h;
|
|
5004
5005
|
}
|
|
5005
|
-
const
|
|
5006
|
+
const yd = {
|
|
5006
5007
|
kernelName: _l,
|
|
5007
5008
|
inputsToSave: ["x"],
|
|
5008
5009
|
gradFunc: (s, t, e) => {
|
|
5009
5010
|
const [n] = t, { axis: i } = e;
|
|
5010
5011
|
let r = [];
|
|
5011
|
-
return i == null ? r = n.shape.map((a, o) => o) : typeof i == "number" ? r = [i] : r = i, { x: () =>
|
|
5012
|
+
return i == null ? r = n.shape.map((a, o) => o) : typeof i == "number" ? r = [i] : r = i, { x: () => bd(n, s, r) };
|
|
5012
5013
|
}
|
|
5013
5014
|
};
|
|
5014
5015
|
/**
|
|
@@ -5027,7 +5028,7 @@ const bd = {
|
|
|
5027
5028
|
* limitations under the License.
|
|
5028
5029
|
* =============================================================================
|
|
5029
5030
|
*/
|
|
5030
|
-
const
|
|
5031
|
+
const wd = {
|
|
5031
5032
|
kernelName: Bl,
|
|
5032
5033
|
inputsToSave: ["a", "b"],
|
|
5033
5034
|
gradFunc: (s, t) => {
|
|
@@ -5060,7 +5061,7 @@ const yd = {
|
|
|
5060
5061
|
* limitations under the License.
|
|
5061
5062
|
* =============================================================================
|
|
5062
5063
|
*/
|
|
5063
|
-
const
|
|
5064
|
+
const kd = {
|
|
5064
5065
|
kernelName: Wl,
|
|
5065
5066
|
inputsToSave: ["x"],
|
|
5066
5067
|
gradFunc: (s, t) => {
|
|
@@ -5084,11 +5085,11 @@ const wd = {
|
|
|
5084
5085
|
* limitations under the License.
|
|
5085
5086
|
* =============================================================================
|
|
5086
5087
|
*/
|
|
5087
|
-
const
|
|
5088
|
+
const xd = {
|
|
5088
5089
|
kernelName: Gl,
|
|
5089
5090
|
inputsToSave: ["x"],
|
|
5090
5091
|
gradFunc: (s, t) => {
|
|
5091
|
-
const [e] = t, n = w(
|
|
5092
|
+
const [e] = t, n = w(es(e, 6), Xn(e));
|
|
5092
5093
|
return { x: () => w(s, L(n, "float32")) };
|
|
5093
5094
|
}
|
|
5094
5095
|
};
|
|
@@ -5108,7 +5109,7 @@ const kd = {
|
|
|
5108
5109
|
* limitations under the License.
|
|
5109
5110
|
* =============================================================================
|
|
5110
5111
|
*/
|
|
5111
|
-
const
|
|
5112
|
+
const Nd = {
|
|
5112
5113
|
kernelName: Pl,
|
|
5113
5114
|
inputsToSave: ["x"],
|
|
5114
5115
|
gradFunc: (s, t) => {
|
|
@@ -5156,7 +5157,7 @@ const vd = {
|
|
|
5156
5157
|
* limitations under the License.
|
|
5157
5158
|
* =============================================================================
|
|
5158
5159
|
*/
|
|
5159
|
-
const
|
|
5160
|
+
const Sd = {
|
|
5160
5161
|
kernelName: Vl,
|
|
5161
5162
|
inputsToSave: ["images"],
|
|
5162
5163
|
gradFunc: (s, t, e) => {
|
|
@@ -5183,7 +5184,7 @@ const Nd = {
|
|
|
5183
5184
|
* limitations under the License.
|
|
5184
5185
|
* =============================================================================
|
|
5185
5186
|
*/
|
|
5186
|
-
const
|
|
5187
|
+
const Ad = {
|
|
5187
5188
|
kernelName: Kl,
|
|
5188
5189
|
inputsToSave: ["images"],
|
|
5189
5190
|
gradFunc: (s, t, e) => {
|
|
@@ -5210,7 +5211,7 @@ const Sd = {
|
|
|
5210
5211
|
* limitations under the License.
|
|
5211
5212
|
* =============================================================================
|
|
5212
5213
|
*/
|
|
5213
|
-
const
|
|
5214
|
+
const Cd = {
|
|
5214
5215
|
kernelName: Oi,
|
|
5215
5216
|
gradFunc: (s, t, e) => {
|
|
5216
5217
|
const { dims: n } = e, i = he(n, s.shape);
|
|
@@ -5233,7 +5234,7 @@ const Ad = {
|
|
|
5233
5234
|
* limitations under the License.
|
|
5234
5235
|
* =============================================================================
|
|
5235
5236
|
*/
|
|
5236
|
-
const
|
|
5237
|
+
const Id = {
|
|
5237
5238
|
kernelName: ql,
|
|
5238
5239
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
5239
5240
|
};
|
|
@@ -5253,7 +5254,7 @@ const Cd = {
|
|
|
5253
5254
|
* limitations under the License.
|
|
5254
5255
|
* =============================================================================
|
|
5255
5256
|
*/
|
|
5256
|
-
const
|
|
5257
|
+
const Dd = {
|
|
5257
5258
|
kernelName: Ri,
|
|
5258
5259
|
inputsToSave: ["x"],
|
|
5259
5260
|
gradFunc: (s, t) => {
|
|
@@ -5277,7 +5278,7 @@ const Id = {
|
|
|
5277
5278
|
* limitations under the License.
|
|
5278
5279
|
* =============================================================================
|
|
5279
5280
|
*/
|
|
5280
|
-
const
|
|
5281
|
+
const zd = {
|
|
5281
5282
|
kernelName: Zl,
|
|
5282
5283
|
inputsToSave: ["condition"],
|
|
5283
5284
|
gradFunc: (s, t) => {
|
|
@@ -5287,7 +5288,7 @@ const Dd = {
|
|
|
5287
5288
|
// when backprop supports it.
|
|
5288
5289
|
condition: () => L(Q(e), "float32"),
|
|
5289
5290
|
t: () => w(s, L(e, s.dtype)),
|
|
5290
|
-
e: () => w(s, L(
|
|
5291
|
+
e: () => w(s, L(Kc(e), s.dtype))
|
|
5291
5292
|
};
|
|
5292
5293
|
}
|
|
5293
5294
|
};
|
|
@@ -5307,15 +5308,15 @@ const Dd = {
|
|
|
5307
5308
|
* limitations under the License.
|
|
5308
5309
|
* =============================================================================
|
|
5309
5310
|
*/
|
|
5310
|
-
const
|
|
5311
|
+
const Td = {
|
|
5311
5312
|
kernelName: _i,
|
|
5312
5313
|
inputsToSave: ["x"],
|
|
5313
5314
|
gradFunc: (s, t) => {
|
|
5314
5315
|
const [e] = t;
|
|
5315
5316
|
return {
|
|
5316
5317
|
x: () => {
|
|
5317
|
-
const n = Gt(e, tt(0)), i = tt(
|
|
5318
|
-
return
|
|
5318
|
+
const n = Gt(e, tt(0)), i = tt(Ch), r = tt(Ih), a = w(s, r), o = w(w(s, i), Jt(L(e, "float32")));
|
|
5319
|
+
return qt(n, a, o);
|
|
5319
5320
|
}
|
|
5320
5321
|
};
|
|
5321
5322
|
}
|
|
@@ -5336,7 +5337,7 @@ const zd = {
|
|
|
5336
5337
|
* limitations under the License.
|
|
5337
5338
|
* =============================================================================
|
|
5338
5339
|
*/
|
|
5339
|
-
const
|
|
5340
|
+
const $d = {
|
|
5340
5341
|
kernelName: Jl,
|
|
5341
5342
|
outputsToSave: [!0],
|
|
5342
5343
|
gradFunc: (s, t) => {
|
|
@@ -5360,7 +5361,7 @@ const Td = {
|
|
|
5360
5361
|
* limitations under the License.
|
|
5361
5362
|
* =============================================================================
|
|
5362
5363
|
*/
|
|
5363
|
-
const
|
|
5364
|
+
const Ed = {
|
|
5364
5365
|
kernelName: Xl,
|
|
5365
5366
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
5366
5367
|
};
|
|
@@ -5380,7 +5381,7 @@ const $d = {
|
|
|
5380
5381
|
* limitations under the License.
|
|
5381
5382
|
* =============================================================================
|
|
5382
5383
|
*/
|
|
5383
|
-
const
|
|
5384
|
+
const Ld = {
|
|
5384
5385
|
kernelName: Yl,
|
|
5385
5386
|
inputsToSave: ["x"],
|
|
5386
5387
|
gradFunc: (s, t) => {
|
|
@@ -5404,12 +5405,12 @@ const Ed = {
|
|
|
5404
5405
|
* limitations under the License.
|
|
5405
5406
|
* =============================================================================
|
|
5406
5407
|
*/
|
|
5407
|
-
const
|
|
5408
|
+
const Fd = {
|
|
5408
5409
|
kernelName: Bi,
|
|
5409
5410
|
inputsToSave: ["x"],
|
|
5410
5411
|
gradFunc: (s, t) => {
|
|
5411
5412
|
const [e] = t;
|
|
5412
|
-
return { x: () => w(
|
|
5413
|
+
return { x: () => w($c(L(e, "float32")), s) };
|
|
5413
5414
|
}
|
|
5414
5415
|
};
|
|
5415
5416
|
/**
|
|
@@ -5428,11 +5429,11 @@ const Ld = {
|
|
|
5428
5429
|
* limitations under the License.
|
|
5429
5430
|
* =============================================================================
|
|
5430
5431
|
*/
|
|
5431
|
-
const
|
|
5432
|
+
const Md = {
|
|
5432
5433
|
kernelName: Ql,
|
|
5433
5434
|
inputsToSave: ["x"],
|
|
5434
5435
|
gradFunc: (s, t, e) => {
|
|
5435
|
-
const [n] = t, { begin: i, size: r } = e, a = n.shape, [o, l] =
|
|
5436
|
+
const [n] = t, { begin: i, size: r } = e, a = n.shape, [o, l] = Ku(n, i, r), u = [];
|
|
5436
5437
|
for (let c = 0; c < s.rank; c++)
|
|
5437
5438
|
u.push([o[c], a[c] - o[c] - l[c]]);
|
|
5438
5439
|
return { x: () => sr(s, u) };
|
|
@@ -5454,7 +5455,7 @@ const Fd = {
|
|
|
5454
5455
|
* limitations under the License.
|
|
5455
5456
|
* =============================================================================
|
|
5456
5457
|
*/
|
|
5457
|
-
const
|
|
5458
|
+
const Od = {
|
|
5458
5459
|
kernelName: tu,
|
|
5459
5460
|
outputsToSave: [!0],
|
|
5460
5461
|
gradFunc: (s, t, e) => {
|
|
@@ -5480,12 +5481,12 @@ const Md = {
|
|
|
5480
5481
|
* limitations under the License.
|
|
5481
5482
|
* =============================================================================
|
|
5482
5483
|
*/
|
|
5483
|
-
const
|
|
5484
|
+
const Rd = {
|
|
5484
5485
|
kernelName: zi,
|
|
5485
5486
|
inputsToSave: ["x"],
|
|
5486
5487
|
gradFunc: (s, t) => {
|
|
5487
5488
|
const [e] = t;
|
|
5488
|
-
return { x: () => w(s,
|
|
5489
|
+
return { x: () => w(s, Yn(e)) };
|
|
5489
5490
|
}
|
|
5490
5491
|
};
|
|
5491
5492
|
/**
|
|
@@ -5508,7 +5509,7 @@ const Ys = {
|
|
|
5508
5509
|
kernelName: Mi,
|
|
5509
5510
|
gradFunc: (s, t, e) => {
|
|
5510
5511
|
const { blockShape: n, paddings: i } = e;
|
|
5511
|
-
return { x: () =>
|
|
5512
|
+
return { x: () => cc(s, n, i) };
|
|
5512
5513
|
}
|
|
5513
5514
|
};
|
|
5514
5515
|
/**
|
|
@@ -5550,7 +5551,7 @@ const Qs = {
|
|
|
5550
5551
|
* limitations under the License.
|
|
5551
5552
|
* =============================================================================
|
|
5552
5553
|
*/
|
|
5553
|
-
const
|
|
5554
|
+
const _d = {
|
|
5554
5555
|
kernelName: nu,
|
|
5555
5556
|
inputsToSave: ["x"],
|
|
5556
5557
|
gradFunc: (s, t) => {
|
|
@@ -5574,7 +5575,7 @@ const Rd = {
|
|
|
5574
5575
|
* limitations under the License.
|
|
5575
5576
|
* =============================================================================
|
|
5576
5577
|
*/
|
|
5577
|
-
const
|
|
5578
|
+
const Bd = {
|
|
5578
5579
|
kernelName: su,
|
|
5579
5580
|
inputsToSave: ["x"],
|
|
5580
5581
|
gradFunc: (s, t) => {
|
|
@@ -5598,7 +5599,7 @@ const _d = {
|
|
|
5598
5599
|
* limitations under the License.
|
|
5599
5600
|
* =============================================================================
|
|
5600
5601
|
*/
|
|
5601
|
-
const
|
|
5602
|
+
const Wd = {
|
|
5602
5603
|
kernelName: iu,
|
|
5603
5604
|
inputsToSave: ["a", "b"],
|
|
5604
5605
|
gradFunc: (s, t) => {
|
|
@@ -5622,7 +5623,7 @@ const Bd = {
|
|
|
5622
5623
|
* limitations under the License.
|
|
5623
5624
|
* =============================================================================
|
|
5624
5625
|
*/
|
|
5625
|
-
const
|
|
5626
|
+
const Gd = {
|
|
5626
5627
|
kernelName: ru,
|
|
5627
5628
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
5628
5629
|
};
|
|
@@ -5642,7 +5643,7 @@ const Wd = {
|
|
|
5642
5643
|
* limitations under the License.
|
|
5643
5644
|
* =============================================================================
|
|
5644
5645
|
*/
|
|
5645
|
-
const
|
|
5646
|
+
const Pd = {
|
|
5646
5647
|
kernelName: au,
|
|
5647
5648
|
inputsToSave: ["a", "b"],
|
|
5648
5649
|
gradFunc: (s, t) => {
|
|
@@ -5674,7 +5675,7 @@ const Gd = {
|
|
|
5674
5675
|
* limitations under the License.
|
|
5675
5676
|
* =============================================================================
|
|
5676
5677
|
*/
|
|
5677
|
-
const
|
|
5678
|
+
const Ud = {
|
|
5678
5679
|
kernelName: ou,
|
|
5679
5680
|
inputsToSave: ["x"],
|
|
5680
5681
|
gradFunc: (s, t, e) => {
|
|
@@ -5702,7 +5703,7 @@ const Pd = {
|
|
|
5702
5703
|
* limitations under the License.
|
|
5703
5704
|
* =============================================================================
|
|
5704
5705
|
*/
|
|
5705
|
-
const
|
|
5706
|
+
const Vd = {
|
|
5706
5707
|
kernelName: lu,
|
|
5707
5708
|
inputsToSave: ["x"],
|
|
5708
5709
|
gradFunc: (s, t) => {
|
|
@@ -5726,7 +5727,7 @@ const Ud = {
|
|
|
5726
5727
|
* limitations under the License.
|
|
5727
5728
|
* =============================================================================
|
|
5728
5729
|
*/
|
|
5729
|
-
const
|
|
5730
|
+
const jd = {
|
|
5730
5731
|
kernelName: yi,
|
|
5731
5732
|
outputsToSave: [!0],
|
|
5732
5733
|
gradFunc: (s, t) => {
|
|
@@ -5750,7 +5751,7 @@ const Vd = {
|
|
|
5750
5751
|
* limitations under the License.
|
|
5751
5752
|
* =============================================================================
|
|
5752
5753
|
*/
|
|
5753
|
-
const
|
|
5754
|
+
const Kd = {
|
|
5754
5755
|
kernelName: uu,
|
|
5755
5756
|
inputsToSave: ["x"],
|
|
5756
5757
|
gradFunc: (s, t, e) => {
|
|
@@ -5805,7 +5806,7 @@ const jd = {
|
|
|
5805
5806
|
* limitations under the License.
|
|
5806
5807
|
* =============================================================================
|
|
5807
5808
|
*/
|
|
5808
|
-
const
|
|
5809
|
+
const Hd = {
|
|
5809
5810
|
kernelName: cu,
|
|
5810
5811
|
gradFunc: (s, t, e) => {
|
|
5811
5812
|
const n = e, { perm: i } = n, r = ss(i);
|
|
@@ -5828,7 +5829,7 @@ const Kd = {
|
|
|
5828
5829
|
* limitations under the License.
|
|
5829
5830
|
* =============================================================================
|
|
5830
5831
|
*/
|
|
5831
|
-
const
|
|
5832
|
+
const qd = {
|
|
5832
5833
|
kernelName: hu,
|
|
5833
5834
|
gradFunc: (s, t, e) => {
|
|
5834
5835
|
const n = e, { axis: i } = n;
|
|
@@ -5851,23 +5852,23 @@ const Hd = {
|
|
|
5851
5852
|
* limitations under the License.
|
|
5852
5853
|
* =============================================================================
|
|
5853
5854
|
*/
|
|
5854
|
-
const
|
|
5855
|
+
const Zd = {
|
|
5855
5856
|
kernelName: Wi,
|
|
5856
5857
|
inputsToSave: ["segmentIds"],
|
|
5857
5858
|
gradFunc: (s, t) => {
|
|
5858
5859
|
const [e] = t;
|
|
5859
|
-
return { x: () =>
|
|
5860
|
+
return { x: () => Jd(s, e) };
|
|
5860
5861
|
}
|
|
5861
5862
|
};
|
|
5862
|
-
function
|
|
5863
|
+
function Jd(s, t) {
|
|
5863
5864
|
const e = Ie(t, Q(t)), n = Qi(s, e);
|
|
5864
|
-
let i =
|
|
5865
|
+
let i = Ke(t, tt(0, "int32"));
|
|
5865
5866
|
const r = n.rank - i.rank;
|
|
5866
5867
|
for (let o = 0; o < r; ++o)
|
|
5867
|
-
i =
|
|
5868
|
-
i =
|
|
5868
|
+
i = ce(i, o + 1);
|
|
5869
|
+
i = je(i, pe(n.shape, "bool"));
|
|
5869
5870
|
const a = Q(n);
|
|
5870
|
-
return
|
|
5871
|
+
return qt(i, n, a);
|
|
5871
5872
|
}
|
|
5872
5873
|
/**
|
|
5873
5874
|
* @license
|
|
@@ -5885,7 +5886,7 @@ function Zd(s, t) {
|
|
|
5885
5886
|
* limitations under the License.
|
|
5886
5887
|
* =============================================================================
|
|
5887
5888
|
*/
|
|
5888
|
-
const
|
|
5889
|
+
const Xd = {
|
|
5889
5890
|
kernelName: pu,
|
|
5890
5891
|
gradFunc: (s) => ({ x: () => Q(s) })
|
|
5891
5892
|
};
|
|
@@ -5905,9 +5906,8 @@ const Jd = {
|
|
|
5905
5906
|
* limitations under the License.
|
|
5906
5907
|
* =============================================================================
|
|
5907
5908
|
*/
|
|
5908
|
-
const
|
|
5909
|
+
const Yd = [
|
|
5909
5910
|
fr,
|
|
5910
|
-
qh,
|
|
5911
5911
|
Zh,
|
|
5912
5912
|
Jh,
|
|
5913
5913
|
Xh,
|
|
@@ -5918,8 +5918,8 @@ const Xd = [
|
|
|
5918
5918
|
np,
|
|
5919
5919
|
sp,
|
|
5920
5920
|
ip,
|
|
5921
|
-
|
|
5922
|
-
|
|
5921
|
+
rp,
|
|
5922
|
+
lp,
|
|
5923
5923
|
hp,
|
|
5924
5924
|
pp,
|
|
5925
5925
|
dp,
|
|
@@ -5928,23 +5928,23 @@ const Xd = [
|
|
|
5928
5928
|
gp,
|
|
5929
5929
|
bp,
|
|
5930
5930
|
yp,
|
|
5931
|
-
kp,
|
|
5932
5931
|
wp,
|
|
5933
|
-
|
|
5932
|
+
xp,
|
|
5933
|
+
kp,
|
|
5934
5934
|
Sp,
|
|
5935
5935
|
Ap,
|
|
5936
5936
|
Cp,
|
|
5937
5937
|
Ip,
|
|
5938
5938
|
Dp,
|
|
5939
|
-
yd,
|
|
5940
5939
|
zp,
|
|
5940
|
+
wd,
|
|
5941
5941
|
Tp,
|
|
5942
5942
|
$p,
|
|
5943
5943
|
Ep,
|
|
5944
5944
|
Lp,
|
|
5945
|
-
Mp,
|
|
5946
5945
|
Fp,
|
|
5947
5946
|
Op,
|
|
5947
|
+
Mp,
|
|
5948
5948
|
Rp,
|
|
5949
5949
|
_p,
|
|
5950
5950
|
Bp,
|
|
@@ -5955,12 +5955,12 @@ const Xd = [
|
|
|
5955
5955
|
Vp,
|
|
5956
5956
|
jp,
|
|
5957
5957
|
Kp,
|
|
5958
|
-
|
|
5958
|
+
Hp,
|
|
5959
|
+
Jp,
|
|
5959
5960
|
Js,
|
|
5960
5961
|
Js,
|
|
5961
|
-
|
|
5962
|
-
|
|
5963
|
-
nd,
|
|
5962
|
+
Xp,
|
|
5963
|
+
td,
|
|
5964
5964
|
sd,
|
|
5965
5965
|
id,
|
|
5966
5966
|
rd,
|
|
@@ -5971,16 +5971,16 @@ const Xd = [
|
|
|
5971
5971
|
cd,
|
|
5972
5972
|
hd,
|
|
5973
5973
|
pd,
|
|
5974
|
+
dd,
|
|
5974
5975
|
Xs,
|
|
5975
5976
|
Xs,
|
|
5976
|
-
dd,
|
|
5977
5977
|
fd,
|
|
5978
|
-
|
|
5979
|
-
|
|
5978
|
+
md,
|
|
5979
|
+
yd,
|
|
5980
5980
|
kd,
|
|
5981
5981
|
xd,
|
|
5982
|
-
vd,
|
|
5983
5982
|
Nd,
|
|
5983
|
+
vd,
|
|
5984
5984
|
Sd,
|
|
5985
5985
|
Ad,
|
|
5986
5986
|
Cd,
|
|
@@ -5994,14 +5994,14 @@ const Xd = [
|
|
|
5994
5994
|
Fd,
|
|
5995
5995
|
Md,
|
|
5996
5996
|
Od,
|
|
5997
|
+
Rd,
|
|
5997
5998
|
Ys,
|
|
5998
5999
|
Ys,
|
|
5999
6000
|
Qs,
|
|
6000
6001
|
Qs,
|
|
6001
|
-
Rd,
|
|
6002
|
-
Bd,
|
|
6003
6002
|
_d,
|
|
6004
6003
|
Wd,
|
|
6004
|
+
Bd,
|
|
6005
6005
|
Gd,
|
|
6006
6006
|
Pd,
|
|
6007
6007
|
Ud,
|
|
@@ -6010,9 +6010,10 @@ const Xd = [
|
|
|
6010
6010
|
Kd,
|
|
6011
6011
|
Hd,
|
|
6012
6012
|
qd,
|
|
6013
|
-
|
|
6013
|
+
Zd,
|
|
6014
|
+
Xd
|
|
6014
6015
|
];
|
|
6015
|
-
for (const s of
|
|
6016
|
+
for (const s of Yd)
|
|
6016
6017
|
du(s);
|
|
6017
6018
|
/**
|
|
6018
6019
|
* @license
|
|
@@ -6062,7 +6063,7 @@ br.className = "UnitNorm";
|
|
|
6062
6063
|
S(br);
|
|
6063
6064
|
class yr extends qe {
|
|
6064
6065
|
apply(t) {
|
|
6065
|
-
return
|
|
6066
|
+
return Pe(t);
|
|
6066
6067
|
}
|
|
6067
6068
|
}
|
|
6068
6069
|
yr.className = "NonNeg";
|
|
@@ -6117,7 +6118,7 @@ function rt(s) {
|
|
|
6117
6118
|
* https://opensource.org/licenses/MIT.
|
|
6118
6119
|
* =============================================================================
|
|
6119
6120
|
*/
|
|
6120
|
-
function
|
|
6121
|
+
function Qd(s) {
|
|
6121
6122
|
return new ps(s);
|
|
6122
6123
|
}
|
|
6123
6124
|
/**
|
|
@@ -6167,7 +6168,7 @@ var ni;
|
|
|
6167
6168
|
(function(s) {
|
|
6168
6169
|
s[s.SILENT = 0] = "SILENT", s[s.VERBOSE = 1] = "VERBOSE";
|
|
6169
6170
|
})(ni || (ni = {}));
|
|
6170
|
-
const
|
|
6171
|
+
const tf = 125;
|
|
6171
6172
|
class Me {
|
|
6172
6173
|
constructor() {
|
|
6173
6174
|
this.validationData = null;
|
|
@@ -6197,7 +6198,7 @@ class Me {
|
|
|
6197
6198
|
setModel(t) {
|
|
6198
6199
|
}
|
|
6199
6200
|
}
|
|
6200
|
-
class
|
|
6201
|
+
class ef {
|
|
6201
6202
|
// TODO(cais): When the need arises, uncomment the following lines and
|
|
6202
6203
|
// implement the queue for time values.
|
|
6203
6204
|
// private deltaTBatch: number;
|
|
@@ -6282,7 +6283,7 @@ class tf {
|
|
|
6282
6283
|
await e.onTrainEnd(t);
|
|
6283
6284
|
}
|
|
6284
6285
|
}
|
|
6285
|
-
class
|
|
6286
|
+
class nf extends Me {
|
|
6286
6287
|
constructor() {
|
|
6287
6288
|
super();
|
|
6288
6289
|
}
|
|
@@ -6314,7 +6315,7 @@ class ef extends Me {
|
|
|
6314
6315
|
}));
|
|
6315
6316
|
}
|
|
6316
6317
|
}
|
|
6317
|
-
class
|
|
6318
|
+
class sf extends Me {
|
|
6318
6319
|
async onTrainBegin(t) {
|
|
6319
6320
|
this.epoch = [], this.history = {};
|
|
6320
6321
|
}
|
|
@@ -6341,11 +6342,11 @@ class nf extends Me {
|
|
|
6341
6342
|
this.history[e[r]][n[r]].dispose(), this.history[e[r]][n[r]] = i[r][0];
|
|
6342
6343
|
}
|
|
6343
6344
|
}
|
|
6344
|
-
class
|
|
6345
|
+
class rf extends Me {
|
|
6345
6346
|
constructor(t, e) {
|
|
6346
|
-
if (super(), this.currentEpoch = 0, this.nowFunc = t.nowFunc, this.nextFrameFunc = t.nextFrameFunc ||
|
|
6347
|
+
if (super(), this.currentEpoch = 0, this.nowFunc = t.nowFunc, this.nextFrameFunc = t.nextFrameFunc || Ah, this.yieldEvery = e || "auto", this.yieldEvery === "auto" && (this.yieldEvery = tf), this.yieldEvery === "never" && t.onYield != null)
|
|
6347
6348
|
throw new Error("yieldEvery is `never` but you provided an `onYield` callback. Either change `yieldEvery` or remove the callback");
|
|
6348
|
-
Fs(this.yieldEvery) && (this.maybeWait =
|
|
6349
|
+
Fs(this.yieldEvery) && (this.maybeWait = Su(this.maybeWait.bind(this), this.yieldEvery, this.nowFunc)), this.trainBegin = t.onTrainBegin, this.trainEnd = t.onTrainEnd, this.epochBegin = t.onEpochBegin, this.epochEnd = t.onEpochEnd, this.batchBegin = t.onBatchBegin, this.batchEnd = t.onBatchEnd, this.yield = t.onYield;
|
|
6349
6350
|
}
|
|
6350
6351
|
async maybeWait(t, e, n) {
|
|
6351
6352
|
const i = [];
|
|
@@ -6373,7 +6374,7 @@ class sf extends Me {
|
|
|
6373
6374
|
}
|
|
6374
6375
|
}
|
|
6375
6376
|
function xr(s, t) {
|
|
6376
|
-
return s == null && (s = {}), s instanceof Me ? [s] : Array.isArray(s) && s[0] instanceof Me ? s : K(s).map((n) => new
|
|
6377
|
+
return s == null && (s = {}), s instanceof Me ? [s] : Array.isArray(s) && s[0] instanceof Me ? s : K(s).map((n) => new rf(n, t));
|
|
6377
6378
|
}
|
|
6378
6379
|
class yt {
|
|
6379
6380
|
/**
|
|
@@ -6427,13 +6428,13 @@ class yt {
|
|
|
6427
6428
|
}
|
|
6428
6429
|
}
|
|
6429
6430
|
yt.constructors = {};
|
|
6430
|
-
function
|
|
6431
|
-
const u = new
|
|
6432
|
-
new
|
|
6431
|
+
function Nr(s, t, e, n, i, r, a, o, l) {
|
|
6432
|
+
const u = new sf(), c = [
|
|
6433
|
+
new nf(),
|
|
6433
6434
|
...yt.createCallbacks(t)
|
|
6434
6435
|
];
|
|
6435
6436
|
s != null && c.push(...s), c.push(u);
|
|
6436
|
-
const h = new
|
|
6437
|
+
const h = new ef(c);
|
|
6437
6438
|
return h.setParams({
|
|
6438
6439
|
epochs: e,
|
|
6439
6440
|
initialEpoch: n,
|
|
@@ -6469,47 +6470,47 @@ function Wt(s, t = {}, e = !1) {
|
|
|
6469
6470
|
function pn(s, t) {
|
|
6470
6471
|
return x(() => {
|
|
6471
6472
|
s.dtype !== "float32" && (s = L(s, "float32"));
|
|
6472
|
-
const e = B(
|
|
6473
|
+
const e = B(Ue(s), t, !0), n = fu(e.shape, nt()), i = ee(Ie(e, n));
|
|
6473
6474
|
return P(s, i);
|
|
6474
6475
|
});
|
|
6475
6476
|
}
|
|
6476
|
-
function
|
|
6477
|
-
return x(() => at(
|
|
6477
|
+
function vn(s, t) {
|
|
6478
|
+
return x(() => at(Ue(V(t, s)), -1));
|
|
6478
6479
|
}
|
|
6479
6480
|
function xs(s, t) {
|
|
6480
6481
|
return x(() => at(Fe(V(t, s)), -1));
|
|
6481
6482
|
}
|
|
6482
|
-
function
|
|
6483
|
+
function Ns(s, t) {
|
|
6483
6484
|
return x(() => {
|
|
6484
6485
|
const e = V(s, t), n = Ct(Fe(s), nt(), Number.MAX_VALUE), i = Fe(P(e, n));
|
|
6485
6486
|
return w(100, at(i, -1));
|
|
6486
6487
|
});
|
|
6487
6488
|
}
|
|
6488
|
-
function
|
|
6489
|
+
function af(s, t) {
|
|
6489
6490
|
return x(() => {
|
|
6490
6491
|
const e = Ct(t, nt(), Number.MAX_VALUE), n = Zt($(1, e)), i = Ct(s, nt(), Number.MAX_VALUE), r = Zt($(1, i));
|
|
6491
|
-
return at(
|
|
6492
|
+
return at(Ue(V(n, r)), -1);
|
|
6492
6493
|
});
|
|
6493
6494
|
}
|
|
6494
|
-
function
|
|
6495
|
+
function of(s, t) {
|
|
6495
6496
|
return x(() => {
|
|
6496
6497
|
const e = Ie(0, V(1, w(s, t)));
|
|
6497
|
-
return at(
|
|
6498
|
+
return at(Ue(e), -1);
|
|
6498
6499
|
});
|
|
6499
6500
|
}
|
|
6500
|
-
function
|
|
6501
|
+
function lf(s, t) {
|
|
6501
6502
|
return x(() => {
|
|
6502
6503
|
const e = Ie(0, V(1, w(s, t)));
|
|
6503
6504
|
return at(e, -1);
|
|
6504
6505
|
});
|
|
6505
6506
|
}
|
|
6506
|
-
function
|
|
6507
|
+
function uf(s, t) {
|
|
6507
6508
|
return x(() => {
|
|
6508
|
-
const e = B(w(s, t), -1), n =
|
|
6509
|
+
const e = B(w(s, t), -1), n = ve(w(V(1, s), t), -1);
|
|
6509
6510
|
return Ie(0, $(1, V(n, e)));
|
|
6510
6511
|
});
|
|
6511
6512
|
}
|
|
6512
|
-
function
|
|
6513
|
+
function cf(s, t) {
|
|
6513
6514
|
return x(() => {
|
|
6514
6515
|
const e = Math.log(2), n = V(t, s), i = V($(n, us(w(-2, n))), e);
|
|
6515
6516
|
return at(i, -1);
|
|
@@ -6528,59 +6529,59 @@ function Oe(s, t, e = !1) {
|
|
|
6528
6529
|
}
|
|
6529
6530
|
function dn(s, t, e = !1) {
|
|
6530
6531
|
return x(() => {
|
|
6531
|
-
const n = L(
|
|
6532
|
+
const n = L(qi(Au(s)), "int32");
|
|
6532
6533
|
t = Ct(t, nt(), 1 - nt());
|
|
6533
|
-
const i = t.shape, r = A(
|
|
6534
|
+
const i = t.shape, r = A(Qc(n, i[i.length - 1]), i);
|
|
6534
6535
|
return Oe(r, t, e);
|
|
6535
6536
|
});
|
|
6536
6537
|
}
|
|
6537
|
-
function
|
|
6538
|
+
function hf(s, t) {
|
|
6538
6539
|
if (!Ft(s.shape, t.shape))
|
|
6539
6540
|
throw new d(`logits and labels must have the same shape, but got shapes ${JSON.stringify(s.shape)} and ${JSON.stringify(t.shape)}`);
|
|
6540
6541
|
return x(() => {
|
|
6541
|
-
const e =
|
|
6542
|
-
return $(V(e, w(t, s)),
|
|
6542
|
+
const e = Pe(t), n = pt(Fe(t));
|
|
6543
|
+
return $(V(e, w(t, s)), Gc(Jt(n)));
|
|
6543
6544
|
});
|
|
6544
6545
|
}
|
|
6545
6546
|
function Sn(s, t) {
|
|
6546
6547
|
return x(() => {
|
|
6547
6548
|
let e;
|
|
6548
|
-
return e = Ct(t, nt(), 1 - nt()), e = Zt(P(e, V(1, e))), at(
|
|
6549
|
+
return e = Ct(t, nt(), 1 - nt()), e = Zt(P(e, V(1, e))), at(hf(s, e), -1);
|
|
6549
6550
|
});
|
|
6550
6551
|
}
|
|
6551
|
-
function
|
|
6552
|
+
function pf(s, t) {
|
|
6552
6553
|
return x(() => {
|
|
6553
6554
|
const e = Ct(s, nt(), 1), n = Ct(t, nt(), 1);
|
|
6554
6555
|
return B(w(s, Zt(P(e, n))), -1);
|
|
6555
6556
|
});
|
|
6556
6557
|
}
|
|
6557
|
-
function
|
|
6558
|
+
function df(s, t) {
|
|
6558
6559
|
return x(() => {
|
|
6559
6560
|
const e = Zt($(nt(), t));
|
|
6560
6561
|
return at(V(t, w(s, e)), -1);
|
|
6561
6562
|
});
|
|
6562
6563
|
}
|
|
6563
|
-
function
|
|
6564
|
+
function vr(s, t) {
|
|
6564
6565
|
return x(() => {
|
|
6565
6566
|
const e = pn(s, -1), n = pn(t, -1), i = w(e, n);
|
|
6566
6567
|
return pt(B(i, -1));
|
|
6567
6568
|
});
|
|
6568
6569
|
}
|
|
6569
6570
|
const fn = {
|
|
6570
|
-
meanSquaredError:
|
|
6571
|
+
meanSquaredError: vn,
|
|
6571
6572
|
meanAbsoluteError: xs,
|
|
6572
|
-
meanAbsolutePercentageError:
|
|
6573
|
-
meanSquaredLogarithmicError:
|
|
6574
|
-
squaredHinge:
|
|
6575
|
-
hinge:
|
|
6576
|
-
categoricalHinge:
|
|
6577
|
-
logcosh:
|
|
6573
|
+
meanAbsolutePercentageError: Ns,
|
|
6574
|
+
meanSquaredLogarithmicError: af,
|
|
6575
|
+
squaredHinge: of,
|
|
6576
|
+
hinge: lf,
|
|
6577
|
+
categoricalHinge: uf,
|
|
6578
|
+
logcosh: cf,
|
|
6578
6579
|
categoricalCrossentropy: Oe,
|
|
6579
6580
|
sparseCategoricalCrossentropy: dn,
|
|
6580
6581
|
binaryCrossentropy: Sn,
|
|
6581
|
-
kullbackLeiblerDivergence:
|
|
6582
|
-
poisson:
|
|
6583
|
-
cosineProximity:
|
|
6582
|
+
kullbackLeiblerDivergence: pf,
|
|
6583
|
+
poisson: df,
|
|
6584
|
+
cosineProximity: vr
|
|
6584
6585
|
};
|
|
6585
6586
|
function $n(s) {
|
|
6586
6587
|
if (typeof s == "string") {
|
|
@@ -6609,39 +6610,39 @@ function Sr(s, t) {
|
|
|
6609
6610
|
function Ar(s, t) {
|
|
6610
6611
|
return x(() => Lt(Xt(sn(s, -1), sn(t, -1)), "float32"));
|
|
6611
6612
|
}
|
|
6612
|
-
function df(s, t) {
|
|
6613
|
-
return x(() => L(B(Pe(Xt(s, 1), Xt(t, 1))), "float32"));
|
|
6614
|
-
}
|
|
6615
6613
|
function ff(s, t) {
|
|
6616
|
-
return x(() => L(B(
|
|
6614
|
+
return x(() => L(B(je(Xt(s, 1), Xt(t, 1))), "float32"));
|
|
6617
6615
|
}
|
|
6618
6616
|
function mf(s, t) {
|
|
6617
|
+
return x(() => L(B(je(Xt(s, 0), Xt(t, 1))), "float32"));
|
|
6618
|
+
}
|
|
6619
|
+
function gf(s, t) {
|
|
6619
6620
|
return x(() => {
|
|
6620
|
-
const e =
|
|
6621
|
-
return L(
|
|
6621
|
+
const e = ff(s, t), n = mf(s, t), i = $(e, n);
|
|
6622
|
+
return L(qt(Gt(i, 0), P(e, i), 0), "float32");
|
|
6622
6623
|
});
|
|
6623
6624
|
}
|
|
6624
|
-
function
|
|
6625
|
+
function bf(s, t) {
|
|
6625
6626
|
return Sn(s, t);
|
|
6626
6627
|
}
|
|
6627
|
-
function
|
|
6628
|
-
return s.rank === t.rank && (s =
|
|
6628
|
+
function yf(s, t) {
|
|
6629
|
+
return s.rank === t.rank && (s = ns(s, [s.rank - 1])), t = sn(t, -1), t.dtype !== s.dtype && (t = L(t, s.dtype)), L(Xt(s, t), "float32");
|
|
6629
6630
|
}
|
|
6630
|
-
const
|
|
6631
|
+
const wf = vn, kf = vn, xf = xs, Nf = xs, vf = Ns, Sf = Ns, Cr = Oe, Af = vr, Ir = dn, mn = {
|
|
6631
6632
|
binaryAccuracy: Sr,
|
|
6632
6633
|
categoricalAccuracy: Ar,
|
|
6633
|
-
precision:
|
|
6634
|
+
precision: gf,
|
|
6634
6635
|
categoricalCrossentropy: Cr,
|
|
6635
6636
|
sparseCategoricalCrossentropy: Ir,
|
|
6636
|
-
mse:
|
|
6637
|
-
MSE:
|
|
6638
|
-
mae:
|
|
6639
|
-
MAE:
|
|
6637
|
+
mse: wf,
|
|
6638
|
+
MSE: kf,
|
|
6639
|
+
mae: xf,
|
|
6640
|
+
MAE: Nf,
|
|
6640
6641
|
mape: vf,
|
|
6641
|
-
MAPE:
|
|
6642
|
-
cosine:
|
|
6642
|
+
MAPE: Sf,
|
|
6643
|
+
cosine: Af
|
|
6643
6644
|
};
|
|
6644
|
-
function
|
|
6645
|
+
function Cf(s) {
|
|
6645
6646
|
if (typeof s == "string" && s in mn)
|
|
6646
6647
|
return mn[s];
|
|
6647
6648
|
if (typeof s != "string" && s != null)
|
|
@@ -6677,7 +6678,7 @@ function tn(s) {
|
|
|
6677
6678
|
* https://opensource.org/licenses/MIT.
|
|
6678
6679
|
* =============================================================================
|
|
6679
6680
|
*/
|
|
6680
|
-
function
|
|
6681
|
+
function If(s) {
|
|
6681
6682
|
const t = {
|
|
6682
6683
|
Adagrad: () => ge.adagrad(0.01),
|
|
6683
6684
|
Adadelta: () => ge.adadelta(1, 0.95, nt()),
|
|
@@ -6739,8 +6740,8 @@ function Pn(s) {
|
|
|
6739
6740
|
* https://opensource.org/licenses/MIT.
|
|
6740
6741
|
* =============================================================================
|
|
6741
6742
|
*/
|
|
6742
|
-
function
|
|
6743
|
-
const i =
|
|
6743
|
+
function Df(s, t, e, n = console.log) {
|
|
6744
|
+
const i = Tf(s), r = ["Layer (type)", "Input Shape", "Output shape", "Param #"];
|
|
6744
6745
|
i ? (t = t || 90, e = e || [0.32, 0.61, 0.89, 1]) : (t = t || 115, e = e || [0.24, 0.48, 0.7, 0.8, 1]), e[e.length - 1] <= 1 && (e = e.map((c) => Math.floor(t * c)));
|
|
6745
6746
|
let a;
|
|
6746
6747
|
if (!i) {
|
|
@@ -6751,16 +6752,16 @@ function If(s, t, e, n = console.log) {
|
|
|
6751
6752
|
n("_".repeat(t)), gn(r, e, n), n("=".repeat(t));
|
|
6752
6753
|
const o = s.layers;
|
|
6753
6754
|
for (let c = 0; c < o.length; ++c)
|
|
6754
|
-
i ?
|
|
6755
|
+
i ? $f(o[c], e, n) : Ef(o[c], e, a, n), n((c === o.length - 1 ? "=" : "_").repeat(t));
|
|
6755
6756
|
s.checkTrainableWeightsConsistency();
|
|
6756
|
-
const l =
|
|
6757
|
+
const l = zf(s), u = un(s.nonTrainableWeights);
|
|
6757
6758
|
n(`Total params: ${l + u}`), n(`Trainable params: ${l}`), n(`Non-trainable params: ${u}`), n("_".repeat(t));
|
|
6758
6759
|
}
|
|
6759
|
-
function
|
|
6760
|
+
function zf(s) {
|
|
6760
6761
|
let t;
|
|
6761
6762
|
return s.collectedTrainableWeights != null ? t = un(s.collectedTrainableWeights) : t = un(s.trainableWeights), t;
|
|
6762
6763
|
}
|
|
6763
|
-
function
|
|
6764
|
+
function Tf(s) {
|
|
6764
6765
|
let t = !0;
|
|
6765
6766
|
const e = [], n = [];
|
|
6766
6767
|
for (const i in s.nodesByDepth)
|
|
@@ -6793,7 +6794,7 @@ function gn(s, t, e = console.log) {
|
|
|
6793
6794
|
i > 0 && (n = n.slice(0, n.length - 1) + " "), n += s[i], n = n.slice(0, t[i]), n += " ".repeat(t[i] - n.length);
|
|
6794
6795
|
e(n);
|
|
6795
6796
|
}
|
|
6796
|
-
function
|
|
6797
|
+
function $f(s, t, e) {
|
|
6797
6798
|
let n, i;
|
|
6798
6799
|
try {
|
|
6799
6800
|
i = s.inboundNodes.map((l) => JSON.stringify(l.inputShapes)).join(",");
|
|
@@ -6813,7 +6814,7 @@ function Tf(s, t, e) {
|
|
|
6813
6814
|
];
|
|
6814
6815
|
gn(o, t, e);
|
|
6815
6816
|
}
|
|
6816
|
-
function
|
|
6817
|
+
function Ef(s, t, e, n) {
|
|
6817
6818
|
let i, r;
|
|
6818
6819
|
try {
|
|
6819
6820
|
r = s.inboundNodes.map((h) => JSON.stringify(h.inputShapes)).join(",");
|
|
@@ -6917,14 +6918,14 @@ const zr = "4.22.0";
|
|
|
6917
6918
|
* https://opensource.org/licenses/MIT.
|
|
6918
6919
|
* =============================================================================
|
|
6919
6920
|
*/
|
|
6920
|
-
const
|
|
6921
|
+
const Lf = (s) => {
|
|
6921
6922
|
const t = Object.keys(s);
|
|
6922
6923
|
if (t.length === 0)
|
|
6923
6924
|
return !1;
|
|
6924
6925
|
const e = t[0].split("/");
|
|
6925
6926
|
return !isNaN(parseInt(e[e.length - 1], 10));
|
|
6926
6927
|
};
|
|
6927
|
-
class
|
|
6928
|
+
class vt extends W {
|
|
6928
6929
|
constructor(t) {
|
|
6929
6930
|
if (super({}), this.containerNodes = /* @__PURE__ */ new Set(), this.name = t.name, this.name == null) {
|
|
6930
6931
|
const y = this.getClassName().toLowerCase();
|
|
@@ -6934,12 +6935,12 @@ class Nt extends W {
|
|
|
6934
6935
|
throw new d(`The list of inputs passed to the model is redundant. All inputs should only appear once. Found: ${this.inputs.map((y) => y.name)}`);
|
|
6935
6936
|
jt(this.outputs).length !== this.outputs.length && console.warn(`The list of outputs passed to the model is redundant. All outputs should only appear once. Found: ${this.outputs.map((y) => y.name)}`), this.inputLayers = [], this.inputLayersNodeIndices = [], this.inputLayersTensorIndices = [], this.outputLayers = [], this.outputLayersNodeIndices = [], this.outputLayersTensorIndices = [], this.layers = [], this.internalContainerRefs = [];
|
|
6936
6937
|
for (const y of this.outputs) {
|
|
6937
|
-
const C = y.sourceLayer,
|
|
6938
|
-
this.outputLayers.push(C), this.outputLayersNodeIndices.push(
|
|
6938
|
+
const C = y.sourceLayer, N = y.nodeIndex, I = y.tensorIndex;
|
|
6939
|
+
this.outputLayers.push(C), this.outputLayersNodeIndices.push(N), this.outputLayersTensorIndices.push(I);
|
|
6939
6940
|
}
|
|
6940
6941
|
for (const y of this.inputs) {
|
|
6941
|
-
const C = y.sourceLayer,
|
|
6942
|
-
Ut(
|
|
6942
|
+
const C = y.sourceLayer, N = y.nodeIndex, I = y.tensorIndex;
|
|
6943
|
+
Ut(N === 0, "input layer has >1 nodes"), Ut(I === 0, "input layer has >1 tensors"), this.inputLayers.push(C), this.inputLayersNodeIndices.push(N), this.inputLayersTensorIndices.push(I);
|
|
6943
6944
|
}
|
|
6944
6945
|
this.inputNames = [], this.outputNames = [], this.feedInputShapes = [], this.feedInputNames = [], this.feedOutputNames = [];
|
|
6945
6946
|
for (let y = 0; y < this.inputLayers.length; y++) {
|
|
@@ -6951,21 +6952,21 @@ class Nt extends W {
|
|
|
6951
6952
|
for (const y of this.outputLayers)
|
|
6952
6953
|
this.outputNames.push(y.name);
|
|
6953
6954
|
this.internalInputShapes = this.inputs.map((y) => y.shape), this.internalOutputShapes = this.outputs.map((y) => y.shape);
|
|
6954
|
-
const e = {}, n = {}, i = {}, r = {}, a = {}, o = [], l = (y, C,
|
|
6955
|
+
const e = {}, n = {}, i = {}, r = {}, a = {}, o = [], l = (y, C, N, I, z, _) => {
|
|
6955
6956
|
(I == null || z == null || _ == null) && (I = y.sourceLayer, z = y.nodeIndex, _ = y.tensorIndex);
|
|
6956
6957
|
const T = I.inboundNodes[z];
|
|
6957
|
-
if (
|
|
6958
|
+
if (N.indexOf(T) !== -1)
|
|
6958
6959
|
throw new Et(`The tensor ${y.name} at layer "${I.name}" is part of a cycle.`);
|
|
6959
6960
|
if (C.indexOf(T) !== -1)
|
|
6960
6961
|
return;
|
|
6961
|
-
this.containerNodes.add(
|
|
6962
|
+
this.containerNodes.add(vt.nodeKey(I, z)), I.id in a || (a[I.id] = Object.keys(a).length), N.indexOf(T) === -1 && N.push(T);
|
|
6962
6963
|
const E = T.inboundLayers.length;
|
|
6963
6964
|
for (let R = 0; R < E; R++) {
|
|
6964
6965
|
const q = T.inputTensors[R], bt = T.inboundLayers[R], ie = T.nodeIndices[R], re = T.tensorIndices[R];
|
|
6965
|
-
l(q, C,
|
|
6966
|
+
l(q, C, N, bt, ie, re);
|
|
6966
6967
|
}
|
|
6967
|
-
for (C.push(T);
|
|
6968
|
-
|
|
6968
|
+
for (C.push(T); N.indexOf(T) >= 0; )
|
|
6969
|
+
N.splice(N.indexOf(T), 1);
|
|
6969
6970
|
o.push(T);
|
|
6970
6971
|
}, u = [], c = [];
|
|
6971
6972
|
for (const y of this.outputs)
|
|
@@ -6974,8 +6975,8 @@ class Nt extends W {
|
|
|
6974
6975
|
for (const y of h) {
|
|
6975
6976
|
n[y.id] = y, y.id in e || (e[y.id] = 0);
|
|
6976
6977
|
let C = e[y.id];
|
|
6977
|
-
const
|
|
6978
|
-
C = Math.max(C,
|
|
6978
|
+
const N = i[y.outboundLayer.id] == null ? 0 : i[y.outboundLayer.id];
|
|
6979
|
+
C = Math.max(C, N), i[y.outboundLayer.id] = C, r[y.outboundLayer.id] = y.outboundLayer, e[y.id] = C;
|
|
6979
6980
|
for (let I = 0; I < y.inboundLayers.length; I++) {
|
|
6980
6981
|
const z = y.inboundLayers[I], _ = y.nodeIndices[I], T = z.inboundNodes[_], E = e[T.id] == null ? 0 : e[T.id];
|
|
6981
6982
|
e[T.id] = Math.max(C + 1, E), n[T.id] = T;
|
|
@@ -6995,35 +6996,35 @@ class Nt extends W {
|
|
|
6995
6996
|
this.layers = [];
|
|
6996
6997
|
for (const y of g) {
|
|
6997
6998
|
const C = f[y];
|
|
6998
|
-
C.sort((
|
|
6999
|
-
const z = a[
|
|
6999
|
+
C.sort((N, I) => {
|
|
7000
|
+
const z = a[N.id], _ = a[I.id];
|
|
7000
7001
|
return z < _ ? -1 : z > _ ? 1 : 0;
|
|
7001
7002
|
});
|
|
7002
|
-
for (const
|
|
7003
|
-
|
|
7003
|
+
for (const N of C)
|
|
7004
|
+
N instanceof vt && this.internalContainerRefs.push(N), this.layers.push(N);
|
|
7004
7005
|
}
|
|
7005
7006
|
this.layersByDepth = f, g = Object.keys(p).map((y) => parseInt(y, 10)).sort(Xe);
|
|
7006
7007
|
const b = this.inputs.slice(), m = [];
|
|
7007
7008
|
for (const y of g)
|
|
7008
7009
|
for (const C of p[y]) {
|
|
7009
|
-
const
|
|
7010
|
-
if (
|
|
7010
|
+
const N = C.outboundLayer;
|
|
7011
|
+
if (N != null) {
|
|
7011
7012
|
for (const I of C.inputTensors)
|
|
7012
7013
|
if (b.indexOf(I) === -1)
|
|
7013
|
-
throw new Et(`Graph disconnected: cannot obtain value for tensor ${I} at layer "${
|
|
7014
|
+
throw new Et(`Graph disconnected: cannot obtain value for tensor ${I} at layer "${N.name}". The following previous layers were accessed without issue: ${m}`);
|
|
7014
7015
|
for (const I of C.outputTensors)
|
|
7015
7016
|
b.push(I);
|
|
7016
|
-
m.push(
|
|
7017
|
+
m.push(N.name);
|
|
7017
7018
|
}
|
|
7018
7019
|
}
|
|
7019
7020
|
this.nodesByDepth = p;
|
|
7020
|
-
const
|
|
7021
|
-
for (const y of
|
|
7022
|
-
const C =
|
|
7021
|
+
const v = this.layers.map((y) => y.name);
|
|
7022
|
+
for (const y of v) {
|
|
7023
|
+
const C = v.filter((N) => N === y).length;
|
|
7023
7024
|
if (C !== 1)
|
|
7024
|
-
throw new Et(`The name "${y}" is used ${C} times in the model. All layer names should be unique. Layer names: ` + JSON.stringify(
|
|
7025
|
+
throw new Et(`The name "${y}" is used ${C} times in the model. All layer names should be unique. Layer names: ` + JSON.stringify(v));
|
|
7025
7026
|
}
|
|
7026
|
-
this.outboundNodes = [], this.inboundNodes = [], new
|
|
7027
|
+
this.outboundNodes = [], this.inboundNodes = [], new Nn({
|
|
7027
7028
|
outboundLayer: this,
|
|
7028
7029
|
inboundLayers: [],
|
|
7029
7030
|
nodeIndices: [],
|
|
@@ -7128,7 +7129,7 @@ class Nt extends W {
|
|
|
7128
7129
|
loadWeights(t, e = !0) {
|
|
7129
7130
|
const n = {};
|
|
7130
7131
|
let i = 0;
|
|
7131
|
-
const r =
|
|
7132
|
+
const r = Lf(t);
|
|
7132
7133
|
r && this.parseWeights(t);
|
|
7133
7134
|
for (const o of this.layers)
|
|
7134
7135
|
for (const [l, u] of o.weights.entries()) {
|
|
@@ -7224,7 +7225,7 @@ class Nt extends W {
|
|
|
7224
7225
|
return x(() => {
|
|
7225
7226
|
t = K(t);
|
|
7226
7227
|
let n;
|
|
7227
|
-
return e == null ? n =
|
|
7228
|
+
return e == null ? n = ue(null, t.length) : n = K(e), this.runInternalGraph(t, n)[1];
|
|
7228
7229
|
});
|
|
7229
7230
|
}
|
|
7230
7231
|
/**
|
|
@@ -7255,8 +7256,8 @@ class Nt extends W {
|
|
|
7255
7256
|
continue;
|
|
7256
7257
|
const h = [];
|
|
7257
7258
|
for (let b = 0; b < u.inboundLayers.length; b++) {
|
|
7258
|
-
const m = u.inboundLayers[b],
|
|
7259
|
-
h.push(
|
|
7259
|
+
const m = u.inboundLayers[b], v = u.nodeIndices[b], y = u.tensorIndices[b], C = `${m.name}_${v}_${y}`, N = n[C];
|
|
7260
|
+
h.push(N);
|
|
7260
7261
|
}
|
|
7261
7262
|
const p = c.computeOutputShape(ht(h)), f = ln(p), g = c.inboundNodes.indexOf(u);
|
|
7262
7263
|
for (let b = 0; b < f.length; b++) {
|
|
@@ -7287,7 +7288,7 @@ class Nt extends W {
|
|
|
7287
7288
|
* @return Three lists: outputTensors, outputMasks, outputShapes
|
|
7288
7289
|
*/
|
|
7289
7290
|
runInternalGraph(t, e) {
|
|
7290
|
-
e == null && (e =
|
|
7291
|
+
e == null && (e = ue(null, t.length));
|
|
7291
7292
|
const n = {};
|
|
7292
7293
|
for (let l = 0; l < this.inputs.length; ++l) {
|
|
7293
7294
|
const u = this.inputs[l], c = t[l], h = e[l];
|
|
@@ -7301,16 +7302,16 @@ class Nt extends W {
|
|
|
7301
7302
|
for (const b of p)
|
|
7302
7303
|
b.id in n && g.push(n[b.id]);
|
|
7303
7304
|
if (g.length === p.length) {
|
|
7304
|
-
let b = {}, m,
|
|
7305
|
+
let b = {}, m, v, y, C;
|
|
7305
7306
|
if (c.callArgs != null && (b = c.callArgs), g.length === 1) {
|
|
7306
|
-
const [
|
|
7307
|
-
b.mask == null && (b.mask = I), y = K(h.call(
|
|
7307
|
+
const [N, I] = g[0];
|
|
7308
|
+
b.mask == null && (b.mask = I), y = K(h.call(N, b)), C = K(h.computeMask(N, I)), m = [N], v = [I];
|
|
7308
7309
|
} else
|
|
7309
|
-
m = g.map((
|
|
7310
|
+
m = g.map((N) => N[0]), v = g.map((N) => N[1]), b.mask == null && (b.mask = v), y = K(h.call(m, b)), C = K(h.computeMask(m, v));
|
|
7310
7311
|
if (h.activityRegularizer)
|
|
7311
7312
|
throw new G("LayersModel invocation with concrete Tensor value(s) in the presence of activity regularizer(s) is not supported yet.");
|
|
7312
|
-
for (let
|
|
7313
|
-
const I = f[
|
|
7313
|
+
for (let N = 0; N < f.length; ++N) {
|
|
7314
|
+
const I = f[N], z = y[N], _ = C[N];
|
|
7314
7315
|
n[I.id] = [z, _];
|
|
7315
7316
|
}
|
|
7316
7317
|
}
|
|
@@ -7336,9 +7337,9 @@ class Nt extends W {
|
|
|
7336
7337
|
const e = {};
|
|
7337
7338
|
let n;
|
|
7338
7339
|
for (const i of this.layers) {
|
|
7339
|
-
n = i instanceof
|
|
7340
|
+
n = i instanceof vt ? 1 : 0;
|
|
7340
7341
|
for (let r = 0; r < i.inboundNodes.length; r++) {
|
|
7341
|
-
const a =
|
|
7342
|
+
const a = vt.nodeKey(i, r);
|
|
7342
7343
|
this.containerNodes.has(a) && (e[a] = n, n += 1);
|
|
7343
7344
|
}
|
|
7344
7345
|
}
|
|
@@ -7371,7 +7372,7 @@ class Nt extends W {
|
|
|
7371
7372
|
const t = [];
|
|
7372
7373
|
for (const e of this.layers)
|
|
7373
7374
|
for (let n = 0; n < e.inboundNodes.length; ++n) {
|
|
7374
|
-
const i =
|
|
7375
|
+
const i = vt.nodeKey(e, n);
|
|
7375
7376
|
this.containerNodes.has(i) && t.push(...e.calculateLosses());
|
|
7376
7377
|
}
|
|
7377
7378
|
return t;
|
|
@@ -7382,7 +7383,7 @@ class Nt extends W {
|
|
|
7382
7383
|
for (const a of this.layers) {
|
|
7383
7384
|
const o = a.getClassName(), l = a.getConfig(), u = [];
|
|
7384
7385
|
for (let h = 0; h < a.inboundNodes.length; h++) {
|
|
7385
|
-
const p = a.inboundNodes[h], f =
|
|
7386
|
+
const p = a.inboundNodes[h], f = vt.nodeKey(a, h);
|
|
7386
7387
|
let g = {};
|
|
7387
7388
|
if (this.containerNodes.has(f)) {
|
|
7388
7389
|
if (p.callArgs)
|
|
@@ -7394,9 +7395,9 @@ class Nt extends W {
|
|
|
7394
7395
|
if (p.inboundLayers.length > 0) {
|
|
7395
7396
|
const b = [];
|
|
7396
7397
|
for (let m = 0; m < p.inboundLayers.length; m++) {
|
|
7397
|
-
const
|
|
7398
|
-
let I = e[
|
|
7399
|
-
I == null && (I = 0), b.push([
|
|
7398
|
+
const v = p.inboundLayers[m], y = p.nodeIndices[m], C = p.tensorIndices[m], N = vt.nodeKey(v, y);
|
|
7399
|
+
let I = e[N];
|
|
7400
|
+
I == null && (I = 0), b.push([v.name, I, C, g]);
|
|
7400
7401
|
}
|
|
7401
7402
|
u.push(b);
|
|
7402
7403
|
}
|
|
@@ -7408,7 +7409,7 @@ class Nt extends W {
|
|
|
7408
7409
|
t.layers = n;
|
|
7409
7410
|
const i = [];
|
|
7410
7411
|
for (let a = 0; a < this.inputLayers.length; a++) {
|
|
7411
|
-
const o = this.inputLayers[a], l = this.inputLayersNodeIndices[a], u =
|
|
7412
|
+
const o = this.inputLayers[a], l = this.inputLayersNodeIndices[a], u = vt.nodeKey(o, l);
|
|
7412
7413
|
if (!this.containerNodes.has(u))
|
|
7413
7414
|
continue;
|
|
7414
7415
|
let c = e[u];
|
|
@@ -7419,7 +7420,7 @@ class Nt extends W {
|
|
|
7419
7420
|
t.inputLayers = i;
|
|
7420
7421
|
const r = [];
|
|
7421
7422
|
for (let a = 0; a < this.outputLayers.length; a++) {
|
|
7422
|
-
const o = this.outputLayers[a], l = this.outputLayersNodeIndices[a], u =
|
|
7423
|
+
const o = this.outputLayers[a], l = this.outputLayersNodeIndices[a], u = vt.nodeKey(o, l);
|
|
7423
7424
|
if (!this.containerNodes.has(u))
|
|
7424
7425
|
continue;
|
|
7425
7426
|
let c = e[u];
|
|
@@ -7444,21 +7445,21 @@ class Nt extends W {
|
|
|
7444
7445
|
/** @nocollapse */
|
|
7445
7446
|
static fromConfig(t, e, n = {}, i = !1) {
|
|
7446
7447
|
const r = {}, a = {};
|
|
7447
|
-
function o(m,
|
|
7448
|
-
m.name in a ? a[m.name].push(
|
|
7448
|
+
function o(m, v) {
|
|
7449
|
+
m.name in a ? a[m.name].push(v) : a[m.name] = [v];
|
|
7449
7450
|
}
|
|
7450
|
-
function l(m,
|
|
7451
|
+
function l(m, v) {
|
|
7451
7452
|
const y = [];
|
|
7452
7453
|
let C;
|
|
7453
|
-
for (const
|
|
7454
|
-
const I =
|
|
7455
|
-
if (C =
|
|
7456
|
-
o(m,
|
|
7454
|
+
for (const N of v) {
|
|
7455
|
+
const I = N[0], z = N[1], _ = N[2];
|
|
7456
|
+
if (C = N[3] == null ? {} : N[3], !(I in r)) {
|
|
7457
|
+
o(m, v);
|
|
7457
7458
|
return;
|
|
7458
7459
|
}
|
|
7459
7460
|
const T = r[I];
|
|
7460
7461
|
if (T.inboundNodes.length <= z) {
|
|
7461
|
-
o(m,
|
|
7462
|
+
o(m, v);
|
|
7462
7463
|
return;
|
|
7463
7464
|
}
|
|
7464
7465
|
const E = T.inboundNodes[z];
|
|
@@ -7467,38 +7468,38 @@ class Nt extends W {
|
|
|
7467
7468
|
y.length > 0 && m.apply(ht(y), C);
|
|
7468
7469
|
}
|
|
7469
7470
|
function u(m) {
|
|
7470
|
-
const
|
|
7471
|
-
y.setFastWeightInitDuringBuild(i), r[
|
|
7472
|
-
if (!(
|
|
7473
|
-
throw new d(`Corrupted configuration, expected array for nodeData: ${
|
|
7474
|
-
o(y,
|
|
7471
|
+
const v = m.name, y = Wt(m, e.customObjects != null ? e.customObjects : {});
|
|
7472
|
+
y.setFastWeightInitDuringBuild(i), r[v] = y, m.inboundNodes.forEach((N) => {
|
|
7473
|
+
if (!(N instanceof Array))
|
|
7474
|
+
throw new d(`Corrupted configuration, expected array for nodeData: ${N}`);
|
|
7475
|
+
o(y, N);
|
|
7475
7476
|
});
|
|
7476
7477
|
}
|
|
7477
7478
|
const c = e.name, h = e.layers;
|
|
7478
7479
|
for (const m of h)
|
|
7479
7480
|
u(m);
|
|
7480
|
-
for (; !
|
|
7481
|
+
for (; !Cu(a); )
|
|
7481
7482
|
for (const m of h) {
|
|
7482
|
-
const
|
|
7483
|
-
if (
|
|
7484
|
-
const y = a[
|
|
7485
|
-
delete a[
|
|
7483
|
+
const v = r[m.name];
|
|
7484
|
+
if (v.name in a) {
|
|
7485
|
+
const y = a[v.name];
|
|
7486
|
+
delete a[v.name];
|
|
7486
7487
|
for (const C of y)
|
|
7487
|
-
l(
|
|
7488
|
+
l(v, C);
|
|
7488
7489
|
}
|
|
7489
7490
|
}
|
|
7490
7491
|
const p = [], f = [], g = e.inputLayers;
|
|
7491
7492
|
for (const m of g) {
|
|
7492
|
-
const
|
|
7493
|
-
Ut(
|
|
7494
|
-
const I = r[
|
|
7493
|
+
const v = m[0], y = m[1], C = m[2];
|
|
7494
|
+
Ut(v in r);
|
|
7495
|
+
const I = r[v].inboundNodes[y].outputTensors;
|
|
7495
7496
|
p.push(I[C]);
|
|
7496
7497
|
}
|
|
7497
7498
|
const b = e.outputLayers;
|
|
7498
7499
|
for (const m of b) {
|
|
7499
|
-
const
|
|
7500
|
-
Ut(
|
|
7501
|
-
const I = r[
|
|
7500
|
+
const v = m[0], y = m[1], C = m[2];
|
|
7501
|
+
Ut(v in r);
|
|
7502
|
+
const I = r[v].inboundNodes[y].outputTensors;
|
|
7502
7503
|
f.push(I[C]);
|
|
7503
7504
|
}
|
|
7504
7505
|
return new t({ inputs: p, outputs: f, name: c });
|
|
@@ -7540,7 +7541,7 @@ class Nt extends W {
|
|
|
7540
7541
|
* https://opensource.org/licenses/MIT.
|
|
7541
7542
|
* =============================================================================
|
|
7542
7543
|
*/
|
|
7543
|
-
function
|
|
7544
|
+
function Ff(s, t, e) {
|
|
7544
7545
|
const n = t.length;
|
|
7545
7546
|
if (s == null || Array.isArray(s) && s.length === 0)
|
|
7546
7547
|
return t.map((i) => null);
|
|
@@ -7559,7 +7560,7 @@ function Lf(s, t, e) {
|
|
|
7559
7560
|
throw new Error(`The model has multiple (${n}) outputs, so ${e} must be either an array with ${n} elements or an object with ${t} keys. Provided ${e} not understood: ${JSON.stringify(s)}`);
|
|
7560
7561
|
}
|
|
7561
7562
|
function Tr(s, t) {
|
|
7562
|
-
return
|
|
7563
|
+
return Ff(s, t, "classWeight");
|
|
7563
7564
|
}
|
|
7564
7565
|
async function $r(s, t, e, n) {
|
|
7565
7566
|
if (e != null) {
|
|
@@ -7585,7 +7586,7 @@ async function $r(s, t, e, n) {
|
|
|
7585
7586
|
} else
|
|
7586
7587
|
return null;
|
|
7587
7588
|
}
|
|
7588
|
-
function
|
|
7589
|
+
function Mf(s, t) {
|
|
7589
7590
|
return w(s, t);
|
|
7590
7591
|
}
|
|
7591
7592
|
/**
|
|
@@ -7597,7 +7598,7 @@ function Ff(s, t) {
|
|
|
7597
7598
|
* https://opensource.org/licenses/MIT.
|
|
7598
7599
|
* =============================================================================
|
|
7599
7600
|
*/
|
|
7600
|
-
const
|
|
7601
|
+
const Of = 32;
|
|
7601
7602
|
function Er(s, t) {
|
|
7602
7603
|
let e, n;
|
|
7603
7604
|
const i = t;
|
|
@@ -7625,12 +7626,12 @@ function ri(s, t, e) {
|
|
|
7625
7626
|
return n;
|
|
7626
7627
|
}
|
|
7627
7628
|
}
|
|
7628
|
-
function
|
|
7629
|
+
function Rf(s) {
|
|
7629
7630
|
if (s.length === 3)
|
|
7630
7631
|
throw new G("Validation with sample weights is not implemented yet.");
|
|
7631
7632
|
return { xs: s[0], ys: s[1] };
|
|
7632
7633
|
}
|
|
7633
|
-
async function
|
|
7634
|
+
async function _f(s, t, e) {
|
|
7634
7635
|
const n = e.batchesPerEpoch != null;
|
|
7635
7636
|
if (k(s.optimizer != null, () => "You must compile a model before training/testing. Use LayersModel.compile(modelCompileConfig)."), k(e != null, () => "For fitDataset(), the 2nd argument (config) is required, but it is not provided in this call."), k(e.epochs != null && e.epochs > 0 && Number.isInteger(e.epochs), () => `For fitDataset(), config.epochs is expected to be a positive integer, but got ${e.epochs}`), k(!n || e.batchesPerEpoch > 0 && Number.isInteger(e.batchesPerEpoch), () => `For fitDataset(), config.batchesPerEpoch is expected to be a positive integer if specified, but got ${e.batchesPerEpoch}`), k(
|
|
7636
7637
|
// tslint:disable-next-line:no-any
|
|
@@ -7646,19 +7647,19 @@ async function Rf(s, t, e) {
|
|
|
7646
7647
|
if (ai(e.validationData))
|
|
7647
7648
|
k(e.validationBatches == null || e.validationBatches > 0 && Number.isInteger(e.validationBatches), () => `For fitDataset() with dataset-based validation, config.validationBatches is expected not to be provided, or to be a positive integer, but got ${e.validationBatches}`);
|
|
7648
7649
|
else {
|
|
7649
|
-
const m =
|
|
7650
|
+
const m = Rf(e.validationData);
|
|
7650
7651
|
r = m.xs, a = m.ys;
|
|
7651
7652
|
}
|
|
7652
7653
|
const o = s.makeTrainFunction(), l = s.getDedupedMetricsNames();
|
|
7653
7654
|
let u;
|
|
7654
7655
|
i ? u = l.slice().concat(l.map((m) => "val_" + m)) : u = l.slice();
|
|
7655
|
-
const c = xr(e.callbacks, e.yieldEvery), h = e.verbose == null ? 1 : e.verbose, { callbackList: p, history: f } =
|
|
7656
|
+
const c = xr(e.callbacks, e.yieldEvery), h = e.verbose == null ? 1 : e.verbose, { callbackList: p, history: f } = Nr(
|
|
7656
7657
|
c,
|
|
7657
7658
|
h,
|
|
7658
7659
|
e.epochs,
|
|
7659
7660
|
null,
|
|
7660
7661
|
null,
|
|
7661
|
-
|
|
7662
|
+
Bf(t, e),
|
|
7662
7663
|
null,
|
|
7663
7664
|
// Batch size determined by the dataset itself.
|
|
7664
7665
|
i,
|
|
@@ -7669,39 +7670,39 @@ async function Rf(s, t, e) {
|
|
|
7669
7670
|
for (; g < e.epochs; ) {
|
|
7670
7671
|
const m = {};
|
|
7671
7672
|
await p.onEpochBegin(g);
|
|
7672
|
-
let
|
|
7673
|
-
for (n || (b = await t.iterator()); !n ||
|
|
7673
|
+
let v = 0, y = 0;
|
|
7674
|
+
for (n || (b = await t.iterator()); !n || v < e.batchesPerEpoch; ) {
|
|
7674
7675
|
const C = await b.next();
|
|
7675
7676
|
if (n && C.done) {
|
|
7676
|
-
console.warn(`You provided \`batchesPerEpoch\` as ${e.batchesPerEpoch}, but your dataset iterator ran out of data after ${
|
|
7677
|
+
console.warn(`You provided \`batchesPerEpoch\` as ${e.batchesPerEpoch}, but your dataset iterator ran out of data after ${v} batches; interrupting training. Make sure that your dataset can generate at least \`batchesPerEpoch * epochs\` batches (in this case, ${e.batchesPerEpoch * e.epochs} batches). You may need to use the repeat() function when building your dataset.`);
|
|
7677
7678
|
break;
|
|
7678
7679
|
}
|
|
7679
7680
|
if (C.value != null) {
|
|
7680
|
-
const { xs:
|
|
7681
|
-
z.batch = y, z.size =
|
|
7681
|
+
const { xs: N, ys: I } = Er(s, C.value), z = {};
|
|
7682
|
+
z.batch = y, z.size = N[0].shape[0], await p.onBatchBegin(y, z);
|
|
7682
7683
|
const _ = [];
|
|
7683
7684
|
if (e.classWeight != null) {
|
|
7684
7685
|
const R = Tr(e.classWeight, s.outputNames);
|
|
7685
7686
|
for (let q = 0; q < R.length; ++q)
|
|
7686
7687
|
_.push(await $r(I[q], null, R[q]));
|
|
7687
7688
|
}
|
|
7688
|
-
const T =
|
|
7689
|
+
const T = N.concat(I).concat(_), E = o(T);
|
|
7689
7690
|
Z(T);
|
|
7690
7691
|
for (let R = 0; R < l.length; ++R) {
|
|
7691
7692
|
const q = l[R], bt = E[R];
|
|
7692
7693
|
z[q] = bt, Bt(bt);
|
|
7693
7694
|
}
|
|
7694
|
-
await p.onBatchEnd(y, z), kr(z), y++,
|
|
7695
|
+
await p.onBatchEnd(y, z), kr(z), y++, v++;
|
|
7695
7696
|
}
|
|
7696
|
-
if (n ?
|
|
7697
|
+
if (n ? v >= e.batchesPerEpoch : C.done) {
|
|
7697
7698
|
if (i) {
|
|
7698
|
-
let
|
|
7699
|
-
ai(e.validationData) ?
|
|
7700
|
-
batchSize: e.validationBatchSize == null ?
|
|
7699
|
+
let N;
|
|
7700
|
+
ai(e.validationData) ? N = K(await s.evaluateDataset(e.validationData, { batches: e.validationBatches })) : N = K(s.evaluate(r, a, {
|
|
7701
|
+
batchSize: e.validationBatchSize == null ? Of : e.validationBatchSize,
|
|
7701
7702
|
verbose: 0
|
|
7702
7703
|
}));
|
|
7703
7704
|
for (let I = 0; I < s.metricsNames.length; ++I)
|
|
7704
|
-
m[`val_${s.metricsNames[I]}`] =
|
|
7705
|
+
m[`val_${s.metricsNames[I]}`] = N[I];
|
|
7705
7706
|
}
|
|
7706
7707
|
break;
|
|
7707
7708
|
}
|
|
@@ -7716,24 +7717,24 @@ async function Rf(s, t, e) {
|
|
|
7716
7717
|
s.isTraining = !1;
|
|
7717
7718
|
}
|
|
7718
7719
|
}
|
|
7719
|
-
function
|
|
7720
|
+
function Bf(s, t) {
|
|
7720
7721
|
let e = null;
|
|
7721
7722
|
return t.batchesPerEpoch != null ? e = t.batchesPerEpoch : Number.isFinite(s.size) && (e = s.size), e;
|
|
7722
7723
|
}
|
|
7723
7724
|
function ai(s) {
|
|
7724
7725
|
return typeof s.iterator == "function";
|
|
7725
7726
|
}
|
|
7726
|
-
function
|
|
7727
|
+
function Wf(s) {
|
|
7727
7728
|
return typeof s.next == "function";
|
|
7728
7729
|
}
|
|
7729
|
-
async function
|
|
7730
|
+
async function Gf(s, t, e) {
|
|
7730
7731
|
e = e || {};
|
|
7731
7732
|
const n = e.batches != null, i = s.testFunction;
|
|
7732
7733
|
let r = [];
|
|
7733
7734
|
if (e.verbose > 0)
|
|
7734
7735
|
throw new G("Verbose mode is not implemented yet.");
|
|
7735
7736
|
k(!n || e.batches > 0 && Number.isInteger(e.batches), () => `Test loop expects \`batches\` to be a positive integer, but received ${JSON.stringify(e.batches)}`);
|
|
7736
|
-
const a =
|
|
7737
|
+
const a = Wf(t) ? t : await t.iterator();
|
|
7737
7738
|
let o = 0, l = 0;
|
|
7738
7739
|
for (; !n || l < e.batches; ) {
|
|
7739
7740
|
const u = await a.next();
|
|
@@ -7745,8 +7746,8 @@ async function Wf(s, t, e) {
|
|
|
7745
7746
|
r.push(tt(0));
|
|
7746
7747
|
const g = p[0].shape[0];
|
|
7747
7748
|
for (let b = 0; b < f.length; ++b) {
|
|
7748
|
-
const m = f[b],
|
|
7749
|
-
r[b] = x(() => $(r[b], w(g, m))), l > 0 && Z(
|
|
7749
|
+
const m = f[b], v = r[b];
|
|
7750
|
+
r[b] = x(() => $(r[b], w(g, m))), l > 0 && Z(v);
|
|
7750
7751
|
}
|
|
7751
7752
|
Z(f), o += g, ++l;
|
|
7752
7753
|
}
|
|
@@ -7775,7 +7776,7 @@ function En(s) {
|
|
|
7775
7776
|
k(s > 0 && Number.isInteger(s), () => `batchSize is required to be a positive integer, but got ${s}`);
|
|
7776
7777
|
}
|
|
7777
7778
|
function Te(s, t, e) {
|
|
7778
|
-
return s == null ? [null] : Array.isArray(s) ? s.map((n) =>
|
|
7779
|
+
return s == null ? [null] : Array.isArray(s) ? s.map((n) => en(n, t, e - t)) : en(s, t, e - t);
|
|
7779
7780
|
}
|
|
7780
7781
|
function jn(s, t) {
|
|
7781
7782
|
return x(() => s == null ? null : Array.isArray(s) ? s.map((e) => jn(e, t)) : Vi(s, t.dtype === "int32" ? t : L(t, "int32")));
|
|
@@ -7802,7 +7803,7 @@ function Lr(s) {
|
|
|
7802
7803
|
}
|
|
7803
7804
|
return t;
|
|
7804
7805
|
}
|
|
7805
|
-
function
|
|
7806
|
+
function Nt(s, t) {
|
|
7806
7807
|
if (s == null)
|
|
7807
7808
|
return;
|
|
7808
7809
|
const e = [];
|
|
@@ -7840,14 +7841,14 @@ function vt(s, t) {
|
|
|
7840
7841
|
* https://opensource.org/licenses/MIT.
|
|
7841
7842
|
* =============================================================================
|
|
7842
7843
|
*/
|
|
7843
|
-
function
|
|
7844
|
+
function Pf(s) {
|
|
7844
7845
|
return s instanceof xe;
|
|
7845
7846
|
}
|
|
7846
7847
|
function Kn(s) {
|
|
7847
7848
|
return Array.isArray(s);
|
|
7848
7849
|
}
|
|
7849
7850
|
function oi(s) {
|
|
7850
|
-
return !
|
|
7851
|
+
return !Pf(s) && !Kn(s);
|
|
7851
7852
|
}
|
|
7852
7853
|
function li(s, t, e, n = !0, i = "") {
|
|
7853
7854
|
if (t == null || t.length === 0) {
|
|
@@ -7904,7 +7905,7 @@ function li(s, t, e, n = !0, i = "") {
|
|
|
7904
7905
|
}
|
|
7905
7906
|
return r;
|
|
7906
7907
|
}
|
|
7907
|
-
function
|
|
7908
|
+
function Uf(s, t, e) {
|
|
7908
7909
|
const n = jt(s.map((r) => r.shape[0]));
|
|
7909
7910
|
n.sort();
|
|
7910
7911
|
const i = jt(t.map((r) => r.shape[0]));
|
|
@@ -7915,9 +7916,9 @@ function Pf(s, t, e) {
|
|
|
7915
7916
|
if (n.length > 0 && i.length > 0 && !Ft(n, i))
|
|
7916
7917
|
throw new d(`Input Tensors should have the same number of samples as target Tensors. Found ${n[0]} input sample(s) and ${i[0]} target sample(s).`);
|
|
7917
7918
|
}
|
|
7918
|
-
function
|
|
7919
|
+
function Vf(s, t, e) {
|
|
7919
7920
|
const n = [
|
|
7920
|
-
|
|
7921
|
+
vn,
|
|
7921
7922
|
Sn,
|
|
7922
7923
|
Oe
|
|
7923
7924
|
];
|
|
@@ -7964,7 +7965,7 @@ function ui(s, t, e, n = !0, i = "") {
|
|
|
7964
7965
|
}
|
|
7965
7966
|
}
|
|
7966
7967
|
}
|
|
7967
|
-
function
|
|
7968
|
+
function jf(s, t) {
|
|
7968
7969
|
if (s == null || Array.isArray(s) && s.length === 0)
|
|
7969
7970
|
return t.map((n) => []);
|
|
7970
7971
|
let e;
|
|
@@ -7985,8 +7986,8 @@ function Vf(s, t) {
|
|
|
7985
7986
|
return n;
|
|
7986
7987
|
}
|
|
7987
7988
|
}
|
|
7988
|
-
const
|
|
7989
|
-
class we extends
|
|
7989
|
+
const Kf = "layers-model";
|
|
7990
|
+
class we extends vt {
|
|
7990
7991
|
constructor(t) {
|
|
7991
7992
|
super(t), this.isTraining = !1;
|
|
7992
7993
|
}
|
|
@@ -8028,7 +8029,7 @@ class we extends Nt {
|
|
|
8028
8029
|
summary(t, e, n = console.log) {
|
|
8029
8030
|
if (!this.built)
|
|
8030
8031
|
throw new d("This model has never been called, thus its weights have not been created yet. So no summary can be displayed. Build the model first (e.g., by calling it on some test data).");
|
|
8031
|
-
|
|
8032
|
+
Df(this, t, e, n);
|
|
8032
8033
|
}
|
|
8033
8034
|
/**
|
|
8034
8035
|
* Configures and prepares the model for training and evaluation. Compiling
|
|
@@ -8042,7 +8043,7 @@ class we extends Nt {
|
|
|
8042
8043
|
*/
|
|
8043
8044
|
compile(t) {
|
|
8044
8045
|
if (t.loss == null && (t.loss = []), this.loss = t.loss, typeof t.optimizer == "string")
|
|
8045
|
-
this.optimizer_ =
|
|
8046
|
+
this.optimizer_ = If(t.optimizer), this.isOptimizerOwned = !0;
|
|
8046
8047
|
else {
|
|
8047
8048
|
if (!(t.optimizer instanceof gu))
|
|
8048
8049
|
throw new d("User-defined optimizer must be an instance of tf.Optimizer.");
|
|
@@ -8080,7 +8081,7 @@ class we extends Nt {
|
|
|
8080
8081
|
this.outputs.length > 1 && (this.metricsTensors.push([o, a]), this.metricsNames.push(this.outputNames[a] + "_loss"));
|
|
8081
8082
|
}
|
|
8082
8083
|
});
|
|
8083
|
-
const i =
|
|
8084
|
+
const i = jf(t.metrics, this.outputNames), r = (a, o, l) => {
|
|
8084
8085
|
this.outputNames.length > 1 && (o = this.outputNames[a] + "_" + o), this.metricsNames.push(o), this.metricsTensors.push([l, a]);
|
|
8085
8086
|
};
|
|
8086
8087
|
le("metric", () => {
|
|
@@ -8093,11 +8094,11 @@ class we extends Nt {
|
|
|
8093
8094
|
for (const g of u) {
|
|
8094
8095
|
if (typeof g == "string" && ["accuracy", "acc", "crossentropy", "ce"].indexOf(g) !== -1) {
|
|
8095
8096
|
const m = this.internalOutputShapes[a];
|
|
8096
|
-
m[m.length - 1] === 1 || this.lossFunctions[a] === Sn ? ["accuracy", "acc"].indexOf(g) !== -1 ? p = Sr : ["crossentropy", "ce"].indexOf(g) !== -1 && (p =
|
|
8097
|
-
let
|
|
8098
|
-
["accuracy", "acc"].indexOf(g) !== -1 ?
|
|
8097
|
+
m[m.length - 1] === 1 || this.lossFunctions[a] === Sn ? ["accuracy", "acc"].indexOf(g) !== -1 ? p = Sr : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = bf) : this.lossFunctions[a] === dn ? ["accuracy", "acc"].indexOf(g) !== -1 ? p = yf : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = Ir) : ["accuracy", "acc"].indexOf(g) !== -1 ? p = Ar : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = Cr);
|
|
8098
|
+
let v;
|
|
8099
|
+
["accuracy", "acc"].indexOf(g) !== -1 ? v = "acc" : ["crossentropy", "ce"].indexOf(g) !== -1 && (v = "ce"), f = p, h = "" + v;
|
|
8099
8100
|
} else
|
|
8100
|
-
f =
|
|
8101
|
+
f = Cf(g), h = "" + tn(g);
|
|
8101
8102
|
let b;
|
|
8102
8103
|
le(h, () => {
|
|
8103
8104
|
b = f;
|
|
@@ -8160,7 +8161,7 @@ class we extends Nt {
|
|
|
8160
8161
|
const l = this.testFunction, u = this.testLoop(l, o, i, n.verbose, n.steps);
|
|
8161
8162
|
return ht(u);
|
|
8162
8163
|
} finally {
|
|
8163
|
-
|
|
8164
|
+
Nt(a[0], t), Nt(a[1], e);
|
|
8164
8165
|
}
|
|
8165
8166
|
}
|
|
8166
8167
|
// TODO(cais): Add code snippet below once real dataset objects are
|
|
@@ -8186,7 +8187,7 @@ class we extends Nt {
|
|
|
8186
8187
|
* @doc {heading: 'Models', subheading: 'Classes'}
|
|
8187
8188
|
*/
|
|
8188
8189
|
async evaluateDataset(t, e) {
|
|
8189
|
-
return this.makeTestFunction(),
|
|
8190
|
+
return this.makeTestFunction(), Gf(this, t, e);
|
|
8190
8191
|
}
|
|
8191
8192
|
/**
|
|
8192
8193
|
* Get number of samples provided for training, evaluation or prediction.
|
|
@@ -8239,7 +8240,7 @@ class we extends Nt {
|
|
|
8239
8240
|
* Retrieve the model's internal symbolic tensors from symbolic-tensor names.
|
|
8240
8241
|
*/
|
|
8241
8242
|
retrieveSymbolicTensors(t) {
|
|
8242
|
-
const e =
|
|
8243
|
+
const e = ue(null, t.length);
|
|
8243
8244
|
let n = t.length;
|
|
8244
8245
|
for (const i of this.layers) {
|
|
8245
8246
|
const r = Array.isArray(i.output) ? i.output : [i.output], a = r.map((o) => o.name);
|
|
@@ -8326,7 +8327,7 @@ class we extends Nt {
|
|
|
8326
8327
|
const i = e.batchSize == null ? 32 : e.batchSize;
|
|
8327
8328
|
return En(i), this.predictLoop(n, i);
|
|
8328
8329
|
} finally {
|
|
8329
|
-
|
|
8330
|
+
Nt(n, t);
|
|
8330
8331
|
}
|
|
8331
8332
|
}
|
|
8332
8333
|
/**
|
|
@@ -8357,7 +8358,7 @@ class we extends Nt {
|
|
|
8357
8358
|
const o = this.feedOutputShapes[a];
|
|
8358
8359
|
this.feedLossFns[a] === dn ? r.push(o.slice(0, o.length - 1).concat([1])) : r.push(o);
|
|
8359
8360
|
}
|
|
8360
|
-
if (t = li(t, this.feedInputNames, this.feedInputShapes, !1, "input"), e = li(e, this.feedOutputNames, r, !1, "target"),
|
|
8361
|
+
if (t = li(t, this.feedInputNames, this.feedInputShapes, !1, "input"), e = li(e, this.feedOutputNames, r, !1, "target"), Uf(t, e), Vf(e, this.feedLossFns, this.feedOutputShapes), this.stateful && i != null && i > 0 && t[0].shape[0] % i !== 0)
|
|
8361
8362
|
throw new d(`In a stateful network, you should only pass inputs with a number of samples that is divisible by the batch size ${i}. Found: ${t[0].shape[0]} sample(s).`);
|
|
8362
8363
|
return [t, e];
|
|
8363
8364
|
}
|
|
@@ -8395,13 +8396,13 @@ class we extends Nt {
|
|
|
8395
8396
|
{
|
|
8396
8397
|
const l = Ln(a, n), u = Rn(It(0, a));
|
|
8397
8398
|
for (let c = 0; c < l.length; ++c) {
|
|
8398
|
-
const h = l[c][0], p = l[c][1], f =
|
|
8399
|
+
const h = l[c][0], p = l[c][1], f = en(u, h, p - h), g = jn(e, f), b = t(g);
|
|
8399
8400
|
if (c === 0)
|
|
8400
8401
|
for (let m = 0; m < b.length; ++m)
|
|
8401
8402
|
o.push(tt(0));
|
|
8402
8403
|
for (let m = 0; m < b.length; ++m) {
|
|
8403
|
-
const
|
|
8404
|
-
o[m] = $(o[m], w(p - h,
|
|
8404
|
+
const v = b[m];
|
|
8405
|
+
o[m] = $(o[m], w(p - h, v));
|
|
8405
8406
|
}
|
|
8406
8407
|
}
|
|
8407
8408
|
for (let c = 0; c < o.length; ++c)
|
|
@@ -8443,18 +8444,18 @@ class we extends Nt {
|
|
|
8443
8444
|
let g;
|
|
8444
8445
|
for (let b = 0; b < this.lossFunctions.length; ++b) {
|
|
8445
8446
|
const m = this.lossFunctions[b];
|
|
8446
|
-
let
|
|
8447
|
-
r[b] != null && (
|
|
8448
|
-
const y = at(
|
|
8449
|
-
e.push(y), b === 0 ? g =
|
|
8447
|
+
let v = m(i[b], f[b]);
|
|
8448
|
+
r[b] != null && (v = Mf(v, r[b]));
|
|
8449
|
+
const y = at(v);
|
|
8450
|
+
e.push(y), b === 0 ? g = v : g = $(g, v);
|
|
8450
8451
|
}
|
|
8451
8452
|
for (let b = 0; b < this.metricsTensors.length; ++b) {
|
|
8452
8453
|
let m;
|
|
8453
8454
|
if (this.outputs.length > 1 && b < this.outputs.length)
|
|
8454
8455
|
m = e[b];
|
|
8455
8456
|
else {
|
|
8456
|
-
const
|
|
8457
|
-
m = at(
|
|
8457
|
+
const v = this.metricsTensors[b][0], y = this.metricsTensors[b][1];
|
|
8458
|
+
m = at(v(i[y], f[y]));
|
|
8458
8459
|
}
|
|
8459
8460
|
Bt(m), a.push(m);
|
|
8460
8461
|
}
|
|
@@ -8533,7 +8534,7 @@ class we extends Nt {
|
|
|
8533
8534
|
En(f);
|
|
8534
8535
|
const b = await this.standardizeUserData(t, e, n.sampleWeight, n.classWeight, !1, f);
|
|
8535
8536
|
i = b[0], r = b[1], p = b[2];
|
|
8536
|
-
let m = !1,
|
|
8537
|
+
let m = !1, v;
|
|
8537
8538
|
if (n.validationData != null && n.validationData.length > 0) {
|
|
8538
8539
|
if (m = !0, n.validationData.length === 2)
|
|
8539
8540
|
l = n.validationData[0], u = n.validationData[1];
|
|
@@ -8547,21 +8548,21 @@ class we extends Nt {
|
|
|
8547
8548
|
!0,
|
|
8548
8549
|
f
|
|
8549
8550
|
);
|
|
8550
|
-
c = R[0], h = R[1],
|
|
8551
|
+
c = R[0], h = R[1], v = c.concat(h);
|
|
8551
8552
|
} else if (n.validationSplit != null && n.validationSplit > 0 && n.validationSplit < 1) {
|
|
8552
8553
|
m = !0;
|
|
8553
8554
|
const E = Math.floor(i[0].shape[0] * (1 - n.validationSplit)), R = i[0].shape[0];
|
|
8554
|
-
c = Te(i, E, R), a = i, i = Te(i, 0, E), h = Te(r, E, R), o = r, r = Te(r, 0, E),
|
|
8555
|
+
c = Te(i, E, R), a = i, i = Te(i, 0, E), h = Te(r, E, R), o = r, r = Te(r, 0, E), v = c.concat(h);
|
|
8555
8556
|
} else n.validationSteps != null && (m = !0);
|
|
8556
8557
|
const y = i.concat(r).concat(p);
|
|
8557
8558
|
this.checkTrainableWeightsConsistency();
|
|
8558
|
-
const C = this.makeTrainFunction(),
|
|
8559
|
+
const C = this.makeTrainFunction(), N = this.getDedupedMetricsNames();
|
|
8559
8560
|
let I, z;
|
|
8560
|
-
m ? (this.makeTestFunction(), I = this.testFunction, z =
|
|
8561
|
+
m ? (this.makeTestFunction(), I = this.testFunction, z = N.slice().concat(N.map((E) => "val_" + E))) : (I = null, v = [], z = N.slice());
|
|
8561
8562
|
const _ = xr(n.callbacks, n.yieldEvery);
|
|
8562
|
-
return await this.fitLoop(C, y,
|
|
8563
|
+
return await this.fitLoop(C, y, N, f, n.epochs, n.verbose, _, I, v, n.shuffle, z, n.initialEpoch, null, null);
|
|
8563
8564
|
} finally {
|
|
8564
|
-
this.isTraining = !1,
|
|
8565
|
+
this.isTraining = !1, Nt(i, t), Nt(r, e), Nt(a, t), Nt(o, e), Nt(c, l), Nt(h, u), p != null && Z(p);
|
|
8565
8566
|
}
|
|
8566
8567
|
}
|
|
8567
8568
|
/**
|
|
@@ -8597,24 +8598,24 @@ class we extends Nt {
|
|
|
8597
8598
|
if (l != null && u != null && (b = !0), g != null && (b = !0, f == null))
|
|
8598
8599
|
throw new d("Can only use `validationSteps` when doing step-wise training, i.e., `stepsPerEpoch` must be set.");
|
|
8599
8600
|
const m = this.checkNumSamples(e, i, f, "steps_per_epoch");
|
|
8600
|
-
let
|
|
8601
|
-
m != null && (
|
|
8602
|
-
const { callbackList: y, history: C } =
|
|
8601
|
+
let v;
|
|
8602
|
+
m != null && (v = It(0, m)), a == null && (a = 1);
|
|
8603
|
+
const { callbackList: y, history: C } = Nr(o, a, r, p, m, f, i, b, h);
|
|
8603
8604
|
y.setModel(this), this.history = C, await y.onTrainBegin(), this.stopTraining_ = !1;
|
|
8604
|
-
for (let
|
|
8605
|
-
await y.onEpochBegin(
|
|
8605
|
+
for (let N = p; N < r; ++N) {
|
|
8606
|
+
await y.onEpochBegin(N);
|
|
8606
8607
|
const I = {};
|
|
8607
8608
|
if (f != null)
|
|
8608
8609
|
throw new G("stepsPerEpoch mode is not implemented yet.");
|
|
8609
8610
|
{
|
|
8610
8611
|
if (c === "batch")
|
|
8611
8612
|
throw new G("batch shuffling is not implemneted yet");
|
|
8612
|
-
c && bu(
|
|
8613
|
-
const z = Rn(
|
|
8613
|
+
c && bu(v);
|
|
8614
|
+
const z = Rn(v), _ = Ln(m, i);
|
|
8614
8615
|
for (let T = 0; T < _.length; ++T) {
|
|
8615
8616
|
const E = {};
|
|
8616
8617
|
if (await y.onBatchBegin(T, E), x(() => {
|
|
8617
|
-
const R = _[T][0], q = _[T][1], bt =
|
|
8618
|
+
const R = _[T][0], q = _[T][1], bt = en(z, R, q - R);
|
|
8618
8619
|
E.batch = T, E.size = q - R;
|
|
8619
8620
|
const ie = jn(e, bt), re = t(ie);
|
|
8620
8621
|
for (let xt = 0; xt < n.length; ++xt) {
|
|
@@ -8633,7 +8634,7 @@ class we extends Nt {
|
|
|
8633
8634
|
}
|
|
8634
8635
|
z.dispose();
|
|
8635
8636
|
}
|
|
8636
|
-
if (await y.onEpochEnd(
|
|
8637
|
+
if (await y.onEpochEnd(N, I), this.stopTraining_)
|
|
8637
8638
|
break;
|
|
8638
8639
|
}
|
|
8639
8640
|
return await y.onTrainEnd(), await this.history.syncData(), this.history;
|
|
@@ -8662,7 +8663,7 @@ class we extends Nt {
|
|
|
8662
8663
|
* @doc {heading: 'Models', subheading: 'Classes'}
|
|
8663
8664
|
*/
|
|
8664
8665
|
async fitDataset(t, e) {
|
|
8665
|
-
return
|
|
8666
|
+
return _f(this, t, e);
|
|
8666
8667
|
}
|
|
8667
8668
|
/**
|
|
8668
8669
|
* Runs a single gradient update on a single batch of data.
|
|
@@ -8693,7 +8694,7 @@ class we extends Nt {
|
|
|
8693
8694
|
const c = await u.data();
|
|
8694
8695
|
l.push(c[0]);
|
|
8695
8696
|
}
|
|
8696
|
-
return Z(o),
|
|
8697
|
+
return Z(o), Nt(n[0], t), Nt(n[1], e), ht(l);
|
|
8697
8698
|
}
|
|
8698
8699
|
/**
|
|
8699
8700
|
* Extract weight values of the model.
|
|
@@ -8925,7 +8926,7 @@ class we extends Nt {
|
|
|
8925
8926
|
throw new d("LayersModel.save() cannot proceed because the IOHandler provided does not have the `save` attribute defined.");
|
|
8926
8927
|
const n = await Os(this.getNamedWeights(e)), o = {
|
|
8927
8928
|
modelTopology: this.toJSON(null, !1),
|
|
8928
|
-
format:
|
|
8929
|
+
format: Kf,
|
|
8929
8930
|
generatedBy: `TensorFlow.js tfjs-layers v${zr}`,
|
|
8930
8931
|
convertedBy: null
|
|
8931
8932
|
};
|
|
@@ -9023,7 +9024,7 @@ class Re extends we {
|
|
|
9023
9024
|
if (t.inboundNodes.length === 0) {
|
|
9024
9025
|
if (t.batchInputShape == null)
|
|
9025
9026
|
throw new d("The first layer in a Sequential model must get an `inputShape` or `batchInputShape` argument.");
|
|
9026
|
-
const i =
|
|
9027
|
+
const i = Ph({
|
|
9027
9028
|
batchShape: t.batchInputShape,
|
|
9028
9029
|
dtype: t.dtype,
|
|
9029
9030
|
name: t.name + "_input"
|
|
@@ -9039,7 +9040,7 @@ class Re extends we {
|
|
|
9039
9040
|
throw new d("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");
|
|
9040
9041
|
this.checkShape(t), this.outputs = [t.inboundNodes[0].outputTensors[0]], this.inputs = dr(this.outputs[0]);
|
|
9041
9042
|
}
|
|
9042
|
-
this.inboundNodes = [], new
|
|
9043
|
+
this.inboundNodes = [], new Nn({
|
|
9043
9044
|
outboundLayer: this,
|
|
9044
9045
|
inboundLayers: [],
|
|
9045
9046
|
nodeIndices: [],
|
|
@@ -9047,7 +9048,7 @@ class Re extends we {
|
|
|
9047
9048
|
inputTensors: this.inputs,
|
|
9048
9049
|
outputTensors: this.outputs,
|
|
9049
9050
|
// no model-level masking for now
|
|
9050
|
-
inputMasks:
|
|
9051
|
+
inputMasks: ue(null, this.inputs.length),
|
|
9051
9052
|
outputMasks: [null],
|
|
9052
9053
|
inputShapes: this.inputs.map((i) => i.shape),
|
|
9053
9054
|
outputShapes: this.outputs[0].shape
|
|
@@ -9495,28 +9496,28 @@ class Mr extends ut {
|
|
|
9495
9496
|
* @return Output of the ELU activation.
|
|
9496
9497
|
*/
|
|
9497
9498
|
apply(t, e = 1) {
|
|
9498
|
-
return
|
|
9499
|
+
return Iu(t, e);
|
|
9499
9500
|
}
|
|
9500
9501
|
}
|
|
9501
9502
|
Mr.className = "elu";
|
|
9502
9503
|
S(Mr);
|
|
9503
9504
|
class Or extends ut {
|
|
9504
9505
|
apply(t) {
|
|
9505
|
-
return
|
|
9506
|
+
return lh(t);
|
|
9506
9507
|
}
|
|
9507
9508
|
}
|
|
9508
9509
|
Or.className = "selu";
|
|
9509
9510
|
S(Or);
|
|
9510
9511
|
class Rr extends ut {
|
|
9511
9512
|
apply(t) {
|
|
9512
|
-
return
|
|
9513
|
+
return Pe(t);
|
|
9513
9514
|
}
|
|
9514
9515
|
}
|
|
9515
9516
|
Rr.className = "relu";
|
|
9516
9517
|
S(Rr);
|
|
9517
9518
|
class _r extends ut {
|
|
9518
9519
|
apply(t) {
|
|
9519
|
-
return x(() =>
|
|
9520
|
+
return x(() => Zi(6, Pe(t)));
|
|
9520
9521
|
}
|
|
9521
9522
|
}
|
|
9522
9523
|
_r.className = "relu6";
|
|
@@ -9530,14 +9531,14 @@ Br.className = "linear";
|
|
|
9530
9531
|
S(Br);
|
|
9531
9532
|
class Wr extends ut {
|
|
9532
9533
|
apply(t) {
|
|
9533
|
-
return
|
|
9534
|
+
return Yn(t);
|
|
9534
9535
|
}
|
|
9535
9536
|
}
|
|
9536
9537
|
Wr.className = "sigmoid";
|
|
9537
9538
|
S(Wr);
|
|
9538
9539
|
class Gr extends ut {
|
|
9539
9540
|
apply(t) {
|
|
9540
|
-
return
|
|
9541
|
+
return Du(t);
|
|
9541
9542
|
}
|
|
9542
9543
|
}
|
|
9543
9544
|
Gr.className = "hardSigmoid";
|
|
@@ -9551,7 +9552,7 @@ Pr.className = "softplus";
|
|
|
9551
9552
|
S(Pr);
|
|
9552
9553
|
class Ur extends ut {
|
|
9553
9554
|
apply(t) {
|
|
9554
|
-
return
|
|
9555
|
+
return zu(t);
|
|
9555
9556
|
}
|
|
9556
9557
|
}
|
|
9557
9558
|
Ur.className = "softsign";
|
|
@@ -9563,7 +9564,7 @@ class Vr extends ut {
|
|
|
9563
9564
|
}
|
|
9564
9565
|
Vr.className = "tanh";
|
|
9565
9566
|
S(Vr);
|
|
9566
|
-
let
|
|
9567
|
+
let vs = class extends ut {
|
|
9567
9568
|
/**
|
|
9568
9569
|
* Calculate the activation function.
|
|
9569
9570
|
*
|
|
@@ -9580,8 +9581,8 @@ let Ns = class extends ut {
|
|
|
9580
9581
|
return tr(t, e);
|
|
9581
9582
|
}
|
|
9582
9583
|
};
|
|
9583
|
-
|
|
9584
|
-
S(
|
|
9584
|
+
vs.className = "softmax";
|
|
9585
|
+
S(vs);
|
|
9585
9586
|
class jr extends ut {
|
|
9586
9587
|
/**
|
|
9587
9588
|
* Calculate the activation function of log softmax:
|
|
@@ -9597,7 +9598,7 @@ class jr extends ut {
|
|
|
9597
9598
|
* @throws ValueError: In case `dim(x) < 2`.
|
|
9598
9599
|
*/
|
|
9599
9600
|
apply(t, e = -1) {
|
|
9600
|
-
return
|
|
9601
|
+
return Vc(t, e);
|
|
9601
9602
|
}
|
|
9602
9603
|
}
|
|
9603
9604
|
jr.className = "logSoftmax";
|
|
@@ -9611,7 +9612,7 @@ class Kr extends ut {
|
|
|
9611
9612
|
*/
|
|
9612
9613
|
apply(t) {
|
|
9613
9614
|
return x(() => x(() => {
|
|
9614
|
-
const e = Math.sqrt(2), n = w(0.5, $(1,
|
|
9615
|
+
const e = Math.sqrt(2), n = w(0.5, $(1, Bc(P(t, e))));
|
|
9615
9616
|
return w(t, n);
|
|
9616
9617
|
}));
|
|
9617
9618
|
}
|
|
@@ -9653,7 +9654,7 @@ class Zr extends ut {
|
|
|
9653
9654
|
* @returns a Tensor of the same shape as x
|
|
9654
9655
|
*/
|
|
9655
9656
|
apply(t, e = 1) {
|
|
9656
|
-
return x(() => w(
|
|
9657
|
+
return x(() => w(Yn(w(t, e)), t));
|
|
9657
9658
|
}
|
|
9658
9659
|
}
|
|
9659
9660
|
Zr.className = "swish";
|
|
@@ -9683,7 +9684,7 @@ function Qt(s) {
|
|
|
9683
9684
|
* https://opensource.org/licenses/MIT.
|
|
9684
9685
|
* =============================================================================
|
|
9685
9686
|
*/
|
|
9686
|
-
function
|
|
9687
|
+
function Hf(s) {
|
|
9687
9688
|
if (s != null && typeof s != "object")
|
|
9688
9689
|
throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an object, but received: ${s}`);
|
|
9689
9690
|
}
|
|
@@ -9691,7 +9692,7 @@ class Jr extends Be {
|
|
|
9691
9692
|
}
|
|
9692
9693
|
class Xr extends Jr {
|
|
9693
9694
|
constructor(t) {
|
|
9694
|
-
super(),
|
|
9695
|
+
super(), Hf(t), this.l1 = t == null || t.l1 == null ? 0.01 : t.l1, this.l2 = t == null || t.l2 == null ? 0.01 : t.l2, this.hasL1 = this.l1 !== 0, this.hasL2 = this.l2 !== 0;
|
|
9695
9696
|
}
|
|
9696
9697
|
/**
|
|
9697
9698
|
* Porting note: Renamed from __call__.
|
|
@@ -9700,7 +9701,7 @@ class Xr extends Jr {
|
|
|
9700
9701
|
apply(t) {
|
|
9701
9702
|
return x(() => {
|
|
9702
9703
|
let e = mt([1]);
|
|
9703
|
-
return this.hasL1 && (e = $(e, B(w(this.l1, Fe(t))))), this.hasL2 && (e = $(e, B(w(this.l2,
|
|
9704
|
+
return this.hasL1 && (e = $(e, B(w(this.l1, Fe(t))))), this.hasL2 && (e = $(e, B(w(this.l2, Ue(t))))), A(e, []);
|
|
9704
9705
|
});
|
|
9705
9706
|
}
|
|
9706
9707
|
getConfig() {
|
|
@@ -9745,7 +9746,7 @@ class Yr extends W {
|
|
|
9745
9746
|
}
|
|
9746
9747
|
call(t, e) {
|
|
9747
9748
|
t = O(t);
|
|
9748
|
-
let n =
|
|
9749
|
+
let n = Pe(t);
|
|
9749
9750
|
return this.maxValue != null && (n = Ct(n, 0, this.maxValue)), n;
|
|
9750
9751
|
}
|
|
9751
9752
|
computeOutputShape(t) {
|
|
@@ -9764,7 +9765,7 @@ class Qr extends W {
|
|
|
9764
9765
|
}
|
|
9765
9766
|
call(t, e) {
|
|
9766
9767
|
const n = O(t);
|
|
9767
|
-
return
|
|
9768
|
+
return Tu(n, this.alpha);
|
|
9768
9769
|
}
|
|
9769
9770
|
computeOutputShape(t) {
|
|
9770
9771
|
return t;
|
|
@@ -9804,7 +9805,7 @@ class ta extends W {
|
|
|
9804
9805
|
})], this.built = !0;
|
|
9805
9806
|
}
|
|
9806
9807
|
call(t, e) {
|
|
9807
|
-
return t = O(t),
|
|
9808
|
+
return t = O(t), $u(t, this.alpha.read());
|
|
9808
9809
|
}
|
|
9809
9810
|
getConfig() {
|
|
9810
9811
|
const t = {
|
|
@@ -9826,7 +9827,7 @@ class ea extends W {
|
|
|
9826
9827
|
}
|
|
9827
9828
|
call(t, e) {
|
|
9828
9829
|
const n = O(t);
|
|
9829
|
-
return
|
|
9830
|
+
return Eu(n);
|
|
9830
9831
|
}
|
|
9831
9832
|
computeOutputShape(t) {
|
|
9832
9833
|
return t;
|
|
@@ -9858,7 +9859,7 @@ na.className = "ThresholdedReLU";
|
|
|
9858
9859
|
S(na);
|
|
9859
9860
|
class sa extends W {
|
|
9860
9861
|
constructor(t) {
|
|
9861
|
-
super(t ?? {}), this.DEFAULT_AXIS = 1, t == null && (t = {}), this.softmax = new
|
|
9862
|
+
super(t ?? {}), this.DEFAULT_AXIS = 1, t == null && (t = {}), this.softmax = new vs().apply, this.axis = t.axis == null ? this.DEFAULT_AXIS : t.axis;
|
|
9862
9863
|
}
|
|
9863
9864
|
call(t, e) {
|
|
9864
9865
|
return x(() => {
|
|
@@ -9868,7 +9869,7 @@ class sa extends W {
|
|
|
9868
9869
|
const r = w(V(pe(n.shape), L(i, n.dtype)), tt(-1e9));
|
|
9869
9870
|
n = $(n, r);
|
|
9870
9871
|
}
|
|
9871
|
-
return this.axis instanceof Array ? this.axis.length > 1 ? Jt(V(n,
|
|
9872
|
+
return this.axis instanceof Array ? this.axis.length > 1 ? Jt(V(n, ju(n, this.axis, !0))) : this.softmax(n, this.axis[0]) : this.softmax(n, this.axis);
|
|
9872
9873
|
});
|
|
9873
9874
|
}
|
|
9874
9875
|
computeOutputShape(t) {
|
|
@@ -9892,12 +9893,12 @@ S(sa);
|
|
|
9892
9893
|
*/
|
|
9893
9894
|
function ke(s, t, e) {
|
|
9894
9895
|
if (typeof s == "number")
|
|
9895
|
-
return
|
|
9896
|
+
return ue(s, t);
|
|
9896
9897
|
if (s.length !== t)
|
|
9897
9898
|
throw new d(`The ${e} argument must be an integer or tuple of ${t} integers. Received: ${s.length} elements.`);
|
|
9898
9899
|
for (let n = 0; n < t; ++n) {
|
|
9899
9900
|
const i = s[n];
|
|
9900
|
-
if (!
|
|
9901
|
+
if (!Lu(i))
|
|
9901
9902
|
throw new d(`The ${e} argument must be an integer or tuple of ${t} integers. Received: ${JSON.stringify(s)} including a non-integer number ${i}`);
|
|
9902
9903
|
}
|
|
9903
9904
|
return s;
|
|
@@ -9913,7 +9914,7 @@ function $t(s, t, e, n) {
|
|
|
9913
9914
|
if (s == null)
|
|
9914
9915
|
return null;
|
|
9915
9916
|
if (n === "valid")
|
|
9916
|
-
s = s * t +
|
|
9917
|
+
s = s * t + Ht([e - t, 0]);
|
|
9917
9918
|
else if (n === "same")
|
|
9918
9919
|
s = s * t;
|
|
9919
9920
|
else
|
|
@@ -9935,7 +9936,7 @@ function Ss(s, t) {
|
|
|
9935
9936
|
function ia(s, t) {
|
|
9936
9937
|
return x(() => (et(t), t === "channelsFirst" ? j(s, [0, 2, 3, 4, 1]) : s));
|
|
9937
9938
|
}
|
|
9938
|
-
function
|
|
9939
|
+
function qf(s, t, e, n = 1, i = "valid", r, a = 1) {
|
|
9939
9940
|
return x(() => {
|
|
9940
9941
|
if (r == null && (r = ne()), et(r), s.shape.length !== 3)
|
|
9941
9942
|
throw new d(`The input of a conv1dWithBias operation should be 3, but is ${s.shape.length} instead.`);
|
|
@@ -9945,7 +9946,7 @@ function Hf(s, t, e, n = 1, i = "valid", r, a = 1) {
|
|
|
9945
9946
|
throw new d(`The bias for a conv1dWithBias operation should be 1, but is ${e.shape.length} instead`);
|
|
9946
9947
|
if (r === "channelsFirst" && (s = j(s, [0, 2, 1])), i === "causal")
|
|
9947
9948
|
throw new G("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");
|
|
9948
|
-
let o =
|
|
9949
|
+
let o = xc(s, t, n, i === "same" ? "same" : "valid", "NWC", a);
|
|
9949
9950
|
return e != null && (o = zt(o, e)), o;
|
|
9950
9951
|
});
|
|
9951
9952
|
}
|
|
@@ -9958,7 +9959,7 @@ function pi(s, t, e, n = [1, 1], i = "valid", r, a, o = null) {
|
|
|
9958
9959
|
let l = Ss(s, r);
|
|
9959
9960
|
if (i === "causal")
|
|
9960
9961
|
throw new G("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");
|
|
9961
|
-
return l =
|
|
9962
|
+
return l = yh({
|
|
9962
9963
|
x: l,
|
|
9963
9964
|
filter: t,
|
|
9964
9965
|
strides: n,
|
|
@@ -9970,7 +9971,7 @@ function pi(s, t, e, n = [1, 1], i = "valid", r, a, o = null) {
|
|
|
9970
9971
|
}), r === "channelsFirst" && (l = j(l, [0, 3, 1, 2])), l;
|
|
9971
9972
|
});
|
|
9972
9973
|
}
|
|
9973
|
-
function
|
|
9974
|
+
function Zf(s, t, e, n = [1, 1, 1], i = "valid", r, a) {
|
|
9974
9975
|
return x(() => {
|
|
9975
9976
|
if (r == null && (r = ne()), et(r), s.rank !== 4 && s.rank !== 5)
|
|
9976
9977
|
throw new d(`conv3dWithBias expects input to be of rank 4 or 5, but received ${s.rank}.`);
|
|
@@ -9979,7 +9980,7 @@ function qf(s, t, e, n = [1, 1, 1], i = "valid", r, a) {
|
|
|
9979
9980
|
let o = ia(s, r);
|
|
9980
9981
|
if (i === "causal")
|
|
9981
9982
|
throw new G("The support for CAUSAL padding mode in conv3dWithBias is not implemented yet.");
|
|
9982
|
-
return o =
|
|
9983
|
+
return o = Cc(o, t, n, i === "same" ? "same" : "valid", "NDHWC", a), e != null && (o = zt(o, e)), r === "channelsFirst" && (o = j(o, [0, 4, 1, 2, 3])), o;
|
|
9983
9984
|
});
|
|
9984
9985
|
}
|
|
9985
9986
|
class An extends W {
|
|
@@ -10001,7 +10002,7 @@ class An extends W {
|
|
|
10001
10002
|
}
|
|
10002
10003
|
}
|
|
10003
10004
|
static verifyArgs(t) {
|
|
10004
|
-
if (Ut("kernelSize" in t, "required key 'kernelSize' not in config"), typeof t.kernelSize != "number" && !
|
|
10005
|
+
if (Ut("kernelSize" in t, "required key 'kernelSize' not in config"), typeof t.kernelSize != "number" && !Qn(t.kernelSize, "number", 1, 3))
|
|
10005
10006
|
throw new d(`BaseConv expects config.kernelSize to be number or number[] with length 1, 2, or 3, but received ${JSON.stringify(t.kernelSize)}.`);
|
|
10006
10007
|
}
|
|
10007
10008
|
getConfig() {
|
|
@@ -10037,16 +10038,16 @@ class De extends An {
|
|
|
10037
10038
|
return x(() => {
|
|
10038
10039
|
t = O(t);
|
|
10039
10040
|
let n;
|
|
10040
|
-
const i = this.bias == null ? null : this.bias.read(), r =
|
|
10041
|
+
const i = this.bias == null ? null : this.bias.read(), r = ji(this.activation.getClassName());
|
|
10041
10042
|
if (r != null && this.rank === 2)
|
|
10042
10043
|
n = pi(t, this.kernel.read(), i, this.strides, this.padding, this.dataFormat, this.dilationRate, r);
|
|
10043
10044
|
else {
|
|
10044
10045
|
if (this.rank === 1)
|
|
10045
|
-
n =
|
|
10046
|
+
n = qf(t, this.kernel.read(), i, this.strides[0], this.padding, this.dataFormat, this.dilationRate[0]);
|
|
10046
10047
|
else if (this.rank === 2)
|
|
10047
10048
|
n = pi(t, this.kernel.read(), i, this.strides, this.padding, this.dataFormat, this.dilationRate);
|
|
10048
10049
|
else if (this.rank === 3)
|
|
10049
|
-
n =
|
|
10050
|
+
n = Zf(t, this.kernel.read(), i, this.strides, this.padding, this.dataFormat, this.dilationRate);
|
|
10050
10051
|
else
|
|
10051
10052
|
throw new G("convolutions greater than 3D are not implemented yet.");
|
|
10052
10053
|
this.activation != null && (n = this.activation.apply(n));
|
|
@@ -10087,7 +10088,7 @@ class Ze extends De {
|
|
|
10087
10088
|
return delete t.rank, t;
|
|
10088
10089
|
}
|
|
10089
10090
|
static verifyArgs(t) {
|
|
10090
|
-
if (typeof t.kernelSize != "number" && !
|
|
10091
|
+
if (typeof t.kernelSize != "number" && !Qn(t.kernelSize, "number", 1, 2))
|
|
10091
10092
|
throw new d(`Conv2D expects config.kernelSize to be number or number[] with length 1 or 2, but received ${JSON.stringify(t.kernelSize)}.`);
|
|
10092
10093
|
}
|
|
10093
10094
|
}
|
|
@@ -10132,8 +10133,8 @@ class ra extends Ze {
|
|
|
10132
10133
|
this.dataFormat === "channelsFirst" ? (a = 2, o = 3) : (a = 1, o = 2);
|
|
10133
10134
|
const l = i[a], u = i[o], c = this.kernelSize[0], h = this.kernelSize[1], p = this.strides[0], f = this.strides[1], g = $t(l, p, c, this.padding), b = $t(u, f, h, this.padding), m = [r, g, b, this.filters];
|
|
10134
10135
|
this.dataFormat !== "channelsLast" && (n = j(n, [0, 2, 3, 1]));
|
|
10135
|
-
let
|
|
10136
|
-
return this.dataFormat !== "channelsLast" && (
|
|
10136
|
+
let v = Sc(n, this.kernel.read(), m, this.strides, this.padding);
|
|
10137
|
+
return this.dataFormat !== "channelsLast" && (v = j(v, [0, 3, 1, 2])), this.bias != null && (v = zt(v, this.bias.read(), this.dataFormat)), this.activation != null && (v = this.activation.apply(v)), v;
|
|
10137
10138
|
});
|
|
10138
10139
|
}
|
|
10139
10140
|
computeOutputShape(t) {
|
|
@@ -10173,9 +10174,9 @@ class aa extends Je {
|
|
|
10173
10174
|
const i = n.shape, r = i[0];
|
|
10174
10175
|
let a, o, l;
|
|
10175
10176
|
this.dataFormat === "channelsFirst" ? (l = 2, a = 3, o = 4) : (l = 1, a = 2, o = 3);
|
|
10176
|
-
const u = i[l], c = i[a], h = i[o], p = this.kernelSize[0], f = this.kernelSize[1], g = this.kernelSize[2], b = this.strides[0], m = this.strides[1],
|
|
10177
|
+
const u = i[l], c = i[a], h = i[o], p = this.kernelSize[0], f = this.kernelSize[1], g = this.kernelSize[2], b = this.strides[0], m = this.strides[1], v = this.strides[2], y = $t(u, b, p, this.padding), C = $t(c, m, f, this.padding), N = $t(h, v, g, this.padding), I = [r, y, C, N, this.filters];
|
|
10177
10178
|
this.dataFormat !== "channelsLast" && (n = j(n, [0, 2, 3, 4, 1]));
|
|
10178
|
-
let z =
|
|
10179
|
+
let z = zc(n, this.kernel.read(), I, this.strides, this.padding);
|
|
10179
10180
|
return this.dataFormat !== "channelsLast" && (z = j(z, [0, 4, 1, 2, 3])), this.bias !== null && (z = zt(z, this.bias.read(), this.dataFormat)), this.activation !== null && (z = this.activation.apply(z)), z;
|
|
10180
10181
|
});
|
|
10181
10182
|
}
|
|
@@ -10223,7 +10224,7 @@ class oa extends De {
|
|
|
10223
10224
|
let n;
|
|
10224
10225
|
if (this.rank === 1)
|
|
10225
10226
|
throw new G("1D separable convolution is not implemented yet.");
|
|
10226
|
-
return this.rank === 2 && (this.dataFormat === "channelsFirst" && (t = j(t, [0, 2, 3, 1])), n =
|
|
10227
|
+
return this.rank === 2 && (this.dataFormat === "channelsFirst" && (t = j(t, [0, 2, 3, 1])), n = ch(t, this.depthwiseKernel.read(), this.pointwiseKernel.read(), this.strides, this.padding, this.dilationRate, "NHWC")), this.useBias && (n = zt(n, this.bias.read(), this.dataFormat)), this.activation != null && (n = this.activation.apply(n)), this.dataFormat === "channelsFirst" && (n = j(n, [0, 3, 1, 2])), n;
|
|
10227
10228
|
});
|
|
10228
10229
|
}
|
|
10229
10230
|
getConfig() {
|
|
@@ -10248,7 +10249,7 @@ class Cn extends De {
|
|
|
10248
10249
|
return delete t.rank, delete t.dataFormat, t;
|
|
10249
10250
|
}
|
|
10250
10251
|
static verifyArgs(t) {
|
|
10251
|
-
if (typeof t.kernelSize != "number" && !
|
|
10252
|
+
if (typeof t.kernelSize != "number" && !Qn(t.kernelSize, "number", 1, 1))
|
|
10252
10253
|
throw new d(`Conv1D expects config.kernelSize to be number or number[] with length 1, but received ${JSON.stringify(t.kernelSize)}.`);
|
|
10253
10254
|
}
|
|
10254
10255
|
}
|
|
@@ -10294,7 +10295,7 @@ ua.className = "Cropping2D";
|
|
|
10294
10295
|
S(ua);
|
|
10295
10296
|
class ca extends W {
|
|
10296
10297
|
constructor(t) {
|
|
10297
|
-
super(t), this.DEFAULT_SIZE = [2, 2], this.inputSpec = [{ ndim: 4 }], this.size = t.size == null ? this.DEFAULT_SIZE : t.size, this.dataFormat = t.dataFormat == null ? "channelsLast" : t.dataFormat, et(this.dataFormat), this.interpolation = t.interpolation == null ? "nearest" : t.interpolation,
|
|
10298
|
+
super(t), this.DEFAULT_SIZE = [2, 2], this.inputSpec = [{ ndim: 4 }], this.size = t.size == null ? this.DEFAULT_SIZE : t.size, this.dataFormat = t.dataFormat == null ? "channelsLast" : t.dataFormat, et(this.dataFormat), this.interpolation = t.interpolation == null ? "nearest" : t.interpolation, Fu(this.interpolation);
|
|
10298
10299
|
}
|
|
10299
10300
|
computeOutputShape(t) {
|
|
10300
10301
|
if (this.dataFormat === "channelsFirst") {
|
|
@@ -10339,7 +10340,7 @@ S(ca);
|
|
|
10339
10340
|
* https://opensource.org/licenses/MIT.
|
|
10340
10341
|
* =============================================================================
|
|
10341
10342
|
*/
|
|
10342
|
-
function
|
|
10343
|
+
function Jf(s, t, e = [1, 1], n = "valid", i, r) {
|
|
10343
10344
|
return x(() => {
|
|
10344
10345
|
i == null && (i = ne()), et(i);
|
|
10345
10346
|
let a = Ss(s, i);
|
|
@@ -10371,7 +10372,7 @@ class ha extends An {
|
|
|
10371
10372
|
call(t, e) {
|
|
10372
10373
|
return x(() => {
|
|
10373
10374
|
t = O(t);
|
|
10374
|
-
let n =
|
|
10375
|
+
let n = Jf(t, this.depthwiseKernel.read(), this.strides, this.padding, this.dataFormat, null);
|
|
10375
10376
|
return this.useBias && (n = zt(n, this.bias.read(), this.dataFormat)), this.activation != null && (n = this.activation.apply(n)), n;
|
|
10376
10377
|
});
|
|
10377
10378
|
}
|
|
@@ -10413,22 +10414,22 @@ function da(s, t, e, n = !1, i, r, a = !1, o = !1) {
|
|
|
10413
10414
|
if (l < 3)
|
|
10414
10415
|
throw new d(`Input should be at least 3D, but is ${l}D.`);
|
|
10415
10416
|
const u = [1, 0].concat(It(2, l));
|
|
10416
|
-
t = j(t, u), a && console.warn("Backend rnn(): the unroll = true option is not applicable to the imperative deeplearn.js backend."), i != null && (i = L(L(i, "bool"), "float32"), i.rank === l - 1 && (i =
|
|
10417
|
+
t = j(t, u), a && console.warn("Backend rnn(): the unroll = true option is not applicable to the imperative deeplearn.js backend."), i != null && (i = L(L(i, "bool"), "float32"), i.rank === l - 1 && (i = ce(i, -1)), i = j(i, u)), n && (t = on(t, 0), i != null && (i = on(i, 0)));
|
|
10417
10418
|
const c = [];
|
|
10418
10419
|
let h, p = e;
|
|
10419
|
-
const f = t.shape[0], g =
|
|
10420
|
+
const f = t.shape[0], g = nn(t);
|
|
10420
10421
|
let b;
|
|
10421
|
-
i != null && (b =
|
|
10422
|
-
for (let
|
|
10423
|
-
const y = g[
|
|
10422
|
+
i != null && (b = nn(i));
|
|
10423
|
+
for (let v = 0; v < f; ++v) {
|
|
10424
|
+
const y = g[v], C = x(() => s(y, p));
|
|
10424
10425
|
if (i == null)
|
|
10425
10426
|
h = C[0], p = C[1];
|
|
10426
10427
|
else {
|
|
10427
|
-
const
|
|
10428
|
-
const I = b[
|
|
10428
|
+
const N = x(() => {
|
|
10429
|
+
const I = b[v], z = V(Dt(I), I), _ = $(w(C[0], I), w(p[0], z)), T = p.map((E, R) => $(w(C[1][R], I), w(E, z)));
|
|
10429
10430
|
return { output: _, newStates: T };
|
|
10430
10431
|
});
|
|
10431
|
-
h =
|
|
10432
|
+
h = N.output, p = N.newStates;
|
|
10432
10433
|
}
|
|
10433
10434
|
o && c.push(h);
|
|
10434
10435
|
}
|
|
@@ -10642,9 +10643,9 @@ class In extends W {
|
|
|
10642
10643
|
}
|
|
10643
10644
|
class As extends In {
|
|
10644
10645
|
constructor(t) {
|
|
10645
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout =
|
|
10646
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout = Ne([1, Ht([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Ne([
|
|
10646
10647
|
1,
|
|
10647
|
-
|
|
10648
|
+
Ht([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
10648
10649
|
]), this.dropoutFunc = t.dropoutFunc, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
10649
10650
|
}
|
|
10650
10651
|
build(t) {
|
|
@@ -10726,9 +10727,9 @@ class Cs extends In {
|
|
|
10726
10727
|
constructor(t) {
|
|
10727
10728
|
if (super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", t.resetAfter)
|
|
10728
10729
|
throw new d("GRUCell does not support reset_after parameter set to true.");
|
|
10729
|
-
this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout =
|
|
10730
|
+
this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout = Ne([1, Ht([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Ne([
|
|
10730
10731
|
1,
|
|
10731
|
-
|
|
10732
|
+
Ht([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
10732
10733
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
10733
10734
|
}
|
|
10734
10735
|
build(t) {
|
|
@@ -10760,10 +10761,10 @@ class Cs extends In {
|
|
|
10760
10761
|
0 < this.dropout && this.dropout < 1 && (t = w(t, r[0]));
|
|
10761
10762
|
let c = St(t, this.kernel.read());
|
|
10762
10763
|
this.useBias && (c = zt(c, this.bias.read())), 0 < this.recurrentDropout && this.recurrentDropout < 1 && (i = w(i, a[0]));
|
|
10763
|
-
const h = this.recurrentKernel.read(), [p, f] = Kt(h, [2 * this.units, this.units], h.rank - 1), g = St(i, p), [b, m,
|
|
10764
|
+
const h = this.recurrentKernel.read(), [p, f] = Kt(h, [2 * this.units, this.units], h.rank - 1), g = St(i, p), [b, m, v] = Kt(c, 3, c.rank - 1), [y, C] = Kt(g, 2, g.rank - 1);
|
|
10764
10765
|
o = this.recurrentActivation.apply($(b, y)), l = this.recurrentActivation.apply($(m, C));
|
|
10765
|
-
const
|
|
10766
|
-
u = this.activation.apply($(
|
|
10766
|
+
const N = St(w(l, i), f);
|
|
10767
|
+
u = this.activation.apply($(v, N));
|
|
10767
10768
|
const I = $(w(o, i), w($(1, pt(o)), u));
|
|
10768
10769
|
return [I, I];
|
|
10769
10770
|
});
|
|
@@ -10814,9 +10815,9 @@ ma.className = "GRU";
|
|
|
10814
10815
|
S(ma);
|
|
10815
10816
|
class Dn extends In {
|
|
10816
10817
|
constructor(t) {
|
|
10817
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout =
|
|
10818
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout = Ne([1, Ht([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Ne([
|
|
10818
10819
|
1,
|
|
10819
|
-
|
|
10820
|
+
Ht([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
10820
10821
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = [this.units, this.units], this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
10821
10822
|
}
|
|
10822
10823
|
build(t) {
|
|
@@ -10869,8 +10870,8 @@ class Dn extends In {
|
|
|
10869
10870
|
0 < this.recurrentDropout && this.recurrentDropout < 1 && (i = w(i, o[0])), p = $(p, St(i, this.recurrentKernel.read())), this.useBias && (p = zt(p, this.bias.read()));
|
|
10870
10871
|
const [f, g, b, m] = Kt(p, 4, p.rank - 1);
|
|
10871
10872
|
l = this.recurrentActivation.apply(f), u = this.recurrentActivation.apply(g), c = $(w(u, r), w(l, this.activation.apply(b))), h = this.recurrentActivation.apply(m);
|
|
10872
|
-
const
|
|
10873
|
-
return [
|
|
10873
|
+
const v = w(h, this.activation.apply(c));
|
|
10874
|
+
return [v, v, c];
|
|
10874
10875
|
});
|
|
10875
10876
|
}
|
|
10876
10877
|
getConfig() {
|
|
@@ -11020,7 +11021,7 @@ class Is extends In {
|
|
|
11020
11021
|
Is.className = "StackedRNNCells";
|
|
11021
11022
|
S(Is);
|
|
11022
11023
|
function te(s) {
|
|
11023
|
-
const { ones: t, rate: e, training: n = !1, count: i = 1, dropoutFunc: r } = s, a = () => r != null ? r(t(), e) :
|
|
11024
|
+
const { ones: t, rate: e, training: n = !1, count: i = 1, dropoutFunc: r } = s, a = () => r != null ? r(t(), e) : Ki(t(), e), o = () => Ve(a, t, n);
|
|
11024
11025
|
return !i || i <= 1 ? Bt(o().clone()) : Array(i).fill(void 0).map(o).map((u) => Bt(u.clone()));
|
|
11025
11026
|
}
|
|
11026
11027
|
/**
|
|
@@ -11032,7 +11033,7 @@ function te(s) {
|
|
|
11032
11033
|
* https://opensource.org/licenses/MIT.
|
|
11033
11034
|
* =============================================================================
|
|
11034
11035
|
*/
|
|
11035
|
-
var
|
|
11036
|
+
var Xf = function(s, t) {
|
|
11036
11037
|
var e = {};
|
|
11037
11038
|
for (var n in s) Object.prototype.hasOwnProperty.call(s, n) && t.indexOf(n) < 0 && (e[n] = s[n]);
|
|
11038
11039
|
if (s != null && typeof Object.getOwnPropertySymbols == "function")
|
|
@@ -11121,7 +11122,7 @@ class Ds extends Dn {
|
|
|
11121
11122
|
l = new (e = class extends kt {
|
|
11122
11123
|
apply(p, f) {
|
|
11123
11124
|
const g = u.apply([c]), b = pe([c]), m = u.apply([c * 2]);
|
|
11124
|
-
return
|
|
11125
|
+
return ts([g, b, m]);
|
|
11125
11126
|
}
|
|
11126
11127
|
}, /** @nocollapse */
|
|
11127
11128
|
e.className = "CustomInit", e)();
|
|
@@ -11153,17 +11154,17 @@ class Ds extends Dn {
|
|
|
11153
11154
|
dropoutFunc: this.dropoutFunc
|
|
11154
11155
|
}));
|
|
11155
11156
|
const g = this.recurrentDropoutMask;
|
|
11156
|
-
let b = u(r, g, 0), m = u(r, g, 1),
|
|
11157
|
-
const C = 3, [
|
|
11158
|
-
c = this.inputConv(c,
|
|
11157
|
+
let b = u(r, g, 0), m = u(r, g, 1), v = u(r, g, 2), y = u(r, g, 3);
|
|
11158
|
+
const C = 3, [N, I, z, _] = Kt(this.kernel.read(), o, C), [T, E, R, q] = this.useBias ? Kt(this.bias.read(), o) : [null, null, null, null];
|
|
11159
|
+
c = this.inputConv(c, N, T, this.padding), h = this.inputConv(h, I, E, this.padding), p = this.inputConv(p, z, R, this.padding), f = this.inputConv(f, _, q, this.padding);
|
|
11159
11160
|
const [bt, ie, re, xt] = Kt(this.recurrentKernel.read(), o, C);
|
|
11160
|
-
b = this.recurrentConv(b, bt), m = this.recurrentConv(m, ie),
|
|
11161
|
-
const Tt = this.recurrentActivation.apply($(c, b)), me = this.recurrentActivation.apply($(h, m)), ze = $(w(me, a), w(Tt, this.activation.apply($(p,
|
|
11161
|
+
b = this.recurrentConv(b, bt), m = this.recurrentConv(m, ie), v = this.recurrentConv(v, re), y = this.recurrentConv(y, xt);
|
|
11162
|
+
const Tt = this.recurrentActivation.apply($(c, b)), me = this.recurrentActivation.apply($(h, m)), ze = $(w(me, a), w(Tt, this.activation.apply($(p, v)))), Ts = w(this.recurrentActivation.apply($(f, y)), this.activation.apply(ze));
|
|
11162
11163
|
return [Ts, Ts, ze];
|
|
11163
11164
|
});
|
|
11164
11165
|
}
|
|
11165
11166
|
getConfig() {
|
|
11166
|
-
const t = super.getConfig(), { units: e } = t, n =
|
|
11167
|
+
const t = super.getConfig(), { units: e } = t, n = Xf(t, ["units"]), i = {
|
|
11167
11168
|
filters: this.filters,
|
|
11168
11169
|
kernelSize: this.kernelSize,
|
|
11169
11170
|
padding: this.padding,
|
|
@@ -11222,7 +11223,7 @@ class zs extends W {
|
|
|
11222
11223
|
const n = O(t);
|
|
11223
11224
|
if (0 < this.rate && this.rate < 1) {
|
|
11224
11225
|
const i = e.training == null ? !1 : e.training, r = this.getNoiseShape(n);
|
|
11225
|
-
return
|
|
11226
|
+
return Ve(() => Ki(n, this.rate, r, this.seed), () => n, i);
|
|
11226
11227
|
}
|
|
11227
11228
|
return t;
|
|
11228
11229
|
});
|
|
@@ -11273,7 +11274,7 @@ class ka extends W {
|
|
|
11273
11274
|
call(t, e) {
|
|
11274
11275
|
return x(() => {
|
|
11275
11276
|
this.invokeCallHook(t, e);
|
|
11276
|
-
const n = O(t), i =
|
|
11277
|
+
const n = O(t), i = ji(this.activation.getClassName());
|
|
11277
11278
|
let r;
|
|
11278
11279
|
return i != null ? r = St(n, this.kernel.read(), i, this.bias ? this.bias.read() : null) : (r = St(n, this.kernel.read()), this.bias != null && (r = zt(r, this.bias.read())), this.activation != null && (r = this.activation.apply(r))), r;
|
|
11279
11280
|
});
|
|
@@ -11317,7 +11318,7 @@ class xa extends W {
|
|
|
11317
11318
|
i.push(r);
|
|
11318
11319
|
i.push(1), n = j(n, i);
|
|
11319
11320
|
}
|
|
11320
|
-
return
|
|
11321
|
+
return Mu(n);
|
|
11321
11322
|
});
|
|
11322
11323
|
}
|
|
11323
11324
|
getConfig() {
|
|
@@ -11329,7 +11330,7 @@ class xa extends W {
|
|
|
11329
11330
|
}
|
|
11330
11331
|
xa.className = "Flatten";
|
|
11331
11332
|
S(xa);
|
|
11332
|
-
class
|
|
11333
|
+
class Na extends W {
|
|
11333
11334
|
constructor(t) {
|
|
11334
11335
|
super(t), this.supportsMasking = !0, this.activation = Qt(t.activation);
|
|
11335
11336
|
}
|
|
@@ -11345,9 +11346,9 @@ class va extends W {
|
|
|
11345
11346
|
return Object.assign(t, e), t;
|
|
11346
11347
|
}
|
|
11347
11348
|
}
|
|
11348
|
-
|
|
11349
|
-
S(
|
|
11350
|
-
class
|
|
11349
|
+
Na.className = "Activation";
|
|
11350
|
+
S(Na);
|
|
11351
|
+
class va extends W {
|
|
11351
11352
|
constructor(t) {
|
|
11352
11353
|
super(t), this.n = t.n, this.inputSpec = [{ ndim: 2 }];
|
|
11353
11354
|
}
|
|
@@ -11355,7 +11356,7 @@ class Na extends W {
|
|
|
11355
11356
|
return [t[0], this.n, t[1]];
|
|
11356
11357
|
}
|
|
11357
11358
|
call(t, e) {
|
|
11358
|
-
return x(() => (t = O(t),
|
|
11359
|
+
return x(() => (t = O(t), Ou(t, this.n)));
|
|
11359
11360
|
}
|
|
11360
11361
|
getConfig() {
|
|
11361
11362
|
const t = {
|
|
@@ -11364,8 +11365,8 @@ class Na extends W {
|
|
|
11364
11365
|
return Object.assign(t, e), t;
|
|
11365
11366
|
}
|
|
11366
11367
|
}
|
|
11367
|
-
|
|
11368
|
-
S(
|
|
11368
|
+
va.className = "RepeatVector";
|
|
11369
|
+
S(va);
|
|
11369
11370
|
class Sa extends W {
|
|
11370
11371
|
constructor(t) {
|
|
11371
11372
|
super(t), this.targetShape = t.targetShape;
|
|
@@ -11635,7 +11636,7 @@ class fe extends W {
|
|
|
11635
11636
|
if (t = t, this.reshapeRequired) {
|
|
11636
11637
|
const n = [], i = t.map((r) => r.rank);
|
|
11637
11638
|
if (i.indexOf(null) === -1) {
|
|
11638
|
-
const r =
|
|
11639
|
+
const r = Ht(i);
|
|
11639
11640
|
for (let a of t) {
|
|
11640
11641
|
const o = a.rank;
|
|
11641
11642
|
for (let l = 0; l < r - o; ++l)
|
|
@@ -11699,10 +11700,10 @@ class fe extends W {
|
|
|
11699
11700
|
throw new d(`The Array 'inputs' and 'mask' are expected to have the same length, but have different lengths (${t.length} vs ${e.length})`);
|
|
11700
11701
|
if (e.every((i) => i == null))
|
|
11701
11702
|
return null;
|
|
11702
|
-
e = e.map((i) => i == null ? i :
|
|
11703
|
+
e = e.map((i) => i == null ? i : ce(i, 0));
|
|
11703
11704
|
let n = e[0];
|
|
11704
11705
|
for (let i = 1; i < e.length - 1; ++i)
|
|
11705
|
-
n =
|
|
11706
|
+
n = je(n, e[i]);
|
|
11706
11707
|
return n;
|
|
11707
11708
|
});
|
|
11708
11709
|
}
|
|
@@ -11775,7 +11776,7 @@ class Ea extends fe {
|
|
|
11775
11776
|
return x(() => {
|
|
11776
11777
|
let e = t[0];
|
|
11777
11778
|
for (let n = 1; n < t.length; ++n)
|
|
11778
|
-
e =
|
|
11779
|
+
e = Zi(e, t[n]);
|
|
11779
11780
|
return e;
|
|
11780
11781
|
});
|
|
11781
11782
|
}
|
|
@@ -11814,7 +11815,7 @@ class La extends fe {
|
|
|
11814
11815
|
throw new d("A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got input shapes: " + JSON.stringify(t));
|
|
11815
11816
|
}
|
|
11816
11817
|
mergeFunction(t) {
|
|
11817
|
-
return x(() =>
|
|
11818
|
+
return x(() => ts(t, this.axis));
|
|
11818
11819
|
}
|
|
11819
11820
|
computeOutputShape(t) {
|
|
11820
11821
|
if (!(Array.isArray(t) && Array.isArray(t[0])))
|
|
@@ -11849,9 +11850,9 @@ class La extends fe {
|
|
|
11849
11850
|
return null;
|
|
11850
11851
|
const i = [];
|
|
11851
11852
|
for (let a = 0; a < t.length; ++a)
|
|
11852
|
-
e[a] == null ? i.push(L(Dt(t[a]), "bool")) : e[a].rank < t[a].rank ? i.push(
|
|
11853
|
+
e[a] == null ? i.push(L(Dt(t[a]), "bool")) : e[a].rank < t[a].rank ? i.push(ce(e[a], -1)) : i.push(e[a]);
|
|
11853
11854
|
const r = is(i, this.axis);
|
|
11854
|
-
return
|
|
11855
|
+
return Xu(r, -1, !1);
|
|
11855
11856
|
});
|
|
11856
11857
|
}
|
|
11857
11858
|
getConfig() {
|
|
@@ -11868,7 +11869,7 @@ function $e(s, t) {
|
|
|
11868
11869
|
s += t;
|
|
11869
11870
|
return s;
|
|
11870
11871
|
}
|
|
11871
|
-
function
|
|
11872
|
+
function Yf(s, t, e) {
|
|
11872
11873
|
if (s.shape.length > 3 || t.shape.length > 3)
|
|
11873
11874
|
throw new G("batchDot is not implemented for tensors of 4D or higher rank yet");
|
|
11874
11875
|
if (k(s.shape.length >= 2, () => `batchDot requires the rank of x to be >= 2, but got ${s.shape.length}`), k(s.shape.length >= 2, () => `batchDot requires the rank of y to be >= 2, but got ${t.shape.length}`), typeof e == "number" && (e = [e, e]), s.dtype === "complex64" || t.dtype === "complex64")
|
|
@@ -11905,9 +11906,9 @@ function Xf(s, t, e) {
|
|
|
11905
11906
|
const u = [];
|
|
11906
11907
|
for (let c = l; c < l + a; ++c)
|
|
11907
11908
|
u.push(c);
|
|
11908
|
-
o =
|
|
11909
|
+
o = ns(o, u);
|
|
11909
11910
|
}
|
|
11910
|
-
return o.shape.length === 1 && (o =
|
|
11911
|
+
return o.shape.length === 1 && (o = ce(o, 1)), o;
|
|
11911
11912
|
});
|
|
11912
11913
|
}
|
|
11913
11914
|
class Fa extends fe {
|
|
@@ -11930,7 +11931,7 @@ class Fa extends fe {
|
|
|
11930
11931
|
return Array.isArray(this.axes) ? i = this.axes.map((r, a) => $e(r, t[a].shape.length)) : i = [
|
|
11931
11932
|
$e(this.axes, e.shape.length),
|
|
11932
11933
|
$e(this.axes, n.shape.length)
|
|
11933
|
-
], this.normalize && (e = pn(e, i[0]), n = pn(n, i[1])),
|
|
11934
|
+
], this.normalize && (e = pn(e, i[0]), n = pn(n, i[1])), Yf(e, n, i);
|
|
11934
11935
|
}
|
|
11935
11936
|
interpretAxes(t, e) {
|
|
11936
11937
|
let n;
|
|
@@ -11986,7 +11987,7 @@ class Ma extends W {
|
|
|
11986
11987
|
return x(() => {
|
|
11987
11988
|
this.invokeCallHook(t, e);
|
|
11988
11989
|
const n = O(t);
|
|
11989
|
-
return
|
|
11990
|
+
return Ve(() => $(bn(n.shape, 0, this.stddev), n), () => n, e.training || !1);
|
|
11990
11991
|
});
|
|
11991
11992
|
}
|
|
11992
11993
|
}
|
|
@@ -12007,7 +12008,7 @@ class Oa extends W {
|
|
|
12007
12008
|
return x(() => {
|
|
12008
12009
|
this.invokeCallHook(t, e);
|
|
12009
12010
|
const n = O(t);
|
|
12010
|
-
return this.rate > 0 && this.rate < 1 ?
|
|
12011
|
+
return this.rate > 0 && this.rate < 1 ? Ve(() => {
|
|
12011
12012
|
const r = Math.sqrt(this.rate / (1 - this.rate));
|
|
12012
12013
|
return w(n, bn(n.shape, 1, r));
|
|
12013
12014
|
}, () => n, e.training || !1) : n;
|
|
@@ -12034,9 +12035,9 @@ class Ra extends W {
|
|
|
12034
12035
|
return x(() => {
|
|
12035
12036
|
if (this.rate < 1 && this.rate > 0) {
|
|
12036
12037
|
const n = this._getNoiseShape(t);
|
|
12037
|
-
return
|
|
12038
|
+
return Ve(() => {
|
|
12038
12039
|
const r = O(t), o = -1.6732632423543772 * 1.0507009873554805;
|
|
12039
|
-
let l =
|
|
12040
|
+
let l = Ke(wn(n), this.rate);
|
|
12040
12041
|
l = Lt(l, "float32");
|
|
12041
12042
|
const u = ((1 - this.rate) * (1 + this.rate * o ** 2)) ** -0.5, c = -u * o * this.rate, h = $(w(r, l), w($(l, -1), o));
|
|
12042
12043
|
return $(w(h, u), c);
|
|
@@ -12060,22 +12061,22 @@ S(Ra);
|
|
|
12060
12061
|
function _e(s, t, e, n, i, r = 1e-3) {
|
|
12061
12062
|
let a;
|
|
12062
12063
|
if (s.rank === 2)
|
|
12063
|
-
a =
|
|
12064
|
+
a = fc(s, t, e, n, i, r);
|
|
12064
12065
|
else if (s.rank === 3)
|
|
12065
|
-
a =
|
|
12066
|
+
a = gc(s, t, e, n, i, r);
|
|
12066
12067
|
else if (s.rank === 4)
|
|
12067
|
-
a =
|
|
12068
|
+
a = yc(s, t, e, n, i, r);
|
|
12068
12069
|
else
|
|
12069
12070
|
throw new G(`batchNormalization is not implemented for array of rank ${s.rank} yet`);
|
|
12070
12071
|
return a;
|
|
12071
12072
|
}
|
|
12072
|
-
function
|
|
12073
|
+
function Qf(s, t, e, n, i = 1e-3) {
|
|
12073
12074
|
return x(() => {
|
|
12074
12075
|
const r = rs(s, n), a = r.mean, o = r.variance;
|
|
12075
12076
|
return [_e(s, a, o, e, t, i), a, o];
|
|
12076
12077
|
});
|
|
12077
12078
|
}
|
|
12078
|
-
function
|
|
12079
|
+
function tm(s, t, e, n, i = 1e-3) {
|
|
12079
12080
|
return x(() => {
|
|
12080
12081
|
const r = rs(s, n), a = r.mean, o = r.variance, l = [];
|
|
12081
12082
|
for (const g of It(0, s.rank))
|
|
@@ -12084,8 +12085,8 @@ function Qf(s, t, e, n, i = 1e-3) {
|
|
|
12084
12085
|
return [_e(s, u, c, p, h, i), a, o];
|
|
12085
12086
|
});
|
|
12086
12087
|
}
|
|
12087
|
-
function
|
|
12088
|
-
return Ft(n.slice().sort(), It(0, s.rank - 1)) ?
|
|
12088
|
+
function em(s, t, e, n, i = 1e-3) {
|
|
12089
|
+
return Ft(n.slice().sort(), It(0, s.rank - 1)) ? Qf(s, t, e, n, i) : tm(s, t, e, n, i);
|
|
12089
12090
|
}
|
|
12090
12091
|
class _a extends W {
|
|
12091
12092
|
constructor(t) {
|
|
@@ -12104,22 +12105,22 @@ class _a extends W {
|
|
|
12104
12105
|
return x(() => {
|
|
12105
12106
|
const n = e.training == null ? !1 : e.training, i = O(t), r = i.shape, a = r.length, o = It(0, a), l = this.axis >= 0 ? this.axis : this.axis + a;
|
|
12106
12107
|
o.splice(l, 1);
|
|
12107
|
-
const u =
|
|
12108
|
+
const u = ue(1, a);
|
|
12108
12109
|
u[l] = r[l];
|
|
12109
12110
|
const c = o.slice();
|
|
12110
12111
|
c.sort();
|
|
12111
12112
|
const h = !Ft(c, It(0, a).slice(0, a - 1)), p = () => {
|
|
12112
12113
|
if (h) {
|
|
12113
|
-
const y = A(this.movingMean.read(), u), C = A(this.movingVariance.read(), u),
|
|
12114
|
-
return _e(i, y, C,
|
|
12114
|
+
const y = A(this.movingMean.read(), u), C = A(this.movingVariance.read(), u), N = this.center ? A(this.beta.read(), u) : null, I = this.scale ? A(this.gamma.read(), u) : null;
|
|
12115
|
+
return _e(i, y, C, N, I, this.epsilon);
|
|
12115
12116
|
} else
|
|
12116
12117
|
return _e(i, this.movingMean.read(), this.movingVariance.read(), this.beta == null ? null : this.beta.read(), this.gamma == null ? null : this.gamma.read(), this.epsilon);
|
|
12117
12118
|
};
|
|
12118
12119
|
if (!n)
|
|
12119
12120
|
return p();
|
|
12120
|
-
const [f, g, b] =
|
|
12121
|
+
const [f, g, b] = em(i, this.gamma.read(), this.beta.read(), o, this.epsilon), m = (y, C, N) => {
|
|
12121
12122
|
x(() => {
|
|
12122
|
-
const I = 1 -
|
|
12123
|
+
const I = 1 - N, z = y.read(), _ = w(V(z, C), I);
|
|
12123
12124
|
y.write(V(z, _));
|
|
12124
12125
|
});
|
|
12125
12126
|
};
|
|
@@ -12180,7 +12181,7 @@ class Ba extends W {
|
|
|
12180
12181
|
const n = O(t), i = n.shape, r = i.length;
|
|
12181
12182
|
return x(() => {
|
|
12182
12183
|
let { mean: o, variance: l } = rs(n, this.axis, !0);
|
|
12183
|
-
const u =
|
|
12184
|
+
const u = ue(1, r);
|
|
12184
12185
|
for (const b of this.axis)
|
|
12185
12186
|
u[b] = i[b];
|
|
12186
12187
|
const c = (b) => b != null && b.shape.length !== r ? A(b, u) : b;
|
|
@@ -12216,7 +12217,7 @@ S(Ba);
|
|
|
12216
12217
|
* https://opensource.org/licenses/MIT.
|
|
12217
12218
|
* =============================================================================
|
|
12218
12219
|
*/
|
|
12219
|
-
function
|
|
12220
|
+
function nm(s, t, e) {
|
|
12220
12221
|
return x(() => {
|
|
12221
12222
|
if (s.rank !== 4)
|
|
12222
12223
|
throw new d(`temporalPadding expects input tensor to be 4-D, but received a ${s.rank}-D tensor.`);
|
|
@@ -12257,7 +12258,7 @@ class Wa extends W {
|
|
|
12257
12258
|
return this.dataFormat === "channelsFirst" ? (t[2] != null && t[2] >= 0 ? e = t[2] + this.padding[0][0] + this.padding[0][1] : e = null, t[3] != null && t[3] >= 0 ? n = t[3] + this.padding[1][0] + this.padding[1][1] : n = null, [t[0], t[1], e, n]) : (t[1] != null && t[1] >= 0 ? e = t[1] + this.padding[0][0] + this.padding[0][1] : e = null, t[2] != null && t[2] >= 0 ? n = t[2] + this.padding[1][0] + this.padding[1][1] : n = null, [t[0], e, n, t[3]]);
|
|
12258
12259
|
}
|
|
12259
12260
|
call(t, e) {
|
|
12260
|
-
return x(() =>
|
|
12261
|
+
return x(() => nm(O(t), this.padding, this.dataFormat));
|
|
12261
12262
|
}
|
|
12262
12263
|
getConfig() {
|
|
12263
12264
|
const t = {
|
|
@@ -12280,10 +12281,10 @@ S(Wa);
|
|
|
12280
12281
|
*/
|
|
12281
12282
|
function zn(s, t, e, n, i, r) {
|
|
12282
12283
|
return x(() => {
|
|
12283
|
-
et(i),
|
|
12284
|
+
et(i), Hi(r), gt(n), e == null && (e = [1, 1]), n == null && (n = "valid"), i == null && (i = ne()), r == null && (r = "max"), s = Ss(s, i);
|
|
12284
12285
|
let a;
|
|
12285
12286
|
const o = n === "same" ? "same" : "valid";
|
|
12286
|
-
return r === "max" ? a =
|
|
12287
|
+
return r === "max" ? a = qc(s, t, e, o) : a = rc(
|
|
12287
12288
|
// TODO(cais): Rank check?
|
|
12288
12289
|
s,
|
|
12289
12290
|
t,
|
|
@@ -12294,10 +12295,10 @@ function zn(s, t, e, n, i, r) {
|
|
|
12294
12295
|
}
|
|
12295
12296
|
function Ga(s, t, e, n, i, r) {
|
|
12296
12297
|
return x(() => {
|
|
12297
|
-
et(i),
|
|
12298
|
+
et(i), Hi(r), gt(n), e == null && (e = [1, 1, 1]), n == null && (n = "valid"), i == null && (i = ne()), r == null && (r = "max"), s = ia(s, i);
|
|
12298
12299
|
let a;
|
|
12299
12300
|
const o = n === "same" ? "same" : "valid";
|
|
12300
|
-
return r === "max" ? a =
|
|
12301
|
+
return r === "max" ? a = Jc(s, t, e, o) : a = oc(s, t, e, o), i === "channelsFirst" && (a = j(a, [0, 4, 1, 2, 3])), a;
|
|
12301
12302
|
});
|
|
12302
12303
|
}
|
|
12303
12304
|
class Pa extends W {
|
|
@@ -12333,7 +12334,7 @@ class Pa extends W {
|
|
|
12333
12334
|
return x(() => {
|
|
12334
12335
|
this.invokeCallHook(t, e), t = yn(O(t), 2);
|
|
12335
12336
|
const n = this.poolingFunction(O(t), [this.poolSize[0], 1], [this.strides[0], 1], this.padding, "channelsLast");
|
|
12336
|
-
return
|
|
12337
|
+
return ns(n, [2]);
|
|
12337
12338
|
});
|
|
12338
12339
|
}
|
|
12339
12340
|
getConfig() {
|
|
@@ -12496,7 +12497,7 @@ class Qa extends Xa {
|
|
|
12496
12497
|
call(t, e) {
|
|
12497
12498
|
return x(() => {
|
|
12498
12499
|
const n = O(t);
|
|
12499
|
-
return
|
|
12500
|
+
return ve(n, 1);
|
|
12500
12501
|
});
|
|
12501
12502
|
}
|
|
12502
12503
|
}
|
|
@@ -12531,7 +12532,7 @@ class no extends to {
|
|
|
12531
12532
|
call(t, e) {
|
|
12532
12533
|
return x(() => {
|
|
12533
12534
|
const n = O(t);
|
|
12534
|
-
return this.dataFormat === "channelsLast" ?
|
|
12535
|
+
return this.dataFormat === "channelsLast" ? ve(n, [1, 2]) : ve(n, [2, 3]);
|
|
12535
12536
|
});
|
|
12536
12537
|
}
|
|
12537
12538
|
}
|
|
@@ -12634,17 +12635,17 @@ class io extends so {
|
|
|
12634
12635
|
}
|
|
12635
12636
|
io.className = "TimeDistributed";
|
|
12636
12637
|
S(io);
|
|
12637
|
-
function
|
|
12638
|
-
Zn(
|
|
12638
|
+
function sm(s) {
|
|
12639
|
+
Zn(Ru, "BidirectionalMergeMode", s);
|
|
12639
12640
|
}
|
|
12640
|
-
const
|
|
12641
|
+
const im = "concat";
|
|
12641
12642
|
class ro extends so {
|
|
12642
12643
|
constructor(t) {
|
|
12643
12644
|
super(t);
|
|
12644
12645
|
const e = t.layer.getConfig(), n = {};
|
|
12645
12646
|
n.className = t.layer.getClassName(), n.config = e, this.forwardLayer = Wt(n), e.goBackwards = e.goBackwards !== !0;
|
|
12646
12647
|
const i = {};
|
|
12647
|
-
if (i.className = t.layer.getClassName(), i.config = e, this.backwardLayer = Wt(i), this.forwardLayer.name = "forward_" + this.forwardLayer.name, this.backwardLayer.name = "backward_" + this.backwardLayer.name, this.mergeMode = t.mergeMode === void 0 ?
|
|
12648
|
+
if (i.className = t.layer.getClassName(), i.config = e, this.backwardLayer = Wt(i), this.forwardLayer.name = "forward_" + this.forwardLayer.name, this.backwardLayer.name = "backward_" + this.backwardLayer.name, this.mergeMode = t.mergeMode === void 0 ? im : t.mergeMode, sm(this.mergeMode), t.weights)
|
|
12648
12649
|
throw new G("weights support is not implemented for Bidirectional layer yet.");
|
|
12649
12650
|
this._stateful = t.layer.stateful, this.returnSequences = t.layer.returnSequences, this.returnState = t.layer.returnState, this.supportsMasking = !0, this._trainable = !0, this.inputSpec = t.layer.inputSpec, this.numConstants = null;
|
|
12650
12651
|
}
|
|
@@ -12709,7 +12710,7 @@ class ro extends so {
|
|
|
12709
12710
|
let a;
|
|
12710
12711
|
this.returnState && (Array.isArray(i) && (a = i.slice(1).concat(r.slice(1))), i = i[0], r = r[0]), this.returnSequences && (r = on(r, 1));
|
|
12711
12712
|
let o;
|
|
12712
|
-
return this.mergeMode === "concat" ? o =
|
|
12713
|
+
return this.mergeMode === "concat" ? o = ts([i, r]) : this.mergeMode === "sum" ? o = $(i, r) : this.mergeMode === "ave" ? o = w(0.5, $(i, r)) : this.mergeMode === "mul" ? o = w(i, r) : this.mergeMode == null && (o = [i, r]), this.returnState ? this.mergeMode == null ? o.concat(a) : [o].concat(a) : o;
|
|
12713
12714
|
});
|
|
12714
12715
|
}
|
|
12715
12716
|
resetStates(t) {
|
|
@@ -12793,7 +12794,7 @@ S(ao);
|
|
|
12793
12794
|
* https://opensource.org/licenses/MIT.
|
|
12794
12795
|
* =============================================================================
|
|
12795
12796
|
*/
|
|
12796
|
-
const { resizeBilinear:
|
|
12797
|
+
const { resizeBilinear: rm, cropAndResize: am } = _t;
|
|
12797
12798
|
class oo extends W {
|
|
12798
12799
|
constructor(t) {
|
|
12799
12800
|
super(t), this.height = t.height, this.width = t.width;
|
|
@@ -12805,13 +12806,13 @@ class oo extends W {
|
|
|
12805
12806
|
t.rank === 3 ? (c = !0, u = kn([t])) : u = t;
|
|
12806
12807
|
for (let I = 0; I < u.shape[0]; I++)
|
|
12807
12808
|
m.push(b);
|
|
12808
|
-
const
|
|
12809
|
-
return c ? Lt(O(
|
|
12809
|
+
const v = Hu(m, [m.length, 4]), y = qu(0, m.length, 1, "int32"), N = am(u, v, y, [i, r], "nearest");
|
|
12810
|
+
return c ? Lt(O(nn(N)), l) : Lt(N, l);
|
|
12810
12811
|
});
|
|
12811
12812
|
}
|
|
12812
12813
|
upsize(t, e, n, i) {
|
|
12813
12814
|
return x(() => {
|
|
12814
|
-
const r =
|
|
12815
|
+
const r = rm(t, [e, n]);
|
|
12815
12816
|
return Lt(r, i);
|
|
12816
12817
|
});
|
|
12817
12818
|
}
|
|
@@ -12848,12 +12849,12 @@ S(oo);
|
|
|
12848
12849
|
* https://opensource.org/licenses/MIT.
|
|
12849
12850
|
* =============================================================================
|
|
12850
12851
|
*/
|
|
12851
|
-
function
|
|
12852
|
+
function om(s, t, e, n) {
|
|
12852
12853
|
let i = O(s);
|
|
12853
12854
|
if (i.dtype !== "int32" && (i = Lt(i, "int32")), t === "int")
|
|
12854
12855
|
return i;
|
|
12855
12856
|
const r = i.shape;
|
|
12856
|
-
if (i.rank === 0 && (i =
|
|
12857
|
+
if (i.rank === 0 && (i = ce(i, -1)), t === "oneHot" && i.shape[i.shape.length - 1] !== 1 && (i = ce(i, -1)), i.rank > 2)
|
|
12857
12858
|
throw new d(`When outputMode is not int, maximum output rank is 2 Received outputMode ${t} and input shape ${r} which would result in output rank ${i.rank}.`);
|
|
12858
12859
|
const a = ["multiHot", "oneHot"].includes(t), o = i;
|
|
12859
12860
|
let l;
|
|
@@ -12896,10 +12897,10 @@ class lo extends W {
|
|
|
12896
12897
|
Received countWeights=${e.countWeights}`);
|
|
12897
12898
|
n = O(e.countWeights);
|
|
12898
12899
|
}
|
|
12899
|
-
const i =
|
|
12900
|
+
const i = ve(t), r = Zu(t), a = Gt(this.numTokens, i).bufferSync().get(0), o = Ke(r, 0).bufferSync().get(0);
|
|
12900
12901
|
if (!(a && o))
|
|
12901
12902
|
throw new d(`Input values must be between 0 < values <= numTokens with numTokens=${this.numTokens}`);
|
|
12902
|
-
return
|
|
12903
|
+
return om(t, this.outputMode, this.numTokens, n);
|
|
12903
12904
|
});
|
|
12904
12905
|
}
|
|
12905
12906
|
}
|
|
@@ -12914,7 +12915,7 @@ S(lo);
|
|
|
12914
12915
|
* https://opensource.org/licenses/MIT.
|
|
12915
12916
|
* =============================================================================
|
|
12916
12917
|
*/
|
|
12917
|
-
const
|
|
12918
|
+
const lm = ["bilinear", "nearest"], di = new Set(lm);
|
|
12918
12919
|
class uo extends W {
|
|
12919
12920
|
constructor(t) {
|
|
12920
12921
|
if (super(t), this.height = t.height, this.width = t.width, t.interpolation)
|
|
@@ -13002,7 +13003,7 @@ ho.className = "BaseRandomLayer";
|
|
|
13002
13003
|
* https://opensource.org/licenses/MIT.
|
|
13003
13004
|
* =============================================================================
|
|
13004
13005
|
*/
|
|
13005
|
-
const
|
|
13006
|
+
const um = ["bilinear", "nearest"], fi = new Set(um);
|
|
13006
13007
|
class po extends ho {
|
|
13007
13008
|
constructor(t) {
|
|
13008
13009
|
super(t);
|
|
@@ -13061,50 +13062,32 @@ class po extends ho {
|
|
|
13061
13062
|
}
|
|
13062
13063
|
po.className = "RandomWidth";
|
|
13063
13064
|
S(po);
|
|
13064
|
-
class
|
|
13065
|
+
class _m extends Pu {
|
|
13065
13066
|
vocabSize;
|
|
13066
13067
|
embedDim;
|
|
13067
|
-
tiedWeights;
|
|
13068
13068
|
initializer;
|
|
13069
|
-
|
|
13070
|
-
|
|
13069
|
+
WEIGHTS;
|
|
13070
|
+
constructor(t, e, n) {
|
|
13071
|
+
super(t, n), this.WEIGHTS = e, this.vocabSize = t.gpt.vocabSize, this.embedDim = t.gpt.nEmbed, this.initializer = Qd({
|
|
13071
13072
|
mean: 0,
|
|
13072
13073
|
stddev: 0.02
|
|
13073
|
-
}), this.
|
|
13074
|
-
this.initializer.apply([this.vocabSize, this.embedDim]),
|
|
13075
|
-
!0,
|
|
13076
|
-
e || "tied_embedding"
|
|
13077
|
-
);
|
|
13078
|
-
}
|
|
13079
|
-
get variables() {
|
|
13080
|
-
return [this.tiedWeights];
|
|
13074
|
+
}), this.addVariable(this.WEIGHTS, Ji(this.initializer.apply([this.vocabSize, this.embedDim]), !0));
|
|
13081
13075
|
}
|
|
13082
13076
|
embed(t) {
|
|
13083
|
-
return Qi(this.
|
|
13077
|
+
return Qi(this.getVariable(this.WEIGHTS), t, 0);
|
|
13084
13078
|
}
|
|
13085
13079
|
project(t) {
|
|
13086
|
-
return St(t, this.
|
|
13080
|
+
return St(t, this.getVariable(this.WEIGHTS).transpose());
|
|
13087
13081
|
}
|
|
13088
|
-
|
|
13089
|
-
|
|
13090
|
-
|
|
13091
|
-
setWeights(t) {
|
|
13092
|
-
this.tiedWeights.assign(t[0]);
|
|
13093
|
-
}
|
|
13094
|
-
getConfig() {
|
|
13095
|
-
return {
|
|
13096
|
-
vocabSize: this.vocabSize,
|
|
13097
|
-
embedDim: this.embedDim
|
|
13098
|
-
};
|
|
13099
|
-
}
|
|
13100
|
-
dispose() {
|
|
13101
|
-
this.tiedWeights.dispose();
|
|
13082
|
+
// Dummy, should not be used.
|
|
13083
|
+
forward(t, e) {
|
|
13084
|
+
return this.project(e);
|
|
13102
13085
|
}
|
|
13103
13086
|
}
|
|
13104
13087
|
export {
|
|
13105
13088
|
zs as D,
|
|
13106
13089
|
Ia as E,
|
|
13107
|
-
|
|
13090
|
+
_m as T,
|
|
13108
13091
|
sr as p,
|
|
13109
|
-
|
|
13092
|
+
Qd as r
|
|
13110
13093
|
};
|