@genai-fi/nanogpt 0.2.8 → 0.2.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +2 -0
- package/dist/Generator.js +37 -32
- package/dist/NanoGPTModel.d.ts +4 -1
- package/dist/NanoGPTModel.js +33 -25
- package/dist/TeachableLLM.d.ts +4 -0
- package/dist/TeachableLLM.js +31 -16
- package/dist/{complex-CeoYJn2o.js → complex-x7w5HPOS.js} +6 -6
- package/dist/{index-DQfEAU9u.js → index-CWQLouWz.js} +312 -303
- package/dist/layers/BaseLayer.d.ts +8 -0
- package/dist/layers/BaseLayer.js +18 -0
- package/dist/layers/CausalSelfAttention.d.ts +2 -1
- package/dist/layers/CausalSelfAttention.js +10 -8
- package/dist/layers/MLP.d.ts +2 -1
- package/dist/layers/MLP.js +16 -14
- package/dist/layers/RMSNorm.d.ts +2 -1
- package/dist/layers/RMSNorm.js +13 -11
- package/dist/layers/TiedEmbedding.js +4 -4
- package/dist/layers/TransformerBlock.d.ts +4 -1
- package/dist/layers/TransformerBlock.js +9 -5
- package/dist/{mat_mul-CuHB58-H.js → mat_mul-4v7St11W.js} +5 -5
- package/dist/ops/attentionMask.js +47 -21
- package/dist/ops/gatherSub.js +2 -2
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/scatterSub.js +10 -10
- package/dist/{stack-C9cTkqpq.js → stack-CTdK-itU.js} +5 -5
- package/dist/{sum-B-O33dgG.js → sum-CnIf1YOh.js} +3 -3
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/Trainer.js +30 -29
- package/dist/training/sparseCrossEntropy.js +12 -12
- package/dist/utilities/profile.d.ts +10 -0
- package/dist/utilities/profile.js +29 -0
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { g as
|
|
1
|
+
import { g as Ot } from "./index-D5v913EJ.js";
|
|
2
2
|
import { p as Q } from "./index-xuotMAFm.js";
|
|
3
3
|
import { B as gt } from "./index-Tf7vU29b.js";
|
|
4
4
|
/**
|
|
@@ -17,8 +17,8 @@ import { B as gt } from "./index-Tf7vU29b.js";
|
|
|
17
17
|
* limitations under the License.
|
|
18
18
|
* =============================================================================
|
|
19
19
|
*/
|
|
20
|
-
const
|
|
21
|
-
class
|
|
20
|
+
const Ee = 1e-7, Ae = 1e-4;
|
|
21
|
+
class Be {
|
|
22
22
|
refCount(t) {
|
|
23
23
|
return v("refCount");
|
|
24
24
|
}
|
|
@@ -64,7 +64,7 @@ class Ae {
|
|
|
64
64
|
}
|
|
65
65
|
/** Returns the smallest representable number. */
|
|
66
66
|
epsilon() {
|
|
67
|
-
return this.floatPrecision() === 32 ?
|
|
67
|
+
return this.floatPrecision() === 32 ? Ee : Ae;
|
|
68
68
|
}
|
|
69
69
|
dispose() {
|
|
70
70
|
return v("dispose");
|
|
@@ -94,9 +94,9 @@ function y(n, t) {
|
|
|
94
94
|
throw new Error(typeof t == "string" ? t : t());
|
|
95
95
|
}
|
|
96
96
|
function Is(n, t, e = "") {
|
|
97
|
-
y(
|
|
97
|
+
y(Rt(n, t), () => e + ` Shapes ${n} and ${t} must match`);
|
|
98
98
|
}
|
|
99
|
-
function
|
|
99
|
+
function G(n) {
|
|
100
100
|
if (n.length === 0)
|
|
101
101
|
return 1;
|
|
102
102
|
let t = n[0];
|
|
@@ -104,7 +104,7 @@ function U(n) {
|
|
|
104
104
|
t *= n[e];
|
|
105
105
|
return t;
|
|
106
106
|
}
|
|
107
|
-
function
|
|
107
|
+
function Rt(n, t) {
|
|
108
108
|
if (n === t)
|
|
109
109
|
return !0;
|
|
110
110
|
if (n == null || t == null || n.length !== t.length)
|
|
@@ -114,7 +114,7 @@ function $t(n, t) {
|
|
|
114
114
|
return !1;
|
|
115
115
|
return !0;
|
|
116
116
|
}
|
|
117
|
-
function
|
|
117
|
+
function ve(n) {
|
|
118
118
|
return n % 1 === 0;
|
|
119
119
|
}
|
|
120
120
|
function ct(n, t) {
|
|
@@ -122,9 +122,9 @@ function ct(n, t) {
|
|
|
122
122
|
}
|
|
123
123
|
function Ts(n, t) {
|
|
124
124
|
const e = t.length;
|
|
125
|
-
return n = n == null ? t.map((s, r) => r) : [].concat(n), y(n.every((s) => s >= -e && s < e), () => `All values in axis param must be in range [-${e}, ${e}) but got axis ${n}`), y(n.every((s) =>
|
|
125
|
+
return n = n == null ? t.map((s, r) => r) : [].concat(n), y(n.every((s) => s >= -e && s < e), () => `All values in axis param must be in range [-${e}, ${e}) but got axis ${n}`), y(n.every((s) => ve(s)), () => `All values in axis param must be integers but got axis ${n}`), n.map((s) => s < 0 ? e + s : s);
|
|
126
126
|
}
|
|
127
|
-
function
|
|
127
|
+
function Me(n, t) {
|
|
128
128
|
let e = null;
|
|
129
129
|
if (n == null || n === "float32")
|
|
130
130
|
e = new Float32Array(t);
|
|
@@ -138,14 +138,14 @@ function ve(n, t) {
|
|
|
138
138
|
throw new Error(`Unknown data type ${n}`);
|
|
139
139
|
return e;
|
|
140
140
|
}
|
|
141
|
-
function
|
|
141
|
+
function Fe(n, t) {
|
|
142
142
|
for (let e = 0; e < n.length; e++) {
|
|
143
143
|
const s = n[e];
|
|
144
144
|
if (isNaN(s) || !isFinite(s))
|
|
145
145
|
throw Error(`A tensor of type ${t} being uploaded contains ${s}.`);
|
|
146
146
|
}
|
|
147
147
|
}
|
|
148
|
-
function
|
|
148
|
+
function $e(n) {
|
|
149
149
|
return n === "bool" || n === "complex64" || n === "float32" || n === "int32" || n === "string";
|
|
150
150
|
}
|
|
151
151
|
function St(n) {
|
|
@@ -157,28 +157,28 @@ function St(n) {
|
|
|
157
157
|
return 1;
|
|
158
158
|
throw new Error(`Unknown dtype ${n}`);
|
|
159
159
|
}
|
|
160
|
-
function
|
|
160
|
+
function Re(n) {
|
|
161
161
|
if (n == null)
|
|
162
162
|
return 0;
|
|
163
163
|
let t = 0;
|
|
164
164
|
return n.forEach((e) => t += e.length), t;
|
|
165
165
|
}
|
|
166
|
-
function
|
|
166
|
+
function xt(n) {
|
|
167
167
|
return typeof n == "string" || n instanceof String;
|
|
168
168
|
}
|
|
169
|
-
function
|
|
169
|
+
function xe(n) {
|
|
170
170
|
return typeof n == "boolean";
|
|
171
171
|
}
|
|
172
|
-
function
|
|
172
|
+
function Ne(n) {
|
|
173
173
|
return typeof n == "number";
|
|
174
174
|
}
|
|
175
175
|
function mt(n) {
|
|
176
|
-
return Array.isArray(n) ? mt(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray ? "int32" :
|
|
176
|
+
return Array.isArray(n) ? mt(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray ? "int32" : Ne(n) ? "float32" : xt(n) ? "string" : xe(n) ? "bool" : "float32";
|
|
177
177
|
}
|
|
178
178
|
function kt(n) {
|
|
179
179
|
return !!(n && n.constructor && n.call && n.apply);
|
|
180
180
|
}
|
|
181
|
-
function
|
|
181
|
+
function Nt(n) {
|
|
182
182
|
const t = n.length;
|
|
183
183
|
if (t < 2)
|
|
184
184
|
return [];
|
|
@@ -188,7 +188,7 @@ function xt(n) {
|
|
|
188
188
|
e[s] = e[s + 1] * n[s + 1];
|
|
189
189
|
return e;
|
|
190
190
|
}
|
|
191
|
-
function
|
|
191
|
+
function Qt(n, t, e, s = !1) {
|
|
192
192
|
const r = new Array();
|
|
193
193
|
if (t.length === 1) {
|
|
194
194
|
const i = t[0] * (s ? 2 : 1);
|
|
@@ -197,11 +197,11 @@ function Yt(n, t, e, s = !1) {
|
|
|
197
197
|
} else {
|
|
198
198
|
const i = t[0], o = t.slice(1), a = o.reduce((c, l) => c * l) * (s ? 2 : 1);
|
|
199
199
|
for (let c = 0; c < i; c++)
|
|
200
|
-
r[c] =
|
|
200
|
+
r[c] = Qt(n + c * a, o, e, s);
|
|
201
201
|
}
|
|
202
202
|
return r;
|
|
203
203
|
}
|
|
204
|
-
function
|
|
204
|
+
function Lt(n, t, e = !1) {
|
|
205
205
|
if (n.length === 0)
|
|
206
206
|
return t[0];
|
|
207
207
|
const s = n.reduce((r, i) => r * i) * (e ? 2 : 1);
|
|
@@ -209,15 +209,15 @@ function Ot(n, t, e = !1) {
|
|
|
209
209
|
return [];
|
|
210
210
|
if (s !== t.length)
|
|
211
211
|
throw new Error(`[${n}] does not match the input size ${t.length}${e ? " for a complex tensor" : ""}.`);
|
|
212
|
-
return
|
|
212
|
+
return Qt(0, n, t, e);
|
|
213
213
|
}
|
|
214
|
-
function
|
|
215
|
-
const e =
|
|
214
|
+
function De(n, t) {
|
|
215
|
+
const e = Zt(n, t);
|
|
216
216
|
for (let s = 0; s < e.length; s++)
|
|
217
217
|
e[s] = 1;
|
|
218
218
|
return e;
|
|
219
219
|
}
|
|
220
|
-
function
|
|
220
|
+
function Zt(n, t) {
|
|
221
221
|
if (t == null || t === "float32" || t === "complex64")
|
|
222
222
|
return new Float32Array(n);
|
|
223
223
|
if (t === "int32")
|
|
@@ -226,12 +226,12 @@ function Qt(n, t) {
|
|
|
226
226
|
return new Uint8Array(n);
|
|
227
227
|
throw new Error(`Unknown data type ${t}`);
|
|
228
228
|
}
|
|
229
|
-
function
|
|
229
|
+
function Dt(n) {
|
|
230
230
|
n.forEach((t) => {
|
|
231
231
|
y(Number.isInteger(t) && t >= 0, () => `Tensor must have a shape comprised of positive integers but got shape [${n}].`);
|
|
232
232
|
});
|
|
233
233
|
}
|
|
234
|
-
function
|
|
234
|
+
function Ct(n) {
|
|
235
235
|
return n && n.then && typeof n.then == "function";
|
|
236
236
|
}
|
|
237
237
|
/**
|
|
@@ -250,11 +250,11 @@ function Dt(n) {
|
|
|
250
250
|
* limitations under the License.
|
|
251
251
|
* =============================================================================
|
|
252
252
|
*/
|
|
253
|
-
const
|
|
254
|
-
class
|
|
253
|
+
const Ut = "tfjsflags";
|
|
254
|
+
class Ce {
|
|
255
255
|
// tslint:disable-next-line: no-any
|
|
256
256
|
constructor(t) {
|
|
257
|
-
this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams =
|
|
257
|
+
this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams = _e, this.populateURLFlags();
|
|
258
258
|
}
|
|
259
259
|
setPlatform(t, e) {
|
|
260
260
|
this.platform != null && (S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(`Platform ${this.platformName} has already been set. Overwriting the platform with ${t}.`)), this.platformName = t, this.platform = e;
|
|
@@ -272,7 +272,7 @@ class De {
|
|
|
272
272
|
if (t in this.flags)
|
|
273
273
|
return this.flags[t];
|
|
274
274
|
const e = this.evaluateFlag(t);
|
|
275
|
-
if (
|
|
275
|
+
if (Ct(e))
|
|
276
276
|
throw new Error(`Flag ${t} cannot be synchronously evaluated. Please use getAsync() instead.`);
|
|
277
277
|
return this.flags[t] = e, this.flags[t];
|
|
278
278
|
}
|
|
@@ -312,29 +312,29 @@ class De {
|
|
|
312
312
|
if (typeof this.global > "u" || typeof this.global.location > "u" || typeof this.global.location.search > "u")
|
|
313
313
|
return;
|
|
314
314
|
const t = this.getQueryParams(this.global.location.search);
|
|
315
|
-
|
|
315
|
+
Ut in t && t[Ut].split(",").forEach((s) => {
|
|
316
316
|
const [r, i] = s.split(":");
|
|
317
|
-
this.urlFlags[r] =
|
|
317
|
+
this.urlFlags[r] = Oe(r, i);
|
|
318
318
|
});
|
|
319
319
|
}
|
|
320
320
|
}
|
|
321
|
-
function
|
|
321
|
+
function _e(n) {
|
|
322
322
|
const t = {};
|
|
323
|
-
return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (
|
|
323
|
+
return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (Pe(t, s[0], s[1]), s.join("="))), t;
|
|
324
324
|
}
|
|
325
|
-
function
|
|
325
|
+
function Pe(n, t, e) {
|
|
326
326
|
n[decodeURIComponent(t)] = decodeURIComponent(e || "");
|
|
327
327
|
}
|
|
328
|
-
function
|
|
328
|
+
function Oe(n, t) {
|
|
329
329
|
const e = t.toLowerCase();
|
|
330
330
|
return e === "true" || e === "false" ? e === "true" : `${+e}` === e ? +e : t;
|
|
331
331
|
}
|
|
332
332
|
function S() {
|
|
333
|
-
return
|
|
333
|
+
return te;
|
|
334
334
|
}
|
|
335
|
-
let
|
|
336
|
-
function
|
|
337
|
-
|
|
335
|
+
let te = null;
|
|
336
|
+
function Le(n) {
|
|
337
|
+
te = n;
|
|
338
338
|
}
|
|
339
339
|
/**
|
|
340
340
|
* @license
|
|
@@ -353,13 +353,13 @@ function Oe(n) {
|
|
|
353
353
|
* =============================================================================
|
|
354
354
|
*/
|
|
355
355
|
let pt;
|
|
356
|
-
function
|
|
356
|
+
function ee() {
|
|
357
357
|
if (pt == null) {
|
|
358
358
|
let n;
|
|
359
359
|
if (typeof window < "u")
|
|
360
360
|
n = window;
|
|
361
|
-
else if (typeof
|
|
362
|
-
n =
|
|
361
|
+
else if (typeof Ot < "u")
|
|
362
|
+
n = Ot;
|
|
363
363
|
else if (typeof Q < "u")
|
|
364
364
|
n = Q;
|
|
365
365
|
else if (typeof self < "u")
|
|
@@ -370,12 +370,12 @@ function te() {
|
|
|
370
370
|
}
|
|
371
371
|
return pt;
|
|
372
372
|
}
|
|
373
|
-
function
|
|
374
|
-
const n =
|
|
373
|
+
function Ue() {
|
|
374
|
+
const n = ee();
|
|
375
375
|
return n._tfGlobals == null && (n._tfGlobals = /* @__PURE__ */ new Map()), n._tfGlobals;
|
|
376
376
|
}
|
|
377
|
-
function
|
|
378
|
-
const e =
|
|
377
|
+
function _t(n, t) {
|
|
378
|
+
const e = Ue();
|
|
379
379
|
if (e.has(n))
|
|
380
380
|
return e.get(n);
|
|
381
381
|
{
|
|
@@ -383,7 +383,7 @@ function Ct(n, t) {
|
|
|
383
383
|
return e.set(n, s), e.get(n);
|
|
384
384
|
}
|
|
385
385
|
}
|
|
386
|
-
const
|
|
386
|
+
const Ge = "Abs", ne = "Add", Es = "BatchMatMul", se = "Cast", As = "Complex", ze = "ComplexAbs", We = "RealDiv", Bs = "Elu", vs = "Exp", je = "Fill", Ke = "FloorDiv", Ms = "GatherNd", re = "Identity", Fs = "Imag", $s = "LeakyRelu", Rs = "Log", xs = "Max", Ve = "Maximum", qe = "Multiply", Ns = "Neg", Ds = "Pack", He = "Pow", Cs = "Prelu", _s = "Range", Ps = "Real", Os = "Relu", Ls = "Reshape", Us = "Relu6", Gs = "ScatterNd", zs = "Sigmoid", Je = "Sqrt", Ws = "Sum", js = "Softmax", Xe = "Sub", Ks = "Transpose", Ye = "ZerosLike", Vs = "Step", qs = "_FusedMatMul";
|
|
387
387
|
/**
|
|
388
388
|
* @license
|
|
389
389
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -400,7 +400,7 @@ const Ue = "Abs", ee = "Add", Es = "BatchMatMul", ne = "Cast", As = "Complex", G
|
|
|
400
400
|
* limitations under the License.
|
|
401
401
|
* =============================================================================
|
|
402
402
|
*/
|
|
403
|
-
function
|
|
403
|
+
function O(...n) {
|
|
404
404
|
S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(...n);
|
|
405
405
|
}
|
|
406
406
|
/**
|
|
@@ -419,15 +419,15 @@ function J(...n) {
|
|
|
419
419
|
* limitations under the License.
|
|
420
420
|
* =============================================================================
|
|
421
421
|
*/
|
|
422
|
-
const ht =
|
|
423
|
-
function
|
|
424
|
-
const e =
|
|
422
|
+
const ht = _t("kernelRegistry", () => /* @__PURE__ */ new Map()), It = _t("gradRegistry", () => /* @__PURE__ */ new Map());
|
|
423
|
+
function Gt(n, t) {
|
|
424
|
+
const e = ie(n, t);
|
|
425
425
|
return ht.get(e);
|
|
426
426
|
}
|
|
427
|
-
function Gt(n) {
|
|
428
|
-
return Ye.get(n);
|
|
429
|
-
}
|
|
430
427
|
function zt(n) {
|
|
428
|
+
return It.get(n);
|
|
429
|
+
}
|
|
430
|
+
function Wt(n) {
|
|
431
431
|
const t = ht.entries(), e = [];
|
|
432
432
|
for (; ; ) {
|
|
433
433
|
const { done: s, value: r } = t.next();
|
|
@@ -439,10 +439,14 @@ function zt(n) {
|
|
|
439
439
|
return e;
|
|
440
440
|
}
|
|
441
441
|
function Hs(n) {
|
|
442
|
-
const { kernelName: t, backendName: e } = n, s =
|
|
443
|
-
ht.has(s) &&
|
|
442
|
+
const { kernelName: t, backendName: e } = n, s = ie(t, e);
|
|
443
|
+
ht.has(s) && O(`The kernel '${t}' for backend '${e}' is already registered`), ht.set(s, n);
|
|
444
444
|
}
|
|
445
|
-
function
|
|
445
|
+
function Js(n) {
|
|
446
|
+
const { kernelName: t } = n;
|
|
447
|
+
It.has(t) && S().getBool("DEBUG") && O(`Overriding the gradient for '${t}'`), It.set(t, n);
|
|
448
|
+
}
|
|
449
|
+
function ie(n, t) {
|
|
446
450
|
return `${t}_${n}`;
|
|
447
451
|
}
|
|
448
452
|
/**
|
|
@@ -461,7 +465,7 @@ function re(n, t) {
|
|
|
461
465
|
* limitations under the License.
|
|
462
466
|
* =============================================================================
|
|
463
467
|
*/
|
|
464
|
-
function
|
|
468
|
+
function oe(n) {
|
|
465
469
|
return n instanceof Float32Array || n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray;
|
|
466
470
|
}
|
|
467
471
|
/**
|
|
@@ -483,10 +487,10 @@ function ie(n) {
|
|
|
483
487
|
function Qe(n, t) {
|
|
484
488
|
return n instanceof Float32Array && t === "float32" || n instanceof Int32Array && t === "int32" || n instanceof Uint8Array && t === "bool";
|
|
485
489
|
}
|
|
486
|
-
function
|
|
490
|
+
function ae(n, t) {
|
|
487
491
|
if (t === "string")
|
|
488
492
|
throw new Error("Cannot convert a string[] to a TypedArray");
|
|
489
|
-
if (Array.isArray(n) && (n = at(n)), S().getBool("DEBUG") &&
|
|
493
|
+
if (Array.isArray(n) && (n = at(n)), S().getBool("DEBUG") && Fe(n, t), Qe(n, t))
|
|
490
494
|
return n;
|
|
491
495
|
if (t == null || t === "float32" || t === "complex64")
|
|
492
496
|
return new Float32Array(n);
|
|
@@ -506,14 +510,14 @@ function ft() {
|
|
|
506
510
|
function Ze(n, t = "utf-8") {
|
|
507
511
|
return t = t || "utf-8", S().platform.encode(n, t);
|
|
508
512
|
}
|
|
509
|
-
function
|
|
513
|
+
function jt(n, t = "utf-8") {
|
|
510
514
|
return t = t || "utf-8", S().platform.decode(n, t);
|
|
511
515
|
}
|
|
512
516
|
function $(n) {
|
|
513
|
-
return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) :
|
|
517
|
+
return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) : oe(n);
|
|
514
518
|
}
|
|
515
519
|
function at(n, t = [], e = !1) {
|
|
516
|
-
if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" ||
|
|
520
|
+
if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" || Ct(n) || n == null || $(n) && e)
|
|
517
521
|
t.push(n);
|
|
518
522
|
else if (Array.isArray(n) || $(n))
|
|
519
523
|
for (let s = 0; s < n.length; ++s)
|
|
@@ -687,7 +691,7 @@ function rn(n, t, e, s) {
|
|
|
687
691
|
if (l.dtype !== "float32")
|
|
688
692
|
throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input ${c} must have 'float32' dtype, but has '${l.dtype}'`);
|
|
689
693
|
const u = i.inputs[c];
|
|
690
|
-
if (
|
|
694
|
+
if (!Rt(l.shape, u.shape))
|
|
691
695
|
throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input '${c}' has shape '${l.shape}', which does not match the shape of the input '${u.shape}'`);
|
|
692
696
|
if (n[u.id] == null)
|
|
693
697
|
n[u.id] = l;
|
|
@@ -714,15 +718,15 @@ function rn(n, t, e, s) {
|
|
|
714
718
|
* limitations under the License.
|
|
715
719
|
* =============================================================================
|
|
716
720
|
*/
|
|
717
|
-
const
|
|
721
|
+
const Kt = 20, rt = 3, yt = 7;
|
|
718
722
|
function on(n, t, e, s) {
|
|
719
|
-
const r =
|
|
723
|
+
const r = Nt(t), i = an(n, t, e, r), o = t.length, a = ut(n, t, e, r, i), c = ["Tensor"];
|
|
720
724
|
return s && (c.push(` dtype: ${e}`), c.push(` rank: ${o}`), c.push(` shape: [${t}]`), c.push(" values:")), c.push(a.map((l) => " " + l).join(`
|
|
721
725
|
`)), c.join(`
|
|
722
726
|
`);
|
|
723
727
|
}
|
|
724
728
|
function an(n, t, e, s) {
|
|
725
|
-
const r =
|
|
729
|
+
const r = G(t), i = s[s.length - 1], o = new Array(i).fill(0), a = t.length, c = e === "complex64" ? ot(n) : n;
|
|
726
730
|
if (a > 1)
|
|
727
731
|
for (let l = 0; l < r / i; l++) {
|
|
728
732
|
const u = l * i;
|
|
@@ -733,9 +737,9 @@ function an(n, t, e, s) {
|
|
|
733
737
|
}
|
|
734
738
|
function it(n, t, e) {
|
|
735
739
|
let s;
|
|
736
|
-
return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(yt))} + ${parseFloat(n[1].toFixed(yt))}j` :
|
|
740
|
+
return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(yt))} + ${parseFloat(n[1].toFixed(yt))}j` : xt(n) ? s = `'${n}'` : e === "bool" ? s = le(n) : s = parseFloat(n.toFixed(yt)).toString(), ct(s, t);
|
|
737
741
|
}
|
|
738
|
-
function
|
|
742
|
+
function le(n) {
|
|
739
743
|
return n === 0 ? "false" : "true";
|
|
740
744
|
}
|
|
741
745
|
function ut(n, t, e, s, r, i = !0) {
|
|
@@ -745,14 +749,14 @@ function ut(n, t, e, s, r, i = !0) {
|
|
|
745
749
|
const d = ot(n);
|
|
746
750
|
return [it(d[0], 0, e)];
|
|
747
751
|
}
|
|
748
|
-
return e === "bool" ? [
|
|
752
|
+
return e === "bool" ? [le(n[0])] : [n[0].toString()];
|
|
749
753
|
}
|
|
750
754
|
if (c === 1) {
|
|
751
|
-
if (a >
|
|
755
|
+
if (a > Kt) {
|
|
752
756
|
const k = rt * o;
|
|
753
757
|
let T = Array.from(n.slice(0, k)), nt = Array.from(n.slice((a - rt) * o, a * o));
|
|
754
758
|
return e === "complex64" && (T = ot(T), nt = ot(nt)), [
|
|
755
|
-
"[" + T.map((
|
|
759
|
+
"[" + T.map((H, J) => it(H, r[J], e)).join(", ") + ", ..., " + nt.map((H, J) => it(H, r[a - rt + J], e)).join(", ") + "]"
|
|
756
760
|
];
|
|
757
761
|
}
|
|
758
762
|
return [
|
|
@@ -760,7 +764,7 @@ function ut(n, t, e, s, r, i = !0) {
|
|
|
760
764
|
];
|
|
761
765
|
}
|
|
762
766
|
const l = t.slice(1), u = s.slice(1), h = s[0] * o, f = [];
|
|
763
|
-
if (a >
|
|
767
|
+
if (a > Kt) {
|
|
764
768
|
for (let d = 0; d < rt; d++) {
|
|
765
769
|
const k = d * h, T = k + h;
|
|
766
770
|
f.push(...ut(
|
|
@@ -834,13 +838,13 @@ function ot(n) {
|
|
|
834
838
|
*/
|
|
835
839
|
class ln {
|
|
836
840
|
constructor(t, e, s) {
|
|
837
|
-
if (this.dtype = e, this.shape = t.slice(), this.size =
|
|
841
|
+
if (this.dtype = e, this.shape = t.slice(), this.size = G(t), s != null) {
|
|
838
842
|
const r = s.length;
|
|
839
843
|
y(r === this.size, () => `Length of values '${r}' does not match the size inferred by the shape '${this.size}'.`);
|
|
840
844
|
}
|
|
841
845
|
if (e === "complex64")
|
|
842
846
|
throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).");
|
|
843
|
-
this.values = s ||
|
|
847
|
+
this.values = s || Me(e, this.size), this.strides = Nt(t);
|
|
844
848
|
}
|
|
845
849
|
/**
|
|
846
850
|
* Sets a value in the buffer at a given location.
|
|
@@ -918,7 +922,7 @@ function un(n) {
|
|
|
918
922
|
}
|
|
919
923
|
class x {
|
|
920
924
|
constructor(t, e, s, r) {
|
|
921
|
-
this.kept = !1, this.isDisposedInternal = !1, this.shape = t.slice(), this.dtype = e || "float32", this.size =
|
|
925
|
+
this.kept = !1, this.isDisposedInternal = !1, this.shape = t.slice(), this.dtype = e || "float32", this.size = G(t), this.strides = Nt(t), this.dataId = s, this.id = r, this.rankType = this.rank < 5 ? this.rank.toString() : "higher";
|
|
922
926
|
}
|
|
923
927
|
get rank() {
|
|
924
928
|
return this.shape.length;
|
|
@@ -947,7 +951,7 @@ class x {
|
|
|
947
951
|
*/
|
|
948
952
|
async array() {
|
|
949
953
|
const t = await this.data();
|
|
950
|
-
return
|
|
954
|
+
return Lt(this.shape, t, this.dtype === "complex64");
|
|
951
955
|
}
|
|
952
956
|
/**
|
|
953
957
|
* Returns the tensor data as a nested array. The transfer of data is done
|
|
@@ -956,7 +960,7 @@ class x {
|
|
|
956
960
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
957
961
|
*/
|
|
958
962
|
arraySync() {
|
|
959
|
-
return
|
|
963
|
+
return Lt(this.shape, this.dataSync(), this.dtype === "complex64");
|
|
960
964
|
}
|
|
961
965
|
/**
|
|
962
966
|
* Asynchronously downloads the values from the `tf.Tensor`. Returns a
|
|
@@ -970,7 +974,7 @@ class x {
|
|
|
970
974
|
if (this.dtype === "string") {
|
|
971
975
|
const e = await t;
|
|
972
976
|
try {
|
|
973
|
-
return e.map((s) =>
|
|
977
|
+
return e.map((s) => jt(s));
|
|
974
978
|
} catch {
|
|
975
979
|
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
|
|
976
980
|
}
|
|
@@ -1025,7 +1029,7 @@ class x {
|
|
|
1025
1029
|
const t = R().readSync(this.dataId);
|
|
1026
1030
|
if (this.dtype === "string")
|
|
1027
1031
|
try {
|
|
1028
|
-
return t.map((e) =>
|
|
1032
|
+
return t.map((e) => jt(e));
|
|
1029
1033
|
} catch {
|
|
1030
1034
|
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
|
|
1031
1035
|
}
|
|
@@ -1089,10 +1093,10 @@ class x {
|
|
|
1089
1093
|
Object.defineProperty(x, Symbol.hasInstance, {
|
|
1090
1094
|
value: (n) => !!n && n.data != null && n.dataSync != null && n.throwIfDisposed != null
|
|
1091
1095
|
});
|
|
1092
|
-
function
|
|
1093
|
-
return
|
|
1096
|
+
function ce() {
|
|
1097
|
+
return _t("Tensor", () => x);
|
|
1094
1098
|
}
|
|
1095
|
-
|
|
1099
|
+
ce();
|
|
1096
1100
|
class dt extends x {
|
|
1097
1101
|
constructor(t, e, s, r) {
|
|
1098
1102
|
super(t.shape, t.dtype, t.dataId, r), this.trainable = e, this.name = s;
|
|
@@ -1108,7 +1112,7 @@ class dt extends x {
|
|
|
1108
1112
|
assign(t) {
|
|
1109
1113
|
if (t.dtype !== this.dtype)
|
|
1110
1114
|
throw new Error(`dtype of the new value (${t.dtype}) and previous value (${this.dtype}) must match`);
|
|
1111
|
-
if (
|
|
1115
|
+
if (!Rt(t.shape, this.shape))
|
|
1112
1116
|
throw new Error(`shape of the new value (${t.shape}) and previous value (${this.shape}) must match`);
|
|
1113
1117
|
R().disposeTensor(this), this.dataId = t.dataId, R().incRef(
|
|
1114
1118
|
this,
|
|
@@ -1139,31 +1143,31 @@ Object.defineProperty(dt, Symbol.hasInstance, {
|
|
|
1139
1143
|
* limitations under the License.
|
|
1140
1144
|
* =============================================================================
|
|
1141
1145
|
*/
|
|
1142
|
-
var
|
|
1146
|
+
var Vt;
|
|
1143
1147
|
(function(n) {
|
|
1144
1148
|
n.R0 = "R0", n.R1 = "R1", n.R2 = "R2", n.R3 = "R3", n.R4 = "R4", n.R5 = "R5", n.R6 = "R6";
|
|
1145
|
-
})(
|
|
1146
|
-
var It;
|
|
1147
|
-
(function(n) {
|
|
1148
|
-
n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64";
|
|
1149
|
-
})(It || (It = {}));
|
|
1149
|
+
})(Vt || (Vt = {}));
|
|
1150
1150
|
var Tt;
|
|
1151
1151
|
(function(n) {
|
|
1152
|
-
n.float32 = "float32", n.int32 = "int32", n.bool = "
|
|
1152
|
+
n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64";
|
|
1153
1153
|
})(Tt || (Tt = {}));
|
|
1154
1154
|
var Et;
|
|
1155
1155
|
(function(n) {
|
|
1156
|
-
n.float32 = "float32", n.int32 = "
|
|
1156
|
+
n.float32 = "float32", n.int32 = "int32", n.bool = "bool", n.complex64 = "complex64";
|
|
1157
1157
|
})(Et || (Et = {}));
|
|
1158
1158
|
var At;
|
|
1159
1159
|
(function(n) {
|
|
1160
|
-
n.float32 = "
|
|
1160
|
+
n.float32 = "float32", n.int32 = "float32", n.bool = "float32", n.complex64 = "complex64";
|
|
1161
1161
|
})(At || (At = {}));
|
|
1162
|
+
var Bt;
|
|
1163
|
+
(function(n) {
|
|
1164
|
+
n.float32 = "complex64", n.int32 = "complex64", n.bool = "complex64", n.complex64 = "complex64";
|
|
1165
|
+
})(Bt || (Bt = {}));
|
|
1162
1166
|
const hn = {
|
|
1163
|
-
float32:
|
|
1164
|
-
int32:
|
|
1165
|
-
bool:
|
|
1166
|
-
complex64:
|
|
1167
|
+
float32: At,
|
|
1168
|
+
int32: Tt,
|
|
1169
|
+
bool: Et,
|
|
1170
|
+
complex64: Bt
|
|
1167
1171
|
};
|
|
1168
1172
|
function fn(n, t) {
|
|
1169
1173
|
if (n === "string" || t === "string") {
|
|
@@ -1173,10 +1177,10 @@ function fn(n, t) {
|
|
|
1173
1177
|
}
|
|
1174
1178
|
return hn[n][t];
|
|
1175
1179
|
}
|
|
1176
|
-
function
|
|
1180
|
+
function ue(n) {
|
|
1177
1181
|
return n != null && typeof n == "object" && "texture" in n && n.texture instanceof WebGLTexture;
|
|
1178
1182
|
}
|
|
1179
|
-
function
|
|
1183
|
+
function he(n) {
|
|
1180
1184
|
return typeof GPUBuffer < "u" && n != null && typeof n == "object" && "buffer" in n && n.buffer instanceof GPUBuffer;
|
|
1181
1185
|
}
|
|
1182
1186
|
/**
|
|
@@ -1195,17 +1199,17 @@ function ue(n) {
|
|
|
1195
1199
|
* limitations under the License.
|
|
1196
1200
|
* =============================================================================
|
|
1197
1201
|
*/
|
|
1198
|
-
function
|
|
1202
|
+
function V(n, t) {
|
|
1199
1203
|
if (n.dtype === t.dtype)
|
|
1200
1204
|
return [n, t];
|
|
1201
1205
|
const e = fn(n.dtype, t.dtype);
|
|
1202
1206
|
return [n.cast(e), t.cast(e)];
|
|
1203
1207
|
}
|
|
1204
|
-
function
|
|
1208
|
+
function fe(n) {
|
|
1205
1209
|
const t = [];
|
|
1206
|
-
return
|
|
1210
|
+
return de(n, t, /* @__PURE__ */ new Set()), t;
|
|
1207
1211
|
}
|
|
1208
|
-
function
|
|
1212
|
+
function de(n, t, e) {
|
|
1209
1213
|
if (n == null)
|
|
1210
1214
|
return;
|
|
1211
1215
|
if (n instanceof x) {
|
|
@@ -1217,7 +1221,7 @@ function fe(n, t, e) {
|
|
|
1217
1221
|
const s = n;
|
|
1218
1222
|
for (const r in s) {
|
|
1219
1223
|
const i = s[r];
|
|
1220
|
-
e.has(i) || (e.add(i),
|
|
1224
|
+
e.has(i) || (e.add(i), de(i, t, e));
|
|
1221
1225
|
}
|
|
1222
1226
|
}
|
|
1223
1227
|
function dn(n) {
|
|
@@ -1242,7 +1246,7 @@ function dn(n) {
|
|
|
1242
1246
|
function bt(n) {
|
|
1243
1247
|
return n.kernelName != null;
|
|
1244
1248
|
}
|
|
1245
|
-
class
|
|
1249
|
+
class qt {
|
|
1246
1250
|
constructor() {
|
|
1247
1251
|
this.registeredVariables = {}, this.nextTapeNodeId = 0, this.numBytes = 0, this.numTensors = 0, this.numStringTensors = 0, this.numDataBuffers = 0, this.gradientDepth = 0, this.kernelDepth = 0, this.scopeStack = [], this.numDataMovesStack = [], this.nextScopeId = 0, this.tensorInfo = /* @__PURE__ */ new WeakMap(), this.profiling = !1, this.activeProfile = {
|
|
1248
1252
|
newBytes: 0,
|
|
@@ -1262,7 +1266,7 @@ class Vt {
|
|
|
1262
1266
|
}
|
|
1263
1267
|
class tt {
|
|
1264
1268
|
constructor(t) {
|
|
1265
|
-
this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new
|
|
1269
|
+
this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new qt();
|
|
1266
1270
|
}
|
|
1267
1271
|
async ready() {
|
|
1268
1272
|
if (this.pendingBackendInit != null)
|
|
@@ -1308,7 +1312,7 @@ class tt {
|
|
|
1308
1312
|
return t in this.registryFactory ? this.registryFactory[t].factory : null;
|
|
1309
1313
|
}
|
|
1310
1314
|
registerBackend(t, e, s = 1) {
|
|
1311
|
-
return t in this.registryFactory ? (
|
|
1315
|
+
return t in this.registryFactory ? (O(`${t} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[t] = { factory: e, priority: s }, !0);
|
|
1312
1316
|
}
|
|
1313
1317
|
async setBackend(t) {
|
|
1314
1318
|
if (this.registryFactory[t] == null)
|
|
@@ -1322,12 +1326,12 @@ class tt {
|
|
|
1322
1326
|
return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new tn(this.backendInstance), !0;
|
|
1323
1327
|
}
|
|
1324
1328
|
setupRegisteredKernels() {
|
|
1325
|
-
|
|
1329
|
+
Wt(this.backendName).forEach((e) => {
|
|
1326
1330
|
e.setupFunc != null && e.setupFunc(this.backendInstance);
|
|
1327
1331
|
});
|
|
1328
1332
|
}
|
|
1329
1333
|
disposeRegisteredKernels(t) {
|
|
1330
|
-
|
|
1334
|
+
Wt(t).forEach((s) => {
|
|
1331
1335
|
s.disposeFunc != null && s.disposeFunc(this.registry[t]);
|
|
1332
1336
|
});
|
|
1333
1337
|
}
|
|
@@ -1343,13 +1347,13 @@ class tt {
|
|
|
1343
1347
|
throw new Error(`Cannot initialize backend ${t}, no registration found.`);
|
|
1344
1348
|
try {
|
|
1345
1349
|
const s = e.factory();
|
|
1346
|
-
if (s && !(s instanceof
|
|
1347
|
-
const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null,
|
|
1350
|
+
if (s && !(s instanceof Be) && typeof s.then == "function") {
|
|
1351
|
+
const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null, O(`Initialization of backend ${t} failed`), O(o.stack || o.message)), !1));
|
|
1348
1352
|
return this.pendingBackendInit = i, { success: i, asyncInit: !0 };
|
|
1349
1353
|
} else
|
|
1350
1354
|
return this.registry[t] = s, { success: !0, asyncInit: !1 };
|
|
1351
1355
|
} catch (s) {
|
|
1352
|
-
return
|
|
1356
|
+
return O(`Initialization of backend ${t} failed`), O(s.stack || s.message), { success: !1, asyncInit: !1 };
|
|
1353
1357
|
}
|
|
1354
1358
|
}
|
|
1355
1359
|
removeBackend(t) {
|
|
@@ -1413,11 +1417,11 @@ class tt {
|
|
|
1413
1417
|
* execution.
|
|
1414
1418
|
*/
|
|
1415
1419
|
clone(t) {
|
|
1416
|
-
const e = g.runKernel(
|
|
1420
|
+
const e = g.runKernel(re, { x: t }), s = { x: t }, r = (o) => ({
|
|
1417
1421
|
x: () => {
|
|
1418
1422
|
const a = "float32", c = { x: o }, l = { dtype: a };
|
|
1419
1423
|
return g.runKernel(
|
|
1420
|
-
|
|
1424
|
+
se,
|
|
1421
1425
|
c,
|
|
1422
1426
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
1423
1427
|
l
|
|
@@ -1440,7 +1444,7 @@ class tt {
|
|
|
1440
1444
|
* tensors are not visible to the user.
|
|
1441
1445
|
*/
|
|
1442
1446
|
runKernel(t, e, s) {
|
|
1443
|
-
if (this.backendName == null && this.backend, !(
|
|
1447
|
+
if (this.backendName == null && this.backend, !(Gt(t, this.backendName) != null))
|
|
1444
1448
|
throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);
|
|
1445
1449
|
return this.runKernelFunc({ kernelName: t, inputs: e, attrs: s });
|
|
1446
1450
|
}
|
|
@@ -1473,18 +1477,18 @@ class tt {
|
|
|
1473
1477
|
if (bt(t)) {
|
|
1474
1478
|
const { kernelName: b, inputs: d, attrs: k } = t;
|
|
1475
1479
|
this.backendName == null && this.backend;
|
|
1476
|
-
const T =
|
|
1480
|
+
const T = Gt(b, this.backendName);
|
|
1477
1481
|
y(T != null, () => `Cannot find registered kernel '${b}' for backend '${this.backendName}'`), a = () => {
|
|
1478
1482
|
const nt = this.backend.numDataIds();
|
|
1479
1483
|
c = T.kernelFunc({ inputs: d, attrs: k, backend: this.backend });
|
|
1480
|
-
const
|
|
1481
|
-
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(b, nt,
|
|
1482
|
-
const
|
|
1484
|
+
const H = Array.isArray(c) ? c : [c];
|
|
1485
|
+
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(b, nt, H);
|
|
1486
|
+
const J = H.map((st) => st.rank != null ? st : this.makeTensorFromTensorInfo(st));
|
|
1483
1487
|
if (r) {
|
|
1484
|
-
const st = this.getTensorsForGradient(b, d,
|
|
1488
|
+
const st = this.getTensorsForGradient(b, d, J);
|
|
1485
1489
|
s = this.saveTensorsForBackwardMode(st);
|
|
1486
1490
|
}
|
|
1487
|
-
return
|
|
1491
|
+
return J;
|
|
1488
1492
|
};
|
|
1489
1493
|
} else {
|
|
1490
1494
|
const { forwardFunc: b } = t, d = (k) => {
|
|
@@ -1534,7 +1538,7 @@ class tt {
|
|
|
1534
1538
|
* @param outputs an array of output tensors from forward mode of kernel.
|
|
1535
1539
|
*/
|
|
1536
1540
|
getTensorsForGradient(t, e, s) {
|
|
1537
|
-
const r =
|
|
1541
|
+
const r = zt(t);
|
|
1538
1542
|
if (r != null) {
|
|
1539
1543
|
const i = r.inputsToSave || [], o = r.outputsToSave || [];
|
|
1540
1544
|
let a;
|
|
@@ -1554,10 +1558,10 @@ class tt {
|
|
|
1554
1558
|
throw new Error("Values passed to engine.makeTensor() are null");
|
|
1555
1559
|
s = s || "float32", r = r || this.backend;
|
|
1556
1560
|
let i = t;
|
|
1557
|
-
s === "string" &&
|
|
1561
|
+
s === "string" && xt(t[0]) && (i = t.map((c) => Ze(c)));
|
|
1558
1562
|
const o = r.write(i, e, s), a = new x(e, s, o, this.nextTensorId());
|
|
1559
1563
|
if (this.trackTensor(a, r), s === "string") {
|
|
1560
|
-
const c = this.state.tensorInfo.get(o), l =
|
|
1564
|
+
const c = this.state.tensorInfo.get(o), l = Re(i);
|
|
1561
1565
|
this.state.numBytes += l - c.bytes, c.bytes = l;
|
|
1562
1566
|
}
|
|
1563
1567
|
return a;
|
|
@@ -1645,10 +1649,10 @@ class tt {
|
|
|
1645
1649
|
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
|
1646
1650
|
}
|
|
1647
1651
|
addTapeNode(t, e, s, r, i, o) {
|
|
1648
|
-
const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c =
|
|
1652
|
+
const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c = zt(t);
|
|
1649
1653
|
c != null && (r = c.gradFunc), r != null && (a.gradient = (l) => (l = l.map((u, h) => {
|
|
1650
1654
|
if (u == null) {
|
|
1651
|
-
const f = s[h], m =
|
|
1655
|
+
const f = s[h], m = Zt(f.size, f.dtype);
|
|
1652
1656
|
return this.makeTensor(m, f.shape, f.dtype);
|
|
1653
1657
|
}
|
|
1654
1658
|
return u;
|
|
@@ -1680,7 +1684,7 @@ class tt {
|
|
|
1680
1684
|
* as scope() without the need for a function closure.
|
|
1681
1685
|
*/
|
|
1682
1686
|
endScope(t) {
|
|
1683
|
-
const e =
|
|
1687
|
+
const e = fe(t), s = new Set(e.map((i) => i.id));
|
|
1684
1688
|
for (let i = 0; i < this.state.activeScope.track.length; i++) {
|
|
1685
1689
|
const o = this.state.activeScope.track[i];
|
|
1686
1690
|
!o.kept && !s.has(o.id) && o.dispose();
|
|
@@ -1774,7 +1778,7 @@ class tt {
|
|
|
1774
1778
|
* registered backend factories.
|
|
1775
1779
|
*/
|
|
1776
1780
|
reset() {
|
|
1777
|
-
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new
|
|
1781
|
+
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new qt();
|
|
1778
1782
|
for (const t in this.registry)
|
|
1779
1783
|
this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t];
|
|
1780
1784
|
this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
|
|
@@ -1783,21 +1787,21 @@ class tt {
|
|
|
1783
1787
|
tt.nextTensorId = 0;
|
|
1784
1788
|
tt.nextVariableId = 0;
|
|
1785
1789
|
function gn(n) {
|
|
1786
|
-
const t =
|
|
1790
|
+
const t = De(G(n), "float32");
|
|
1787
1791
|
return g.makeTensor(t, n, "float32");
|
|
1788
1792
|
}
|
|
1789
|
-
function
|
|
1790
|
-
const n =
|
|
1793
|
+
function ge() {
|
|
1794
|
+
const n = ee();
|
|
1791
1795
|
if (n._tfengine == null) {
|
|
1792
|
-
const t = new
|
|
1796
|
+
const t = new Ce(n);
|
|
1793
1797
|
n._tfengine = new tt(t);
|
|
1794
1798
|
}
|
|
1795
|
-
return
|
|
1799
|
+
return Le(n._tfengine.ENV), cn(() => n._tfengine), n._tfengine;
|
|
1796
1800
|
}
|
|
1797
|
-
const g =
|
|
1801
|
+
const g = ge();
|
|
1798
1802
|
function mn(n, t) {
|
|
1799
1803
|
const e = { a: n, b: t };
|
|
1800
|
-
return g.runKernel(
|
|
1804
|
+
return g.runKernel(ne, e);
|
|
1801
1805
|
}
|
|
1802
1806
|
/**
|
|
1803
1807
|
* @license
|
|
@@ -1855,19 +1859,19 @@ function yn(n, t) {
|
|
|
1855
1859
|
let e = n;
|
|
1856
1860
|
if ($(n))
|
|
1857
1861
|
return t === "string" ? [] : [n.length];
|
|
1858
|
-
if (
|
|
1862
|
+
if (ue(n)) {
|
|
1859
1863
|
const r = n.channels || "RGBA";
|
|
1860
1864
|
return [n.height, n.width * r.length];
|
|
1861
|
-
} else if (
|
|
1865
|
+
} else if (he(n))
|
|
1862
1866
|
return [n.buffer.size / (t == null ? 4 : St(t))];
|
|
1863
1867
|
if (!Array.isArray(n))
|
|
1864
1868
|
return [];
|
|
1865
1869
|
const s = [];
|
|
1866
1870
|
for (; Array.isArray(e) || $(e) && t !== "string"; )
|
|
1867
1871
|
s.push(e.length), e = e[0];
|
|
1868
|
-
return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") &&
|
|
1872
|
+
return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") && me(n, s, []), s;
|
|
1869
1873
|
}
|
|
1870
|
-
function
|
|
1874
|
+
function me(n, t, e) {
|
|
1871
1875
|
if (e = e || [], !Array.isArray(n) && !$(n)) {
|
|
1872
1876
|
y(t.length === 0, () => `Element arr[${e.join("][")}] is a primitive, but should be an array/TypedArray of ${t[0]} elements`);
|
|
1873
1877
|
return;
|
|
@@ -1875,9 +1879,9 @@ function ge(n, t, e) {
|
|
|
1875
1879
|
y(t.length > 0, () => `Element arr[${e.join("][")}] should be a primitive, but is an array of ${n.length} elements`), y(n.length === t[0], () => `Element arr[${e.join("][")}] should have ${t[0]} elements, but has ${n.length} elements`);
|
|
1876
1880
|
const s = t.slice(1);
|
|
1877
1881
|
for (let r = 0; r < n.length; ++r)
|
|
1878
|
-
|
|
1882
|
+
me(n[r], s, e.concat(r));
|
|
1879
1883
|
}
|
|
1880
|
-
function
|
|
1884
|
+
function Ht(n, t, e, s) {
|
|
1881
1885
|
if (n !== "string_or_numeric") {
|
|
1882
1886
|
if (n == null)
|
|
1883
1887
|
throw new Error("Expected dtype cannot be null.");
|
|
@@ -1886,19 +1890,19 @@ function qt(n, t, e, s) {
|
|
|
1886
1890
|
}
|
|
1887
1891
|
}
|
|
1888
1892
|
function I(n, t, e, s = "numeric") {
|
|
1889
|
-
if (n instanceof
|
|
1890
|
-
return
|
|
1893
|
+
if (n instanceof ce())
|
|
1894
|
+
return Ht(s, n.dtype, t, e), n;
|
|
1891
1895
|
let r = mt(n);
|
|
1892
|
-
if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s),
|
|
1896
|
+
if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s), Ht(s, r, t, e), n == null || !$(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string") {
|
|
1893
1897
|
const c = n == null ? "null" : n.constructor.name;
|
|
1894
1898
|
throw new Error(`Argument '${t}' passed to '${e}' must be a Tensor or TensorLike, but got '${c}'`);
|
|
1895
1899
|
}
|
|
1896
1900
|
const i = yn(n, r);
|
|
1897
1901
|
!$(n) && !Array.isArray(n) && (n = [n]);
|
|
1898
|
-
const a = r !== "string" ?
|
|
1902
|
+
const a = r !== "string" ? ae(n, r) : at(n, [], !0);
|
|
1899
1903
|
return g.makeTensor(a, i, r);
|
|
1900
1904
|
}
|
|
1901
|
-
function
|
|
1905
|
+
function Xs(n, t, e, s = "numeric") {
|
|
1902
1906
|
if (!Array.isArray(n))
|
|
1903
1907
|
throw new Error(`Argument ${t} passed to ${e} must be a \`Tensor[]\` or \`TensorLike[]\``);
|
|
1904
1908
|
return n.map((i, o) => I(i, `${t}[${o}]`, e, s));
|
|
@@ -1931,7 +1935,7 @@ function F(n) {
|
|
|
1931
1935
|
g.startScope(e);
|
|
1932
1936
|
try {
|
|
1933
1937
|
const o = s(...i);
|
|
1934
|
-
return
|
|
1938
|
+
return Ct(o) && console.error("Cannot return a Promise inside of tidy."), g.endScope(o), o;
|
|
1935
1939
|
} catch (o) {
|
|
1936
1940
|
throw g.endScope(null), o;
|
|
1937
1941
|
}
|
|
@@ -1959,7 +1963,7 @@ function wn(n, t, e, s) {
|
|
|
1959
1963
|
s = mt(n);
|
|
1960
1964
|
else if (s === "complex64")
|
|
1961
1965
|
throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");
|
|
1962
|
-
if (
|
|
1966
|
+
if (he(n) || ue(n)) {
|
|
1963
1967
|
if (s !== "float32" && s !== "int32")
|
|
1964
1968
|
throw new Error(`Creating tensor from GPU data only supports 'float32'|'int32' dtype, while the dtype is ${s}.`);
|
|
1965
1969
|
return g.backend.createTensorFromGPUData(n, t || e, s);
|
|
@@ -1967,15 +1971,15 @@ function wn(n, t, e, s) {
|
|
|
1967
1971
|
if (!$(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string")
|
|
1968
1972
|
throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");
|
|
1969
1973
|
if (t != null) {
|
|
1970
|
-
|
|
1971
|
-
const r =
|
|
1974
|
+
Dt(t);
|
|
1975
|
+
const r = G(t), i = G(e);
|
|
1972
1976
|
y(r === i, () => `Based on the provided shape, [${t}], the tensor should have ${r} values but has ${i}`);
|
|
1973
1977
|
for (let o = 0; o < e.length; ++o) {
|
|
1974
|
-
const a = e[o], c = o === e.length - 1 ? a !==
|
|
1978
|
+
const a = e[o], c = o === e.length - 1 ? a !== G(t.slice(o)) : !0;
|
|
1975
1979
|
y(e[o] === t[o] || !c, () => `Error creating a new Tensor. Inferred shape (${e}) does not match the provided shape (${t}). `);
|
|
1976
1980
|
}
|
|
1977
1981
|
}
|
|
1978
|
-
return !$(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = s !== "string" ?
|
|
1982
|
+
return !$(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = s !== "string" ? ae(n, s) : at(n, [], !0), g.makeTensor(n, t, s);
|
|
1979
1983
|
}
|
|
1980
1984
|
class lt {
|
|
1981
1985
|
/**
|
|
@@ -2061,24 +2065,27 @@ function Sn(n, t) {
|
|
|
2061
2065
|
* limitations under the License.
|
|
2062
2066
|
* =============================================================================
|
|
2063
2067
|
*/
|
|
2064
|
-
function
|
|
2068
|
+
function Ys() {
|
|
2065
2069
|
return g;
|
|
2066
2070
|
}
|
|
2071
|
+
function Qs() {
|
|
2072
|
+
return g.memory();
|
|
2073
|
+
}
|
|
2067
2074
|
function E(n, t) {
|
|
2068
2075
|
return g.tidy(n, t);
|
|
2069
2076
|
}
|
|
2070
2077
|
function M(n) {
|
|
2071
|
-
|
|
2078
|
+
fe(n).forEach((e) => e.dispose());
|
|
2072
2079
|
}
|
|
2073
2080
|
function kn(n) {
|
|
2074
2081
|
return g.keep(n);
|
|
2075
2082
|
}
|
|
2076
|
-
const
|
|
2077
|
-
function
|
|
2078
|
-
return
|
|
2083
|
+
const Pt = typeof gt < "u" && (typeof Blob > "u" || typeof atob > "u" || typeof btoa > "u");
|
|
2084
|
+
function Jt(n) {
|
|
2085
|
+
return Pt ? gt.byteLength(n, "utf8") : new Blob([n]).size;
|
|
2079
2086
|
}
|
|
2080
2087
|
function In(n) {
|
|
2081
|
-
if (
|
|
2088
|
+
if (Pt)
|
|
2082
2089
|
return gt.from(n).toString("base64");
|
|
2083
2090
|
const t = new Uint8Array(n);
|
|
2084
2091
|
let e = "";
|
|
@@ -2087,7 +2094,7 @@ function In(n) {
|
|
|
2087
2094
|
return btoa(e);
|
|
2088
2095
|
}
|
|
2089
2096
|
function Tn(n) {
|
|
2090
|
-
if (
|
|
2097
|
+
if (Pt) {
|
|
2091
2098
|
const s = gt.from(n, "base64");
|
|
2092
2099
|
return s.buffer.slice(s.byteOffset, s.byteOffset + s.byteLength);
|
|
2093
2100
|
}
|
|
@@ -2096,14 +2103,14 @@ function Tn(n) {
|
|
|
2096
2103
|
e.set([t.charCodeAt(s)], s);
|
|
2097
2104
|
return e.buffer;
|
|
2098
2105
|
}
|
|
2099
|
-
function
|
|
2106
|
+
function pe(n) {
|
|
2100
2107
|
if (n.modelTopology instanceof ArrayBuffer)
|
|
2101
2108
|
throw new Error("Expected JSON model topology, received ArrayBuffer.");
|
|
2102
2109
|
return {
|
|
2103
2110
|
dateSaved: /* @__PURE__ */ new Date(),
|
|
2104
2111
|
modelTopologyType: "JSON",
|
|
2105
|
-
modelTopologyBytes: n.modelTopology == null ? 0 :
|
|
2106
|
-
weightSpecsBytes: n.weightSpecs == null ? 0 :
|
|
2112
|
+
modelTopologyBytes: n.modelTopology == null ? 0 : Jt(JSON.stringify(n.modelTopology)),
|
|
2113
|
+
weightSpecsBytes: n.weightSpecs == null ? 0 : Jt(JSON.stringify(n.weightSpecs)),
|
|
2107
2114
|
weightDataBytes: n.weightData == null ? 0 : new lt(n.weightData).byteLength
|
|
2108
2115
|
};
|
|
2109
2116
|
}
|
|
@@ -2194,8 +2201,8 @@ class A {
|
|
|
2194
2201
|
* limitations under the License.
|
|
2195
2202
|
* =============================================================================
|
|
2196
2203
|
*/
|
|
2197
|
-
const
|
|
2198
|
-
function
|
|
2204
|
+
const vt = "tensorflowjs", Mt = 1, U = "models_store", P = "model_info_store";
|
|
2205
|
+
function ye() {
|
|
2199
2206
|
if (!S().getBool("IS_BROWSER"))
|
|
2200
2207
|
throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");
|
|
2201
2208
|
const n = typeof window > "u" ? self : window, t = n.indexedDB || n.mozIndexedDB || n.webkitIndexedDB || n.msIndexedDB || n.shimIndexedDB;
|
|
@@ -2203,13 +2210,13 @@ function pe() {
|
|
|
2203
2210
|
throw new Error("The current browser does not appear to support IndexedDB.");
|
|
2204
2211
|
return t;
|
|
2205
2212
|
}
|
|
2206
|
-
function
|
|
2213
|
+
function Ft(n) {
|
|
2207
2214
|
const t = n.result;
|
|
2208
|
-
t.createObjectStore(
|
|
2215
|
+
t.createObjectStore(U, { keyPath: "modelPath" }), t.createObjectStore(P, { keyPath: "modelPath" });
|
|
2209
2216
|
}
|
|
2210
|
-
class
|
|
2217
|
+
class W {
|
|
2211
2218
|
constructor(t) {
|
|
2212
|
-
if (this.indexedDB =
|
|
2219
|
+
if (this.indexedDB = ye(), t == null || !t)
|
|
2213
2220
|
throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
|
|
2214
2221
|
this.modelPath = t;
|
|
2215
2222
|
}
|
|
@@ -2237,11 +2244,11 @@ class z {
|
|
|
2237
2244
|
*/
|
|
2238
2245
|
databaseAction(t, e) {
|
|
2239
2246
|
return new Promise((s, r) => {
|
|
2240
|
-
const i = this.indexedDB.open(
|
|
2241
|
-
i.onupgradeneeded = () =>
|
|
2247
|
+
const i = this.indexedDB.open(vt, Mt);
|
|
2248
|
+
i.onupgradeneeded = () => Ft(i), i.onsuccess = () => {
|
|
2242
2249
|
const o = i.result;
|
|
2243
2250
|
if (e == null) {
|
|
2244
|
-
const a = o.transaction(
|
|
2251
|
+
const a = o.transaction(U, "readonly"), l = a.objectStore(U).get(this.modelPath);
|
|
2245
2252
|
l.onsuccess = () => {
|
|
2246
2253
|
if (l.result == null)
|
|
2247
2254
|
return o.close(), r(new Error(`Cannot find model with path '${this.modelPath}' in IndexedDB.`));
|
|
@@ -2249,7 +2256,7 @@ class z {
|
|
|
2249
2256
|
}, l.onerror = (u) => (o.close(), r(l.error)), a.oncomplete = () => o.close();
|
|
2250
2257
|
} else {
|
|
2251
2258
|
e.weightData = lt.join(e.weightData);
|
|
2252
|
-
const a =
|
|
2259
|
+
const a = pe(e), c = o.transaction(P, "readwrite");
|
|
2253
2260
|
let l = c.objectStore(P), u;
|
|
2254
2261
|
try {
|
|
2255
2262
|
u = l.put({ modelPath: this.modelPath, modelArtifactsInfo: a });
|
|
@@ -2258,8 +2265,8 @@ class z {
|
|
|
2258
2265
|
}
|
|
2259
2266
|
let h;
|
|
2260
2267
|
u.onsuccess = () => {
|
|
2261
|
-
h = o.transaction(
|
|
2262
|
-
const f = h.objectStore(
|
|
2268
|
+
h = o.transaction(U, "readwrite");
|
|
2269
|
+
const f = h.objectStore(U);
|
|
2263
2270
|
let m;
|
|
2264
2271
|
try {
|
|
2265
2272
|
m = f.put({
|
|
@@ -2283,24 +2290,24 @@ class z {
|
|
|
2283
2290
|
});
|
|
2284
2291
|
}
|
|
2285
2292
|
}
|
|
2286
|
-
|
|
2287
|
-
const
|
|
2288
|
-
A.registerSaveRouter(
|
|
2289
|
-
A.registerLoadRouter(
|
|
2293
|
+
W.URL_SCHEME = "indexeddb://";
|
|
2294
|
+
const be = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(W.URL_SCHEME) ? En(n.slice(W.URL_SCHEME.length)) : null;
|
|
2295
|
+
A.registerSaveRouter(be);
|
|
2296
|
+
A.registerLoadRouter(be);
|
|
2290
2297
|
function En(n) {
|
|
2291
|
-
return new
|
|
2298
|
+
return new W(n);
|
|
2292
2299
|
}
|
|
2293
2300
|
function An(n) {
|
|
2294
|
-
return n.startsWith(
|
|
2301
|
+
return n.startsWith(W.URL_SCHEME) ? n.slice(W.URL_SCHEME.length) : n;
|
|
2295
2302
|
}
|
|
2296
2303
|
class Bn {
|
|
2297
2304
|
constructor() {
|
|
2298
|
-
this.indexedDB =
|
|
2305
|
+
this.indexedDB = ye();
|
|
2299
2306
|
}
|
|
2300
2307
|
async listModels() {
|
|
2301
2308
|
return new Promise((t, e) => {
|
|
2302
|
-
const s = this.indexedDB.open(
|
|
2303
|
-
s.onupgradeneeded = () =>
|
|
2309
|
+
const s = this.indexedDB.open(vt, Mt);
|
|
2310
|
+
s.onupgradeneeded = () => Ft(s), s.onsuccess = () => {
|
|
2304
2311
|
const r = s.result, i = r.transaction(P, "readonly"), a = i.objectStore(P).getAll();
|
|
2305
2312
|
a.onsuccess = () => {
|
|
2306
2313
|
const c = {};
|
|
@@ -2313,8 +2320,8 @@ class Bn {
|
|
|
2313
2320
|
}
|
|
2314
2321
|
async removeModel(t) {
|
|
2315
2322
|
return t = An(t), new Promise((e, s) => {
|
|
2316
|
-
const r = this.indexedDB.open(
|
|
2317
|
-
r.onupgradeneeded = () =>
|
|
2323
|
+
const r = this.indexedDB.open(vt, Mt);
|
|
2324
|
+
r.onupgradeneeded = () => Ft(r), r.onsuccess = () => {
|
|
2318
2325
|
const i = r.result, o = i.transaction(P, "readwrite"), a = o.objectStore(P), c = a.get(t);
|
|
2319
2326
|
let l;
|
|
2320
2327
|
c.onsuccess = () => {
|
|
@@ -2322,8 +2329,8 @@ class Bn {
|
|
|
2322
2329
|
return i.close(), s(new Error(`Cannot find model with path '${t}' in IndexedDB.`));
|
|
2323
2330
|
{
|
|
2324
2331
|
const u = a.delete(t), h = () => {
|
|
2325
|
-
l = i.transaction(
|
|
2326
|
-
const m = l.objectStore(
|
|
2332
|
+
l = i.transaction(U, "readwrite");
|
|
2333
|
+
const m = l.objectStore(U).delete(t);
|
|
2327
2334
|
m.onsuccess = () => e(c.result.modelArtifactsInfo), m.onerror = (b) => s(c.error);
|
|
2328
2335
|
};
|
|
2329
2336
|
u.onsuccess = h, u.onerror = (f) => (h(), i.close(), s(c.error));
|
|
@@ -2351,17 +2358,17 @@ class Bn {
|
|
|
2351
2358
|
* limitations under the License.
|
|
2352
2359
|
* =============================================================================
|
|
2353
2360
|
*/
|
|
2354
|
-
const _ = "/", Y = "tensorflowjs_models",
|
|
2355
|
-
function
|
|
2361
|
+
const _ = "/", Y = "tensorflowjs_models", we = "info", vn = "model_topology", Mn = "weight_specs", Fn = "weight_data", $n = "model_metadata";
|
|
2362
|
+
function Se(n) {
|
|
2356
2363
|
return {
|
|
2357
|
-
info: [Y, n,
|
|
2364
|
+
info: [Y, n, we].join(_),
|
|
2358
2365
|
topology: [Y, n, vn].join(_),
|
|
2359
2366
|
weightSpecs: [Y, n, Mn].join(_),
|
|
2360
2367
|
weightData: [Y, n, Fn].join(_),
|
|
2361
2368
|
modelMetadata: [Y, n, $n].join(_)
|
|
2362
2369
|
};
|
|
2363
2370
|
}
|
|
2364
|
-
function
|
|
2371
|
+
function ke(n) {
|
|
2365
2372
|
for (const t of Object.values(n))
|
|
2366
2373
|
window.localStorage.removeItem(t);
|
|
2367
2374
|
}
|
|
@@ -2372,15 +2379,15 @@ function Rn(n) {
|
|
|
2372
2379
|
return t.slice(1, t.length - 1).join(_);
|
|
2373
2380
|
}
|
|
2374
2381
|
function xn(n) {
|
|
2375
|
-
return n.startsWith(
|
|
2382
|
+
return n.startsWith(j.URL_SCHEME) ? n.slice(j.URL_SCHEME.length) : n;
|
|
2376
2383
|
}
|
|
2377
|
-
class
|
|
2384
|
+
class j {
|
|
2378
2385
|
constructor(t) {
|
|
2379
2386
|
if (!S().getBool("IS_BROWSER") || typeof window > "u" || typeof window.localStorage > "u")
|
|
2380
2387
|
throw new Error("The current environment does not support local storage.");
|
|
2381
2388
|
if (this.LS = window.localStorage, t == null || !t)
|
|
2382
2389
|
throw new Error("For local storage, modelPath must not be null, undefined or empty.");
|
|
2383
|
-
this.modelPath = t, this.keys =
|
|
2390
|
+
this.modelPath = t, this.keys = Se(this.modelPath);
|
|
2384
2391
|
}
|
|
2385
2392
|
/**
|
|
2386
2393
|
* Save model artifacts to browser local storage.
|
|
@@ -2395,7 +2402,7 @@ class W {
|
|
|
2395
2402
|
if (t.modelTopology instanceof ArrayBuffer)
|
|
2396
2403
|
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
|
|
2397
2404
|
{
|
|
2398
|
-
const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r =
|
|
2405
|
+
const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r = pe(t), i = lt.join(t.weightData);
|
|
2399
2406
|
try {
|
|
2400
2407
|
this.LS.setItem(this.keys.info, JSON.stringify(r)), this.LS.setItem(this.keys.topology, e), this.LS.setItem(this.keys.weightSpecs, s), this.LS.setItem(this.keys.weightData, In(i));
|
|
2401
2408
|
const o = {
|
|
@@ -2410,7 +2417,7 @@ class W {
|
|
|
2410
2417
|
};
|
|
2411
2418
|
return this.LS.setItem(this.keys.modelMetadata, JSON.stringify(o)), { modelArtifactsInfo: r };
|
|
2412
2419
|
} catch {
|
|
2413
|
-
throw
|
|
2420
|
+
throw ke(this.keys), new Error(`Failed to save model '${this.modelPath}' to local storage: size quota being exceeded is a possible cause of this failure: modelTopologyBytes=${r.modelTopologyBytes}, weightSpecsBytes=${r.weightSpecsBytes}, weightDataBytes=${r.weightDataBytes}.`);
|
|
2414
2421
|
}
|
|
2415
2422
|
}
|
|
2416
2423
|
}
|
|
@@ -2447,19 +2454,19 @@ class W {
|
|
|
2447
2454
|
return e.weightData = Tn(o), e;
|
|
2448
2455
|
}
|
|
2449
2456
|
}
|
|
2450
|
-
|
|
2451
|
-
const
|
|
2452
|
-
A.registerSaveRouter(
|
|
2453
|
-
A.registerLoadRouter(
|
|
2457
|
+
j.URL_SCHEME = "localstorage://";
|
|
2458
|
+
const Ie = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(j.URL_SCHEME) ? Nn(n.slice(j.URL_SCHEME.length)) : null;
|
|
2459
|
+
A.registerSaveRouter(Ie);
|
|
2460
|
+
A.registerLoadRouter(Ie);
|
|
2454
2461
|
function Nn(n) {
|
|
2455
|
-
return new
|
|
2462
|
+
return new j(n);
|
|
2456
2463
|
}
|
|
2457
2464
|
class Dn {
|
|
2458
2465
|
constructor() {
|
|
2459
2466
|
y(S().getBool("IS_BROWSER"), () => "Current environment is not a web browser"), y(typeof window > "u" || typeof window.localStorage < "u", () => "Current browser does not appear to support localStorage"), this.LS = window.localStorage;
|
|
2460
2467
|
}
|
|
2461
2468
|
async listModels() {
|
|
2462
|
-
const t = {}, e = Y + _, s = _ +
|
|
2469
|
+
const t = {}, e = Y + _, s = _ + we;
|
|
2463
2470
|
for (let r = 0; r < this.LS.length; ++r) {
|
|
2464
2471
|
const i = this.LS.key(r);
|
|
2465
2472
|
if (i.startsWith(e) && i.endsWith(s)) {
|
|
@@ -2471,11 +2478,11 @@ class Dn {
|
|
|
2471
2478
|
}
|
|
2472
2479
|
async removeModel(t) {
|
|
2473
2480
|
t = xn(t);
|
|
2474
|
-
const e =
|
|
2481
|
+
const e = Se(t);
|
|
2475
2482
|
if (this.LS.getItem(e.info) == null)
|
|
2476
2483
|
throw new Error(`Cannot find model at path '${t}'`);
|
|
2477
2484
|
const s = JSON.parse(this.LS.getItem(e.info));
|
|
2478
|
-
return
|
|
2485
|
+
return ke(e), s;
|
|
2479
2486
|
}
|
|
2480
2487
|
}
|
|
2481
2488
|
/**
|
|
@@ -2494,7 +2501,7 @@ class Dn {
|
|
|
2494
2501
|
* limitations under the License.
|
|
2495
2502
|
* =============================================================================
|
|
2496
2503
|
*/
|
|
2497
|
-
const
|
|
2504
|
+
const Xt = "://";
|
|
2498
2505
|
class N {
|
|
2499
2506
|
constructor() {
|
|
2500
2507
|
this.managers = {};
|
|
@@ -2509,7 +2516,7 @@ class N {
|
|
|
2509
2516
|
* of `IOHandler` with the `save` method defined or `null`.
|
|
2510
2517
|
*/
|
|
2511
2518
|
static registerManager(t, e) {
|
|
2512
|
-
y(t != null, () => "scheme must not be undefined or null."), t.endsWith(
|
|
2519
|
+
y(t != null, () => "scheme must not be undefined or null."), t.endsWith(Xt) && (t = t.slice(0, t.indexOf(Xt))), y(t.length > 0, () => "scheme must not be an empty string.");
|
|
2513
2520
|
const s = N.getInstance();
|
|
2514
2521
|
y(s.managers[t] == null, () => `A model store manager is already registered for scheme '${t}'.`), s.managers[t] = e;
|
|
2515
2522
|
}
|
|
@@ -2577,17 +2584,17 @@ class Cn {
|
|
|
2577
2584
|
}, !0));
|
|
2578
2585
|
}
|
|
2579
2586
|
isTypedArray(t) {
|
|
2580
|
-
return
|
|
2587
|
+
return oe(t);
|
|
2581
2588
|
}
|
|
2582
2589
|
}
|
|
2583
2590
|
if (S().get("IS_BROWSER")) {
|
|
2584
2591
|
S().setPlatform("browser", new Cn());
|
|
2585
2592
|
try {
|
|
2586
|
-
N.registerManager(
|
|
2593
|
+
N.registerManager(j.URL_SCHEME, new Dn());
|
|
2587
2594
|
} catch {
|
|
2588
2595
|
}
|
|
2589
2596
|
try {
|
|
2590
|
-
N.registerManager(
|
|
2597
|
+
N.registerManager(W.URL_SCHEME, new Bn());
|
|
2591
2598
|
} catch {
|
|
2592
2599
|
}
|
|
2593
2600
|
}
|
|
@@ -2637,7 +2644,7 @@ S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new Pn()
|
|
|
2637
2644
|
* =============================================================================
|
|
2638
2645
|
*/
|
|
2639
2646
|
function On(n, t = "float32", e) {
|
|
2640
|
-
return t = t || "float32",
|
|
2647
|
+
return t = t || "float32", Dt(n), new ln(n, t, e);
|
|
2641
2648
|
}
|
|
2642
2649
|
/**
|
|
2643
2650
|
* @license
|
|
@@ -2657,14 +2664,14 @@ function On(n, t = "float32", e) {
|
|
|
2657
2664
|
*/
|
|
2658
2665
|
function Ln(n, t) {
|
|
2659
2666
|
const e = I(n, "x", "cast");
|
|
2660
|
-
if (
|
|
2667
|
+
if (!$e(t))
|
|
2661
2668
|
throw new Error(`Failed to cast to unknown dtype ${t}`);
|
|
2662
2669
|
if (t === "string" && e.dtype !== "string" || t !== "string" && e.dtype === "string")
|
|
2663
2670
|
throw new Error("Only strings can be casted to strings");
|
|
2664
2671
|
const s = { x: e }, r = { dtype: t };
|
|
2665
|
-
return g.runKernel(
|
|
2672
|
+
return g.runKernel(se, s, r);
|
|
2666
2673
|
}
|
|
2667
|
-
const
|
|
2674
|
+
const $t = /* @__PURE__ */ F({ cast_: Ln });
|
|
2668
2675
|
/**
|
|
2669
2676
|
* @license
|
|
2670
2677
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2683,7 +2690,7 @@ const Ft = /* @__PURE__ */ F({ cast_: Ln });
|
|
|
2683
2690
|
*/
|
|
2684
2691
|
function Un(n) {
|
|
2685
2692
|
const e = { x: I(n, "x", "clone", "string_or_numeric") };
|
|
2686
|
-
return g.runKernel(
|
|
2693
|
+
return g.runKernel(re, e);
|
|
2687
2694
|
}
|
|
2688
2695
|
const Gn = /* @__PURE__ */ F({ clone_: Un });
|
|
2689
2696
|
/**
|
|
@@ -2721,10 +2728,10 @@ function zn(n, t = !1) {
|
|
|
2721
2728
|
* limitations under the License.
|
|
2722
2729
|
* =============================================================================
|
|
2723
2730
|
*/
|
|
2724
|
-
|
|
2731
|
+
ge();
|
|
2725
2732
|
const Wn = {
|
|
2726
2733
|
buffer: On,
|
|
2727
|
-
cast:
|
|
2734
|
+
cast: $t,
|
|
2728
2735
|
clone: Gn,
|
|
2729
2736
|
print: zn
|
|
2730
2737
|
};
|
|
@@ -2747,9 +2754,9 @@ un(Wn);
|
|
|
2747
2754
|
*/
|
|
2748
2755
|
function jn(n, t) {
|
|
2749
2756
|
let e = I(n, "a", "add"), s = I(t, "b", "add");
|
|
2750
|
-
[e, s] =
|
|
2757
|
+
[e, s] = V(e, s);
|
|
2751
2758
|
const r = { a: e, b: s };
|
|
2752
|
-
return g.runKernel(
|
|
2759
|
+
return g.runKernel(ne, r);
|
|
2753
2760
|
}
|
|
2754
2761
|
const w = /* @__PURE__ */ F({ add_: jn });
|
|
2755
2762
|
/**
|
|
@@ -2770,9 +2777,9 @@ const w = /* @__PURE__ */ F({ add_: jn });
|
|
|
2770
2777
|
*/
|
|
2771
2778
|
function Kn(n, t) {
|
|
2772
2779
|
let e = I(n, "a", "floorDiv"), s = I(t, "b", "floorDiv");
|
|
2773
|
-
[e, s] =
|
|
2780
|
+
[e, s] = V(e, s);
|
|
2774
2781
|
const r = { a: e, b: s };
|
|
2775
|
-
return g.runKernel(
|
|
2782
|
+
return g.runKernel(Ke, r);
|
|
2776
2783
|
}
|
|
2777
2784
|
const Vn = /* @__PURE__ */ F({ floorDiv_: Kn });
|
|
2778
2785
|
/**
|
|
@@ -2793,10 +2800,10 @@ const Vn = /* @__PURE__ */ F({ floorDiv_: Kn });
|
|
|
2793
2800
|
*/
|
|
2794
2801
|
function qn(n, t) {
|
|
2795
2802
|
let e = I(n, "a", "div"), s = I(t, "b", "div");
|
|
2796
|
-
if ([e, s] =
|
|
2803
|
+
if ([e, s] = V(e, s), e.dtype === "int32" && s.dtype === "int32")
|
|
2797
2804
|
return Vn(e, s);
|
|
2798
2805
|
const r = { a: e, b: s }, i = {};
|
|
2799
|
-
return g.runKernel(
|
|
2806
|
+
return g.runKernel(We, r, i);
|
|
2800
2807
|
}
|
|
2801
2808
|
const D = /* @__PURE__ */ F({ div_: qn });
|
|
2802
2809
|
/**
|
|
@@ -2817,9 +2824,9 @@ const D = /* @__PURE__ */ F({ div_: qn });
|
|
|
2817
2824
|
*/
|
|
2818
2825
|
function Hn(n, t) {
|
|
2819
2826
|
let e = I(n, "a", "mul"), s = I(t, "b", "mul");
|
|
2820
|
-
[e, s] =
|
|
2827
|
+
[e, s] = V(e, s);
|
|
2821
2828
|
const r = { a: e, b: s };
|
|
2822
|
-
return g.runKernel(
|
|
2829
|
+
return g.runKernel(qe, r);
|
|
2823
2830
|
}
|
|
2824
2831
|
const p = /* @__PURE__ */ F({ mul_: Hn });
|
|
2825
2832
|
/**
|
|
@@ -2842,10 +2849,10 @@ function Jn(n) {
|
|
|
2842
2849
|
const t = I(n, "x", "abs");
|
|
2843
2850
|
if (t.dtype === "complex64") {
|
|
2844
2851
|
const e = { x: t };
|
|
2845
|
-
return g.runKernel(
|
|
2852
|
+
return g.runKernel(ze, e);
|
|
2846
2853
|
} else {
|
|
2847
2854
|
const e = { x: t };
|
|
2848
|
-
return g.runKernel(
|
|
2855
|
+
return g.runKernel(Ge, e);
|
|
2849
2856
|
}
|
|
2850
2857
|
}
|
|
2851
2858
|
const Xn = /* @__PURE__ */ F({ abs_: Jn });
|
|
@@ -2866,9 +2873,9 @@ const Xn = /* @__PURE__ */ F({ abs_: Jn });
|
|
|
2866
2873
|
* =============================================================================
|
|
2867
2874
|
*/
|
|
2868
2875
|
function Yn(n, t, e) {
|
|
2869
|
-
|
|
2876
|
+
Dt(n), e = e || mt(t);
|
|
2870
2877
|
const s = { shape: n, value: t, dtype: e };
|
|
2871
|
-
return g.runKernel(
|
|
2878
|
+
return g.runKernel(je, {}, s);
|
|
2872
2879
|
}
|
|
2873
2880
|
/**
|
|
2874
2881
|
* @license
|
|
@@ -2886,7 +2893,7 @@ function Yn(n, t, e) {
|
|
|
2886
2893
|
* limitations under the License.
|
|
2887
2894
|
* =============================================================================
|
|
2888
2895
|
*/
|
|
2889
|
-
function
|
|
2896
|
+
function Zs(n, t) {
|
|
2890
2897
|
const e = [];
|
|
2891
2898
|
for (let s = 0; s < t.length; s++) {
|
|
2892
2899
|
const r = n[n.length - s - 1], i = t.length - s - 1, o = t[i];
|
|
@@ -2930,7 +2937,7 @@ function Qn(n, t) {
|
|
|
2930
2937
|
*/
|
|
2931
2938
|
function Zn(n) {
|
|
2932
2939
|
const e = { x: I(n, "x", "zerosLike") };
|
|
2933
|
-
return g.runKernel(
|
|
2940
|
+
return g.runKernel(Ye, e);
|
|
2934
2941
|
}
|
|
2935
2942
|
const C = /* @__PURE__ */ F({ zerosLike_: Zn });
|
|
2936
2943
|
/**
|
|
@@ -2951,11 +2958,11 @@ const C = /* @__PURE__ */ F({ zerosLike_: Zn });
|
|
|
2951
2958
|
*/
|
|
2952
2959
|
function ts(n, t) {
|
|
2953
2960
|
let e = I(n, "base", "pow"), s = I(t, "exp", "pow");
|
|
2954
|
-
[e, s] =
|
|
2961
|
+
[e, s] = V(e, s);
|
|
2955
2962
|
const r = { a: e, b: s };
|
|
2956
|
-
return g.runKernel(
|
|
2963
|
+
return g.runKernel(He, r);
|
|
2957
2964
|
}
|
|
2958
|
-
const
|
|
2965
|
+
const Yt = /* @__PURE__ */ F({ pow_: ts });
|
|
2959
2966
|
/**
|
|
2960
2967
|
* @license
|
|
2961
2968
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -2972,7 +2979,7 @@ const Xt = /* @__PURE__ */ F({ pow_: ts });
|
|
|
2972
2979
|
* limitations under the License.
|
|
2973
2980
|
* =============================================================================
|
|
2974
2981
|
*/
|
|
2975
|
-
function
|
|
2982
|
+
function K(n, t) {
|
|
2976
2983
|
if (($(n) && t !== "string" || Array.isArray(n)) && t !== "complex64")
|
|
2977
2984
|
throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");
|
|
2978
2985
|
if (t === "string" && $(n) && !(n instanceof Uint8Array))
|
|
@@ -2997,7 +3004,7 @@ function j(n, t) {
|
|
|
2997
3004
|
*/
|
|
2998
3005
|
function es(n) {
|
|
2999
3006
|
const e = { x: I(n, "x", "sqrt", "float32") };
|
|
3000
|
-
return g.runKernel(
|
|
3007
|
+
return g.runKernel(Je, e);
|
|
3001
3008
|
}
|
|
3002
3009
|
const et = /* @__PURE__ */ F({ sqrt_: es });
|
|
3003
3010
|
/**
|
|
@@ -3020,7 +3027,7 @@ function ns(n) {
|
|
|
3020
3027
|
const t = I(n, "x", "square"), e = {};
|
|
3021
3028
|
return g.runKernel("Square", { x: t }, e);
|
|
3022
3029
|
}
|
|
3023
|
-
const
|
|
3030
|
+
const z = /* @__PURE__ */ F({ square_: ns });
|
|
3024
3031
|
/**
|
|
3025
3032
|
* @license
|
|
3026
3033
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -3054,7 +3061,7 @@ function ss(n, t) {
|
|
|
3054
3061
|
a[u] != null && (c[l.name] = a[u]);
|
|
3055
3062
|
}), s?.forEach((l) => c[l.name] = null), { value: o, grads: c };
|
|
3056
3063
|
}
|
|
3057
|
-
function
|
|
3064
|
+
function tr(n) {
|
|
3058
3065
|
return g.customGrad(n);
|
|
3059
3066
|
}
|
|
3060
3067
|
/**
|
|
@@ -3075,9 +3082,9 @@ function Qs(n) {
|
|
|
3075
3082
|
*/
|
|
3076
3083
|
function rs(n, t) {
|
|
3077
3084
|
let e = I(n, "a", "sub"), s = I(t, "b", "sub");
|
|
3078
|
-
[e, s] =
|
|
3085
|
+
[e, s] = V(e, s);
|
|
3079
3086
|
const r = { a: e, b: s };
|
|
3080
|
-
return g.runKernel(
|
|
3087
|
+
return g.runKernel(Xe, r);
|
|
3081
3088
|
}
|
|
3082
3089
|
const Z = /* @__PURE__ */ F({ sub_: rs });
|
|
3083
3090
|
/**
|
|
@@ -3098,9 +3105,9 @@ const Z = /* @__PURE__ */ F({ sub_: rs });
|
|
|
3098
3105
|
*/
|
|
3099
3106
|
function is(n, t) {
|
|
3100
3107
|
let e = I(n, "a", "maximum"), s = I(t, "b", "maximum");
|
|
3101
|
-
[e, s] =
|
|
3108
|
+
[e, s] = V(e, s), e.dtype === "bool" && (e = $t(e, "int32"), s = $t(s, "int32")), Qn(e.shape, s.shape);
|
|
3102
3109
|
const r = { a: e, b: s };
|
|
3103
|
-
return g.runKernel(
|
|
3110
|
+
return g.runKernel(Ve, r);
|
|
3104
3111
|
}
|
|
3105
3112
|
const os = /* @__PURE__ */ F({ maximum_: is });
|
|
3106
3113
|
/**
|
|
@@ -3148,7 +3155,7 @@ class cs {
|
|
|
3148
3155
|
return new t(e);
|
|
3149
3156
|
}
|
|
3150
3157
|
}
|
|
3151
|
-
class
|
|
3158
|
+
class L {
|
|
3152
3159
|
constructor() {
|
|
3153
3160
|
this.classNameMap = {};
|
|
3154
3161
|
}
|
|
@@ -3156,19 +3163,19 @@ class O {
|
|
|
3156
3163
|
* Returns the singleton instance of the map.
|
|
3157
3164
|
*/
|
|
3158
3165
|
static getMap() {
|
|
3159
|
-
return
|
|
3166
|
+
return L.instance == null && (L.instance = new L()), L.instance;
|
|
3160
3167
|
}
|
|
3161
3168
|
/**
|
|
3162
3169
|
* Registers the class as serializable.
|
|
3163
3170
|
*/
|
|
3164
3171
|
static register(t) {
|
|
3165
|
-
|
|
3172
|
+
L.getMap().classNameMap[t.className] = [t, t.fromConfig];
|
|
3166
3173
|
}
|
|
3167
3174
|
}
|
|
3168
3175
|
function us(n, t, e) {
|
|
3169
3176
|
y(n.className != null, () => "Class being registered does not have the static className property defined."), y(typeof n.className == "string", () => "className is required to be a string, but got type " + typeof n.className), y(n.className.length > 0, () => "Class being registered has an empty-string as its className, which is disallowed."), typeof t > "u" && (t = "Custom"), typeof e > "u" && (e = n.className);
|
|
3170
3177
|
const s = e, r = t + ">" + s;
|
|
3171
|
-
return
|
|
3178
|
+
return L.register(n), as.set(r, n), ls.set(n, r), n;
|
|
3172
3179
|
}
|
|
3173
3180
|
/**
|
|
3174
3181
|
* @license
|
|
@@ -3186,7 +3193,7 @@ function us(n, t, e) {
|
|
|
3186
3193
|
* limitations under the License.
|
|
3187
3194
|
* =============================================================================
|
|
3188
3195
|
*/
|
|
3189
|
-
class
|
|
3196
|
+
class q extends cs {
|
|
3190
3197
|
/**
|
|
3191
3198
|
* Executes `f()` and minimizes the scalar output of `f()` by computing
|
|
3192
3199
|
* gradients of y with respect to the list of trainable variables provided by
|
|
@@ -3245,7 +3252,7 @@ class V extends cs {
|
|
|
3245
3252
|
return this.iterations_ == null && (this.iterations_ = 0), {
|
|
3246
3253
|
name: "iter",
|
|
3247
3254
|
// TODO(cais): Use 'int64' type when available.
|
|
3248
|
-
tensor:
|
|
3255
|
+
tensor: K(this.iterations_, "int32")
|
|
3249
3256
|
};
|
|
3250
3257
|
}
|
|
3251
3258
|
async getWeights() {
|
|
@@ -3265,7 +3272,7 @@ class V extends cs {
|
|
|
3265
3272
|
return this.iterations_ = (await t[0].tensor.data())[0], t.slice(1);
|
|
3266
3273
|
}
|
|
3267
3274
|
}
|
|
3268
|
-
Object.defineProperty(
|
|
3275
|
+
Object.defineProperty(q, Symbol.hasInstance, {
|
|
3269
3276
|
value: (n) => n.minimize != null && n.computeGradients != null && n.applyGradients != null
|
|
3270
3277
|
});
|
|
3271
3278
|
/**
|
|
@@ -3284,7 +3291,7 @@ Object.defineProperty(V, Symbol.hasInstance, {
|
|
|
3284
3291
|
* limitations under the License.
|
|
3285
3292
|
* =============================================================================
|
|
3286
3293
|
*/
|
|
3287
|
-
class hs extends
|
|
3294
|
+
class hs extends q {
|
|
3288
3295
|
/** @nocollapse */
|
|
3289
3296
|
static get className() {
|
|
3290
3297
|
return "Adadelta";
|
|
@@ -3307,7 +3314,7 @@ class hs extends V {
|
|
|
3307
3314
|
return;
|
|
3308
3315
|
const c = this.accumulatedGrads[r].variable, l = this.accumulatedUpdates[r].variable;
|
|
3309
3316
|
E(() => {
|
|
3310
|
-
const u = w(p(c, this.rho), p(
|
|
3317
|
+
const u = w(p(c, this.rho), p(z(a), 1 - this.rho)), h = p(D(et(w(l, this.epsilon)), et(w(c, this.epsilon))), a), f = w(p(l, this.rho), p(z(h), 1 - this.rho));
|
|
3311
3318
|
c.assign(u), l.assign(f);
|
|
3312
3319
|
const m = w(p(h, -this.learningRate), i);
|
|
3313
3320
|
i.assign(m);
|
|
@@ -3360,7 +3367,7 @@ class hs extends V {
|
|
|
3360
3367
|
* limitations under the License.
|
|
3361
3368
|
* =============================================================================
|
|
3362
3369
|
*/
|
|
3363
|
-
class fs extends
|
|
3370
|
+
class fs extends q {
|
|
3364
3371
|
/** @nocollapse */
|
|
3365
3372
|
static get className() {
|
|
3366
3373
|
return "Adagrad";
|
|
@@ -3380,7 +3387,7 @@ class fs extends V {
|
|
|
3380
3387
|
return;
|
|
3381
3388
|
const a = this.accumulatedGrads[r].variable;
|
|
3382
3389
|
E(() => {
|
|
3383
|
-
const c = w(a,
|
|
3390
|
+
const c = w(a, z(o));
|
|
3384
3391
|
a.assign(c);
|
|
3385
3392
|
const l = w(p(D(o, et(w(c, g.backend.epsilon()))), -this.learningRate), i);
|
|
3386
3393
|
i.assign(l);
|
|
@@ -3425,14 +3432,14 @@ class fs extends V {
|
|
|
3425
3432
|
* limitations under the License.
|
|
3426
3433
|
* =============================================================================
|
|
3427
3434
|
*/
|
|
3428
|
-
class ds extends
|
|
3435
|
+
class ds extends q {
|
|
3429
3436
|
/** @nocollapse */
|
|
3430
3437
|
static get className() {
|
|
3431
3438
|
return "Adam";
|
|
3432
3439
|
}
|
|
3433
3440
|
constructor(t, e, s, r = null) {
|
|
3434
3441
|
super(), this.learningRate = t, this.beta1 = e, this.beta2 = s, this.epsilon = r, this.accumulatedFirstMoment = [], this.accumulatedSecondMoment = [], E(() => {
|
|
3435
|
-
this.accBeta1 =
|
|
3442
|
+
this.accBeta1 = K(e).variable(), this.accBeta2 = K(s).variable();
|
|
3436
3443
|
}), r == null && (this.epsilon = g.backend.epsilon());
|
|
3437
3444
|
}
|
|
3438
3445
|
applyGradients(t) {
|
|
@@ -3451,7 +3458,7 @@ class ds extends V {
|
|
|
3451
3458
|
const l = Array.isArray(t) ? t[o].tensor : t[i];
|
|
3452
3459
|
if (l == null)
|
|
3453
3460
|
return;
|
|
3454
|
-
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedSecondMoment[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = w(p(h, this.beta2), p(
|
|
3461
|
+
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedSecondMoment[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = w(p(h, this.beta2), p(z(l), 1 - this.beta2)), b = D(f, s), d = D(m, r);
|
|
3455
3462
|
u.assign(f), h.assign(m);
|
|
3456
3463
|
const k = w(p(D(b, w(et(d), this.epsilon)), -this.learningRate), a);
|
|
3457
3464
|
a.assign(k);
|
|
@@ -3467,7 +3474,7 @@ class ds extends V {
|
|
|
3467
3474
|
}
|
|
3468
3475
|
async setWeights(t) {
|
|
3469
3476
|
t = await this.extractIterations(t), E(() => {
|
|
3470
|
-
this.accBeta1.assign(
|
|
3477
|
+
this.accBeta1.assign(Yt(this.beta1, this.iterations_ + 1)), this.accBeta2.assign(Yt(this.beta2, this.iterations_ + 1));
|
|
3471
3478
|
});
|
|
3472
3479
|
const e = t.length / 2, s = !1;
|
|
3473
3480
|
this.accumulatedFirstMoment = t.slice(0, e).map((r) => ({
|
|
@@ -3507,14 +3514,14 @@ class ds extends V {
|
|
|
3507
3514
|
* limitations under the License.
|
|
3508
3515
|
* =============================================================================
|
|
3509
3516
|
*/
|
|
3510
|
-
class gs extends
|
|
3517
|
+
class gs extends q {
|
|
3511
3518
|
/** @nocollapse */
|
|
3512
3519
|
static get className() {
|
|
3513
3520
|
return "Adamax";
|
|
3514
3521
|
}
|
|
3515
3522
|
constructor(t, e, s, r = null, i = 0) {
|
|
3516
3523
|
super(), this.learningRate = t, this.beta1 = e, this.beta2 = s, this.epsilon = r, this.decay = i, this.accumulatedFirstMoment = [], this.accumulatedWeightedInfNorm = [], E(() => {
|
|
3517
|
-
this.iteration =
|
|
3524
|
+
this.iteration = K(0).variable(), this.accBeta1 = K(e).variable();
|
|
3518
3525
|
}), r == null && (this.epsilon = g.backend.epsilon());
|
|
3519
3526
|
}
|
|
3520
3527
|
applyGradients(t) {
|
|
@@ -3579,7 +3586,7 @@ class gs extends V {
|
|
|
3579
3586
|
* limitations under the License.
|
|
3580
3587
|
* =============================================================================
|
|
3581
3588
|
*/
|
|
3582
|
-
class
|
|
3589
|
+
class Te extends q {
|
|
3583
3590
|
/** @nocollapse */
|
|
3584
3591
|
static get className() {
|
|
3585
3592
|
return "SGD";
|
|
@@ -3603,7 +3610,7 @@ class Ie extends V {
|
|
|
3603
3610
|
* Sets the learning rate of the optimizer.
|
|
3604
3611
|
*/
|
|
3605
3612
|
setLearningRate(t) {
|
|
3606
|
-
this.learningRate = t, this.c != null && this.c.dispose(), this.c = kn(
|
|
3613
|
+
this.learningRate = t, this.c != null && this.c.dispose(), this.c = kn(K(-t));
|
|
3607
3614
|
}
|
|
3608
3615
|
dispose() {
|
|
3609
3616
|
this.c.dispose();
|
|
@@ -3639,14 +3646,14 @@ class Ie extends V {
|
|
|
3639
3646
|
* limitations under the License.
|
|
3640
3647
|
* =============================================================================
|
|
3641
3648
|
*/
|
|
3642
|
-
class ms extends
|
|
3649
|
+
class ms extends Te {
|
|
3643
3650
|
/** @nocollapse */
|
|
3644
3651
|
// Name matters for Python compatibility.
|
|
3645
3652
|
static get className() {
|
|
3646
3653
|
return "Momentum";
|
|
3647
3654
|
}
|
|
3648
3655
|
constructor(t, e, s = !1) {
|
|
3649
|
-
super(t), this.learningRate = t, this.momentum = e, this.useNesterov = s, this.accumulations = [], this.m =
|
|
3656
|
+
super(t), this.learningRate = t, this.momentum = e, this.useNesterov = s, this.accumulations = [], this.m = K(this.momentum);
|
|
3650
3657
|
}
|
|
3651
3658
|
applyGradients(t) {
|
|
3652
3659
|
(Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t)).forEach((s, r) => {
|
|
@@ -3710,7 +3717,7 @@ class ms extends Ie {
|
|
|
3710
3717
|
* limitations under the License.
|
|
3711
3718
|
* =============================================================================
|
|
3712
3719
|
*/
|
|
3713
|
-
class ps extends
|
|
3720
|
+
class ps extends q {
|
|
3714
3721
|
/** @nocollapse */
|
|
3715
3722
|
static get className() {
|
|
3716
3723
|
return "RMSProp";
|
|
@@ -3737,14 +3744,14 @@ class ps extends V {
|
|
|
3737
3744
|
return;
|
|
3738
3745
|
const c = this.accumulatedMeanSquares[r].variable, l = this.accumulatedMoments[r].variable;
|
|
3739
3746
|
E(() => {
|
|
3740
|
-
const u = w(p(c, this.decay), p(
|
|
3747
|
+
const u = w(p(c, this.decay), p(z(a), 1 - this.decay));
|
|
3741
3748
|
if (this.centered) {
|
|
3742
|
-
const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate), et(Z(u, w(
|
|
3749
|
+
const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate), et(Z(u, w(z(f), this.epsilon)))), b = w(p(l, this.momentum), m);
|
|
3743
3750
|
c.assign(u), h.assign(f), l.assign(b);
|
|
3744
3751
|
const d = Z(i, b);
|
|
3745
3752
|
i.assign(d);
|
|
3746
3753
|
} else {
|
|
3747
|
-
const h = w(p(c, this.decay), p(
|
|
3754
|
+
const h = w(p(c, this.decay), p(z(a), 1 - this.decay)), f = w(p(l, this.momentum), D(p(a, this.learningRate), et(w(h, this.epsilon))));
|
|
3748
3755
|
c.assign(h), l.assign(f);
|
|
3749
3756
|
const m = Z(i, f);
|
|
3750
3757
|
i.assign(m);
|
|
@@ -3810,7 +3817,7 @@ const ys = [
|
|
|
3810
3817
|
gs,
|
|
3811
3818
|
ms,
|
|
3812
3819
|
ps,
|
|
3813
|
-
|
|
3820
|
+
Te
|
|
3814
3821
|
];
|
|
3815
3822
|
function bs() {
|
|
3816
3823
|
for (const n of ys)
|
|
@@ -3837,50 +3844,52 @@ export {
|
|
|
3837
3844
|
ds as A,
|
|
3838
3845
|
Es as B,
|
|
3839
3846
|
As as C,
|
|
3840
|
-
|
|
3847
|
+
C as D,
|
|
3841
3848
|
g as E,
|
|
3842
|
-
|
|
3849
|
+
zs as F,
|
|
3843
3850
|
Ms as G,
|
|
3844
|
-
|
|
3851
|
+
Bs as H,
|
|
3845
3852
|
Fs as I,
|
|
3846
|
-
|
|
3847
|
-
|
|
3853
|
+
$s as J,
|
|
3854
|
+
Cs as K,
|
|
3848
3855
|
Rs as L,
|
|
3849
3856
|
xs as M,
|
|
3850
3857
|
Ns as N,
|
|
3851
|
-
|
|
3858
|
+
Ps as O,
|
|
3852
3859
|
Ds as P,
|
|
3853
|
-
|
|
3860
|
+
Os as Q,
|
|
3854
3861
|
_s as R,
|
|
3855
3862
|
Ws as S,
|
|
3856
|
-
|
|
3857
|
-
|
|
3858
|
-
|
|
3863
|
+
Us as T,
|
|
3864
|
+
Vs as U,
|
|
3865
|
+
Ks as V,
|
|
3866
|
+
Zs as W,
|
|
3867
|
+
Qn as X,
|
|
3859
3868
|
qs as _,
|
|
3860
|
-
|
|
3861
|
-
|
|
3862
|
-
|
|
3863
|
-
|
|
3864
|
-
|
|
3865
|
-
|
|
3866
|
-
|
|
3867
|
-
|
|
3868
|
-
|
|
3869
|
-
|
|
3870
|
-
|
|
3871
|
-
|
|
3872
|
-
|
|
3873
|
-
|
|
3869
|
+
p as a,
|
|
3870
|
+
Z as b,
|
|
3871
|
+
Js as c,
|
|
3872
|
+
I as d,
|
|
3873
|
+
Ys as e,
|
|
3874
|
+
V as f,
|
|
3875
|
+
Is as g,
|
|
3876
|
+
Xs as h,
|
|
3877
|
+
y as i,
|
|
3878
|
+
Ls as j,
|
|
3879
|
+
$t as k,
|
|
3880
|
+
Dt as l,
|
|
3881
|
+
Qs as m,
|
|
3882
|
+
Zt as n,
|
|
3874
3883
|
F as o,
|
|
3875
|
-
|
|
3876
|
-
|
|
3884
|
+
G as p,
|
|
3885
|
+
De as q,
|
|
3877
3886
|
Hs as r,
|
|
3878
|
-
|
|
3879
|
-
|
|
3880
|
-
|
|
3881
|
-
|
|
3882
|
-
|
|
3883
|
-
|
|
3884
|
-
|
|
3885
|
-
|
|
3887
|
+
K as s,
|
|
3888
|
+
Gs as t,
|
|
3889
|
+
vs as u,
|
|
3890
|
+
Ts as v,
|
|
3891
|
+
w,
|
|
3892
|
+
js as x,
|
|
3893
|
+
tr as y,
|
|
3894
|
+
E as z
|
|
3886
3895
|
};
|