@genai-fi/nanogpt 0.6.3 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +11 -11
- package/dist/NanoGPTModel.d.ts +2 -2
- package/dist/NanoGPTModel.js +104 -136
- package/dist/{RealDiv-BYViZwhN.js → RealDiv-C4hOvYOZ.js} +26 -25
- package/dist/{Reshape-t7Kcikjk.js → Reshape-BLijOA8h.js} +5 -5
- package/dist/TeachableLLM.js +5 -5
- package/dist/{TiedEmbedding-9WeDwvjO.js → TiedEmbedding-BLltddza.js} +4 -4
- package/dist/{axis_util-Bu4h7XWV.js → axis_util-DaAl5MER.js} +3 -3
- package/dist/backend.d.ts +1 -0
- package/dist/backend.js +7 -0
- package/dist/backend_util-DWiwsi2N.js +749 -0
- package/dist/{broadcast_to-DARN-DBD.js → broadcast_to-C4v-j9yA.js} +2 -2
- package/dist/{concat-5aPGqw3Z.js → concat-CsHeR4zV.js} +8 -8
- package/dist/{dataset-pgqp-YfL.js → dataset-JDyjG3QR.js} +3 -3
- package/dist/{dropout-Bciw46HT.js → dropout-hpDwECTe.js} +7 -7
- package/dist/{gather-DjyCjmOD.js → gather-D0_gPiBz.js} +4 -4
- package/dist/gelu-uyHP1x1f.js +26 -0
- package/dist/gpgpu_math-DJm3ZTAf.js +2371 -0
- package/dist/index-BPPzKVdR.js +12099 -0
- package/dist/{index-BAzbokzv.js → index-C0dhsYom.js} +405 -389
- package/dist/{kernel_funcs_utils-CUxJCg0g.js → kernel_funcs_utils-CwRTFqrc.js} +31 -30
- package/dist/layers/BaseLayer.js +2 -2
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +5 -5
- package/dist/{log_sum_exp-YEo2h3gb.js → log_sum_exp-D086OgZJ.js} +15 -15
- package/dist/main.d.ts +2 -0
- package/dist/main.js +9 -5
- package/dist/{mat_mul-7121rsJk.js → mat_mul-1nwdPkQ_.js} +4 -4
- package/dist/{max-DtlIuVeW.js → max-BQc2Aj-I.js} +4 -4
- package/dist/{mulmat_packed_gpu-D4nKF7Je.js → mulmat_packed_gpu-Gzf3I9UV.js} +1 -1
- package/dist/non_max_suppression_impl-CsEgBuMA.js +134 -0
- package/dist/{ones-BBlSRqn1.js → ones-D63HpSF_.js} +2 -2
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +8 -8
- package/dist/ops/cpu/attentionMask.js +9 -9
- package/dist/ops/cpu/fusedSoftmax.js +17 -11
- package/dist/ops/cpu/gatherSub.js +7 -7
- package/dist/ops/cpu/gelu.js +13 -13
- package/dist/ops/cpu/matMulGelu.js +36 -24
- package/dist/ops/cpu/matMulMul.js +14 -8
- package/dist/ops/cpu/mulDropout.js +9 -3
- package/dist/ops/cpu/normRMS.js +5 -5
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +11 -11
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +3 -24
- package/dist/ops/grads/matMulGelu.js +5 -5
- package/dist/ops/grads/normRMS.js +6 -6
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +3 -3
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +5 -5
- package/dist/ops/webgl/matMulGelu.js +17 -17
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +4 -4
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/appendCache.d.ts +1 -0
- package/dist/ops/webgpu/appendCache.js +56 -0
- package/dist/ops/webgpu/attentionMask.d.ts +1 -0
- package/dist/ops/webgpu/attentionMask.js +64 -0
- package/dist/ops/webgpu/gatherSub.d.ts +1 -0
- package/dist/ops/webgpu/gatherSub.js +37 -0
- package/dist/ops/webgpu/gelu.d.ts +14 -0
- package/dist/ops/webgpu/gelu.js +86 -0
- package/dist/ops/webgpu/index.d.ts +0 -0
- package/dist/ops/webgpu/index.js +8 -0
- package/dist/ops/webgpu/normRMS.d.ts +1 -0
- package/dist/ops/webgpu/normRMS.js +115 -0
- package/dist/ops/webgpu/qkv.d.ts +1 -0
- package/dist/ops/webgpu/qkv.js +56 -0
- package/dist/ops/webgpu/rope.d.ts +1 -0
- package/dist/ops/webgpu/rope.js +68 -0
- package/dist/ops/webgpu/scatterSub.d.ts +1 -0
- package/dist/ops/webgpu/scatterSub.js +37 -0
- package/dist/{ops-C0sQEcPw.js → ops-CIQLNshk.js} +452 -503
- package/dist/{random_width-DWzaOgrn.js → random_width-DkYP8W8N.js} +143 -144
- package/dist/{range-DYsrnfiy.js → range-CYzpQY53.js} +1 -1
- package/dist/{reciprocal-CJQeasVa.js → reciprocal-_A9yv27J.js} +1 -1
- package/dist/{register_all_kernels-BfFCQAqs.js → register_all_kernels-guvSxp7M.js} +202 -200
- package/dist/{reshape-krWGKraP.js → reshape-BMUzc1UY.js} +3 -3
- package/dist/{scatter_nd_util-93ln7Hut.js → scatter_nd_util-IRBqKz_b.js} +3 -3
- package/dist/{selu_util-sntGesxr.js → selu_util-Dt_iuXaq.js} +6 -6
- package/dist/shared-BNa2q6jD.js +69 -0
- package/dist/{shared-Ca6iDobD.js → shared-CDu9S76h.js} +541 -606
- package/dist/{sin-D_h-qCSx.js → sin-Cocju-BY.js} +6 -6
- package/dist/{softmax-fsdtf6JC.js → softmax-GPNK3o-U.js} +3 -3
- package/dist/{split-eiktj-6L.js → split-CHzJjxDv.js} +4 -4
- package/dist/{stack-dfEEz2OY.js → stack-Dpgg_1W1.js} +2 -2
- package/dist/{sum-BE_Irnim.js → sum-B8wEpKsg.js} +5 -5
- package/dist/{tensor-Xyi595sG.js → tensor-RvZVNmg0.js} +1 -1
- package/dist/{tensor2d-CPEkynbH.js → tensor2d-B_kyod7_.js} +1 -1
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/Evaluator.js +1 -1
- package/dist/training/FullTrainer.js +20 -20
- package/dist/training/Trainer.d.ts +5 -6
- package/dist/training/Trainer.js +59 -60
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +19 -19
- package/dist/utilities/generate.js +15 -16
- package/dist/utilities/multinomialCPU.d.ts +2 -0
- package/dist/utilities/multinomialCPU.js +13 -0
- package/dist/utilities/performance.d.ts +2 -0
- package/dist/utilities/performance.js +16 -0
- package/dist/utilities/profile.d.ts +1 -0
- package/dist/utilities/profile.js +9 -6
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-wSS22xj5.js → variable-DXEUOwew.js} +1 -1
- package/dist/webgpu_util-g13LvDIv.js +625 -0
- package/dist/{zeros-YJDE7oRb.js → zeros-DCPCdFGq.js} +8 -8
- package/package.json +2 -1
- package/dist/gpgpu_math-CNslybmD.js +0 -3115
- package/dist/norm-CzltS9Fz.js +0 -86
|
@@ -1,25 +1,24 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { k as ke, c as Nt, o as ze, s as er, b as nr, d as Wu, m as sr, t as In, l as ir, v as os, a as Gu, S as Pu, p as Uu, w as as, x as rr, y as Je, z as Vu, A as ju } from "./selu_util-
|
|
3
|
-
import { s as Dt, n as wt, w as ne, a as Ze, g as
|
|
4
|
-
import { r as N } from "./reshape-
|
|
5
|
-
import { s as W } from "./sum-
|
|
6
|
-
import { m as ct } from "./mat_mul-
|
|
7
|
-
import { s as Qt } from "./split-
|
|
8
|
-
import { s as
|
|
9
|
-
import { e as Hn, g as lr, h as cs, c as
|
|
10
|
-
import { a as se, e as ie, l as Yu } from "./log_sum_exp-
|
|
11
|
-
import { s as Dn } from "./stack-
|
|
12
|
-
import { o as xe } from "./ones-
|
|
13
|
-
import { M as
|
|
14
|
-
import { z as vt } from "./zeros-
|
|
15
|
-
import { c as pe } from "./concat-
|
|
16
|
-
import { g as cr } from "./gather-
|
|
17
|
-
import { s as hr } from "./softmax-
|
|
18
|
-
import { m as Ee } from "./max-
|
|
19
|
-
import { t as
|
|
20
|
-
import { r as
|
|
21
|
-
import {
|
|
22
|
-
import { v as ic } from "./variable-wSS22xj5.js";
|
|
1
|
+
import { x as T, y as I, E as O, bE as Oa, bF as Ma, bG as Ci, l as b, bH as Ii, Q as L, bI as Di, bJ as $i, bK as Ti, bL as zi, h as Ei, bM as Li, bN as Fi, bO as Oi, bP as Mi, bQ as _a, bR as _i, bS as Ra, bT as Ri, bU as Ba, bV as Bi, F as Ge, k as kt, bq as Wa, bW as Wi, bX as Gi, bY as Pi, q as Pe, c as V, a as w, bZ as Ga, b_ as Ui, b$ as Vi, c0 as ji, p as ce, aU as pt, bv as Pa, c1 as Ki, c2 as Hi, c3 as qi, c4 as Ji, c5 as Zi, bx as Xi, c6 as Yi, c7 as Qi, C as Ua, ai as Va, ao as ja, c8 as tr, c9 as Ka, a6 as z, ca as Ps, cb as Ha, cc as qa, j as pn, cd as Us, ce as Ja, cf as Za, cg as Xa, A as Ya, ch as Qa, ci as tl, cj as el, bf as nl, ck as sl, aT as he, b as et, aj as U, cl as il, bm as rl, au as ht, cm as ol, z as Q, cn as al, co as ll, cp as ul, cq as cl, cr as hl, cs as pl, ct as dl, cu as fl, H as ml, cv as gl, bj as bl, bp as yl, cw as wl, M as kl, cx as xl, a3 as Nl, cy as vl, cz as Al, cA as Sl, aq as Cl, cB as Il, a7 as Dl, aV as $l, br as Tl, ah as zl, bs as El, G as Ll, aX as Fl, ak as Ol, cC as Ml, cD as _l, cE as Rl, al as Bl, a8 as Wl, cF as Gl, cG as Pl, cH as Ul, N as Vl, bt as jl, cI as Kl, cJ as Hl, aR as ql, b1 as Jl, cK as Zl, cL as Xl, bu as Yl, a$ as Ql, P as tu, cM as eu, n as rs, an as nu, bw as su, ax as iu, O as ru, as as ou, ar as au, I as lu, bc as uu, cN as cu, bd as hu, cO as pu, b3 as du, aP as fu, ap as mu, cP as gu, a4 as bu, aN as yu, S as wu, J as ku, bz as xu, cQ as Nu, bA as vu, at as Au, bC as Su, U as Cu, cR as Iu, L as Du, b5 as $u, b4 as Tu, cS as Oe, cT as zu, i as Eu, ag as Vs, cU as Lu, t as x, aS as $e, cV as S, cW as He, cX as qe, ab as Vt, d as J, ac as Fu, cY as js, o as Jt, K as Ou, T as Te, cZ as Mu, c_ as _u, m as Ks, c$ as Ru, d0 as Hs, d1 as Bu } from "./index-C0dhsYom.js";
|
|
2
|
+
import { k as ke, c as Nt, o as ze, s as er, b as nr, d as Wu, m as sr, t as In, l as ir, v as os, a as Gu, S as Pu, p as Uu, w as as, x as rr, y as Je, z as Vu, A as ju } from "./selu_util-Dt_iuXaq.js";
|
|
3
|
+
import { s as Dt, n as wt, w as ne, a as Ze, g as Xe, b as ls, t as K, c as Ce, d as Xt, e as Ku, u as dn, f as ye, h as fn, i as Hu, l as qu, j as us, m as or, k as qt, o as Ju } from "./ops-CIQLNshk.js";
|
|
4
|
+
import { r as N } from "./reshape-BMUzc1UY.js";
|
|
5
|
+
import { s as W } from "./sum-B8wEpKsg.js";
|
|
6
|
+
import { m as ct } from "./mat_mul-1nwdPkQ_.js";
|
|
7
|
+
import { s as Qt } from "./split-CHzJjxDv.js";
|
|
8
|
+
import { s as Zu, c as ar } from "./sin-Cocju-BY.js";
|
|
9
|
+
import { e as Hn, g as lr, h as cs, c as Xu } from "./axis_util-DaAl5MER.js";
|
|
10
|
+
import { a as se, e as ie, l as Yu } from "./log_sum_exp-D086OgZJ.js";
|
|
11
|
+
import { s as Dn } from "./stack-Dpgg_1W1.js";
|
|
12
|
+
import { o as xe } from "./ones-D63HpSF_.js";
|
|
13
|
+
import { M as Qu, f as ur, r as tc, d as ec, a as $n } from "./dropout-hpDwECTe.js";
|
|
14
|
+
import { z as vt } from "./zeros-DCPCdFGq.js";
|
|
15
|
+
import { c as pe } from "./concat-CsHeR4zV.js";
|
|
16
|
+
import { g as cr } from "./gather-D0_gPiBz.js";
|
|
17
|
+
import { s as hr } from "./softmax-GPNK3o-U.js";
|
|
18
|
+
import { m as Ee } from "./max-BQc2Aj-I.js";
|
|
19
|
+
import { t as nc } from "./tensor-RvZVNmg0.js";
|
|
20
|
+
import { r as sc } from "./range-CYzpQY53.js";
|
|
21
|
+
import { v as ic } from "./variable-DXEUOwew.js";
|
|
23
22
|
/**
|
|
24
23
|
* @license
|
|
25
24
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -531,13 +530,13 @@ const dr = /* @__PURE__ */ T({ depthwiseConv2d_: Zc });
|
|
|
531
530
|
* limitations under the License.
|
|
532
531
|
* =============================================================================
|
|
533
532
|
*/
|
|
534
|
-
function
|
|
533
|
+
function Xc(n, t) {
|
|
535
534
|
let e = I(n, "a", "equal", "string_or_numeric"), s = I(t, "b", "equal", "string_or_numeric");
|
|
536
535
|
[e, s] = Ge(e, s), kt(e.shape, s.shape);
|
|
537
536
|
const i = { a: e, b: s };
|
|
538
537
|
return O.runKernel(Wa, i);
|
|
539
538
|
}
|
|
540
|
-
const re = /* @__PURE__ */ T({ equal_:
|
|
539
|
+
const re = /* @__PURE__ */ T({ equal_: Xc });
|
|
541
540
|
/**
|
|
542
541
|
* @license
|
|
543
542
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -554,13 +553,13 @@ const re = /* @__PURE__ */ T({ equal_: Yc });
|
|
|
554
553
|
* limitations under the License.
|
|
555
554
|
* =============================================================================
|
|
556
555
|
*/
|
|
557
|
-
function
|
|
556
|
+
function Yc(n) {
|
|
558
557
|
let t = I(n, "x", "erf");
|
|
559
558
|
b(t.dtype === "int32" || t.dtype === "float32", () => "Input dtype must be `int32` or `float32`."), t.dtype === "int32" && (t = L(t, "float32"));
|
|
560
559
|
const e = { x: t };
|
|
561
560
|
return O.runKernel(Wi, e);
|
|
562
561
|
}
|
|
563
|
-
const Qc = /* @__PURE__ */ T({ erf_:
|
|
562
|
+
const Qc = /* @__PURE__ */ T({ erf_: Yc });
|
|
564
563
|
/**
|
|
565
564
|
* @license
|
|
566
565
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -900,7 +899,7 @@ const gn = /* @__PURE__ */ T({ reverse_: kh });
|
|
|
900
899
|
*/
|
|
901
900
|
function xh(n) {
|
|
902
901
|
const e = { x: I(n, "x", "rsqrt", "float32") };
|
|
903
|
-
return O.runKernel(
|
|
902
|
+
return O.runKernel(Xi, e);
|
|
904
903
|
}
|
|
905
904
|
const Nh = /* @__PURE__ */ T({ rsqrt_: xh });
|
|
906
905
|
/**
|
|
@@ -921,7 +920,7 @@ const Nh = /* @__PURE__ */ T({ rsqrt_: xh });
|
|
|
921
920
|
*/
|
|
922
921
|
function vh(n) {
|
|
923
922
|
const e = { x: I(n, "x", "selu") };
|
|
924
|
-
return O.runKernel(
|
|
923
|
+
return O.runKernel(Yi, e);
|
|
925
924
|
}
|
|
926
925
|
const Ah = /* @__PURE__ */ T({ selu_: vh });
|
|
927
926
|
function Sh(n, t, e, s, i, r = [1, 1], o = "NHWC") {
|
|
@@ -1060,7 +1059,7 @@ const bn = /* @__PURE__ */ T({ slice4d_: Eh });
|
|
|
1060
1059
|
function Lh(n, t = 0, e = 1, s, i) {
|
|
1061
1060
|
if (Ua(n), s != null && s === "bool")
|
|
1062
1061
|
throw new Error("Unsupported data type $ { dtype }");
|
|
1063
|
-
const r = new
|
|
1062
|
+
const r = new Qu(t, e, s, !0, i), o = Va(n, s);
|
|
1064
1063
|
for (let a = 0; a < o.values.length; a++)
|
|
1065
1064
|
o.values[a] = r.nextValue();
|
|
1066
1065
|
return o.toTensor();
|
|
@@ -1305,7 +1304,7 @@ function Uh({ a: n, b: t, transposeA: e = !1, transposeB: s = !1, bias: i, activ
|
|
|
1305
1304
|
return at([ut, ft, bt, mt]), { value: N(bt, C), gradFunc: E };
|
|
1306
1305
|
})(v, D, $);
|
|
1307
1306
|
}
|
|
1308
|
-
const
|
|
1307
|
+
const Xs = /* @__PURE__ */ T({ fusedMatMul_: Uh });
|
|
1309
1308
|
/**
|
|
1310
1309
|
* @license
|
|
1311
1310
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -1402,7 +1401,7 @@ class Vh {
|
|
|
1402
1401
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1403
1402
|
*/
|
|
1404
1403
|
static rmsprop(t, e = 0.9, s = 0, i = null, r = !1) {
|
|
1405
|
-
return new
|
|
1404
|
+
return new Xa(t, e, s, i, r);
|
|
1406
1405
|
}
|
|
1407
1406
|
/**
|
|
1408
1407
|
* Constructs a `tf.AdamOptimizer` that uses the Adam algorithm.
|
|
@@ -1417,7 +1416,7 @@ class Vh {
|
|
|
1417
1416
|
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
|
|
1418
1417
|
*/
|
|
1419
1418
|
static adam(t = 1e-3, e = 0.9, s = 0.999, i = null) {
|
|
1420
|
-
return new
|
|
1419
|
+
return new Ya(t, e, s, i);
|
|
1421
1420
|
}
|
|
1422
1421
|
/**
|
|
1423
1422
|
* Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
|
|
@@ -1663,7 +1662,7 @@ const Zh = {
|
|
|
1663
1662
|
* limitations under the License.
|
|
1664
1663
|
* =============================================================================
|
|
1665
1664
|
*/
|
|
1666
|
-
const
|
|
1665
|
+
const Xh = {
|
|
1667
1666
|
kernelName: Ci,
|
|
1668
1667
|
inputsToSave: ["x"],
|
|
1669
1668
|
gradFunc: (n, t) => {
|
|
@@ -1687,7 +1686,7 @@ const Yh = {
|
|
|
1687
1686
|
* limitations under the License.
|
|
1688
1687
|
* =============================================================================
|
|
1689
1688
|
*/
|
|
1690
|
-
const
|
|
1689
|
+
const Yh = {
|
|
1691
1690
|
kernelName: al,
|
|
1692
1691
|
inputsToSave: ["x"],
|
|
1693
1692
|
gradFunc: (n, t) => {
|
|
@@ -2092,7 +2091,7 @@ const mp = {
|
|
|
2092
2091
|
gradFunc: (n, t, e) => {
|
|
2093
2092
|
const [s] = t, { clipValueMin: i, clipValueMax: r } = e;
|
|
2094
2093
|
return {
|
|
2095
|
-
x: () => ne(Ze(
|
|
2094
|
+
x: () => ne(Ze(Xe(s, i), ls(s, r)), n, Q(n))
|
|
2096
2095
|
};
|
|
2097
2096
|
}
|
|
2098
2097
|
};
|
|
@@ -2270,7 +2269,7 @@ const vp = {
|
|
|
2270
2269
|
inputsToSave: ["x"],
|
|
2271
2270
|
gradFunc: (n, t) => {
|
|
2272
2271
|
const [e] = t;
|
|
2273
|
-
return { x: () => w(wt(
|
|
2272
|
+
return { x: () => w(wt(Zu(L(e, "float32"))), n) };
|
|
2274
2273
|
}
|
|
2275
2274
|
};
|
|
2276
2275
|
/**
|
|
@@ -2626,7 +2625,7 @@ const Mp = {
|
|
|
2626
2625
|
inputsToSave: ["x", "indices"],
|
|
2627
2626
|
gradFunc: (n, t, e) => {
|
|
2628
2627
|
const [s, i] = t, { axis: r, batchDims: o } = e, a = ce(r, s.shape)[0], l = (u, c, h) => () => {
|
|
2629
|
-
const p = u.shape, f = c.size, y = p.slice(0, a), g = y.length, m = p.slice(r, p.length).slice(1), A = m.length, k =
|
|
2628
|
+
const p = u.shape, f = c.size, y = p.slice(0, a), g = y.length, m = p.slice(r, p.length).slice(1), A = m.length, k = Ys(0, g), C = Ys(g + 1, g + 1 + A), v = Qs([
|
|
2630
2629
|
y,
|
|
2631
2630
|
[f],
|
|
2632
2631
|
m
|
|
@@ -2642,7 +2641,7 @@ const Mp = {
|
|
|
2642
2641
|
return { x: l(s, i, n), indices: () => i };
|
|
2643
2642
|
}
|
|
2644
2643
|
};
|
|
2645
|
-
function
|
|
2644
|
+
function Ys(n, t) {
|
|
2646
2645
|
const e = [];
|
|
2647
2646
|
for (let s = n; s < t; ++s)
|
|
2648
2647
|
e.push(s);
|
|
@@ -2779,7 +2778,7 @@ const Pp = {
|
|
|
2779
2778
|
kernelName: Bl,
|
|
2780
2779
|
inputsToSave: ["x"],
|
|
2781
2780
|
gradFunc: (n, t, e) => {
|
|
2782
|
-
const [s] = t, { alpha: i } = e, r =
|
|
2781
|
+
const [s] = t, { alpha: i } = e, r = Xt(s, 0);
|
|
2783
2782
|
return { x: () => ne(r, n, w(n, i)) };
|
|
2784
2783
|
}
|
|
2785
2784
|
};
|
|
@@ -2978,7 +2977,7 @@ const Jp = {
|
|
|
2978
2977
|
inputsToSave: ["a", "b"],
|
|
2979
2978
|
gradFunc: (n, t) => {
|
|
2980
2979
|
const [e, s] = t;
|
|
2981
|
-
return { a: () => w(n, L(
|
|
2980
|
+
return { a: () => w(n, L(Xe(e, s), "float32")), b: () => w(n, L(Ku(e, s), "float32")) };
|
|
2982
2981
|
}
|
|
2983
2982
|
};
|
|
2984
2983
|
/**
|
|
@@ -3016,7 +3015,7 @@ function Zp(n, t, e, s, i, r, o) {
|
|
|
3016
3015
|
const y = { dy: c, input: h, output: p }, g = { filterSize: s, strides: i, pad: r, dimRoundingMode: o }, m = O.runKernel(Kl, y, g);
|
|
3017
3016
|
return f ? N(m, [m.shape[1], m.shape[2], m.shape[3], m.shape[4]]) : m;
|
|
3018
3017
|
}
|
|
3019
|
-
const
|
|
3018
|
+
const Xp = /* @__PURE__ */ T({ maxPool3dGrad_: Zp });
|
|
3020
3019
|
/**
|
|
3021
3020
|
* @license
|
|
3022
3021
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3033,14 +3032,14 @@ const Yp = /* @__PURE__ */ T({ maxPool3dGrad_: Zp });
|
|
|
3033
3032
|
* limitations under the License.
|
|
3034
3033
|
* =============================================================================
|
|
3035
3034
|
*/
|
|
3036
|
-
const
|
|
3035
|
+
const Yp = {
|
|
3037
3036
|
kernelName: Vi,
|
|
3038
3037
|
inputsToSave: ["x"],
|
|
3039
3038
|
outputsToSave: [!0],
|
|
3040
3039
|
gradFunc: (n, t, e) => {
|
|
3041
3040
|
const [s, i] = t, { filterSize: r, strides: o, pad: a, dimRoundingMode: l } = e;
|
|
3042
3041
|
return {
|
|
3043
|
-
x: () =>
|
|
3042
|
+
x: () => Xp(n, s, i, r, o, a, l)
|
|
3044
3043
|
};
|
|
3045
3044
|
}
|
|
3046
3045
|
};
|
|
@@ -3114,7 +3113,7 @@ const nd = {
|
|
|
3114
3113
|
kernelName: ji,
|
|
3115
3114
|
inputsToSave: ["x"],
|
|
3116
3115
|
gradFunc: (n, t, e) => {
|
|
3117
|
-
const [s] = t, { axis: i } = e, r = ce(i, s.shape), a =
|
|
3116
|
+
const [s] = t, { axis: i } = e, r = ce(i, s.shape), a = Xu(s.shape, r)[1], l = pn(a);
|
|
3118
3117
|
return { x: () => {
|
|
3119
3118
|
const c = s.shape.slice();
|
|
3120
3119
|
r.forEach((f) => {
|
|
@@ -3173,7 +3172,7 @@ const id = {
|
|
|
3173
3172
|
inputsToSave: ["a", "b"],
|
|
3174
3173
|
gradFunc: (n, t) => {
|
|
3175
3174
|
const [e, s] = t;
|
|
3176
|
-
return { a: () => w(n, L(ls(e, s), "float32")), b: () => w(n, L(
|
|
3175
|
+
return { a: () => w(n, L(ls(e, s), "float32")), b: () => w(n, L(Xt(e, s), "float32")) };
|
|
3177
3176
|
}
|
|
3178
3177
|
};
|
|
3179
3178
|
/**
|
|
@@ -3217,7 +3216,7 @@ const rd = {
|
|
|
3217
3216
|
* =============================================================================
|
|
3218
3217
|
*/
|
|
3219
3218
|
const od = {
|
|
3220
|
-
kernelName:
|
|
3219
|
+
kernelName: Xl,
|
|
3221
3220
|
inputsToSave: ["a", "b"],
|
|
3222
3221
|
gradFunc: (n, t) => {
|
|
3223
3222
|
const [e, s] = t, i = kt(e.shape, s.shape);
|
|
@@ -3247,7 +3246,7 @@ const od = {
|
|
|
3247
3246
|
* =============================================================================
|
|
3248
3247
|
*/
|
|
3249
3248
|
const ad = {
|
|
3250
|
-
kernelName:
|
|
3249
|
+
kernelName: Yl,
|
|
3251
3250
|
inputsToSave: ["a", "b"],
|
|
3252
3251
|
gradFunc: (n, t) => {
|
|
3253
3252
|
const [e, s] = t, i = kt(e.shape, s.shape);
|
|
@@ -3400,7 +3399,7 @@ const pd = {
|
|
|
3400
3399
|
const p = ht(r.shape, a);
|
|
3401
3400
|
return p.length > 0 && (h = W(h, p)), N(h, r.shape);
|
|
3402
3401
|
}, b: () => {
|
|
3403
|
-
const c =
|
|
3402
|
+
const c = Xt(r, 0), h = ne(c, se(r), Q(r));
|
|
3404
3403
|
let p = w(n, w(i, h));
|
|
3405
3404
|
const f = ht(o.shape, a);
|
|
3406
3405
|
return f.length > 0 && (p = W(p, f)), N(p, o.shape);
|
|
@@ -3427,7 +3426,7 @@ const dd = {
|
|
|
3427
3426
|
kernelName: nu,
|
|
3428
3427
|
inputsToSave: ["x", "alpha"],
|
|
3429
3428
|
gradFunc: (n, t) => {
|
|
3430
|
-
const [e, s] = t, i =
|
|
3429
|
+
const [e, s] = t, i = Xt(e, 0);
|
|
3431
3430
|
return {
|
|
3432
3431
|
x: () => ne(i, n, w(n, s)),
|
|
3433
3432
|
alpha: () => {
|
|
@@ -3726,7 +3725,7 @@ const Sd = {
|
|
|
3726
3725
|
* =============================================================================
|
|
3727
3726
|
*/
|
|
3728
3727
|
const Cd = {
|
|
3729
|
-
kernelName:
|
|
3728
|
+
kernelName: Xi,
|
|
3730
3729
|
inputsToSave: ["x"],
|
|
3731
3730
|
gradFunc: (n, t) => {
|
|
3732
3731
|
const [e] = t;
|
|
@@ -3780,13 +3779,13 @@ const Id = {
|
|
|
3780
3779
|
* =============================================================================
|
|
3781
3780
|
*/
|
|
3782
3781
|
const Dd = {
|
|
3783
|
-
kernelName:
|
|
3782
|
+
kernelName: Yi,
|
|
3784
3783
|
inputsToSave: ["x"],
|
|
3785
3784
|
gradFunc: (n, t) => {
|
|
3786
3785
|
const [e] = t;
|
|
3787
3786
|
return {
|
|
3788
3787
|
x: () => {
|
|
3789
|
-
const s =
|
|
3788
|
+
const s = Xt(e, et(0)), i = et(Gu), r = et(Pu), o = w(n, r), a = w(w(n, i), ie(L(e, "float32")));
|
|
3790
3789
|
return ne(s, o, a);
|
|
3791
3790
|
}
|
|
3792
3791
|
};
|
|
@@ -4333,7 +4332,7 @@ const Hd = {
|
|
|
4333
4332
|
};
|
|
4334
4333
|
function qd(n, t) {
|
|
4335
4334
|
const e = Oe(t, Q(t)), s = cr(n, e);
|
|
4336
|
-
let i =
|
|
4335
|
+
let i = Xe(t, et(0, "int32"));
|
|
4337
4336
|
const r = s.rank - i.rank;
|
|
4338
4337
|
for (let a = 0; a < r; ++a)
|
|
4339
4338
|
i = ye(i, a + 1);
|
|
@@ -4383,8 +4382,8 @@ const Zd = [
|
|
|
4383
4382
|
qh,
|
|
4384
4383
|
Jh,
|
|
4385
4384
|
Zh,
|
|
4386
|
-
Yh,
|
|
4387
4385
|
Xh,
|
|
4386
|
+
Yh,
|
|
4388
4387
|
Qh,
|
|
4389
4388
|
tp,
|
|
4390
4389
|
ep,
|
|
@@ -4431,7 +4430,7 @@ const Zd = [
|
|
|
4431
4430
|
ti,
|
|
4432
4431
|
ti,
|
|
4433
4432
|
Jp,
|
|
4434
|
-
|
|
4433
|
+
Yp,
|
|
4435
4434
|
ed,
|
|
4436
4435
|
nd,
|
|
4437
4436
|
sd,
|
|
@@ -4636,7 +4635,7 @@ function Jn(n) {
|
|
|
4636
4635
|
}
|
|
4637
4636
|
}
|
|
4638
4637
|
}
|
|
4639
|
-
function
|
|
4638
|
+
function Ye(n, t = {}, e = {}, s = "object", i = !1) {
|
|
4640
4639
|
if (typeof n == "string") {
|
|
4641
4640
|
const r = n;
|
|
4642
4641
|
let o;
|
|
@@ -4683,11 +4682,11 @@ function Xe(n, t = {}, e = {}, s = "object", i = !1) {
|
|
|
4683
4682
|
}
|
|
4684
4683
|
}
|
|
4685
4684
|
}
|
|
4686
|
-
function
|
|
4685
|
+
function Xd(n, t) {
|
|
4687
4686
|
return n < t ? -1 : n > t ? 1 : 0;
|
|
4688
4687
|
}
|
|
4689
4688
|
function ln(n, t) {
|
|
4690
|
-
return -1 *
|
|
4689
|
+
return -1 * Xd(n, t);
|
|
4691
4690
|
}
|
|
4692
4691
|
function te(n) {
|
|
4693
4692
|
if (n == null)
|
|
@@ -4697,7 +4696,7 @@ function te(n) {
|
|
|
4697
4696
|
t.indexOf(e) === -1 && t.push(e);
|
|
4698
4697
|
return t;
|
|
4699
4698
|
}
|
|
4700
|
-
function
|
|
4699
|
+
function Yd(n) {
|
|
4701
4700
|
if (n == null)
|
|
4702
4701
|
throw new d(`Invalid value in obj: ${JSON.stringify(n)}`);
|
|
4703
4702
|
for (const t in n)
|
|
@@ -5023,7 +5022,7 @@ function Zn(n, t) {
|
|
|
5023
5022
|
return Ce(n, t);
|
|
5024
5023
|
}
|
|
5025
5024
|
function zn(n, t = 0, e = 1, s, i) {
|
|
5026
|
-
return
|
|
5025
|
+
return tc(n, t, e, s, i);
|
|
5027
5026
|
}
|
|
5028
5027
|
function Ut(n, t, e, s) {
|
|
5029
5028
|
if (n.rank < 2 || t.rank < 2)
|
|
@@ -5034,12 +5033,12 @@ function Ut(n, t, e, s) {
|
|
|
5034
5033
|
throw new B(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${n.shape} and y shape = ${t.shape}`);
|
|
5035
5034
|
}
|
|
5036
5035
|
if (n.rank === 2 && t.rank === 2)
|
|
5037
|
-
return
|
|
5036
|
+
return Xs({
|
|
5038
5037
|
a: n,
|
|
5039
5038
|
b: t,
|
|
5040
5039
|
transposeA: !1,
|
|
5041
5040
|
transposeB: !1,
|
|
5042
|
-
bias: s ?
|
|
5041
|
+
bias: s ? Xn(n.rank, s, _t()) : null,
|
|
5043
5042
|
activation: e
|
|
5044
5043
|
});
|
|
5045
5044
|
{
|
|
@@ -5048,12 +5047,12 @@ function Ut(n, t, e, s) {
|
|
|
5048
5047
|
const o = t.shape.slice(), a = o.pop(), l = o.pop(), u = [...o, a], c = Array.from({ length: t.rank }, (y, g) => g === 0 ? t.rank - 2 : g <= t.rank - 2 ? g - 1 : g);
|
|
5049
5048
|
t = N(K(t, c), [l, -1]);
|
|
5050
5049
|
const h = [...i, ...u];
|
|
5051
|
-
return N(
|
|
5050
|
+
return N(Xs({
|
|
5052
5051
|
a: n,
|
|
5053
5052
|
b: t,
|
|
5054
5053
|
transposeA: !1,
|
|
5055
5054
|
transposeB: !1,
|
|
5056
|
-
bias: s ?
|
|
5055
|
+
bias: s ? Xn(n.rank, s, _t()) : null,
|
|
5057
5056
|
activation: e
|
|
5058
5057
|
}), h);
|
|
5059
5058
|
}
|
|
@@ -5064,7 +5063,7 @@ function Ir(n, t, e) {
|
|
|
5064
5063
|
function tn(n) {
|
|
5065
5064
|
return w(n, n);
|
|
5066
5065
|
}
|
|
5067
|
-
function
|
|
5066
|
+
function Xn(n, t, e) {
|
|
5068
5067
|
const s = t.shape;
|
|
5069
5068
|
if (t.rank !== 1 && t.rank !== n)
|
|
5070
5069
|
throw new d(`Unexpected bias dimensions: ${t.rank}; expected it to be 1 or ${n}`);
|
|
@@ -5088,7 +5087,7 @@ function Yn(n, t, e) {
|
|
|
5088
5087
|
throw new d(`Unsupported input rank by biasAdd: ${t.rank}`);
|
|
5089
5088
|
}
|
|
5090
5089
|
function Rt(n, t, e) {
|
|
5091
|
-
return x(() => (e == null && (e = _t()), tt(e), z(n,
|
|
5090
|
+
return x(() => (e == null && (e = _t()), tt(e), z(n, Xn(n.rank, t, e))));
|
|
5092
5091
|
}
|
|
5093
5092
|
function ff(n, t = 1) {
|
|
5094
5093
|
if (t !== 1)
|
|
@@ -5099,7 +5098,7 @@ function mf(n) {
|
|
|
5099
5098
|
return x(() => U(n, z($e(n), 1)));
|
|
5100
5099
|
}
|
|
5101
5100
|
function Dr(n, t, e, s) {
|
|
5102
|
-
return x(() =>
|
|
5101
|
+
return x(() => ec(n, t, e, s));
|
|
5103
5102
|
}
|
|
5104
5103
|
function gf(n) {
|
|
5105
5104
|
return x(() => {
|
|
@@ -5438,9 +5437,9 @@ const ai = {
|
|
|
5438
5437
|
zeros: "Zeros"
|
|
5439
5438
|
};
|
|
5440
5439
|
function li(n, t = {}) {
|
|
5441
|
-
return
|
|
5440
|
+
return Ye(n, qe.getMap().classNameMap, t, "initializer");
|
|
5442
5441
|
}
|
|
5443
|
-
function
|
|
5442
|
+
function Y(n) {
|
|
5444
5443
|
return ks(n);
|
|
5445
5444
|
}
|
|
5446
5445
|
function Z(n) {
|
|
@@ -5473,7 +5472,7 @@ function Z(n) {
|
|
|
5473
5472
|
* https://opensource.org/licenses/MIT.
|
|
5474
5473
|
* =============================================================================
|
|
5475
5474
|
*/
|
|
5476
|
-
function
|
|
5475
|
+
function Yn(n) {
|
|
5477
5476
|
return Array.isArray(n) && Array.isArray(n[0]);
|
|
5478
5477
|
}
|
|
5479
5478
|
function yn(n) {
|
|
@@ -6438,14 +6437,14 @@ function zf(n, t) {
|
|
|
6438
6437
|
throw new d(`The dtype of the feed (${t.dtype}) can not be cast to the dtype of the key '${n.name}' (${n.dtype}).`);
|
|
6439
6438
|
}
|
|
6440
6439
|
}
|
|
6441
|
-
class
|
|
6440
|
+
class Yt {
|
|
6442
6441
|
/**
|
|
6443
6442
|
* Constructor, optionally does copy-construction.
|
|
6444
6443
|
* @param feeds An Array of `Feed`s, or another `FeedDict`, in which case
|
|
6445
6444
|
* copy-construction will be performed.
|
|
6446
6445
|
*/
|
|
6447
6446
|
constructor(t) {
|
|
6448
|
-
if (this.id2Value = {}, this.id2Mask = {}, this.name2Id = {}, t instanceof
|
|
6447
|
+
if (this.id2Value = {}, this.id2Mask = {}, this.name2Id = {}, t instanceof Yt)
|
|
6449
6448
|
for (const e in t.id2Value)
|
|
6450
6449
|
this.id2Value[e] = t.id2Value[e], e in t.id2Mask && (this.id2Mask[e] = t.id2Mask[e]);
|
|
6451
6450
|
else {
|
|
@@ -6551,7 +6550,7 @@ function Be(n, t, e, s) {
|
|
|
6551
6550
|
h = y.sorted, p = y.recipientCounts, kn.put(c, h), xn.put(c, p);
|
|
6552
6551
|
}
|
|
6553
6552
|
p = {}, i || Object.assign(p, xn.get(c));
|
|
6554
|
-
const f = new
|
|
6553
|
+
const f = new Yt(t);
|
|
6555
6554
|
for (let y = 0; y < h.length; ++y) {
|
|
6556
6555
|
const g = h[y], m = g.sourceLayer;
|
|
6557
6556
|
if (m instanceof nn)
|
|
@@ -6743,7 +6742,7 @@ function rt(n) {
|
|
|
6743
6742
|
return ks(n);
|
|
6744
6743
|
}
|
|
6745
6744
|
function pi(n, t = {}) {
|
|
6746
|
-
return
|
|
6745
|
+
return Ye(n, qe.getMap().classNameMap, t, "constraint");
|
|
6747
6746
|
}
|
|
6748
6747
|
function ot(n) {
|
|
6749
6748
|
if (n == null)
|
|
@@ -7088,7 +7087,7 @@ function Ur(n, t, e, s, i, r, o, a, l) {
|
|
|
7088
7087
|
* =============================================================================
|
|
7089
7088
|
*/
|
|
7090
7089
|
function Zt(n, t = {}, e = !1) {
|
|
7091
|
-
return
|
|
7090
|
+
return Ye(n, qe.getMap().classNameMap, t, "layer", e);
|
|
7092
7091
|
}
|
|
7093
7092
|
/**
|
|
7094
7093
|
* @license
|
|
@@ -7235,7 +7234,7 @@ function Un(n) {
|
|
|
7235
7234
|
*/
|
|
7236
7235
|
function jr(n, t) {
|
|
7237
7236
|
return x(() => {
|
|
7238
|
-
const e = w(0.5, Ot(t)), s = Pt(
|
|
7237
|
+
const e = w(0.5, Ot(t)), s = Pt(Xt(t, e), n.dtype);
|
|
7239
7238
|
return nt(re(n, s), -1);
|
|
7240
7239
|
});
|
|
7241
7240
|
}
|
|
@@ -7245,13 +7244,13 @@ function Kr(n, t) {
|
|
|
7245
7244
|
function Zf(n, t) {
|
|
7246
7245
|
return x(() => L(W(Ze(re(n, 1), re(t, 1))), "float32"));
|
|
7247
7246
|
}
|
|
7248
|
-
function
|
|
7247
|
+
function Xf(n, t) {
|
|
7249
7248
|
return x(() => L(W(Ze(re(n, 0), re(t, 1))), "float32"));
|
|
7250
7249
|
}
|
|
7251
|
-
function
|
|
7250
|
+
function Yf(n, t) {
|
|
7252
7251
|
return x(() => {
|
|
7253
|
-
const e = Zf(n, t), s =
|
|
7254
|
-
return L(ne(
|
|
7252
|
+
const e = Zf(n, t), s = Xf(n, t), i = z(e, s);
|
|
7253
|
+
return L(ne(Xt(i, 0), U(e, i), 0), "float32");
|
|
7255
7254
|
});
|
|
7256
7255
|
}
|
|
7257
7256
|
function Qf(n, t) {
|
|
@@ -7263,7 +7262,7 @@ function tm(n, t) {
|
|
|
7263
7262
|
const em = Ln, nm = Ln, sm = Es, im = Es, rm = Ls, om = Ls, Hr = Ve, am = Vr, qr = vn, Sn = {
|
|
7264
7263
|
binaryAccuracy: jr,
|
|
7265
7264
|
categoricalAccuracy: Kr,
|
|
7266
|
-
precision:
|
|
7265
|
+
precision: Yf,
|
|
7267
7266
|
categoricalCrossentropy: Hr,
|
|
7268
7267
|
sparseCategoricalCrossentropy: qr,
|
|
7269
7268
|
mse: em,
|
|
@@ -7838,7 +7837,7 @@ class Lt extends G {
|
|
|
7838
7837
|
call(t, e) {
|
|
7839
7838
|
return x(() => {
|
|
7840
7839
|
t = H(t);
|
|
7841
|
-
const s = new
|
|
7840
|
+
const s = new Yt();
|
|
7842
7841
|
for (let i = 0; i < this.inputs.length; ++i)
|
|
7843
7842
|
s.add(this.inputs[i], t[i]);
|
|
7844
7843
|
return Be(this.outputs, s, e);
|
|
@@ -8110,7 +8109,7 @@ class Lt extends G {
|
|
|
8110
8109
|
const c = e.name, h = e.layers;
|
|
8111
8110
|
for (const m of h)
|
|
8112
8111
|
u(m);
|
|
8113
|
-
for (; !
|
|
8112
|
+
for (; !Yd(o); )
|
|
8114
8113
|
for (const m of h) {
|
|
8115
8114
|
const A = r[m.name];
|
|
8116
8115
|
if (A.name in o) {
|
|
@@ -8191,10 +8190,10 @@ function gm(n, t, e) {
|
|
|
8191
8190
|
} else
|
|
8192
8191
|
throw new Error(`The model has multiple (${s}) outputs, so ${e} must be either an array with ${s} elements or an object with ${t} keys. Provided ${e} not understood: ${JSON.stringify(n)}`);
|
|
8193
8192
|
}
|
|
8194
|
-
function
|
|
8193
|
+
function Xr(n, t) {
|
|
8195
8194
|
return gm(n, t, "classWeight");
|
|
8196
8195
|
}
|
|
8197
|
-
async function
|
|
8196
|
+
async function Yr(n, t, e, s) {
|
|
8198
8197
|
if (e != null) {
|
|
8199
8198
|
const i = x(() => {
|
|
8200
8199
|
if (n.shape.length === 1)
|
|
@@ -8314,9 +8313,9 @@ async function km(n, t, e) {
|
|
|
8314
8313
|
$.batch = k, $.size = v[0].shape[0], await p.onBatchBegin(k, $);
|
|
8315
8314
|
const R = [];
|
|
8316
8315
|
if (e.classWeight != null) {
|
|
8317
|
-
const M =
|
|
8316
|
+
const M = Xr(e.classWeight, n.outputNames);
|
|
8318
8317
|
for (let P = 0; P < M.length; ++P)
|
|
8319
|
-
R.push(await
|
|
8318
|
+
R.push(await Yr(D[P], null, M[P]));
|
|
8320
8319
|
}
|
|
8321
8320
|
const E = v.concat(D).concat(R), F = a(E);
|
|
8322
8321
|
J(E);
|
|
@@ -8852,7 +8851,7 @@ class Ie extends Lt {
|
|
|
8852
8851
|
execute(t, e) {
|
|
8853
8852
|
if (Array.isArray(e) && e.length === 0)
|
|
8854
8853
|
throw new d("`outputs` is an empty Array, which is not allowed.");
|
|
8855
|
-
const s = Array.isArray(e), i = s ? e : [e], r = this.retrieveSymbolicTensors(i), o = new
|
|
8854
|
+
const s = Array.isArray(e), i = s ? e : [e], r = this.retrieveSymbolicTensors(i), o = new Yt();
|
|
8856
8855
|
if (t instanceof Te && (t = [t]), Array.isArray(t)) {
|
|
8857
8856
|
if (t.length !== this.inputs.length)
|
|
8858
8857
|
throw new d(`The number of inputs provided (${t.length}) does not match the number of inputs of this model (${this.inputs.length}).`);
|
|
@@ -8919,7 +8918,7 @@ class Ie extends Lt {
|
|
|
8919
8918
|
p.push({ key: this.inputs[y], value: h[y] });
|
|
8920
8919
|
else
|
|
8921
8920
|
p.push({ key: this.inputs[0], value: h });
|
|
8922
|
-
const f = new
|
|
8921
|
+
const f = new Yt(p);
|
|
8923
8922
|
return Be(this.outputs, f);
|
|
8924
8923
|
}).forEach((u, c) => o[c].push(u));
|
|
8925
8924
|
return yt(o.map((a) => pe(a, 0)));
|
|
@@ -9000,10 +8999,10 @@ class Ie extends Lt {
|
|
|
9000
8999
|
throw new Error("sample weight is not supported yet.");
|
|
9001
9000
|
let u = null;
|
|
9002
9001
|
if (i != null) {
|
|
9003
|
-
const c =
|
|
9002
|
+
const c = Xr(i, this.outputNames);
|
|
9004
9003
|
u = [];
|
|
9005
9004
|
for (let h = 0; h < c.length; ++h)
|
|
9006
|
-
u.push(await
|
|
9005
|
+
u.push(await Yr(l[h], null, c[h]));
|
|
9007
9006
|
}
|
|
9008
9007
|
return [a, l, u];
|
|
9009
9008
|
}
|
|
@@ -9072,7 +9071,7 @@ class Ie extends Lt {
|
|
|
9072
9071
|
const h = [];
|
|
9073
9072
|
for (let g = 0; g < this.inputs.length; ++g)
|
|
9074
9073
|
h.push({ key: this.inputs[g], value: s[g] });
|
|
9075
|
-
const p = new
|
|
9074
|
+
const p = new Yt(h), f = Be(this.outputs, p, { training: !0 });
|
|
9076
9075
|
let y;
|
|
9077
9076
|
for (let g = 0; g < this.lossFunctions.length; ++g) {
|
|
9078
9077
|
const m = this.lossFunctions[g];
|
|
@@ -9110,7 +9109,7 @@ class Ie extends Lt {
|
|
|
9110
9109
|
const i = t.slice(0, this.inputs.length), r = t.slice(this.inputs.length, this.inputs.length + this.outputs.length), o = [];
|
|
9111
9110
|
for (let u = 0; u < this.inputs.length; ++u)
|
|
9112
9111
|
o.push({ key: this.inputs[u], value: i[u] });
|
|
9113
|
-
const a = new
|
|
9112
|
+
const a = new Yt(o), l = Be(this.outputs, a);
|
|
9114
9113
|
for (let u = 0; u < this.lossFunctions.length; ++u) {
|
|
9115
9114
|
const c = this.lossFunctions[u], h = nt(c(r[u], l[u]));
|
|
9116
9115
|
u === 0 ? s = h : s = z(s, h), e.push(s);
|
|
@@ -10295,7 +10294,7 @@ function ae(n) {
|
|
|
10295
10294
|
return n.getClassName();
|
|
10296
10295
|
}
|
|
10297
10296
|
function Kn(n, t = {}) {
|
|
10298
|
-
return
|
|
10297
|
+
return Ye(n, qe.getMap().classNameMap, t, "activation");
|
|
10299
10298
|
}
|
|
10300
10299
|
function le(n) {
|
|
10301
10300
|
if (n == null) {
|
|
@@ -10353,9 +10352,9 @@ function q(n) {
|
|
|
10353
10352
|
return ks(n);
|
|
10354
10353
|
}
|
|
10355
10354
|
function Ni(n, t = {}) {
|
|
10356
|
-
return
|
|
10355
|
+
return Ye(n, qe.getMap().classNameMap, t, "regularizer");
|
|
10357
10356
|
}
|
|
10358
|
-
function
|
|
10357
|
+
function X(n) {
|
|
10359
10358
|
if (n == null)
|
|
10360
10359
|
return null;
|
|
10361
10360
|
if (typeof n == "string") {
|
|
@@ -10411,7 +10410,7 @@ xo.className = "LeakyReLU";
|
|
|
10411
10410
|
S(xo);
|
|
10412
10411
|
class No extends G {
|
|
10413
10412
|
constructor(t) {
|
|
10414
|
-
if (super(t ?? {}), this.DEFAULT_ALPHA_INITIALIZER = "zeros", t == null && (t = {}), this.supportsMasking = !0, this.alphaInitializer = Z(t.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER), this.alphaRegularizer =
|
|
10413
|
+
if (super(t ?? {}), this.DEFAULT_ALPHA_INITIALIZER = "zeros", t == null && (t = {}), this.supportsMasking = !0, this.alphaInitializer = Z(t.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER), this.alphaRegularizer = X(t.alphaRegularizer), this.alphaConstraint = ot(t.alphaConstraint), t.sharedAxes == null)
|
|
10415
10414
|
this.sharedAxes = null;
|
|
10416
10415
|
else if (Array.isArray(t.sharedAxes))
|
|
10417
10416
|
this.sharedAxes = t.sharedAxes;
|
|
@@ -10441,7 +10440,7 @@ class No extends G {
|
|
|
10441
10440
|
}
|
|
10442
10441
|
getConfig() {
|
|
10443
10442
|
const t = {
|
|
10444
|
-
alphaInitializer:
|
|
10443
|
+
alphaInitializer: Y(this.alphaInitializer),
|
|
10445
10444
|
alphaRegularizer: q(this.alphaRegularizer),
|
|
10446
10445
|
alphaConstraint: rt(this.alphaConstraint),
|
|
10447
10446
|
sharedAxes: this.sharedAxes
|
|
@@ -10477,7 +10476,7 @@ class Ao extends G {
|
|
|
10477
10476
|
}
|
|
10478
10477
|
call(t, e) {
|
|
10479
10478
|
const s = _(t);
|
|
10480
|
-
return w(s, L(
|
|
10479
|
+
return w(s, L(Xt(s, this.theta), "float32"));
|
|
10481
10480
|
}
|
|
10482
10481
|
computeOutputShape(t) {
|
|
10483
10482
|
return t;
|
|
@@ -10619,7 +10618,7 @@ class On extends G {
|
|
|
10619
10618
|
constructor(t, e) {
|
|
10620
10619
|
if (super(e), this.bias = null, this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_BIAS_INITIALIZER = "zeros", On.verifyArgs(e), this.rank = t, lt(this.rank, "rank"), this.rank !== 1 && this.rank !== 2 && this.rank !== 3)
|
|
10621
10620
|
throw new B(`Convolution layer for rank other than 1, 2, or 3 (${this.rank}) is not implemented yet.`);
|
|
10622
|
-
if (this.kernelSize = De(e.kernelSize, t, "kernelSize"), this.strides = De(e.strides == null ? 1 : e.strides, t, "strides"), this.padding = e.padding == null ? "valid" : e.padding, At(this.padding), this.dataFormat = e.dataFormat == null ? "channelsLast" : e.dataFormat, tt(this.dataFormat), this.activation = le(e.activation), this.useBias = e.useBias == null ? !0 : e.useBias, this.biasInitializer = Z(e.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.biasConstraint = ot(e.biasConstraint), this.biasRegularizer =
|
|
10621
|
+
if (this.kernelSize = De(e.kernelSize, t, "kernelSize"), this.strides = De(e.strides == null ? 1 : e.strides, t, "strides"), this.padding = e.padding == null ? "valid" : e.padding, At(this.padding), this.dataFormat = e.dataFormat == null ? "channelsLast" : e.dataFormat, tt(this.dataFormat), this.activation = le(e.activation), this.useBias = e.useBias == null ? !0 : e.useBias, this.biasInitializer = Z(e.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.biasConstraint = ot(e.biasConstraint), this.biasRegularizer = X(e.biasRegularizer), this.activityRegularizer = X(e.activityRegularizer), this.dilationRate = De(e.dilationRate == null ? 1 : e.dilationRate, t, "dilationRate"), this.rank === 1 && Array.isArray(this.dilationRate) && this.dilationRate.length !== 1)
|
|
10623
10622
|
throw new d(`dilationRate must be a number or an array of a single number for 1D convolution, but received ${JSON.stringify(this.dilationRate)}`);
|
|
10624
10623
|
if (this.rank === 2) {
|
|
10625
10624
|
if (typeof this.dilationRate == "number")
|
|
@@ -10646,7 +10645,7 @@ class On extends G {
|
|
|
10646
10645
|
dilationRate: this.dilationRate,
|
|
10647
10646
|
activation: ae(this.activation),
|
|
10648
10647
|
useBias: this.useBias,
|
|
10649
|
-
biasInitializer:
|
|
10648
|
+
biasInitializer: Y(this.biasInitializer),
|
|
10650
10649
|
biasRegularizer: q(this.biasRegularizer),
|
|
10651
10650
|
activityRegularizer: q(this.activityRegularizer),
|
|
10652
10651
|
biasConstraint: rt(this.biasConstraint)
|
|
@@ -10656,7 +10655,7 @@ class On extends G {
|
|
|
10656
10655
|
}
|
|
10657
10656
|
class Me extends On {
|
|
10658
10657
|
constructor(t, e) {
|
|
10659
|
-
super(t, e), this.kernel = null, Me.verifyArgs(e), this.filters = e.filters, lt(this.filters, "filters"), this.kernelInitializer = Z(e.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.kernelConstraint = ot(e.kernelConstraint), this.kernelRegularizer =
|
|
10658
|
+
super(t, e), this.kernel = null, Me.verifyArgs(e), this.filters = e.filters, lt(this.filters, "filters"), this.kernelInitializer = Z(e.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.kernelConstraint = ot(e.kernelConstraint), this.kernelRegularizer = X(e.kernelRegularizer);
|
|
10660
10659
|
}
|
|
10661
10660
|
build(t) {
|
|
10662
10661
|
t = j(t);
|
|
@@ -10700,7 +10699,7 @@ class Me extends On {
|
|
|
10700
10699
|
getConfig() {
|
|
10701
10700
|
const t = {
|
|
10702
10701
|
filters: this.filters,
|
|
10703
|
-
kernelInitializer:
|
|
10702
|
+
kernelInitializer: Y(this.kernelInitializer),
|
|
10704
10703
|
kernelRegularizer: q(this.kernelRegularizer),
|
|
10705
10704
|
kernelConstraint: rt(this.kernelConstraint)
|
|
10706
10705
|
}, e = super.getConfig();
|
|
@@ -10835,7 +10834,7 @@ class $o extends Me {
|
|
|
10835
10834
|
throw new d("Fields kernelInitializer, kernelRegularizer and kernelConstraint are invalid for SeparableConv2D. Use depthwiseInitializer, depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, pointwiseRegularizer and pointwiseConstraint instead.");
|
|
10836
10835
|
if (e.padding != null && e.padding !== "same" && e.padding !== "valid")
|
|
10837
10836
|
throw new d(`SeparableConv${this.rank}D supports only padding modes: 'same' and 'valid', but received ${JSON.stringify(e.padding)}`);
|
|
10838
|
-
this.depthMultiplier = e.depthMultiplier == null ? 1 : e.depthMultiplier, this.depthwiseInitializer = Z(e.depthwiseInitializer || this.DEFAULT_DEPTHWISE_INITIALIZER), this.depthwiseRegularizer =
|
|
10837
|
+
this.depthMultiplier = e.depthMultiplier == null ? 1 : e.depthMultiplier, this.depthwiseInitializer = Z(e.depthwiseInitializer || this.DEFAULT_DEPTHWISE_INITIALIZER), this.depthwiseRegularizer = X(e.depthwiseRegularizer), this.depthwiseConstraint = ot(e.depthwiseConstraint), this.pointwiseInitializer = Z(e.depthwiseInitializer || this.DEFAULT_POINTWISE_INITIALIZER), this.pointwiseRegularizer = X(e.pointwiseRegularizer), this.pointwiseConstraint = ot(e.pointwiseConstraint);
|
|
10839
10838
|
}
|
|
10840
10839
|
build(t) {
|
|
10841
10840
|
if (t = j(t), t.length < this.rank + 2)
|
|
@@ -10861,7 +10860,7 @@ class $o extends Me {
|
|
|
10861
10860
|
}
|
|
10862
10861
|
getConfig() {
|
|
10863
10862
|
const t = super.getConfig();
|
|
10864
|
-
return delete t.rank, delete t.kernelInitializer, delete t.kernelRegularizer, delete t.kernelConstraint, t.depthwiseInitializer =
|
|
10863
|
+
return delete t.rank, delete t.kernelInitializer, delete t.kernelRegularizer, delete t.kernelConstraint, t.depthwiseInitializer = Y(this.depthwiseInitializer), t.pointwiseInitializer = Y(this.pointwiseInitializer), t.depthwiseRegularizer = q(this.depthwiseRegularizer), t.pointwiseRegularizer = q(this.pointwiseRegularizer), t.depthwiseConstraint = rt(this.depthwiseConstraint), t.pointwiseConstraint = rt(this.pointwiseConstraint), t;
|
|
10865
10864
|
}
|
|
10866
10865
|
}
|
|
10867
10866
|
$o.className = "SeparableConv";
|
|
@@ -10985,7 +10984,7 @@ function Em(n, t, e = [1, 1], s = "valid", i, r) {
|
|
|
10985
10984
|
}
|
|
10986
10985
|
class Lo extends On {
|
|
10987
10986
|
constructor(t) {
|
|
10988
|
-
super(2, t), this.depthwiseKernel = null, this.depthMultiplier = t.depthMultiplier == null ? 1 : t.depthMultiplier, this.depthwiseInitializer = Z(t.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.depthwiseConstraint = ot(t.depthwiseConstraint), this.depthwiseRegularizer =
|
|
10987
|
+
super(2, t), this.depthwiseKernel = null, this.depthMultiplier = t.depthMultiplier == null ? 1 : t.depthMultiplier, this.depthwiseInitializer = Z(t.depthwiseInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.depthwiseConstraint = ot(t.depthwiseConstraint), this.depthwiseRegularizer = X(t.depthwiseRegularizer);
|
|
10989
10988
|
}
|
|
10990
10989
|
build(t) {
|
|
10991
10990
|
if (t = j(t), t.length < 4)
|
|
@@ -11015,7 +11014,7 @@ class Lo extends On {
|
|
|
11015
11014
|
}
|
|
11016
11015
|
getConfig() {
|
|
11017
11016
|
const t = super.getConfig();
|
|
11018
|
-
return t.depthMultiplier = this.depthMultiplier, t.depthwiseInitializer =
|
|
11017
|
+
return t.depthMultiplier = this.depthMultiplier, t.depthwiseInitializer = Y(this.depthwiseInitializer), t.depthwiseRegularizer = q(this.depthwiseRegularizer), t.depthwiseConstraint = rt(this.depthwiseRegularizer), t;
|
|
11019
11018
|
}
|
|
11020
11019
|
}
|
|
11021
11020
|
Lo.className = "DepthwiseConv2D";
|
|
@@ -11094,7 +11093,7 @@ class de extends G {
|
|
|
11094
11093
|
this.states_ = t;
|
|
11095
11094
|
}
|
|
11096
11095
|
computeOutputShape(t) {
|
|
11097
|
-
|
|
11096
|
+
Yn(t) && (t = t[0]), t = t;
|
|
11098
11097
|
let e = this.cell.stateSize;
|
|
11099
11098
|
Array.isArray(e) || (e = [e]);
|
|
11100
11099
|
const s = e[0];
|
|
@@ -11139,7 +11138,7 @@ class de extends G {
|
|
|
11139
11138
|
build(t) {
|
|
11140
11139
|
if (this.numConstants != null)
|
|
11141
11140
|
throw new B("Constants support is not implemented in RNN yet.");
|
|
11142
|
-
|
|
11141
|
+
Yn(t) && (t = t[0]), t = t;
|
|
11143
11142
|
const e = this.stateful ? t[0] : null, s = t.slice(2);
|
|
11144
11143
|
this.inputSpec[0] = new it({ shape: [e, null, ...s] });
|
|
11145
11144
|
const i = [t[0]].concat(t.slice(2));
|
|
@@ -11275,7 +11274,7 @@ class _n extends G {
|
|
|
11275
11274
|
}
|
|
11276
11275
|
class Ms extends _n {
|
|
11277
11276
|
constructor(t) {
|
|
11278
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = Z(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = Z(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = Z(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer =
|
|
11277
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation == null ? this.DEFAULT_ACTIVATION : t.activation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = Z(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = Z(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = Z(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = ot(t.kernelConstraint), this.recurrentConstraint = ot(t.recurrentConstraint), this.biasConstraint = ot(t.biasConstraint), this.dropout = Fe([1, oe([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Fe([
|
|
11279
11278
|
1,
|
|
11280
11279
|
oe([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
11281
11280
|
]), this.dropoutFunc = t.dropoutFunc, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -11319,9 +11318,9 @@ class Ms extends _n {
|
|
|
11319
11318
|
units: this.units,
|
|
11320
11319
|
activation: ae(this.activation),
|
|
11321
11320
|
useBias: this.useBias,
|
|
11322
|
-
kernelInitializer:
|
|
11323
|
-
recurrentInitializer:
|
|
11324
|
-
biasInitializer:
|
|
11321
|
+
kernelInitializer: Y(this.kernelInitializer),
|
|
11322
|
+
recurrentInitializer: Y(this.recurrentInitializer),
|
|
11323
|
+
biasInitializer: Y(this.biasInitializer),
|
|
11325
11324
|
kernelRegularizer: q(this.kernelRegularizer),
|
|
11326
11325
|
recurrentRegularizer: q(this.recurrentRegularizer),
|
|
11327
11326
|
biasRegularizer: q(this.biasRegularizer),
|
|
@@ -11359,7 +11358,7 @@ class _s extends _n {
|
|
|
11359
11358
|
constructor(t) {
|
|
11360
11359
|
if (super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", t.resetAfter)
|
|
11361
11360
|
throw new d("GRUCell does not support reset_after parameter set to true.");
|
|
11362
|
-
this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = le(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = Z(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = Z(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = Z(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer =
|
|
11361
|
+
this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = le(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = Z(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = Z(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = Z(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = ot(t.kernelConstraint), this.recurrentConstraint = ot(t.recurrentConstraint), this.biasConstraint = ot(t.biasConstraint), this.dropout = Fe([1, oe([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Fe([
|
|
11363
11362
|
1,
|
|
11364
11363
|
oe([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
11365
11364
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = this.units, this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -11407,9 +11406,9 @@ class _s extends _n {
|
|
|
11407
11406
|
activation: ae(this.activation),
|
|
11408
11407
|
recurrentActivation: ae(this.recurrentActivation),
|
|
11409
11408
|
useBias: this.useBias,
|
|
11410
|
-
kernelInitializer:
|
|
11411
|
-
recurrentInitializer:
|
|
11412
|
-
biasInitializer:
|
|
11409
|
+
kernelInitializer: Y(this.kernelInitializer),
|
|
11410
|
+
recurrentInitializer: Y(this.recurrentInitializer),
|
|
11411
|
+
biasInitializer: Y(this.biasInitializer),
|
|
11413
11412
|
kernelRegularizer: q(this.kernelRegularizer),
|
|
11414
11413
|
recurrentRegularizer: q(this.recurrentRegularizer),
|
|
11415
11414
|
biasRegularizer: q(this.biasRegularizer),
|
|
@@ -11447,7 +11446,7 @@ _o.className = "GRU";
|
|
|
11447
11446
|
S(_o);
|
|
11448
11447
|
class Rn extends _n {
|
|
11449
11448
|
constructor(t) {
|
|
11450
|
-
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = le(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = Z(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = Z(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = Z(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer =
|
|
11449
|
+
super(t), this.DEFAULT_ACTIVATION = "tanh", this.DEFAULT_RECURRENT_ACTIVATION = "hardSigmoid", this.DEFAULT_KERNEL_INITIALIZER = "glorotNormal", this.DEFAULT_RECURRENT_INITIALIZER = "orthogonal", this.DEFAULT_BIAS_INITIALIZER = "zeros", this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation === void 0 ? this.DEFAULT_ACTIVATION : t.activation), this.recurrentActivation = le(t.recurrentActivation === void 0 ? this.DEFAULT_RECURRENT_ACTIVATION : t.recurrentActivation), this.useBias = t.useBias == null ? !0 : t.useBias, this.kernelInitializer = Z(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.recurrentInitializer = Z(t.recurrentInitializer || this.DEFAULT_RECURRENT_INITIALIZER), this.biasInitializer = Z(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.unitForgetBias = t.unitForgetBias, this.kernelRegularizer = X(t.kernelRegularizer), this.recurrentRegularizer = X(t.recurrentRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.kernelConstraint = ot(t.kernelConstraint), this.recurrentConstraint = ot(t.recurrentConstraint), this.biasConstraint = ot(t.biasConstraint), this.dropout = Fe([1, oe([0, t.dropout == null ? 0 : t.dropout])]), this.recurrentDropout = Fe([
|
|
11451
11450
|
1,
|
|
11452
11451
|
oe([0, t.recurrentDropout == null ? 0 : t.recurrentDropout])
|
|
11453
11452
|
]), this.dropoutFunc = t.dropoutFunc, this.implementation = t.implementation, this.stateSize = [this.units, this.units], this.dropoutMask = null, this.recurrentDropoutMask = null;
|
|
@@ -11512,9 +11511,9 @@ class Rn extends _n {
|
|
|
11512
11511
|
activation: ae(this.activation),
|
|
11513
11512
|
recurrentActivation: ae(this.recurrentActivation),
|
|
11514
11513
|
useBias: this.useBias,
|
|
11515
|
-
kernelInitializer:
|
|
11516
|
-
recurrentInitializer:
|
|
11517
|
-
biasInitializer:
|
|
11514
|
+
kernelInitializer: Y(this.kernelInitializer),
|
|
11515
|
+
recurrentInitializer: Y(this.recurrentInitializer),
|
|
11516
|
+
biasInitializer: Y(this.biasInitializer),
|
|
11518
11517
|
unitForgetBias: this.unitForgetBias,
|
|
11519
11518
|
kernelRegularizer: q(this.kernelRegularizer),
|
|
11520
11519
|
recurrentRegularizer: q(this.recurrentRegularizer),
|
|
@@ -11581,7 +11580,7 @@ class Rs extends _n {
|
|
|
11581
11580
|
});
|
|
11582
11581
|
}
|
|
11583
11582
|
build(t) {
|
|
11584
|
-
|
|
11583
|
+
Yn(t) && (t = t[0]), t = t;
|
|
11585
11584
|
let e;
|
|
11586
11585
|
this.cells.forEach((s, i) => {
|
|
11587
11586
|
ge(`RNNCell_${i}`, () => {
|
|
@@ -11891,7 +11890,7 @@ class Po extends G {
|
|
|
11891
11890
|
let e = null;
|
|
11892
11891
|
t.batchSize != null && (e = t.batchSize), this.batchInputShape = [e, t.inputDim];
|
|
11893
11892
|
}
|
|
11894
|
-
this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation), t.useBias != null && (this.useBias = t.useBias), this.kernelInitializer = Z(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.biasInitializer = Z(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelConstraint = ot(t.kernelConstraint), this.biasConstraint = ot(t.biasConstraint), this.kernelRegularizer =
|
|
11893
|
+
this.units = t.units, lt(this.units, "units"), this.activation = le(t.activation), t.useBias != null && (this.useBias = t.useBias), this.kernelInitializer = Z(t.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER), this.biasInitializer = Z(t.biasInitializer || this.DEFAULT_BIAS_INITIALIZER), this.kernelConstraint = ot(t.kernelConstraint), this.biasConstraint = ot(t.biasConstraint), this.kernelRegularizer = X(t.kernelRegularizer), this.biasRegularizer = X(t.biasRegularizer), this.activityRegularizer = X(t.activityRegularizer), this.supportsMasking = !0, this.inputSpec = [{ minNDim: 2 }];
|
|
11895
11894
|
}
|
|
11896
11895
|
build(t) {
|
|
11897
11896
|
t = j(t);
|
|
@@ -11916,8 +11915,8 @@ class Po extends G {
|
|
|
11916
11915
|
units: this.units,
|
|
11917
11916
|
activation: ae(this.activation),
|
|
11918
11917
|
useBias: this.useBias,
|
|
11919
|
-
kernelInitializer:
|
|
11920
|
-
biasInitializer:
|
|
11918
|
+
kernelInitializer: Y(this.kernelInitializer),
|
|
11919
|
+
biasInitializer: Y(this.biasInitializer),
|
|
11921
11920
|
kernelRegularizer: q(this.kernelRegularizer),
|
|
11922
11921
|
biasRegularizer: q(this.biasRegularizer),
|
|
11923
11922
|
activityRegularizer: q(this.activityRegularizer),
|
|
@@ -12139,7 +12138,7 @@ class Jo extends G {
|
|
|
12139
12138
|
let e = null;
|
|
12140
12139
|
t.batchSize != null && (e = t.batchSize), t.inputLength == null ? this.batchInputShape = [e, null] : this.batchInputShape = [e].concat(H(t.inputLength));
|
|
12141
12140
|
}
|
|
12142
|
-
this.inputDim = t.inputDim, lt(this.inputDim, "inputDim"), this.outputDim = t.outputDim, lt(this.outputDim, "outputDim"), this.embeddingsInitializer = Z(t.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER), this.embeddingsRegularizer =
|
|
12141
|
+
this.inputDim = t.inputDim, lt(this.inputDim, "inputDim"), this.outputDim = t.outputDim, lt(this.outputDim, "outputDim"), this.embeddingsInitializer = Z(t.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER), this.embeddingsRegularizer = X(t.embeddingsRegularizer), this.activityRegularizer = X(t.activityRegularizer), this.embeddingsConstraint = ot(t.embeddingsConstraint), this.maskZero = t.maskZero, this.supportsMasking = t.maskZero, this.inputLength = t.inputLength;
|
|
12143
12142
|
}
|
|
12144
12143
|
build(t) {
|
|
12145
12144
|
this.embeddings = this.addWeight("embeddings", [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, !0, this.embeddingsConstraint), this.built = !0;
|
|
@@ -12181,7 +12180,7 @@ class Jo extends G {
|
|
|
12181
12180
|
const t = {
|
|
12182
12181
|
inputDim: this.inputDim,
|
|
12183
12182
|
outputDim: this.outputDim,
|
|
12184
|
-
embeddingsInitializer:
|
|
12183
|
+
embeddingsInitializer: Y(this.embeddingsInitializer),
|
|
12185
12184
|
embeddingsRegularizer: q(this.embeddingsRegularizer),
|
|
12186
12185
|
activityRegularizer: q(this.activityRegularizer),
|
|
12187
12186
|
embeddingsConstraint: rt(this.embeddingsConstraint),
|
|
@@ -12355,7 +12354,7 @@ class Zo extends ve {
|
|
|
12355
12354
|
}
|
|
12356
12355
|
Zo.className = "Add";
|
|
12357
12356
|
S(Zo);
|
|
12358
|
-
class
|
|
12357
|
+
class Xo extends ve {
|
|
12359
12358
|
constructor(t) {
|
|
12360
12359
|
super(t);
|
|
12361
12360
|
}
|
|
@@ -12368,9 +12367,9 @@ class Yo extends ve {
|
|
|
12368
12367
|
});
|
|
12369
12368
|
}
|
|
12370
12369
|
}
|
|
12371
|
-
|
|
12372
|
-
S(
|
|
12373
|
-
class
|
|
12370
|
+
Xo.className = "Multiply";
|
|
12371
|
+
S(Xo);
|
|
12372
|
+
class Yo extends ve {
|
|
12374
12373
|
constructor(t) {
|
|
12375
12374
|
super(t);
|
|
12376
12375
|
}
|
|
@@ -12383,8 +12382,8 @@ class Xo extends ve {
|
|
|
12383
12382
|
});
|
|
12384
12383
|
}
|
|
12385
12384
|
}
|
|
12386
|
-
|
|
12387
|
-
S(
|
|
12385
|
+
Yo.className = "Average";
|
|
12386
|
+
S(Yo);
|
|
12388
12387
|
class Qo extends ve {
|
|
12389
12388
|
constructor(t) {
|
|
12390
12389
|
super(t);
|
|
@@ -12669,7 +12668,7 @@ class ra extends G {
|
|
|
12669
12668
|
const s = this._getNoiseShape(t);
|
|
12670
12669
|
return en(() => {
|
|
12671
12670
|
const r = _(t), a = -1.6732632423543772 * 1.0507009873554805;
|
|
12672
|
-
let l =
|
|
12671
|
+
let l = Xe($n(s), this.rate);
|
|
12673
12672
|
l = Pt(l, "float32");
|
|
12674
12673
|
const u = ((1 - this.rate) * (1 + this.rate * a ** 2)) ** -0.5, c = -u * a * this.rate, h = z(w(r, l), w(z(l, -1), a));
|
|
12675
12674
|
return z(w(h, u), c);
|
|
@@ -12722,7 +12721,7 @@ function _m(n, t, e, s, i = 1e-3) {
|
|
|
12722
12721
|
}
|
|
12723
12722
|
class oa extends G {
|
|
12724
12723
|
constructor(t) {
|
|
12725
|
-
t == null && (t = {}), super(t), this.supportsMasking = !0, this.axis = t.axis == null ? -1 : t.axis, this.momentum = t.momentum == null ? 0.99 : t.momentum, this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = Z(t.betaInitializer || "zeros"), this.gammaInitializer = Z(t.gammaInitializer || "ones"), this.movingMeanInitializer = Z(t.movingMeanInitializer || "zeros"), this.movingVarianceInitializer = Z(t.movingVarianceInitializer || "ones"), this.betaConstraint = ot(t.betaConstraint), this.gammaConstraint = ot(t.gammaConstraint), this.betaRegularizer =
|
|
12724
|
+
t == null && (t = {}), super(t), this.supportsMasking = !0, this.axis = t.axis == null ? -1 : t.axis, this.momentum = t.momentum == null ? 0.99 : t.momentum, this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = Z(t.betaInitializer || "zeros"), this.gammaInitializer = Z(t.gammaInitializer || "ones"), this.movingMeanInitializer = Z(t.movingMeanInitializer || "zeros"), this.movingVarianceInitializer = Z(t.movingVarianceInitializer || "ones"), this.betaConstraint = ot(t.betaConstraint), this.gammaConstraint = ot(t.gammaConstraint), this.betaRegularizer = X(t.betaRegularizer), this.gammaRegularizer = X(t.gammaRegularizer);
|
|
12726
12725
|
}
|
|
12727
12726
|
build(t) {
|
|
12728
12727
|
t = j(t);
|
|
@@ -12768,10 +12767,10 @@ class oa extends G {
|
|
|
12768
12767
|
epsilon: this.epsilon,
|
|
12769
12768
|
center: this.center,
|
|
12770
12769
|
scale: this.scale,
|
|
12771
|
-
betaInitializer:
|
|
12772
|
-
gammaInitializer:
|
|
12773
|
-
movingMeanInitializer:
|
|
12774
|
-
movingVarianceInitializer:
|
|
12770
|
+
betaInitializer: Y(this.betaInitializer),
|
|
12771
|
+
gammaInitializer: Y(this.gammaInitializer),
|
|
12772
|
+
movingMeanInitializer: Y(this.movingMeanInitializer),
|
|
12773
|
+
movingVarianceInitializer: Y(this.movingVarianceInitializer),
|
|
12775
12774
|
betaRegularizer: q(this.betaRegularizer),
|
|
12776
12775
|
gammaRegularizer: q(this.gammaRegularizer),
|
|
12777
12776
|
betaConstraint: rt(this.betaConstraint),
|
|
@@ -12793,7 +12792,7 @@ class aa extends G {
|
|
|
12793
12792
|
throw new Error(`Expected axis to be an array of integers, but received ${JSON.stringify(this.axis)}`);
|
|
12794
12793
|
} else
|
|
12795
12794
|
throw new Error(`Expected axis to be an integer or an array of integers, but received ${JSON.stringify(this.axis)}`);
|
|
12796
|
-
this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = Z(t.betaInitializer || "zeros"), this.gammaInitializer = Z(t.gammaInitializer || "ones"), this.betaRegularizer =
|
|
12795
|
+
this.epsilon = t.epsilon == null ? 1e-3 : t.epsilon, this.center = t.center == null ? !0 : t.center, this.scale = t.scale == null ? !0 : t.scale, this.betaInitializer = Z(t.betaInitializer || "zeros"), this.gammaInitializer = Z(t.gammaInitializer || "ones"), this.betaRegularizer = X(t.betaRegularizer), this.gammaRegularizer = X(t.gammaRegularizer), this.supportsMasking = !0;
|
|
12797
12796
|
}
|
|
12798
12797
|
build(t) {
|
|
12799
12798
|
t = j(t);
|
|
@@ -12830,8 +12829,8 @@ class aa extends G {
|
|
|
12830
12829
|
epsilon: this.epsilon,
|
|
12831
12830
|
center: this.center,
|
|
12832
12831
|
scale: this.scale,
|
|
12833
|
-
betaInitializer:
|
|
12834
|
-
gammaInitializer:
|
|
12832
|
+
betaInitializer: Y(this.betaInitializer),
|
|
12833
|
+
gammaInitializer: Y(this.gammaInitializer),
|
|
12835
12834
|
betaRegularizer: q(this.betaRegularizer),
|
|
12836
12835
|
gammaRegularizer: q(this.gammaRegularizer)
|
|
12837
12836
|
}, e = super.getConfig();
|
|
@@ -13438,7 +13437,7 @@ class $a extends G {
|
|
|
13438
13437
|
t.rank === 3 ? (c = !0, u = Dn([t])) : u = t;
|
|
13439
13438
|
for (let D = 0; D < u.shape[0]; D++)
|
|
13440
13439
|
m.push(g);
|
|
13441
|
-
const A =
|
|
13440
|
+
const A = nc(m, [m.length, 4]), k = sc(0, m.length, 1, "int32"), v = Pm(u, A, k, [i, r], "nearest");
|
|
13442
13441
|
return Pt(c ? _(dn(v)) : v, l);
|
|
13443
13442
|
});
|
|
13444
13443
|
}
|
|
@@ -13529,7 +13528,7 @@ class Ta extends G {
|
|
|
13529
13528
|
Received countWeights=${e.countWeights}`);
|
|
13530
13529
|
s = _(e.countWeights);
|
|
13531
13530
|
}
|
|
13532
|
-
const i = Ee(t), r =
|
|
13531
|
+
const i = Ee(t), r = Ju(t), o = Xt(this.numTokens, i).bufferSync().get(0), a = Xe(r, 0).bufferSync().get(0);
|
|
13533
13532
|
if (!(o && a))
|
|
13534
13533
|
throw new d(`Input values must be between 0 < values <= numTokens with numTokens=${this.numTokens}`);
|
|
13535
13534
|
return Um(t, this.outputMode, this.numTokens, s);
|