@genai-fi/nanogpt 0.4.3 → 0.4.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +3 -3
- package/dist/NanoGPTModel.js +7 -7
- package/dist/Reshape-CiAY8ltP.js +212 -0
- package/dist/TeachableLLM.js +7 -1
- package/dist/{TiedEmbedding-CnJ1bx4q.js → TiedEmbedding-DznFwzcB.js} +244 -244
- package/dist/{axis_util-BgTGy5w8.js → axis_util-QP0LdI1v.js} +1 -1
- package/dist/{concat-CuRsVY-K.js → concat-DvWM7HGZ.js} +1 -1
- package/dist/data/parquet.js +9 -6
- package/dist/data/textLoader.js +6 -5
- package/dist/{dropout-DfDdklfL.js → dropout-DFEXTPV0.js} +4 -4
- package/dist/{gather-ZYRWhmXR.js → gather-C5D8PxwA.js} +1 -1
- package/dist/gpgpu_math-CUzjlO9A.js +23 -0
- package/dist/{index-C4JCoBvj.js → index--6vO-cOz.js} +87 -87
- package/dist/{kernel_funcs_utils-CAd1h9X1.js → kernel_funcs_utils-C6YBCuOt.js} +72 -91
- package/dist/layers/CausalSelfAttention.js +8 -8
- package/dist/layers/MLP.js +31 -33
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/{log_sum_exp-BswFnwOb.js → log_sum_exp-CiEy1aUe.js} +7 -7
- package/dist/main.js +25 -19
- package/dist/{mat_mul-415y5Qn2.js → mat_mul-BEHRPMh0.js} +1 -1
- package/dist/{max-CP_9O2Yd.js → max-BUShNgfh.js} +1 -1
- package/dist/{moments-CjeIaVdp.js → moments-DYOHXoRV.js} +5 -5
- package/dist/{norm-CZM380I3.js → norm-DSva3hI3.js} +13 -13
- package/dist/{ones-Bf3YR48P.js → ones-D6kB8bdY.js} +2 -2
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +2 -2
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +4 -4
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.d.ts +1 -0
- package/dist/ops/cpu/matMulGelu.js +40 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +4 -4
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +24 -3
- package/dist/ops/grads/matMulGelu.d.ts +1 -0
- package/dist/ops/grads/matMulGelu.js +17 -0
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.d.ts +3 -0
- package/dist/ops/matMulGelu.js +14 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +689 -895
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.d.ts +20 -0
- package/dist/ops/webgl/matMulGelu.js +166 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{range-9AzeApCc.js → range-C_vpUjBu.js} +1 -1
- package/dist/{reshape-Boe4DuIO.js → reshape-z51Eu-re.js} +1 -1
- package/dist/{sin-KmhiDuMa.js → sin-H567uayl.js} +1 -1
- package/dist/{slice_util-19zDNNSn.js → slice_util-BdhYwFY_.js} +2 -2
- package/dist/{softmax-Cujsg4ay.js → softmax-Dsxflvdl.js} +1 -1
- package/dist/{split-DbcNm1-i.js → split-B_k_jwud.js} +1 -1
- package/dist/{stack-D1YjmgKN.js → stack-CmqSdsfs.js} +1 -1
- package/dist/{sum-R28pucR5.js → sum-DdkDf2MG.js} +1 -1
- package/dist/{tensor-BVeHdl7V.js → tensor-BGYi41cj.js} +1 -1
- package/dist/{tensor2d-DqFGNs_K.js → tensor2d-DUr_htjt.js} +1 -1
- package/dist/{tfjs_backend-Cug-PH75.js → tfjs_backend-DuKis_xG.js} +46 -46
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +18 -18
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +5 -5
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-LJT9Ld63.js → variable-BJTZ3jOy.js} +1 -1
- package/dist/{zeros-dnQxFgAD.js → zeros-8xl-W2DC.js} +1 -1
- package/package.json +1 -1
- package/dist/gelu-CnCt17Lk.js +0 -26
|
@@ -1,26 +1,26 @@
|
|
|
1
|
-
import { o as F, h as D, E as M, bb as fo, bc as mo, bd as mi, j as k,
|
|
2
|
-
import { s as ku, a as xu, g as
|
|
3
|
-
import { M as Gu, a as wn, f as Zi } from "./dropout-
|
|
4
|
-
import { z as mt } from "./zeros-
|
|
5
|
-
import { o as pe } from "./ones-
|
|
6
|
-
import { v as Ji } from "./variable-
|
|
7
|
-
import { r as A } from "./reshape-
|
|
8
|
-
import { s as B } from "./sum-
|
|
9
|
-
import { m as Ot } from "./mat_mul-
|
|
10
|
-
import { s as Kt } from "./split-
|
|
11
|
-
import { s as Pu, c as Xi } from "./sin-
|
|
12
|
-
import { g as Yi, d as ss, e as Ws, c as Uu } from "./axis_util-
|
|
13
|
-
import { a as Zt, e as Jt, l as Vu } from "./log_sum_exp-
|
|
14
|
-
import { s as kn } from "./stack-
|
|
15
|
-
import { p as ju } from "./slice_util-
|
|
16
|
-
import { c as is } from "./concat-
|
|
17
|
-
import { g as Qi } from "./gather-
|
|
18
|
-
import { m as at, a as rs } from "./moments-
|
|
19
|
-
import { s as tr } from "./softmax-
|
|
20
|
-
import { m as
|
|
21
|
-
import { t as Ku } from "./tensor-
|
|
22
|
-
import { r as Hu } from "./range-
|
|
23
|
-
import { m as qu } from "./norm-
|
|
1
|
+
import { o as F, h as D, E as M, bb as fo, bc as mo, bd as mi, j as k, b9 as Mn, be as gi, x as L, bf as bi, bg as yi, bh as wi, bi as ki, bj as xi, bk as vi, bl as Ni, bm as go, bn as Si, bo, bp as Ai, bq as yo, br as Ci, p as Hn, O as wt, bs as wo, bt as Ii, bu as Di, bv as zi, c as On, s as V, b as w, bw as ko, bx as Ti, by as $i, bz as xo, bA as Ei, bB as Li, bC as Fi, bD as Mi, bE as Oi, bF as Ri, bG as _i, bH as Bi, k as vo, ag as No, bI as Wi, bJ as So, a8 as $, bK as Ls, bL as Ao, bM as Co, bN as Io, bO as Do, bP as zo, A as To, bQ as $o, bR as Eo, bS as Lo, bT as S, t as x, f as tt, n as Gi, bU as Be, bV as We, ad as Ft, a as Z, L as Fo, bW as Mo, bX as Oo, a0 as ct, a3 as ee, ai as P, bY as Ro, bZ as _o, aP as lt, b_ as Bo, z as Q, b$ as Wo, c0 as Go, c1 as Po, c2 as Uo, c3 as Vo, c4 as jo, c5 as Ko, c6 as Ho, B as qo, c7 as Zo, c8 as Jo, c9 as Xo, as as Yo, ca as Qo, C as tl, $ as he, cb as el, V as nl, cc as sl, cd as il, ce as rl, av as al, cf as ol, a6 as ll, aw as ul, cg as cl, af as hl, ch as pl, G as dl, ay as fl, ci as ml, cj as gl, ck as bl, cl as yl, aA as wl, a7 as kl, cm as xl, cn as vl, co as Nl, M as Sl, cp as Al, cq as Cl, cr as Il, _ as Dl, a1 as zl, aF as Tl, cs as $l, a9 as El, ct as Ll, aD as Fl, P as Ml, cu as Ol, U as qn, aG as Rl, cv as _l, cw as Bl, T as Wl, aJ as Gl, aI as Pl, q as Ul, aX as Vl, cx as jl, aY as Kl, cy as Hl, aK as ql, at as Zl, ap as Jl, cz as Xl, W as Yl, aq as Ql, S as tu, u as eu, cA as nu, cB as su, cC as iu, aM as ru, cD as au, y as ou, cE as lu, a4 as uu, aO as cu, aN as hu, cF as Ie, cG as pu, g as du, cH as Fs, F as Bt, a2 as Fe, D as fu, w as mu, ah as xe, cI as gu, cJ as bu, m as Ms, cK as yu, cL as Os, cM as wu } from "./index--6vO-cOz.js";
|
|
2
|
+
import { s as ku, a as xu, g as vu, b as Nu, V as d, N as G, r as bn, e as Su, l as Au, c as Zn, f as et, h as ye, i as Ge, j as Jn, k as Pi, m as Ui, t as Rt, R as Et, n as ht, A as Pt, o as K, p as le, q as Xn, u as pt, w as Ht, v as Pe, x as Ue, y as Yn, z as j, B as Ee, C as Gt, D as Cu, E as be, F as en, G as Qn, H as ue, I as Ct, J as nt, K as Ve, L as Iu, M as Du, O as je, P as Ut, Q as Lt, S as ts, T as oe, U as jt, W as Xe, X as ce, Y as zu, Z as Rn, _ as nn, $ as yn, a0 as Vi, a1 as It, a2 as Rs, a3 as Tu, a4 as ji, a5 as $u, a6 as Eu, a7 as Lu, a8 as Fu, a9 as Mu, aa as qt, ab as Ou, ac as es, ad as zt, ae as Ye, af as Ru, ag as _t, ah as ot, ai as Ki, aj as gt, ak as ne, al as _s, am as ve, d as St, an as Bs, ao as Ke, ap as Hi, aq as ns, ar as _u, as as Bu, at as qi, au as Wu } from "./tfjs_backend-DuKis_xG.js";
|
|
3
|
+
import { M as Gu, a as wn, f as Zi } from "./dropout-DFEXTPV0.js";
|
|
4
|
+
import { z as mt } from "./zeros-8xl-W2DC.js";
|
|
5
|
+
import { o as pe } from "./ones-D6kB8bdY.js";
|
|
6
|
+
import { v as Ji } from "./variable-BJTZ3jOy.js";
|
|
7
|
+
import { r as A } from "./reshape-z51Eu-re.js";
|
|
8
|
+
import { s as B } from "./sum-DdkDf2MG.js";
|
|
9
|
+
import { m as Ot } from "./mat_mul-BEHRPMh0.js";
|
|
10
|
+
import { s as Kt } from "./split-B_k_jwud.js";
|
|
11
|
+
import { s as Pu, c as Xi } from "./sin-H567uayl.js";
|
|
12
|
+
import { g as Yi, d as ss, e as Ws, c as Uu } from "./axis_util-QP0LdI1v.js";
|
|
13
|
+
import { a as Zt, e as Jt, l as Vu } from "./log_sum_exp-CiEy1aUe.js";
|
|
14
|
+
import { s as kn } from "./stack-CmqSdsfs.js";
|
|
15
|
+
import { p as ju } from "./slice_util-BdhYwFY_.js";
|
|
16
|
+
import { c as is } from "./concat-DvWM7HGZ.js";
|
|
17
|
+
import { g as Qi } from "./gather-C5D8PxwA.js";
|
|
18
|
+
import { m as at, a as rs } from "./moments-DYOHXoRV.js";
|
|
19
|
+
import { s as tr } from "./softmax-Dsxflvdl.js";
|
|
20
|
+
import { m as Ne } from "./max-BUShNgfh.js";
|
|
21
|
+
import { t as Ku } from "./tensor-BGYi41cj.js";
|
|
22
|
+
import { r as Hu } from "./range-C_vpUjBu.js";
|
|
23
|
+
import { m as qu } from "./norm-DSva3hI3.js";
|
|
24
24
|
/**
|
|
25
25
|
* @license
|
|
26
26
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -108,7 +108,7 @@ function Qu(s, t, e, n, i, r, a = !1, o = "channelsLast") {
|
|
|
108
108
|
[l, h, u, c] = s;
|
|
109
109
|
else
|
|
110
110
|
throw new Error(`Unknown dataFormat ${o}`);
|
|
111
|
-
const [p, f, , g] = t, [b, m] = rn(e), [
|
|
111
|
+
const [p, f, , g] = t, [b, m] = rn(e), [N, y] = rn(n), C = _n(p, N), v = _n(f, y), { padInfo: I, outHeight: z, outWidth: _ } = nc(i, u, c, b, m, C, v, r, o), T = a ? g * h : g;
|
|
112
112
|
let E;
|
|
113
113
|
return o === "channelsFirst" ? E = [l, T, z, _] : o === "channelsLast" && (E = [l, z, _, T]), {
|
|
114
114
|
batchSize: l,
|
|
@@ -125,8 +125,8 @@ function Qu(s, t, e, n, i, r, a = !1, o = "channelsLast") {
|
|
|
125
125
|
filterHeight: p,
|
|
126
126
|
filterWidth: f,
|
|
127
127
|
effectiveFilterHeight: C,
|
|
128
|
-
effectiveFilterWidth:
|
|
129
|
-
dilationHeight:
|
|
128
|
+
effectiveFilterWidth: v,
|
|
129
|
+
dilationHeight: N,
|
|
130
130
|
dilationWidth: y,
|
|
131
131
|
inShape: s,
|
|
132
132
|
outShape: E,
|
|
@@ -156,8 +156,8 @@ function nc(s, t, e, n, i, r, a, o, l) {
|
|
|
156
156
|
c = f[0], h = f[1];
|
|
157
157
|
} else if (s === "same") {
|
|
158
158
|
c = Math.ceil(t / n), h = Math.ceil(e / i);
|
|
159
|
-
const p = Math.max(0, (c - 1) * n + r - t), f = Math.max(0, (h - 1) * i + a - e), g = Math.floor(p / 2), b = p - g, m = Math.floor(f / 2),
|
|
160
|
-
u = { top: g, bottom: b, left: m, right:
|
|
159
|
+
const p = Math.max(0, (c - 1) * n + r - t), f = Math.max(0, (h - 1) * i + a - e), g = Math.floor(p / 2), b = p - g, m = Math.floor(f / 2), N = f - m;
|
|
160
|
+
u = { top: g, bottom: b, left: m, right: N, type: "SAME" };
|
|
161
161
|
} else if (s === "valid")
|
|
162
162
|
u = { top: 0, bottom: 0, left: 0, right: 0, type: "VALID" }, c = Math.ceil((t - r + 1) / n), h = Math.ceil((e - a + 1) / i);
|
|
163
163
|
else if (typeof s == "object") {
|
|
@@ -419,15 +419,15 @@ function xc(s, t, e, n, i, r = "NHWC", a) {
|
|
|
419
419
|
t.rank === 3 && (u = !0, l = A(t, [1, t.shape[0], t.shape[1], t.shape[2]]), o = [1, s[0], s[1], s[2]]), k(o.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ${o.length}.`), k(l.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got rank ${l.rank}`), k(e.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got rank ${e.rank}`);
|
|
420
420
|
const c = r === "NHWC" ? o[3] : o[1], h = r === "NHWC" ? l.shape[3] : l.shape[1];
|
|
421
421
|
k(c === e.shape[2], () => `Error in conv2dDerInput: depth of input (${c}) must match input depth for filter ${e.shape[2]}.`), k(h === e.shape[3], () => `Error in conv2dDerInput: depth of output (${h}) must match output depth for filter ${e.shape[3]}.`), ft("conv2dDerInput", i, a);
|
|
422
|
-
const p = { dy: l, filter: e }, f = { strides: n, pad: i, dataFormat: r, dimRoundingMode: a, inputShape: o }, g = M.runKernel(
|
|
422
|
+
const p = { dy: l, filter: e }, f = { strides: n, pad: i, dataFormat: r, dimRoundingMode: a, inputShape: o }, g = M.runKernel(vi, p, f);
|
|
423
423
|
return u ? A(g, [g.shape[1], g.shape[2], g.shape[3]]) : g;
|
|
424
424
|
}
|
|
425
425
|
const ls = /* @__PURE__ */ F({ conv2DBackpropInput_: xc });
|
|
426
|
-
function
|
|
426
|
+
function vc(s, t, e, n, i, r) {
|
|
427
427
|
const a = D(s, "x", "conv2dTranspose"), o = D(t, "filter", "conv2dTranspose");
|
|
428
428
|
return ls(e, a, o, n, i, "NHWC", r);
|
|
429
429
|
}
|
|
430
|
-
const
|
|
430
|
+
const Nc = /* @__PURE__ */ F({ conv2dTranspose_: vc });
|
|
431
431
|
/**
|
|
432
432
|
* @license
|
|
433
433
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -448,7 +448,7 @@ function Sc(s, t, e, n, i = "NDHWC", r = [1, 1, 1]) {
|
|
|
448
448
|
const a = D(s, "x", "conv3d"), o = D(t, "filter", "conv3d");
|
|
449
449
|
let l = a, u = !1;
|
|
450
450
|
a.rank === 4 && (u = !0, l = A(a, [1, a.shape[0], a.shape[1], a.shape[2], a.shape[3]])), k(l.rank === 5, () => `Error in conv3d: input must be rank 5, but got rank ${l.rank}.`), k(o.rank === 5, () => `Error in conv3d: filter must be rank 5, but got rank ${o.rank}.`), k(l.shape[4] === o.shape[3], () => `Error in conv3d: depth of input (${l.shape[4]}) must match input depth for filter ${o.shape[3]}.`), k(de(e, r), () => `Error in conv3D: Either strides or dilations must be 1. Got strides ${e} and dilations '${r}'`), k(i === "NDHWC", () => `Error in conv3d: got dataFormat of ${i} but only NDHWC is currently supported.`), k(Ae(r), () => "Error in conv3D: Dilated rates should be larger than 0."), k(Ae(e), () => "Error in conv3D: Strides should be larger than 0.");
|
|
451
|
-
const c = { x: l, filter: o }, h = { strides: e, pad: n, dataFormat: i, dilations: r }, p = M.runKernel(
|
|
451
|
+
const c = { x: l, filter: o }, h = { strides: e, pad: n, dataFormat: i, dilations: r }, p = M.runKernel(Ni, c, h);
|
|
452
452
|
return u ? A(p, [p.shape[1], p.shape[2], p.shape[3], p.shape[4]]) : p;
|
|
453
453
|
}
|
|
454
454
|
const Ac = /* @__PURE__ */ F({ conv3d_: Sc });
|
|
@@ -704,7 +704,7 @@ function Pc(s, t = -1) {
|
|
|
704
704
|
if (t === -1 && (t = e.rank - 1), t !== e.rank - 1)
|
|
705
705
|
throw Error(`Log Softmax along a non-last dimension is not yet supported. Logits was rank ${e.rank} and axis was ${t}`);
|
|
706
706
|
return On((i, r) => {
|
|
707
|
-
const o =
|
|
707
|
+
const o = Ne(i, t, !0), l = V(i, o), u = V(L(l, "float32"), Zt(B(Jt(l), t, !0)));
|
|
708
708
|
return r([u]), { value: u, gradFunc: (h, p) => {
|
|
709
709
|
const [f] = p, g = !0, b = Jt(f);
|
|
710
710
|
return V(h, w(B(h, t, g), b));
|
|
@@ -1008,9 +1008,9 @@ const hh = /* @__PURE__ */ F({ sinh_: ch });
|
|
|
1008
1008
|
* =============================================================================
|
|
1009
1009
|
*/
|
|
1010
1010
|
function ph(s, t = 0, e = 1, n, i) {
|
|
1011
|
-
if (
|
|
1011
|
+
if (vo(s), n != null && n === "bool")
|
|
1012
1012
|
throw new Error("Unsupported data type $ { dtype }");
|
|
1013
|
-
const r = new Gu(t, e, n, !0, i), a =
|
|
1013
|
+
const r = new Gu(t, e, n, !0, i), a = No(s, n);
|
|
1014
1014
|
for (let o = 0; o < a.values.length; o++)
|
|
1015
1015
|
a.values[o] = r.nextValue();
|
|
1016
1016
|
return a.toTensor();
|
|
@@ -1094,8 +1094,8 @@ function gh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilat
|
|
|
1094
1094
|
const b = i === "NHWC" ? f.shape[3] : f.shape[1];
|
|
1095
1095
|
k(p.shape[2] === b, () => `Error in conv2d: depth of input (${b}) must match input depth for filter ${p.shape[2]}.`), k(de(e, r), () => `Error in conv2D: Either strides or dilations must be 1. Got strides ${e} and dilations '${r}'`);
|
|
1096
1096
|
const m = Qu(f.shape, p.shape, e, r, n, a);
|
|
1097
|
-
let
|
|
1098
|
-
o != null && (
|
|
1097
|
+
let N;
|
|
1098
|
+
o != null && (N = D(o, "bias", "fused conv2d"), [N] = Hn(N, h), i === "NHWC" ? wt(m.outShape, N.shape) : (k(N.shape.length <= 1, () => `Error in fused conv2d: only supports scalar or 1-D Tensor bias for NCHW format but got the bias of rank-${N.shape.length}.`), k(N.shape.length === 0 || N.shape[0] === m.outChannels || N.shape[0] === 1, () => `Error in fused conv2d: bias shape (${N.shape}) is not compatible with the number of output channels (${m.outChannels})`)));
|
|
1099
1099
|
let y;
|
|
1100
1100
|
if (u != null) {
|
|
1101
1101
|
const z = u.shape;
|
|
@@ -1112,18 +1112,18 @@ function gh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilat
|
|
|
1112
1112
|
}
|
|
1113
1113
|
const C = (z, _) => {
|
|
1114
1114
|
k(i === "NHWC", () => `Error in gradient of fused conv2D: got dataFormat of ${i} but only NHWC is currently supported.`);
|
|
1115
|
-
const [T, E, R, q] = _, bt =
|
|
1115
|
+
const [T, E, R, q] = _, bt = vu(z, R, l);
|
|
1116
1116
|
k(Se(r), () => `Error in gradient of fused conv2D: dilation rates greater than 1 are not yet supported in gradients. Got dilations '${r}'`);
|
|
1117
1117
|
const ie = ls(E.shape, bt, T, e, n), re = cs(E, bt, T.shape, e, n), xt = [ie, re];
|
|
1118
1118
|
if (q != null) {
|
|
1119
|
-
const Tt =
|
|
1119
|
+
const Tt = Nu(q, bt);
|
|
1120
1120
|
xt.push(Tt);
|
|
1121
1121
|
}
|
|
1122
1122
|
return xt;
|
|
1123
|
-
},
|
|
1123
|
+
}, v = {
|
|
1124
1124
|
x: f,
|
|
1125
1125
|
filter: p,
|
|
1126
|
-
bias:
|
|
1126
|
+
bias: N,
|
|
1127
1127
|
preluActivationWeights: y
|
|
1128
1128
|
}, I = {
|
|
1129
1129
|
strides: e,
|
|
@@ -1137,13 +1137,13 @@ function gh({ x: s, filter: t, strides: e, pad: n, dataFormat: i = "NHWC", dilat
|
|
|
1137
1137
|
return o == null ? On((_, T, E) => {
|
|
1138
1138
|
let R = (
|
|
1139
1139
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
1140
|
-
M.runKernel(Ls,
|
|
1140
|
+
M.runKernel(Ls, v, I)
|
|
1141
1141
|
);
|
|
1142
1142
|
return E([T, _, R]), g && (R = A(R, [R.shape[1], R.shape[2], R.shape[3]])), { value: R, gradFunc: C };
|
|
1143
1143
|
})(f, p) : On((_, T, E, R) => {
|
|
1144
|
-
let q = M.runKernel(Ls,
|
|
1144
|
+
let q = M.runKernel(Ls, v, I);
|
|
1145
1145
|
return R([T, _, q, E]), g && (q = A(q, [q.shape[1], q.shape[2], q.shape[3]])), { value: q, gradFunc: C };
|
|
1146
|
-
})(f, p,
|
|
1146
|
+
})(f, p, N);
|
|
1147
1147
|
}
|
|
1148
1148
|
const bh = /* @__PURE__ */ F({ fusedConv2d_: gh });
|
|
1149
1149
|
/**
|
|
@@ -1213,7 +1213,7 @@ const xh = F({ depthwiseConv2dNativeBackpropInput_: kh });
|
|
|
1213
1213
|
* limitations under the License.
|
|
1214
1214
|
* =============================================================================
|
|
1215
1215
|
*/
|
|
1216
|
-
class
|
|
1216
|
+
class vh {
|
|
1217
1217
|
/**
|
|
1218
1218
|
* Constructs a `tf.SGDOptimizer` that uses stochastic gradient descent.
|
|
1219
1219
|
*
|
|
@@ -1377,7 +1377,7 @@ class Nh {
|
|
|
1377
1377
|
* limitations under the License.
|
|
1378
1378
|
* =============================================================================
|
|
1379
1379
|
*/
|
|
1380
|
-
const ge =
|
|
1380
|
+
const ge = vh;
|
|
1381
1381
|
/**
|
|
1382
1382
|
* @license
|
|
1383
1383
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
@@ -1394,9 +1394,9 @@ const ge = Nh;
|
|
|
1394
1394
|
* limitations under the License.
|
|
1395
1395
|
* =============================================================================
|
|
1396
1396
|
*/
|
|
1397
|
-
const
|
|
1397
|
+
const Nh = typeof requestAnimationFrame < "u" ? requestAnimationFrame : typeof setImmediate < "u" ? setImmediate : (s) => s();
|
|
1398
1398
|
function Sh() {
|
|
1399
|
-
return new Promise((s) =>
|
|
1399
|
+
return new Promise((s) => Nh(() => s()));
|
|
1400
1400
|
}
|
|
1401
1401
|
/**
|
|
1402
1402
|
* @license
|
|
@@ -1998,7 +1998,7 @@ class Mt {
|
|
|
1998
1998
|
}
|
|
1999
1999
|
}
|
|
2000
2000
|
let Mh = 0;
|
|
2001
|
-
class
|
|
2001
|
+
class vn {
|
|
2002
2002
|
constructor(t, e) {
|
|
2003
2003
|
this.callArgs = e, this.id = Mh++, this.outboundLayer = t.outboundLayer, this.inboundLayers = t.inboundLayers, this.nodeIndices = t.nodeIndices, this.tensorIndices = t.tensorIndices, this.inputTensors = t.inputTensors, this.outputTensors = t.outputTensors, this.inputMasks = t.inputMasks, this.outputMasks = t.outputMasks, this.inputShapes = t.inputShapes, this.outputShapes = t.outputShapes;
|
|
2004
2004
|
for (const n of t.inboundLayers)
|
|
@@ -2587,7 +2587,7 @@ class W extends Be {
|
|
|
2587
2587
|
const u = [], c = [], h = [];
|
|
2588
2588
|
for (const p of l)
|
|
2589
2589
|
u.push(p.sourceLayer), c.push(p.nodeIndex), h.push(p.tensorIndex);
|
|
2590
|
-
new
|
|
2590
|
+
new vn({
|
|
2591
2591
|
outboundLayer: this,
|
|
2592
2592
|
inboundLayers: u,
|
|
2593
2593
|
nodeIndices: c,
|
|
@@ -2751,7 +2751,7 @@ class He extends W {
|
|
|
2751
2751
|
const n = t.dtype || "float32";
|
|
2752
2752
|
this.batchInputShape = e, this.dtype = n, this.inputSpec = [{ shape: e }];
|
|
2753
2753
|
const i = new Mt(this.dtype, this.batchInputShape, this, [], {}, this.name);
|
|
2754
|
-
i.nodeIndex = 0, i.tensorIndex = 0, new
|
|
2754
|
+
i.nodeIndex = 0, i.tensorIndex = 0, new vn({
|
|
2755
2755
|
outboundLayer: this,
|
|
2756
2756
|
inboundLayers: [],
|
|
2757
2757
|
nodeIndices: [],
|
|
@@ -2932,16 +2932,16 @@ function Le(s, t, e, n) {
|
|
|
2932
2932
|
const b = h[g], m = b.sourceLayer;
|
|
2933
2933
|
if (m instanceof He)
|
|
2934
2934
|
continue;
|
|
2935
|
-
const
|
|
2936
|
-
let
|
|
2935
|
+
const N = [], y = [], C = [];
|
|
2936
|
+
let v = !1;
|
|
2937
2937
|
for (const E of b.inputs) {
|
|
2938
2938
|
const R = f.getValue(E), q = f.getMask(E);
|
|
2939
|
-
|
|
2939
|
+
N.push(R), y.push(q), q != null && (v = !0), i || (p[E.name]--, p[E.name] === 0 && !t.hasKey(E) && o.indexOf(E.name) === -1 && !R.isDisposed && E.sourceLayer.stateful !== !0 && C.push(R));
|
|
2940
2940
|
}
|
|
2941
|
-
|
|
2942
|
-
const I = K(m.apply(
|
|
2941
|
+
v && (e = e || {}, e.mask = y[0]);
|
|
2942
|
+
const I = K(m.apply(N, e));
|
|
2943
2943
|
let z = null;
|
|
2944
|
-
m.supportsMasking && (z = m.computeMask(
|
|
2944
|
+
m.supportsMasking && (z = m.computeMask(N, y));
|
|
2945
2945
|
const _ = Kh(b), T = Array.isArray(_) ? _ : [_];
|
|
2946
2946
|
for (let E = 0; E < T.length; ++E) {
|
|
2947
2947
|
f.hasKey(T[E]) || f.add(T[E], I[E], Array.isArray(z) ? z[0] : z);
|
|
@@ -3713,7 +3713,7 @@ const wp = {
|
|
|
3713
3713
|
* =============================================================================
|
|
3714
3714
|
*/
|
|
3715
3715
|
const kp = {
|
|
3716
|
-
kernelName:
|
|
3716
|
+
kernelName: vi,
|
|
3717
3717
|
inputsToSave: ["dy", "filter"],
|
|
3718
3718
|
gradFunc: (s, t, e) => {
|
|
3719
3719
|
const [n, i] = t, { strides: r, pad: a, dataFormat: o, dimRoundingMode: l } = e;
|
|
@@ -3747,7 +3747,7 @@ function xp(s, t, e, n, i) {
|
|
|
3747
3747
|
const o = { x: r, dy: a }, l = { strides: n, pad: i, filterShape: e };
|
|
3748
3748
|
return M.runKernel(el, o, l);
|
|
3749
3749
|
}
|
|
3750
|
-
const
|
|
3750
|
+
const vp = /* @__PURE__ */ F({ conv3DBackpropFilter_: xp });
|
|
3751
3751
|
/**
|
|
3752
3752
|
* @license
|
|
3753
3753
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3764,8 +3764,8 @@ const Np = /* @__PURE__ */ F({ conv3DBackpropFilter_: xp });
|
|
|
3764
3764
|
* limitations under the License.
|
|
3765
3765
|
* =============================================================================
|
|
3766
3766
|
*/
|
|
3767
|
-
const
|
|
3768
|
-
kernelName:
|
|
3767
|
+
const Np = {
|
|
3768
|
+
kernelName: Ni,
|
|
3769
3769
|
inputsToSave: ["x", "filter"],
|
|
3770
3770
|
gradFunc: (s, t, e) => {
|
|
3771
3771
|
const { dilations: n, strides: i, pad: r } = e;
|
|
@@ -3773,7 +3773,7 @@ const vp = {
|
|
|
3773
3773
|
const [a, o] = t;
|
|
3774
3774
|
return {
|
|
3775
3775
|
x: () => er(a.shape, s, o, i, r),
|
|
3776
|
-
filter: () =>
|
|
3776
|
+
filter: () => vp(a, s, o.shape, i, r)
|
|
3777
3777
|
};
|
|
3778
3778
|
}
|
|
3779
3779
|
};
|
|
@@ -4106,29 +4106,29 @@ const Op = {
|
|
|
4106
4106
|
gradFunc: (s, t, e) => {
|
|
4107
4107
|
const { varianceEpsilon: n } = e, [i, r, a, o] = t, l = o ?? tt(1), u = lt(r.shape, i.shape), c = [];
|
|
4108
4108
|
if (r.rank === 1) {
|
|
4109
|
-
for (let
|
|
4110
|
-
c.push(i.shape[
|
|
4109
|
+
for (let v = 0; v < i.shape.length - 1; ++v)
|
|
4110
|
+
c.push(i.shape[v]);
|
|
4111
4111
|
c.push(1);
|
|
4112
4112
|
}
|
|
4113
4113
|
const h = V(i, r), p = w(s, l), f = rh($(a, tt(n))), g = w(w(w(f, f), f), tt(-0.5));
|
|
4114
4114
|
return {
|
|
4115
4115
|
x: () => r.rank === 1 ? A(w(w(s, Ee(A(f, [1, 1, 1, r.shape[0]]), c)), l), i.shape) : A(w(w(s, f), l), i.shape),
|
|
4116
4116
|
mean: () => {
|
|
4117
|
-
let
|
|
4118
|
-
return r.rank === 1 && (
|
|
4117
|
+
let v = w(w(f, tt(-1)), p);
|
|
4118
|
+
return r.rank === 1 && (v = B(v, u)), A(v, r.shape);
|
|
4119
4119
|
},
|
|
4120
4120
|
variance: () => {
|
|
4121
|
-
let
|
|
4122
|
-
return r.rank === 1 && (
|
|
4121
|
+
let v = w(w(g, h), p);
|
|
4122
|
+
return r.rank === 1 && (v = B(v, u)), A(v, r.shape);
|
|
4123
4123
|
},
|
|
4124
4124
|
scale: () => {
|
|
4125
|
-
const
|
|
4126
|
-
let I = w(s,
|
|
4125
|
+
const v = w(h, f);
|
|
4126
|
+
let I = w(s, v);
|
|
4127
4127
|
return r.rank === 1 && (I = B(I, u)), A(I, r.shape);
|
|
4128
4128
|
},
|
|
4129
4129
|
offset: () => {
|
|
4130
|
-
let
|
|
4131
|
-
return r.rank === 1 && (
|
|
4130
|
+
let v = s;
|
|
4131
|
+
return r.rank === 1 && (v = B(v, u)), A(v, r.shape);
|
|
4132
4132
|
}
|
|
4133
4133
|
};
|
|
4134
4134
|
}
|
|
@@ -4154,11 +4154,11 @@ const Rp = {
|
|
|
4154
4154
|
inputsToSave: ["x", "indices"],
|
|
4155
4155
|
gradFunc: (s, t, e) => {
|
|
4156
4156
|
const [n, i] = t, { axis: r, batchDims: a } = e, o = he(r, n.shape)[0], l = (u, c, h) => () => {
|
|
4157
|
-
const p = u.shape, f = c.size, g = p.slice(0, o), b = g.length, m = p.slice(r, p.length).slice(1),
|
|
4157
|
+
const p = u.shape, f = c.size, g = p.slice(0, o), b = g.length, m = p.slice(r, p.length).slice(1), N = m.length, y = qs(0, b), C = qs(b + 1, b + 1 + N), v = Zs([
|
|
4158
4158
|
g,
|
|
4159
4159
|
[f],
|
|
4160
4160
|
m
|
|
4161
|
-
]), I = A(h,
|
|
4161
|
+
]), I = A(h, v), z = A(c, [f]), _ = Zs([[b], y, C]), T = j(I, _);
|
|
4162
4162
|
let E = fh(T, z, u.shape[o]);
|
|
4163
4163
|
const R = ss(_);
|
|
4164
4164
|
return E = j(E, R), E;
|
|
@@ -4407,7 +4407,7 @@ const Kp = {
|
|
|
4407
4407
|
*/
|
|
4408
4408
|
function Hp(s, t, e, n = 5, i = 1, r = 1, a = 0.5) {
|
|
4409
4409
|
const o = { x: s, y: t, dy: e }, l = { depthRadius: n, bias: i, alpha: r, beta: a };
|
|
4410
|
-
return M.runKernel(
|
|
4410
|
+
return M.runKernel(vl, o, l);
|
|
4411
4411
|
}
|
|
4412
4412
|
const qp = F({ localResponseNormalizationBackprop_: Hp });
|
|
4413
4413
|
/**
|
|
@@ -4427,7 +4427,7 @@ const qp = F({ localResponseNormalizationBackprop_: Hp });
|
|
|
4427
4427
|
* =============================================================================
|
|
4428
4428
|
*/
|
|
4429
4429
|
const Zp = {
|
|
4430
|
-
kernelName:
|
|
4430
|
+
kernelName: Nl,
|
|
4431
4431
|
inputsToSave: ["x"],
|
|
4432
4432
|
outputsToSave: [!0],
|
|
4433
4433
|
gradFunc: (s, t, e) => {
|
|
@@ -5132,7 +5132,7 @@ const xd = {
|
|
|
5132
5132
|
* limitations under the License.
|
|
5133
5133
|
* =============================================================================
|
|
5134
5134
|
*/
|
|
5135
|
-
const
|
|
5135
|
+
const vd = {
|
|
5136
5136
|
kernelName: Ul,
|
|
5137
5137
|
inputsToSave: ["x"],
|
|
5138
5138
|
gradFunc: (s, t) => {
|
|
@@ -5156,7 +5156,7 @@ const Nd = {
|
|
|
5156
5156
|
* limitations under the License.
|
|
5157
5157
|
* =============================================================================
|
|
5158
5158
|
*/
|
|
5159
|
-
const
|
|
5159
|
+
const Nd = {
|
|
5160
5160
|
kernelName: Vl,
|
|
5161
5161
|
inputsToSave: ["images"],
|
|
5162
5162
|
gradFunc: (s, t, e) => {
|
|
@@ -5930,7 +5930,7 @@ const Xd = [
|
|
|
5930
5930
|
yp,
|
|
5931
5931
|
kp,
|
|
5932
5932
|
wp,
|
|
5933
|
-
|
|
5933
|
+
Np,
|
|
5934
5934
|
Sp,
|
|
5935
5935
|
Ap,
|
|
5936
5936
|
Cp,
|
|
@@ -5979,8 +5979,8 @@ const Xd = [
|
|
|
5979
5979
|
wd,
|
|
5980
5980
|
kd,
|
|
5981
5981
|
xd,
|
|
5982
|
-
Nd,
|
|
5983
5982
|
vd,
|
|
5983
|
+
Nd,
|
|
5984
5984
|
Sd,
|
|
5985
5985
|
Ad,
|
|
5986
5986
|
Cd,
|
|
@@ -6427,7 +6427,7 @@ class yt {
|
|
|
6427
6427
|
}
|
|
6428
6428
|
}
|
|
6429
6429
|
yt.constructors = {};
|
|
6430
|
-
function
|
|
6430
|
+
function vr(s, t, e, n, i, r, a, o, l) {
|
|
6431
6431
|
const u = new nf(), c = [
|
|
6432
6432
|
new ef(),
|
|
6433
6433
|
...yt.createCallbacks(t)
|
|
@@ -6473,13 +6473,13 @@ function pn(s, t) {
|
|
|
6473
6473
|
return P(s, i);
|
|
6474
6474
|
});
|
|
6475
6475
|
}
|
|
6476
|
-
function
|
|
6476
|
+
function Nn(s, t) {
|
|
6477
6477
|
return x(() => at(je(V(t, s)), -1));
|
|
6478
6478
|
}
|
|
6479
6479
|
function xs(s, t) {
|
|
6480
6480
|
return x(() => at(Fe(V(t, s)), -1));
|
|
6481
6481
|
}
|
|
6482
|
-
function
|
|
6482
|
+
function vs(s, t) {
|
|
6483
6483
|
return x(() => {
|
|
6484
6484
|
const e = V(s, t), n = Ct(Fe(s), nt(), Number.MAX_VALUE), i = Fe(P(e, n));
|
|
6485
6485
|
return w(100, at(i, -1));
|
|
@@ -6505,7 +6505,7 @@ function of(s, t) {
|
|
|
6505
6505
|
}
|
|
6506
6506
|
function lf(s, t) {
|
|
6507
6507
|
return x(() => {
|
|
6508
|
-
const e = B(w(s, t), -1), n =
|
|
6508
|
+
const e = B(w(s, t), -1), n = Ne(w(V(1, s), t), -1);
|
|
6509
6509
|
return Ie(0, $(1, V(n, e)));
|
|
6510
6510
|
});
|
|
6511
6511
|
}
|
|
@@ -6560,16 +6560,16 @@ function pf(s, t) {
|
|
|
6560
6560
|
return at(V(t, w(s, e)), -1);
|
|
6561
6561
|
});
|
|
6562
6562
|
}
|
|
6563
|
-
function
|
|
6563
|
+
function Nr(s, t) {
|
|
6564
6564
|
return x(() => {
|
|
6565
6565
|
const e = pn(s, -1), n = pn(t, -1), i = w(e, n);
|
|
6566
6566
|
return pt(B(i, -1));
|
|
6567
6567
|
});
|
|
6568
6568
|
}
|
|
6569
6569
|
const fn = {
|
|
6570
|
-
meanSquaredError:
|
|
6570
|
+
meanSquaredError: Nn,
|
|
6571
6571
|
meanAbsoluteError: xs,
|
|
6572
|
-
meanAbsolutePercentageError:
|
|
6572
|
+
meanAbsolutePercentageError: vs,
|
|
6573
6573
|
meanSquaredLogarithmicError: rf,
|
|
6574
6574
|
squaredHinge: af,
|
|
6575
6575
|
hinge: of,
|
|
@@ -6580,7 +6580,7 @@ const fn = {
|
|
|
6580
6580
|
binaryCrossentropy: Sn,
|
|
6581
6581
|
kullbackLeiblerDivergence: hf,
|
|
6582
6582
|
poisson: pf,
|
|
6583
|
-
cosineProximity:
|
|
6583
|
+
cosineProximity: Nr
|
|
6584
6584
|
};
|
|
6585
6585
|
function $n(s) {
|
|
6586
6586
|
if (typeof s == "string") {
|
|
@@ -6627,7 +6627,7 @@ function gf(s, t) {
|
|
|
6627
6627
|
function bf(s, t) {
|
|
6628
6628
|
return s.rank === t.rank && (s = ts(s, [s.rank - 1])), t = sn(t, -1), t.dtype !== s.dtype && (t = L(t, s.dtype)), L(Xt(s, t), "float32");
|
|
6629
6629
|
}
|
|
6630
|
-
const yf =
|
|
6630
|
+
const yf = Nn, wf = Nn, kf = xs, xf = xs, vf = vs, Nf = vs, Cr = Oe, Sf = Nr, Ir = dn, mn = {
|
|
6631
6631
|
binaryAccuracy: Sr,
|
|
6632
6632
|
categoricalAccuracy: Ar,
|
|
6633
6633
|
precision: mf,
|
|
@@ -6637,8 +6637,8 @@ const yf = vn, wf = vn, kf = xs, xf = xs, Nf = Ns, vf = Ns, Cr = Oe, Sf = vr, Ir
|
|
|
6637
6637
|
MSE: wf,
|
|
6638
6638
|
mae: kf,
|
|
6639
6639
|
MAE: xf,
|
|
6640
|
-
mape:
|
|
6641
|
-
MAPE:
|
|
6640
|
+
mape: vf,
|
|
6641
|
+
MAPE: Nf,
|
|
6642
6642
|
cosine: Sf
|
|
6643
6643
|
};
|
|
6644
6644
|
function Af(s) {
|
|
@@ -6924,7 +6924,7 @@ const Ef = (s) => {
|
|
|
6924
6924
|
const e = t[0].split("/");
|
|
6925
6925
|
return !isNaN(parseInt(e[e.length - 1], 10));
|
|
6926
6926
|
};
|
|
6927
|
-
class
|
|
6927
|
+
class Nt extends W {
|
|
6928
6928
|
constructor(t) {
|
|
6929
6929
|
if (super({}), this.containerNodes = /* @__PURE__ */ new Set(), this.name = t.name, this.name == null) {
|
|
6930
6930
|
const y = this.getClassName().toLowerCase();
|
|
@@ -6934,12 +6934,12 @@ class vt extends W {
|
|
|
6934
6934
|
throw new d(`The list of inputs passed to the model is redundant. All inputs should only appear once. Found: ${this.inputs.map((y) => y.name)}`);
|
|
6935
6935
|
jt(this.outputs).length !== this.outputs.length && console.warn(`The list of outputs passed to the model is redundant. All outputs should only appear once. Found: ${this.outputs.map((y) => y.name)}`), this.inputLayers = [], this.inputLayersNodeIndices = [], this.inputLayersTensorIndices = [], this.outputLayers = [], this.outputLayersNodeIndices = [], this.outputLayersTensorIndices = [], this.layers = [], this.internalContainerRefs = [];
|
|
6936
6936
|
for (const y of this.outputs) {
|
|
6937
|
-
const C = y.sourceLayer,
|
|
6938
|
-
this.outputLayers.push(C), this.outputLayersNodeIndices.push(
|
|
6937
|
+
const C = y.sourceLayer, v = y.nodeIndex, I = y.tensorIndex;
|
|
6938
|
+
this.outputLayers.push(C), this.outputLayersNodeIndices.push(v), this.outputLayersTensorIndices.push(I);
|
|
6939
6939
|
}
|
|
6940
6940
|
for (const y of this.inputs) {
|
|
6941
|
-
const C = y.sourceLayer,
|
|
6942
|
-
Ut(
|
|
6941
|
+
const C = y.sourceLayer, v = y.nodeIndex, I = y.tensorIndex;
|
|
6942
|
+
Ut(v === 0, "input layer has >1 nodes"), Ut(I === 0, "input layer has >1 tensors"), this.inputLayers.push(C), this.inputLayersNodeIndices.push(v), this.inputLayersTensorIndices.push(I);
|
|
6943
6943
|
}
|
|
6944
6944
|
this.inputNames = [], this.outputNames = [], this.feedInputShapes = [], this.feedInputNames = [], this.feedOutputNames = [];
|
|
6945
6945
|
for (let y = 0; y < this.inputLayers.length; y++) {
|
|
@@ -6951,21 +6951,21 @@ class vt extends W {
|
|
|
6951
6951
|
for (const y of this.outputLayers)
|
|
6952
6952
|
this.outputNames.push(y.name);
|
|
6953
6953
|
this.internalInputShapes = this.inputs.map((y) => y.shape), this.internalOutputShapes = this.outputs.map((y) => y.shape);
|
|
6954
|
-
const e = {}, n = {}, i = {}, r = {}, a = {}, o = [], l = (y, C,
|
|
6954
|
+
const e = {}, n = {}, i = {}, r = {}, a = {}, o = [], l = (y, C, v, I, z, _) => {
|
|
6955
6955
|
(I == null || z == null || _ == null) && (I = y.sourceLayer, z = y.nodeIndex, _ = y.tensorIndex);
|
|
6956
6956
|
const T = I.inboundNodes[z];
|
|
6957
|
-
if (
|
|
6957
|
+
if (v.indexOf(T) !== -1)
|
|
6958
6958
|
throw new Et(`The tensor ${y.name} at layer "${I.name}" is part of a cycle.`);
|
|
6959
6959
|
if (C.indexOf(T) !== -1)
|
|
6960
6960
|
return;
|
|
6961
|
-
this.containerNodes.add(
|
|
6961
|
+
this.containerNodes.add(Nt.nodeKey(I, z)), I.id in a || (a[I.id] = Object.keys(a).length), v.indexOf(T) === -1 && v.push(T);
|
|
6962
6962
|
const E = T.inboundLayers.length;
|
|
6963
6963
|
for (let R = 0; R < E; R++) {
|
|
6964
6964
|
const q = T.inputTensors[R], bt = T.inboundLayers[R], ie = T.nodeIndices[R], re = T.tensorIndices[R];
|
|
6965
|
-
l(q, C,
|
|
6965
|
+
l(q, C, v, bt, ie, re);
|
|
6966
6966
|
}
|
|
6967
|
-
for (C.push(T);
|
|
6968
|
-
|
|
6967
|
+
for (C.push(T); v.indexOf(T) >= 0; )
|
|
6968
|
+
v.splice(v.indexOf(T), 1);
|
|
6969
6969
|
o.push(T);
|
|
6970
6970
|
}, u = [], c = [];
|
|
6971
6971
|
for (const y of this.outputs)
|
|
@@ -6974,8 +6974,8 @@ class vt extends W {
|
|
|
6974
6974
|
for (const y of h) {
|
|
6975
6975
|
n[y.id] = y, y.id in e || (e[y.id] = 0);
|
|
6976
6976
|
let C = e[y.id];
|
|
6977
|
-
const
|
|
6978
|
-
C = Math.max(C,
|
|
6977
|
+
const v = i[y.outboundLayer.id] == null ? 0 : i[y.outboundLayer.id];
|
|
6978
|
+
C = Math.max(C, v), i[y.outboundLayer.id] = C, r[y.outboundLayer.id] = y.outboundLayer, e[y.id] = C;
|
|
6979
6979
|
for (let I = 0; I < y.inboundLayers.length; I++) {
|
|
6980
6980
|
const z = y.inboundLayers[I], _ = y.nodeIndices[I], T = z.inboundNodes[_], E = e[T.id] == null ? 0 : e[T.id];
|
|
6981
6981
|
e[T.id] = Math.max(C + 1, E), n[T.id] = T;
|
|
@@ -6995,35 +6995,35 @@ class vt extends W {
|
|
|
6995
6995
|
this.layers = [];
|
|
6996
6996
|
for (const y of g) {
|
|
6997
6997
|
const C = f[y];
|
|
6998
|
-
C.sort((
|
|
6999
|
-
const z = a[
|
|
6998
|
+
C.sort((v, I) => {
|
|
6999
|
+
const z = a[v.id], _ = a[I.id];
|
|
7000
7000
|
return z < _ ? -1 : z > _ ? 1 : 0;
|
|
7001
7001
|
});
|
|
7002
|
-
for (const
|
|
7003
|
-
|
|
7002
|
+
for (const v of C)
|
|
7003
|
+
v instanceof Nt && this.internalContainerRefs.push(v), this.layers.push(v);
|
|
7004
7004
|
}
|
|
7005
7005
|
this.layersByDepth = f, g = Object.keys(p).map((y) => parseInt(y, 10)).sort(Xe);
|
|
7006
7006
|
const b = this.inputs.slice(), m = [];
|
|
7007
7007
|
for (const y of g)
|
|
7008
7008
|
for (const C of p[y]) {
|
|
7009
|
-
const
|
|
7010
|
-
if (
|
|
7009
|
+
const v = C.outboundLayer;
|
|
7010
|
+
if (v != null) {
|
|
7011
7011
|
for (const I of C.inputTensors)
|
|
7012
7012
|
if (b.indexOf(I) === -1)
|
|
7013
|
-
throw new Et(`Graph disconnected: cannot obtain value for tensor ${I} at layer "${
|
|
7013
|
+
throw new Et(`Graph disconnected: cannot obtain value for tensor ${I} at layer "${v.name}". The following previous layers were accessed without issue: ${m}`);
|
|
7014
7014
|
for (const I of C.outputTensors)
|
|
7015
7015
|
b.push(I);
|
|
7016
|
-
m.push(
|
|
7016
|
+
m.push(v.name);
|
|
7017
7017
|
}
|
|
7018
7018
|
}
|
|
7019
7019
|
this.nodesByDepth = p;
|
|
7020
|
-
const
|
|
7021
|
-
for (const y of
|
|
7022
|
-
const C =
|
|
7020
|
+
const N = this.layers.map((y) => y.name);
|
|
7021
|
+
for (const y of N) {
|
|
7022
|
+
const C = N.filter((v) => v === y).length;
|
|
7023
7023
|
if (C !== 1)
|
|
7024
|
-
throw new Et(`The name "${y}" is used ${C} times in the model. All layer names should be unique. Layer names: ` + JSON.stringify(
|
|
7024
|
+
throw new Et(`The name "${y}" is used ${C} times in the model. All layer names should be unique. Layer names: ` + JSON.stringify(N));
|
|
7025
7025
|
}
|
|
7026
|
-
this.outboundNodes = [], this.inboundNodes = [], new
|
|
7026
|
+
this.outboundNodes = [], this.inboundNodes = [], new vn({
|
|
7027
7027
|
outboundLayer: this,
|
|
7028
7028
|
inboundLayers: [],
|
|
7029
7029
|
nodeIndices: [],
|
|
@@ -7255,8 +7255,8 @@ class vt extends W {
|
|
|
7255
7255
|
continue;
|
|
7256
7256
|
const h = [];
|
|
7257
7257
|
for (let b = 0; b < u.inboundLayers.length; b++) {
|
|
7258
|
-
const m = u.inboundLayers[b],
|
|
7259
|
-
h.push(
|
|
7258
|
+
const m = u.inboundLayers[b], N = u.nodeIndices[b], y = u.tensorIndices[b], C = `${m.name}_${N}_${y}`, v = n[C];
|
|
7259
|
+
h.push(v);
|
|
7260
7260
|
}
|
|
7261
7261
|
const p = c.computeOutputShape(ht(h)), f = ln(p), g = c.inboundNodes.indexOf(u);
|
|
7262
7262
|
for (let b = 0; b < f.length; b++) {
|
|
@@ -7301,16 +7301,16 @@ class vt extends W {
|
|
|
7301
7301
|
for (const b of p)
|
|
7302
7302
|
b.id in n && g.push(n[b.id]);
|
|
7303
7303
|
if (g.length === p.length) {
|
|
7304
|
-
let b = {}, m,
|
|
7304
|
+
let b = {}, m, N, y, C;
|
|
7305
7305
|
if (c.callArgs != null && (b = c.callArgs), g.length === 1) {
|
|
7306
|
-
const [
|
|
7307
|
-
b.mask == null && (b.mask = I), y = K(h.call(
|
|
7306
|
+
const [v, I] = g[0];
|
|
7307
|
+
b.mask == null && (b.mask = I), y = K(h.call(v, b)), C = K(h.computeMask(v, I)), m = [v], N = [I];
|
|
7308
7308
|
} else
|
|
7309
|
-
m = g.map((
|
|
7309
|
+
m = g.map((v) => v[0]), N = g.map((v) => v[1]), b.mask == null && (b.mask = N), y = K(h.call(m, b)), C = K(h.computeMask(m, N));
|
|
7310
7310
|
if (h.activityRegularizer)
|
|
7311
7311
|
throw new G("LayersModel invocation with concrete Tensor value(s) in the presence of activity regularizer(s) is not supported yet.");
|
|
7312
|
-
for (let
|
|
7313
|
-
const I = f[
|
|
7312
|
+
for (let v = 0; v < f.length; ++v) {
|
|
7313
|
+
const I = f[v], z = y[v], _ = C[v];
|
|
7314
7314
|
n[I.id] = [z, _];
|
|
7315
7315
|
}
|
|
7316
7316
|
}
|
|
@@ -7336,9 +7336,9 @@ class vt extends W {
|
|
|
7336
7336
|
const e = {};
|
|
7337
7337
|
let n;
|
|
7338
7338
|
for (const i of this.layers) {
|
|
7339
|
-
n = i instanceof
|
|
7339
|
+
n = i instanceof Nt ? 1 : 0;
|
|
7340
7340
|
for (let r = 0; r < i.inboundNodes.length; r++) {
|
|
7341
|
-
const a =
|
|
7341
|
+
const a = Nt.nodeKey(i, r);
|
|
7342
7342
|
this.containerNodes.has(a) && (e[a] = n, n += 1);
|
|
7343
7343
|
}
|
|
7344
7344
|
}
|
|
@@ -7371,7 +7371,7 @@ class vt extends W {
|
|
|
7371
7371
|
const t = [];
|
|
7372
7372
|
for (const e of this.layers)
|
|
7373
7373
|
for (let n = 0; n < e.inboundNodes.length; ++n) {
|
|
7374
|
-
const i =
|
|
7374
|
+
const i = Nt.nodeKey(e, n);
|
|
7375
7375
|
this.containerNodes.has(i) && t.push(...e.calculateLosses());
|
|
7376
7376
|
}
|
|
7377
7377
|
return t;
|
|
@@ -7382,7 +7382,7 @@ class vt extends W {
|
|
|
7382
7382
|
for (const a of this.layers) {
|
|
7383
7383
|
const o = a.getClassName(), l = a.getConfig(), u = [];
|
|
7384
7384
|
for (let h = 0; h < a.inboundNodes.length; h++) {
|
|
7385
|
-
const p = a.inboundNodes[h], f =
|
|
7385
|
+
const p = a.inboundNodes[h], f = Nt.nodeKey(a, h);
|
|
7386
7386
|
let g = {};
|
|
7387
7387
|
if (this.containerNodes.has(f)) {
|
|
7388
7388
|
if (p.callArgs)
|
|
@@ -7394,9 +7394,9 @@ class vt extends W {
|
|
|
7394
7394
|
if (p.inboundLayers.length > 0) {
|
|
7395
7395
|
const b = [];
|
|
7396
7396
|
for (let m = 0; m < p.inboundLayers.length; m++) {
|
|
7397
|
-
const
|
|
7398
|
-
let I = e[
|
|
7399
|
-
I == null && (I = 0), b.push([
|
|
7397
|
+
const N = p.inboundLayers[m], y = p.nodeIndices[m], C = p.tensorIndices[m], v = Nt.nodeKey(N, y);
|
|
7398
|
+
let I = e[v];
|
|
7399
|
+
I == null && (I = 0), b.push([N.name, I, C, g]);
|
|
7400
7400
|
}
|
|
7401
7401
|
u.push(b);
|
|
7402
7402
|
}
|
|
@@ -7408,7 +7408,7 @@ class vt extends W {
|
|
|
7408
7408
|
t.layers = n;
|
|
7409
7409
|
const i = [];
|
|
7410
7410
|
for (let a = 0; a < this.inputLayers.length; a++) {
|
|
7411
|
-
const o = this.inputLayers[a], l = this.inputLayersNodeIndices[a], u =
|
|
7411
|
+
const o = this.inputLayers[a], l = this.inputLayersNodeIndices[a], u = Nt.nodeKey(o, l);
|
|
7412
7412
|
if (!this.containerNodes.has(u))
|
|
7413
7413
|
continue;
|
|
7414
7414
|
let c = e[u];
|
|
@@ -7419,7 +7419,7 @@ class vt extends W {
|
|
|
7419
7419
|
t.inputLayers = i;
|
|
7420
7420
|
const r = [];
|
|
7421
7421
|
for (let a = 0; a < this.outputLayers.length; a++) {
|
|
7422
|
-
const o = this.outputLayers[a], l = this.outputLayersNodeIndices[a], u =
|
|
7422
|
+
const o = this.outputLayers[a], l = this.outputLayersNodeIndices[a], u = Nt.nodeKey(o, l);
|
|
7423
7423
|
if (!this.containerNodes.has(u))
|
|
7424
7424
|
continue;
|
|
7425
7425
|
let c = e[u];
|
|
@@ -7444,21 +7444,21 @@ class vt extends W {
|
|
|
7444
7444
|
/** @nocollapse */
|
|
7445
7445
|
static fromConfig(t, e, n = {}, i = !1) {
|
|
7446
7446
|
const r = {}, a = {};
|
|
7447
|
-
function o(m,
|
|
7448
|
-
m.name in a ? a[m.name].push(
|
|
7447
|
+
function o(m, N) {
|
|
7448
|
+
m.name in a ? a[m.name].push(N) : a[m.name] = [N];
|
|
7449
7449
|
}
|
|
7450
|
-
function l(m,
|
|
7450
|
+
function l(m, N) {
|
|
7451
7451
|
const y = [];
|
|
7452
7452
|
let C;
|
|
7453
|
-
for (const
|
|
7454
|
-
const I =
|
|
7455
|
-
if (C =
|
|
7456
|
-
o(m,
|
|
7453
|
+
for (const v of N) {
|
|
7454
|
+
const I = v[0], z = v[1], _ = v[2];
|
|
7455
|
+
if (C = v[3] == null ? {} : v[3], !(I in r)) {
|
|
7456
|
+
o(m, N);
|
|
7457
7457
|
return;
|
|
7458
7458
|
}
|
|
7459
7459
|
const T = r[I];
|
|
7460
7460
|
if (T.inboundNodes.length <= z) {
|
|
7461
|
-
o(m,
|
|
7461
|
+
o(m, N);
|
|
7462
7462
|
return;
|
|
7463
7463
|
}
|
|
7464
7464
|
const E = T.inboundNodes[z];
|
|
@@ -7467,11 +7467,11 @@ class vt extends W {
|
|
|
7467
7467
|
y.length > 0 && m.apply(ht(y), C);
|
|
7468
7468
|
}
|
|
7469
7469
|
function u(m) {
|
|
7470
|
-
const
|
|
7471
|
-
y.setFastWeightInitDuringBuild(i), r[
|
|
7472
|
-
if (!(
|
|
7473
|
-
throw new d(`Corrupted configuration, expected array for nodeData: ${
|
|
7474
|
-
o(y,
|
|
7470
|
+
const N = m.name, y = Wt(m, e.customObjects != null ? e.customObjects : {});
|
|
7471
|
+
y.setFastWeightInitDuringBuild(i), r[N] = y, m.inboundNodes.forEach((v) => {
|
|
7472
|
+
if (!(v instanceof Array))
|
|
7473
|
+
throw new d(`Corrupted configuration, expected array for nodeData: ${v}`);
|
|
7474
|
+
o(y, v);
|
|
7475
7475
|
});
|
|
7476
7476
|
}
|
|
7477
7477
|
const c = e.name, h = e.layers;
|
|
@@ -7479,26 +7479,26 @@ class vt extends W {
|
|
|
7479
7479
|
u(m);
|
|
7480
7480
|
for (; !zu(a); )
|
|
7481
7481
|
for (const m of h) {
|
|
7482
|
-
const
|
|
7483
|
-
if (
|
|
7484
|
-
const y = a[
|
|
7485
|
-
delete a[
|
|
7482
|
+
const N = r[m.name];
|
|
7483
|
+
if (N.name in a) {
|
|
7484
|
+
const y = a[N.name];
|
|
7485
|
+
delete a[N.name];
|
|
7486
7486
|
for (const C of y)
|
|
7487
|
-
l(
|
|
7487
|
+
l(N, C);
|
|
7488
7488
|
}
|
|
7489
7489
|
}
|
|
7490
7490
|
const p = [], f = [], g = e.inputLayers;
|
|
7491
7491
|
for (const m of g) {
|
|
7492
|
-
const
|
|
7493
|
-
Ut(
|
|
7494
|
-
const I = r[
|
|
7492
|
+
const N = m[0], y = m[1], C = m[2];
|
|
7493
|
+
Ut(N in r);
|
|
7494
|
+
const I = r[N].inboundNodes[y].outputTensors;
|
|
7495
7495
|
p.push(I[C]);
|
|
7496
7496
|
}
|
|
7497
7497
|
const b = e.outputLayers;
|
|
7498
7498
|
for (const m of b) {
|
|
7499
|
-
const
|
|
7500
|
-
Ut(
|
|
7501
|
-
const I = r[
|
|
7499
|
+
const N = m[0], y = m[1], C = m[2];
|
|
7500
|
+
Ut(N in r);
|
|
7501
|
+
const I = r[N].inboundNodes[y].outputTensors;
|
|
7502
7502
|
f.push(I[C]);
|
|
7503
7503
|
}
|
|
7504
7504
|
return new t({ inputs: p, outputs: f, name: c });
|
|
@@ -7652,7 +7652,7 @@ async function Rf(s, t, e) {
|
|
|
7652
7652
|
const o = s.makeTrainFunction(), l = s.getDedupedMetricsNames();
|
|
7653
7653
|
let u;
|
|
7654
7654
|
i ? u = l.slice().concat(l.map((m) => "val_" + m)) : u = l.slice();
|
|
7655
|
-
const c = xr(e.callbacks, e.yieldEvery), h = e.verbose == null ? 1 : e.verbose, { callbackList: p, history: f } =
|
|
7655
|
+
const c = xr(e.callbacks, e.yieldEvery), h = e.verbose == null ? 1 : e.verbose, { callbackList: p, history: f } = vr(
|
|
7656
7656
|
c,
|
|
7657
7657
|
h,
|
|
7658
7658
|
e.epochs,
|
|
@@ -7669,39 +7669,39 @@ async function Rf(s, t, e) {
|
|
|
7669
7669
|
for (; g < e.epochs; ) {
|
|
7670
7670
|
const m = {};
|
|
7671
7671
|
await p.onEpochBegin(g);
|
|
7672
|
-
let
|
|
7673
|
-
for (n || (b = await t.iterator()); !n ||
|
|
7672
|
+
let N = 0, y = 0;
|
|
7673
|
+
for (n || (b = await t.iterator()); !n || N < e.batchesPerEpoch; ) {
|
|
7674
7674
|
const C = await b.next();
|
|
7675
7675
|
if (n && C.done) {
|
|
7676
|
-
console.warn(`You provided \`batchesPerEpoch\` as ${e.batchesPerEpoch}, but your dataset iterator ran out of data after ${
|
|
7676
|
+
console.warn(`You provided \`batchesPerEpoch\` as ${e.batchesPerEpoch}, but your dataset iterator ran out of data after ${N} batches; interrupting training. Make sure that your dataset can generate at least \`batchesPerEpoch * epochs\` batches (in this case, ${e.batchesPerEpoch * e.epochs} batches). You may need to use the repeat() function when building your dataset.`);
|
|
7677
7677
|
break;
|
|
7678
7678
|
}
|
|
7679
7679
|
if (C.value != null) {
|
|
7680
|
-
const { xs:
|
|
7681
|
-
z.batch = y, z.size =
|
|
7680
|
+
const { xs: v, ys: I } = Er(s, C.value), z = {};
|
|
7681
|
+
z.batch = y, z.size = v[0].shape[0], await p.onBatchBegin(y, z);
|
|
7682
7682
|
const _ = [];
|
|
7683
7683
|
if (e.classWeight != null) {
|
|
7684
7684
|
const R = Tr(e.classWeight, s.outputNames);
|
|
7685
7685
|
for (let q = 0; q < R.length; ++q)
|
|
7686
7686
|
_.push(await $r(I[q], null, R[q]));
|
|
7687
7687
|
}
|
|
7688
|
-
const T =
|
|
7688
|
+
const T = v.concat(I).concat(_), E = o(T);
|
|
7689
7689
|
Z(T);
|
|
7690
7690
|
for (let R = 0; R < l.length; ++R) {
|
|
7691
7691
|
const q = l[R], bt = E[R];
|
|
7692
7692
|
z[q] = bt, Bt(bt);
|
|
7693
7693
|
}
|
|
7694
|
-
await p.onBatchEnd(y, z), kr(z), y++,
|
|
7694
|
+
await p.onBatchEnd(y, z), kr(z), y++, N++;
|
|
7695
7695
|
}
|
|
7696
|
-
if (n ?
|
|
7696
|
+
if (n ? N >= e.batchesPerEpoch : C.done) {
|
|
7697
7697
|
if (i) {
|
|
7698
|
-
let
|
|
7699
|
-
ai(e.validationData) ?
|
|
7698
|
+
let v;
|
|
7699
|
+
ai(e.validationData) ? v = K(await s.evaluateDataset(e.validationData, { batches: e.validationBatches })) : v = K(s.evaluate(r, a, {
|
|
7700
7700
|
batchSize: e.validationBatchSize == null ? Mf : e.validationBatchSize,
|
|
7701
7701
|
verbose: 0
|
|
7702
7702
|
}));
|
|
7703
7703
|
for (let I = 0; I < s.metricsNames.length; ++I)
|
|
7704
|
-
m[`val_${s.metricsNames[I]}`] =
|
|
7704
|
+
m[`val_${s.metricsNames[I]}`] = v[I];
|
|
7705
7705
|
}
|
|
7706
7706
|
break;
|
|
7707
7707
|
}
|
|
@@ -7745,8 +7745,8 @@ async function Wf(s, t, e) {
|
|
|
7745
7745
|
r.push(tt(0));
|
|
7746
7746
|
const g = p[0].shape[0];
|
|
7747
7747
|
for (let b = 0; b < f.length; ++b) {
|
|
7748
|
-
const m = f[b],
|
|
7749
|
-
r[b] = x(() => $(r[b], w(g, m))), l > 0 && Z(
|
|
7748
|
+
const m = f[b], N = r[b];
|
|
7749
|
+
r[b] = x(() => $(r[b], w(g, m))), l > 0 && Z(N);
|
|
7750
7750
|
}
|
|
7751
7751
|
Z(f), o += g, ++l;
|
|
7752
7752
|
}
|
|
@@ -7802,7 +7802,7 @@ function Lr(s) {
|
|
|
7802
7802
|
}
|
|
7803
7803
|
return t;
|
|
7804
7804
|
}
|
|
7805
|
-
function
|
|
7805
|
+
function vt(s, t) {
|
|
7806
7806
|
if (s == null)
|
|
7807
7807
|
return;
|
|
7808
7808
|
const e = [];
|
|
@@ -7917,7 +7917,7 @@ function Pf(s, t, e) {
|
|
|
7917
7917
|
}
|
|
7918
7918
|
function Uf(s, t, e) {
|
|
7919
7919
|
const n = [
|
|
7920
|
-
|
|
7920
|
+
Nn,
|
|
7921
7921
|
Sn,
|
|
7922
7922
|
Oe
|
|
7923
7923
|
];
|
|
@@ -7986,7 +7986,7 @@ function Vf(s, t) {
|
|
|
7986
7986
|
}
|
|
7987
7987
|
}
|
|
7988
7988
|
const jf = "layers-model";
|
|
7989
|
-
class we extends
|
|
7989
|
+
class we extends Nt {
|
|
7990
7990
|
constructor(t) {
|
|
7991
7991
|
super(t), this.isTraining = !1;
|
|
7992
7992
|
}
|
|
@@ -8094,8 +8094,8 @@ class we extends vt {
|
|
|
8094
8094
|
if (typeof g == "string" && ["accuracy", "acc", "crossentropy", "ce"].indexOf(g) !== -1) {
|
|
8095
8095
|
const m = this.internalOutputShapes[a];
|
|
8096
8096
|
m[m.length - 1] === 1 || this.lossFunctions[a] === Sn ? ["accuracy", "acc"].indexOf(g) !== -1 ? p = Sr : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = gf) : this.lossFunctions[a] === dn ? ["accuracy", "acc"].indexOf(g) !== -1 ? p = bf : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = Ir) : ["accuracy", "acc"].indexOf(g) !== -1 ? p = Ar : ["crossentropy", "ce"].indexOf(g) !== -1 && (p = Cr);
|
|
8097
|
-
let
|
|
8098
|
-
["accuracy", "acc"].indexOf(g) !== -1 ?
|
|
8097
|
+
let N;
|
|
8098
|
+
["accuracy", "acc"].indexOf(g) !== -1 ? N = "acc" : ["crossentropy", "ce"].indexOf(g) !== -1 && (N = "ce"), f = p, h = "" + N;
|
|
8099
8099
|
} else
|
|
8100
8100
|
f = Af(g), h = "" + tn(g);
|
|
8101
8101
|
let b;
|
|
@@ -8160,7 +8160,7 @@ class we extends vt {
|
|
|
8160
8160
|
const l = this.testFunction, u = this.testLoop(l, o, i, n.verbose, n.steps);
|
|
8161
8161
|
return ht(u);
|
|
8162
8162
|
} finally {
|
|
8163
|
-
|
|
8163
|
+
vt(a[0], t), vt(a[1], e);
|
|
8164
8164
|
}
|
|
8165
8165
|
}
|
|
8166
8166
|
// TODO(cais): Add code snippet below once real dataset objects are
|
|
@@ -8326,7 +8326,7 @@ class we extends vt {
|
|
|
8326
8326
|
const i = e.batchSize == null ? 32 : e.batchSize;
|
|
8327
8327
|
return En(i), this.predictLoop(n, i);
|
|
8328
8328
|
} finally {
|
|
8329
|
-
|
|
8329
|
+
vt(n, t);
|
|
8330
8330
|
}
|
|
8331
8331
|
}
|
|
8332
8332
|
/**
|
|
@@ -8400,8 +8400,8 @@ class we extends vt {
|
|
|
8400
8400
|
for (let m = 0; m < b.length; ++m)
|
|
8401
8401
|
o.push(tt(0));
|
|
8402
8402
|
for (let m = 0; m < b.length; ++m) {
|
|
8403
|
-
const
|
|
8404
|
-
o[m] = $(o[m], w(p - h,
|
|
8403
|
+
const N = b[m];
|
|
8404
|
+
o[m] = $(o[m], w(p - h, N));
|
|
8405
8405
|
}
|
|
8406
8406
|
}
|
|
8407
8407
|
for (let c = 0; c < o.length; ++c)
|
|
@@ -8443,18 +8443,18 @@ class we extends vt {
|
|
|
8443
8443
|
let g;
|
|
8444
8444
|
for (let b = 0; b < this.lossFunctions.length; ++b) {
|
|
8445
8445
|
const m = this.lossFunctions[b];
|
|
8446
|
-
let
|
|
8447
|
-
r[b] != null && (
|
|
8448
|
-
const y = at(
|
|
8449
|
-
e.push(y), b === 0 ? g =
|
|
8446
|
+
let N = m(i[b], f[b]);
|
|
8447
|
+
r[b] != null && (N = Ff(N, r[b]));
|
|
8448
|
+
const y = at(N);
|
|
8449
|
+
e.push(y), b === 0 ? g = N : g = $(g, N);
|
|
8450
8450
|
}
|
|
8451
8451
|
for (let b = 0; b < this.metricsTensors.length; ++b) {
|
|
8452
8452
|
let m;
|
|
8453
8453
|
if (this.outputs.length > 1 && b < this.outputs.length)
|
|
8454
8454
|
m = e[b];
|
|
8455
8455
|
else {
|
|
8456
|
-
const
|
|
8457
|
-
m = at(
|
|
8456
|
+
const N = this.metricsTensors[b][0], y = this.metricsTensors[b][1];
|
|
8457
|
+
m = at(N(i[y], f[y]));
|
|
8458
8458
|
}
|
|
8459
8459
|
Bt(m), a.push(m);
|
|
8460
8460
|
}
|
|
@@ -8533,7 +8533,7 @@ class we extends vt {
|
|
|
8533
8533
|
En(f);
|
|
8534
8534
|
const b = await this.standardizeUserData(t, e, n.sampleWeight, n.classWeight, !1, f);
|
|
8535
8535
|
i = b[0], r = b[1], p = b[2];
|
|
8536
|
-
let m = !1,
|
|
8536
|
+
let m = !1, N;
|
|
8537
8537
|
if (n.validationData != null && n.validationData.length > 0) {
|
|
8538
8538
|
if (m = !0, n.validationData.length === 2)
|
|
8539
8539
|
l = n.validationData[0], u = n.validationData[1];
|
|
@@ -8547,21 +8547,21 @@ class we extends vt {
|
|
|
8547
8547
|
!0,
|
|
8548
8548
|
f
|
|
8549
8549
|
);
|
|
8550
|
-
c = R[0], h = R[1],
|
|
8550
|
+
c = R[0], h = R[1], N = c.concat(h);
|
|
8551
8551
|
} else if (n.validationSplit != null && n.validationSplit > 0 && n.validationSplit < 1) {
|
|
8552
8552
|
m = !0;
|
|
8553
8553
|
const E = Math.floor(i[0].shape[0] * (1 - n.validationSplit)), R = i[0].shape[0];
|
|
8554
|
-
c = Te(i, E, R), a = i, i = Te(i, 0, E), h = Te(r, E, R), o = r, r = Te(r, 0, E),
|
|
8554
|
+
c = Te(i, E, R), a = i, i = Te(i, 0, E), h = Te(r, E, R), o = r, r = Te(r, 0, E), N = c.concat(h);
|
|
8555
8555
|
} else n.validationSteps != null && (m = !0);
|
|
8556
8556
|
const y = i.concat(r).concat(p);
|
|
8557
8557
|
this.checkTrainableWeightsConsistency();
|
|
8558
|
-
const C = this.makeTrainFunction(),
|
|
8558
|
+
const C = this.makeTrainFunction(), v = this.getDedupedMetricsNames();
|
|
8559
8559
|
let I, z;
|
|
8560
|
-
m ? (this.makeTestFunction(), I = this.testFunction, z =
|
|
8560
|
+
m ? (this.makeTestFunction(), I = this.testFunction, z = v.slice().concat(v.map((E) => "val_" + E))) : (I = null, N = [], z = v.slice());
|
|
8561
8561
|
const _ = xr(n.callbacks, n.yieldEvery);
|
|
8562
|
-
return await this.fitLoop(C, y,
|
|
8562
|
+
return await this.fitLoop(C, y, v, f, n.epochs, n.verbose, _, I, N, n.shuffle, z, n.initialEpoch, null, null);
|
|
8563
8563
|
} finally {
|
|
8564
|
-
this.isTraining = !1,
|
|
8564
|
+
this.isTraining = !1, vt(i, t), vt(r, e), vt(a, t), vt(o, e), vt(c, l), vt(h, u), p != null && Z(p);
|
|
8565
8565
|
}
|
|
8566
8566
|
}
|
|
8567
8567
|
/**
|
|
@@ -8597,20 +8597,20 @@ class we extends vt {
|
|
|
8597
8597
|
if (l != null && u != null && (b = !0), g != null && (b = !0, f == null))
|
|
8598
8598
|
throw new d("Can only use `validationSteps` when doing step-wise training, i.e., `stepsPerEpoch` must be set.");
|
|
8599
8599
|
const m = this.checkNumSamples(e, i, f, "steps_per_epoch");
|
|
8600
|
-
let
|
|
8601
|
-
m != null && (
|
|
8602
|
-
const { callbackList: y, history: C } =
|
|
8600
|
+
let N;
|
|
8601
|
+
m != null && (N = It(0, m)), a == null && (a = 1);
|
|
8602
|
+
const { callbackList: y, history: C } = vr(o, a, r, p, m, f, i, b, h);
|
|
8603
8603
|
y.setModel(this), this.history = C, await y.onTrainBegin(), this.stopTraining_ = !1;
|
|
8604
|
-
for (let
|
|
8605
|
-
await y.onEpochBegin(
|
|
8604
|
+
for (let v = p; v < r; ++v) {
|
|
8605
|
+
await y.onEpochBegin(v);
|
|
8606
8606
|
const I = {};
|
|
8607
8607
|
if (f != null)
|
|
8608
8608
|
throw new G("stepsPerEpoch mode is not implemented yet.");
|
|
8609
8609
|
{
|
|
8610
8610
|
if (c === "batch")
|
|
8611
8611
|
throw new G("batch shuffling is not implemneted yet");
|
|
8612
|
-
c && bu(
|
|
8613
|
-
const z = Rn(
|
|
8612
|
+
c && bu(N);
|
|
8613
|
+
const z = Rn(N), _ = Ln(m, i);
|
|
8614
8614
|
for (let T = 0; T < _.length; ++T) {
|
|
8615
8615
|
const E = {};
|
|
8616
8616
|
if (await y.onBatchBegin(T, E), x(() => {
|
|
@@ -8633,7 +8633,7 @@ class we extends vt {
|
|
|
8633
8633
|
}
|
|
8634
8634
|
z.dispose();
|
|
8635
8635
|
}
|
|
8636
|
-
if (await y.onEpochEnd(
|
|
8636
|
+
if (await y.onEpochEnd(v, I), this.stopTraining_)
|
|
8637
8637
|
break;
|
|
8638
8638
|
}
|
|
8639
8639
|
return await y.onTrainEnd(), await this.history.syncData(), this.history;
|
|
@@ -8693,7 +8693,7 @@ class we extends vt {
|
|
|
8693
8693
|
const c = await u.data();
|
|
8694
8694
|
l.push(c[0]);
|
|
8695
8695
|
}
|
|
8696
|
-
return Z(o),
|
|
8696
|
+
return Z(o), vt(n[0], t), vt(n[1], e), ht(l);
|
|
8697
8697
|
}
|
|
8698
8698
|
/**
|
|
8699
8699
|
* Extract weight values of the model.
|
|
@@ -9039,7 +9039,7 @@ class Re extends we {
|
|
|
9039
9039
|
throw new d("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");
|
|
9040
9040
|
this.checkShape(t), this.outputs = [t.inboundNodes[0].outputTensors[0]], this.inputs = dr(this.outputs[0]);
|
|
9041
9041
|
}
|
|
9042
|
-
this.inboundNodes = [], new
|
|
9042
|
+
this.inboundNodes = [], new vn({
|
|
9043
9043
|
outboundLayer: this,
|
|
9044
9044
|
inboundLayers: [],
|
|
9045
9045
|
nodeIndices: [],
|
|
@@ -9563,7 +9563,7 @@ class Vr extends ut {
|
|
|
9563
9563
|
}
|
|
9564
9564
|
Vr.className = "tanh";
|
|
9565
9565
|
S(Vr);
|
|
9566
|
-
let
|
|
9566
|
+
let Ns = class extends ut {
|
|
9567
9567
|
/**
|
|
9568
9568
|
* Calculate the activation function.
|
|
9569
9569
|
*
|
|
@@ -9580,8 +9580,8 @@ let vs = class extends ut {
|
|
|
9580
9580
|
return tr(t, e);
|
|
9581
9581
|
}
|
|
9582
9582
|
};
|
|
9583
|
-
|
|
9584
|
-
S(
|
|
9583
|
+
Ns.className = "softmax";
|
|
9584
|
+
S(Ns);
|
|
9585
9585
|
class jr extends ut {
|
|
9586
9586
|
/**
|
|
9587
9587
|
* Calculate the activation function of log softmax:
|
|
@@ -9858,7 +9858,7 @@ na.className = "ThresholdedReLU";
|
|
|
9858
9858
|
S(na);
|
|
9859
9859
|
class sa extends W {
|
|
9860
9860
|
constructor(t) {
|
|
9861
|
-
super(t ?? {}), this.DEFAULT_AXIS = 1, t == null && (t = {}), this.softmax = new
|
|
9861
|
+
super(t ?? {}), this.DEFAULT_AXIS = 1, t == null && (t = {}), this.softmax = new Ns().apply, this.axis = t.axis == null ? this.DEFAULT_AXIS : t.axis;
|
|
9862
9862
|
}
|
|
9863
9863
|
call(t, e) {
|
|
9864
9864
|
return x(() => {
|
|
@@ -10132,8 +10132,8 @@ class ra extends Ze {
|
|
|
10132
10132
|
this.dataFormat === "channelsFirst" ? (a = 2, o = 3) : (a = 1, o = 2);
|
|
10133
10133
|
const l = i[a], u = i[o], c = this.kernelSize[0], h = this.kernelSize[1], p = this.strides[0], f = this.strides[1], g = $t(l, p, c, this.padding), b = $t(u, f, h, this.padding), m = [r, g, b, this.filters];
|
|
10134
10134
|
this.dataFormat !== "channelsLast" && (n = j(n, [0, 2, 3, 1]));
|
|
10135
|
-
let
|
|
10136
|
-
return this.dataFormat !== "channelsLast" && (
|
|
10135
|
+
let N = Nc(n, this.kernel.read(), m, this.strides, this.padding);
|
|
10136
|
+
return this.dataFormat !== "channelsLast" && (N = j(N, [0, 3, 1, 2])), this.bias != null && (N = zt(N, this.bias.read(), this.dataFormat)), this.activation != null && (N = this.activation.apply(N)), N;
|
|
10137
10137
|
});
|
|
10138
10138
|
}
|
|
10139
10139
|
computeOutputShape(t) {
|
|
@@ -10173,7 +10173,7 @@ class aa extends Je {
|
|
|
10173
10173
|
const i = n.shape, r = i[0];
|
|
10174
10174
|
let a, o, l;
|
|
10175
10175
|
this.dataFormat === "channelsFirst" ? (l = 2, a = 3, o = 4) : (l = 1, a = 2, o = 3);
|
|
10176
|
-
const u = i[l], c = i[a], h = i[o], p = this.kernelSize[0], f = this.kernelSize[1], g = this.kernelSize[2], b = this.strides[0], m = this.strides[1],
|
|
10176
|
+
const u = i[l], c = i[a], h = i[o], p = this.kernelSize[0], f = this.kernelSize[1], g = this.kernelSize[2], b = this.strides[0], m = this.strides[1], N = this.strides[2], y = $t(u, b, p, this.padding), C = $t(c, m, f, this.padding), v = $t(h, N, g, this.padding), I = [r, y, C, v, this.filters];
|
|
10177
10177
|
this.dataFormat !== "channelsLast" && (n = j(n, [0, 2, 3, 4, 1]));
|
|
10178
10178
|
let z = Dc(n, this.kernel.read(), I, this.strides, this.padding);
|
|
10179
10179
|
return this.dataFormat !== "channelsLast" && (z = j(z, [0, 4, 1, 2, 3])), this.bias !== null && (z = zt(z, this.bias.read(), this.dataFormat)), this.activation !== null && (z = this.activation.apply(z)), z;
|
|
@@ -10419,16 +10419,16 @@ function da(s, t, e, n = !1, i, r, a = !1, o = !1) {
|
|
|
10419
10419
|
const f = t.shape[0], g = en(t);
|
|
10420
10420
|
let b;
|
|
10421
10421
|
i != null && (b = en(i));
|
|
10422
|
-
for (let
|
|
10423
|
-
const y = g[
|
|
10422
|
+
for (let N = 0; N < f; ++N) {
|
|
10423
|
+
const y = g[N], C = x(() => s(y, p));
|
|
10424
10424
|
if (i == null)
|
|
10425
10425
|
h = C[0], p = C[1];
|
|
10426
10426
|
else {
|
|
10427
|
-
const
|
|
10428
|
-
const I = b[
|
|
10427
|
+
const v = x(() => {
|
|
10428
|
+
const I = b[N], z = V(Dt(I), I), _ = $(w(C[0], I), w(p[0], z)), T = p.map((E, R) => $(w(C[1][R], I), w(E, z)));
|
|
10429
10429
|
return { output: _, newStates: T };
|
|
10430
10430
|
});
|
|
10431
|
-
h =
|
|
10431
|
+
h = v.output, p = v.newStates;
|
|
10432
10432
|
}
|
|
10433
10433
|
o && c.push(h);
|
|
10434
10434
|
}
|
|
@@ -10642,7 +10642,7 @@ class In extends W {
|
|
|
10642
10642
|
}
|
|
10643
10643
|
class As extends In {
|
|
10644
10644
|
constructor(t) {
|
|
10645
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout =
|
|
10645
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout = ve([1, qt([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = ve([
|
|
10646
10646
|
1,
|
|
10647
10647
|
qt([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
10648
10648
|
]), this.dropoutFunc = t.dropoutFunc, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -10726,7 +10726,7 @@ class Cs extends In {
|
|
|
10726
10726
|
constructor(t) {
|
|
10727
10727
|
if (super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", t.resetAfter)
|
|
10728
10728
|
throw new d("GRUCell does not support reset_after parameter set to true.");
|
|
10729
|
-
this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout =
|
|
10729
|
+
this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout = ve([1, qt([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = ve([
|
|
10730
10730
|
1,
|
|
10731
10731
|
qt([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
10732
10732
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -10760,10 +10760,10 @@ class Cs extends In {
|
|
|
10760
10760
|
0 < this.dropout && this.dropout < 1 && (t = w(t, r[0]));
|
|
10761
10761
|
let c = St(t, this.kernel.read());
|
|
10762
10762
|
this.useBias && (c = zt(c, this.bias.read())), 0 < this.recurrentDropout && this.recurrentDropout < 1 && (i = w(i, a[0]));
|
|
10763
|
-
const h = this.recurrentKernel.read(), [p, f] = Kt(h, [2 * this.units, this.units], h.rank - 1), g = St(i, p), [b, m,
|
|
10763
|
+
const h = this.recurrentKernel.read(), [p, f] = Kt(h, [2 * this.units, this.units], h.rank - 1), g = St(i, p), [b, m, N] = Kt(c, 3, c.rank - 1), [y, C] = Kt(g, 2, g.rank - 1);
|
|
10764
10764
|
o = this.recurrentActivation.apply($(b, y)), l = this.recurrentActivation.apply($(m, C));
|
|
10765
|
-
const
|
|
10766
|
-
u = this.activation.apply($(
|
|
10765
|
+
const v = St(w(l, i), f);
|
|
10766
|
+
u = this.activation.apply($(N, v));
|
|
10767
10767
|
const I = $(w(o, i), w($(1, pt(o)), u));
|
|
10768
10768
|
return [I, I];
|
|
10769
10769
|
});
|
|
@@ -10814,7 +10814,7 @@ ma.className = "GRU";
|
|
|
10814
10814
|
S(ma);
|
|
10815
10815
|
class Dn extends In {
|
|
10816
10816
|
constructor(t) {
|
|
10817
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout =
|
|
10817
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, ot(this.units, "units"), this.activation = Qt(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = Qt(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = J(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = J(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = J(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = rt(t.kernelConstraint), this.recurrentConstraint = rt(t.recurrentConstraint), this.biasConstraint = rt(t.biasConstraint), this.dropout = ve([1, qt([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = ve([
|
|
10818
10818
|
1,
|
|
10819
10819
|
qt([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
10820
10820
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = [this.units, this.units], this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -10869,8 +10869,8 @@ class Dn extends In {
|
|
|
10869
10869
|
0 < this.recurrentDropout && this.recurrentDropout < 1 && (i = w(i, o[0])), p = $(p, St(i, this.recurrentKernel.read())), this.useBias && (p = zt(p, this.bias.read()));
|
|
10870
10870
|
const [f, g, b, m] = Kt(p, 4, p.rank - 1);
|
|
10871
10871
|
l = this.recurrentActivation.apply(f), u = this.recurrentActivation.apply(g), c = $(w(u, r), w(l, this.activation.apply(b))), h = this.recurrentActivation.apply(m);
|
|
10872
|
-
const
|
|
10873
|
-
return [
|
|
10872
|
+
const N = w(h, this.activation.apply(c));
|
|
10873
|
+
return [N, N, c];
|
|
10874
10874
|
});
|
|
10875
10875
|
}
|
|
10876
10876
|
getConfig() {
|
|
@@ -11153,12 +11153,12 @@ class Ds extends Dn {
|
|
|
11153
11153
|
dropoutFunc: this.dropoutFunc
|
|
11154
11154
|
}));
|
|
11155
11155
|
const g = this.recurrentDropoutMask;
|
|
11156
|
-
let b = u(r, g, 0), m = u(r, g, 1),
|
|
11157
|
-
const C = 3, [
|
|
11158
|
-
c = this.inputConv(c,
|
|
11156
|
+
let b = u(r, g, 0), m = u(r, g, 1), N = u(r, g, 2), y = u(r, g, 3);
|
|
11157
|
+
const C = 3, [v, I, z, _] = Kt(this.kernel.read(), o, C), [T, E, R, q] = this.useBias ? Kt(this.bias.read(), o) : [null, null, null, null];
|
|
11158
|
+
c = this.inputConv(c, v, T, this.padding), h = this.inputConv(h, I, E, this.padding), p = this.inputConv(p, z, R, this.padding), f = this.inputConv(f, _, q, this.padding);
|
|
11159
11159
|
const [bt, ie, re, xt] = Kt(this.recurrentKernel.read(), o, C);
|
|
11160
|
-
b = this.recurrentConv(b, bt), m = this.recurrentConv(m, ie),
|
|
11161
|
-
const Tt = this.recurrentActivation.apply($(c, b)), me = this.recurrentActivation.apply($(h, m)), ze = $(w(me, a), w(Tt, this.activation.apply($(p,
|
|
11160
|
+
b = this.recurrentConv(b, bt), m = this.recurrentConv(m, ie), N = this.recurrentConv(N, re), y = this.recurrentConv(y, xt);
|
|
11161
|
+
const Tt = this.recurrentActivation.apply($(c, b)), me = this.recurrentActivation.apply($(h, m)), ze = $(w(me, a), w(Tt, this.activation.apply($(p, N)))), Ts = w(this.recurrentActivation.apply($(f, y)), this.activation.apply(ze));
|
|
11162
11162
|
return [Ts, Ts, ze];
|
|
11163
11163
|
});
|
|
11164
11164
|
}
|
|
@@ -11329,7 +11329,7 @@ class xa extends W {
|
|
|
11329
11329
|
}
|
|
11330
11330
|
xa.className = "Flatten";
|
|
11331
11331
|
S(xa);
|
|
11332
|
-
class
|
|
11332
|
+
class va extends W {
|
|
11333
11333
|
constructor(t) {
|
|
11334
11334
|
super(t), this.supportsMasking = !0, this.activation = Qt(t.activation);
|
|
11335
11335
|
}
|
|
@@ -11345,9 +11345,9 @@ class Na extends W {
|
|
|
11345
11345
|
return Object.assign(t, e), t;
|
|
11346
11346
|
}
|
|
11347
11347
|
}
|
|
11348
|
-
|
|
11349
|
-
S(
|
|
11350
|
-
class
|
|
11348
|
+
va.className = "Activation";
|
|
11349
|
+
S(va);
|
|
11350
|
+
class Na extends W {
|
|
11351
11351
|
constructor(t) {
|
|
11352
11352
|
super(t), this.n = t.n, this.inputSpec = [{ ndim: 2 }];
|
|
11353
11353
|
}
|
|
@@ -11364,8 +11364,8 @@ class va extends W {
|
|
|
11364
11364
|
return Object.assign(t, e), t;
|
|
11365
11365
|
}
|
|
11366
11366
|
}
|
|
11367
|
-
|
|
11368
|
-
S(
|
|
11367
|
+
Na.className = "RepeatVector";
|
|
11368
|
+
S(Na);
|
|
11369
11369
|
class Sa extends W {
|
|
11370
11370
|
constructor(t) {
|
|
11371
11371
|
super(t), this.targetShape = t.targetShape;
|
|
@@ -12110,16 +12110,16 @@ class _a extends W {
|
|
|
12110
12110
|
c.sort();
|
|
12111
12111
|
const h = !Ft(c, It(0, a).slice(0, a - 1)), p = () => {
|
|
12112
12112
|
if (h) {
|
|
12113
|
-
const y = A(this.movingMean.read(), u), C = A(this.movingVariance.read(), u),
|
|
12114
|
-
return _e(i, y, C,
|
|
12113
|
+
const y = A(this.movingMean.read(), u), C = A(this.movingVariance.read(), u), v = this.center ? A(this.beta.read(), u) : null, I = this.scale ? A(this.gamma.read(), u) : null;
|
|
12114
|
+
return _e(i, y, C, v, I, this.epsilon);
|
|
12115
12115
|
} else
|
|
12116
12116
|
return _e(i, this.movingMean.read(), this.movingVariance.read(), this.beta == null ? null : this.beta.read(), this.gamma == null ? null : this.gamma.read(), this.epsilon);
|
|
12117
12117
|
};
|
|
12118
12118
|
if (!n)
|
|
12119
12119
|
return p();
|
|
12120
|
-
const [f, g, b] = tm(i, this.gamma.read(), this.beta.read(), o, this.epsilon), m = (y, C,
|
|
12120
|
+
const [f, g, b] = tm(i, this.gamma.read(), this.beta.read(), o, this.epsilon), m = (y, C, v) => {
|
|
12121
12121
|
x(() => {
|
|
12122
|
-
const I = 1 -
|
|
12122
|
+
const I = 1 - v, z = y.read(), _ = w(V(z, C), I);
|
|
12123
12123
|
y.write(V(z, _));
|
|
12124
12124
|
});
|
|
12125
12125
|
};
|
|
@@ -12496,7 +12496,7 @@ class Qa extends Xa {
|
|
|
12496
12496
|
call(t, e) {
|
|
12497
12497
|
return x(() => {
|
|
12498
12498
|
const n = O(t);
|
|
12499
|
-
return
|
|
12499
|
+
return Ne(n, 1);
|
|
12500
12500
|
});
|
|
12501
12501
|
}
|
|
12502
12502
|
}
|
|
@@ -12531,7 +12531,7 @@ class no extends to {
|
|
|
12531
12531
|
call(t, e) {
|
|
12532
12532
|
return x(() => {
|
|
12533
12533
|
const n = O(t);
|
|
12534
|
-
return this.dataFormat === "channelsLast" ?
|
|
12534
|
+
return this.dataFormat === "channelsLast" ? Ne(n, [1, 2]) : Ne(n, [2, 3]);
|
|
12535
12535
|
});
|
|
12536
12536
|
}
|
|
12537
12537
|
}
|
|
@@ -12805,8 +12805,8 @@ class oo extends W {
|
|
|
12805
12805
|
t.rank === 3 ? (c = !0, u = kn([t])) : u = t;
|
|
12806
12806
|
for (let I = 0; I < u.shape[0]; I++)
|
|
12807
12807
|
m.push(b);
|
|
12808
|
-
const
|
|
12809
|
-
return c ? Lt(O(en(
|
|
12808
|
+
const N = Ku(m, [m.length, 4]), y = Hu(0, m.length, 1, "int32"), v = rm(u, N, y, [i, r], "nearest");
|
|
12809
|
+
return c ? Lt(O(en(v)), l) : Lt(v, l);
|
|
12810
12810
|
});
|
|
12811
12811
|
}
|
|
12812
12812
|
upsize(t, e, n, i) {
|
|
@@ -12896,7 +12896,7 @@ class lo extends W {
|
|
|
12896
12896
|
Received countWeights=${e.countWeights}`);
|
|
12897
12897
|
n = O(e.countWeights);
|
|
12898
12898
|
}
|
|
12899
|
-
const i =
|
|
12899
|
+
const i = Ne(t), r = qu(t), a = Gt(this.numTokens, i).bufferSync().get(0), o = Ue(r, 0).bufferSync().get(0);
|
|
12900
12900
|
if (!(a && o))
|
|
12901
12901
|
throw new d(`Input values must be between 0 < values <= numTokens with numTokens=${this.numTokens}`);
|
|
12902
12902
|
return am(t, this.outputMode, this.numTokens, n);
|