@genai-fi/nanogpt 0.2.7 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/TeachableLLM.js +1 -0
- package/dist/{complex-D6Bq1XDf.js → complex-Cd8sqiBC.js} +1 -1
- package/dist/{index-D1SlunD-.js → index-Dsg28SG6.js} +304 -299
- package/dist/layers/CausalSelfAttention.js +40 -39
- package/dist/layers/TiedEmbedding.js +106 -128
- package/dist/main.js +15 -14
- package/dist/mat_mul-BAYDrXvE.js +27 -0
- package/dist/ops/attentionMask.d.ts +2 -0
- package/dist/ops/attentionMask.js +82 -0
- package/dist/ops/gatherSub.js +2 -2
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/scatterSub.js +9 -9
- package/dist/{stack-DB2YLlAs.js → stack-1o648CP_.js} +5 -5
- package/dist/{sum-02UQ5Eaq.js → sum-NWazHI7f.js} +3 -3
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/sparseCrossEntropy.js +12 -12
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { g as
|
|
1
|
+
import { g as Ot } from "./index-D5v913EJ.js";
|
|
2
2
|
import { p as Q } from "./index-xuotMAFm.js";
|
|
3
3
|
import { B as gt } from "./index-Tf7vU29b.js";
|
|
4
4
|
/**
|
|
@@ -17,8 +17,8 @@ import { B as gt } from "./index-Tf7vU29b.js";
|
|
|
17
17
|
* limitations under the License.
|
|
18
18
|
* =============================================================================
|
|
19
19
|
*/
|
|
20
|
-
const
|
|
21
|
-
class
|
|
20
|
+
const Ee = 1e-7, Ae = 1e-4;
|
|
21
|
+
class Be {
|
|
22
22
|
refCount(t) {
|
|
23
23
|
return v("refCount");
|
|
24
24
|
}
|
|
@@ -64,7 +64,7 @@ class Ae {
|
|
|
64
64
|
}
|
|
65
65
|
/** Returns the smallest representable number. */
|
|
66
66
|
epsilon() {
|
|
67
|
-
return this.floatPrecision() === 32 ?
|
|
67
|
+
return this.floatPrecision() === 32 ? Ee : Ae;
|
|
68
68
|
}
|
|
69
69
|
dispose() {
|
|
70
70
|
return v("dispose");
|
|
@@ -94,9 +94,9 @@ function y(n, t) {
|
|
|
94
94
|
throw new Error(typeof t == "string" ? t : t());
|
|
95
95
|
}
|
|
96
96
|
function Is(n, t, e = "") {
|
|
97
|
-
y(
|
|
97
|
+
y(Rt(n, t), () => e + ` Shapes ${n} and ${t} must match`);
|
|
98
98
|
}
|
|
99
|
-
function
|
|
99
|
+
function G(n) {
|
|
100
100
|
if (n.length === 0)
|
|
101
101
|
return 1;
|
|
102
102
|
let t = n[0];
|
|
@@ -104,7 +104,7 @@ function U(n) {
|
|
|
104
104
|
t *= n[e];
|
|
105
105
|
return t;
|
|
106
106
|
}
|
|
107
|
-
function
|
|
107
|
+
function Rt(n, t) {
|
|
108
108
|
if (n === t)
|
|
109
109
|
return !0;
|
|
110
110
|
if (n == null || t == null || n.length !== t.length)
|
|
@@ -114,7 +114,7 @@ function $t(n, t) {
|
|
|
114
114
|
return !1;
|
|
115
115
|
return !0;
|
|
116
116
|
}
|
|
117
|
-
function
|
|
117
|
+
function ve(n) {
|
|
118
118
|
return n % 1 === 0;
|
|
119
119
|
}
|
|
120
120
|
function ct(n, t) {
|
|
@@ -122,9 +122,9 @@ function ct(n, t) {
|
|
|
122
122
|
}
|
|
123
123
|
function Ts(n, t) {
|
|
124
124
|
const e = t.length;
|
|
125
|
-
return n = n == null ? t.map((s, r) => r) : [].concat(n), y(n.every((s) => s >= -e && s < e), () => `All values in axis param must be in range [-${e}, ${e}) but got axis ${n}`), y(n.every((s) =>
|
|
125
|
+
return n = n == null ? t.map((s, r) => r) : [].concat(n), y(n.every((s) => s >= -e && s < e), () => `All values in axis param must be in range [-${e}, ${e}) but got axis ${n}`), y(n.every((s) => ve(s)), () => `All values in axis param must be integers but got axis ${n}`), n.map((s) => s < 0 ? e + s : s);
|
|
126
126
|
}
|
|
127
|
-
function
|
|
127
|
+
function Me(n, t) {
|
|
128
128
|
let e = null;
|
|
129
129
|
if (n == null || n === "float32")
|
|
130
130
|
e = new Float32Array(t);
|
|
@@ -138,14 +138,14 @@ function ve(n, t) {
|
|
|
138
138
|
throw new Error(`Unknown data type ${n}`);
|
|
139
139
|
return e;
|
|
140
140
|
}
|
|
141
|
-
function
|
|
141
|
+
function Fe(n, t) {
|
|
142
142
|
for (let e = 0; e < n.length; e++) {
|
|
143
143
|
const s = n[e];
|
|
144
144
|
if (isNaN(s) || !isFinite(s))
|
|
145
145
|
throw Error(`A tensor of type ${t} being uploaded contains ${s}.`);
|
|
146
146
|
}
|
|
147
147
|
}
|
|
148
|
-
function
|
|
148
|
+
function $e(n) {
|
|
149
149
|
return n === "bool" || n === "complex64" || n === "float32" || n === "int32" || n === "string";
|
|
150
150
|
}
|
|
151
151
|
function St(n) {
|
|
@@ -157,28 +157,28 @@ function St(n) {
|
|
|
157
157
|
return 1;
|
|
158
158
|
throw new Error(`Unknown dtype ${n}`);
|
|
159
159
|
}
|
|
160
|
-
function
|
|
160
|
+
function Re(n) {
|
|
161
161
|
if (n == null)
|
|
162
162
|
return 0;
|
|
163
163
|
let t = 0;
|
|
164
164
|
return n.forEach((e) => t += e.length), t;
|
|
165
165
|
}
|
|
166
|
-
function
|
|
166
|
+
function xt(n) {
|
|
167
167
|
return typeof n == "string" || n instanceof String;
|
|
168
168
|
}
|
|
169
|
-
function
|
|
169
|
+
function xe(n) {
|
|
170
170
|
return typeof n == "boolean";
|
|
171
171
|
}
|
|
172
|
-
function
|
|
172
|
+
function Ne(n) {
|
|
173
173
|
return typeof n == "number";
|
|
174
174
|
}
|
|
175
175
|
function mt(n) {
|
|
176
|
-
return Array.isArray(n) ? mt(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray ? "int32" :
|
|
176
|
+
return Array.isArray(n) ? mt(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray ? "int32" : Ne(n) ? "float32" : xt(n) ? "string" : xe(n) ? "bool" : "float32";
|
|
177
177
|
}
|
|
178
178
|
function kt(n) {
|
|
179
179
|
return !!(n && n.constructor && n.call && n.apply);
|
|
180
180
|
}
|
|
181
|
-
function
|
|
181
|
+
function Nt(n) {
|
|
182
182
|
const t = n.length;
|
|
183
183
|
if (t < 2)
|
|
184
184
|
return [];
|
|
@@ -188,7 +188,7 @@ function xt(n) {
|
|
|
188
188
|
e[s] = e[s + 1] * n[s + 1];
|
|
189
189
|
return e;
|
|
190
190
|
}
|
|
191
|
-
function
|
|
191
|
+
function Qt(n, t, e, s = !1) {
|
|
192
192
|
const r = new Array();
|
|
193
193
|
if (t.length === 1) {
|
|
194
194
|
const i = t[0] * (s ? 2 : 1);
|
|
@@ -197,11 +197,11 @@ function Yt(n, t, e, s = !1) {
|
|
|
197
197
|
} else {
|
|
198
198
|
const i = t[0], o = t.slice(1), a = o.reduce((c, l) => c * l) * (s ? 2 : 1);
|
|
199
199
|
for (let c = 0; c < i; c++)
|
|
200
|
-
r[c] =
|
|
200
|
+
r[c] = Qt(n + c * a, o, e, s);
|
|
201
201
|
}
|
|
202
202
|
return r;
|
|
203
203
|
}
|
|
204
|
-
function
|
|
204
|
+
function Lt(n, t, e = !1) {
|
|
205
205
|
if (n.length === 0)
|
|
206
206
|
return t[0];
|
|
207
207
|
const s = n.reduce((r, i) => r * i) * (e ? 2 : 1);
|
|
@@ -209,15 +209,15 @@ function Ot(n, t, e = !1) {
|
|
|
209
209
|
return [];
|
|
210
210
|
if (s !== t.length)
|
|
211
211
|
throw new Error(`[${n}] does not match the input size ${t.length}${e ? " for a complex tensor" : ""}.`);
|
|
212
|
-
return
|
|
212
|
+
return Qt(0, n, t, e);
|
|
213
213
|
}
|
|
214
|
-
function
|
|
215
|
-
const e =
|
|
214
|
+
function De(n, t) {
|
|
215
|
+
const e = Zt(n, t);
|
|
216
216
|
for (let s = 0; s < e.length; s++)
|
|
217
217
|
e[s] = 1;
|
|
218
218
|
return e;
|
|
219
219
|
}
|
|
220
|
-
function
|
|
220
|
+
function Zt(n, t) {
|
|
221
221
|
if (t == null || t === "float32" || t === "complex64")
|
|
222
222
|
return new Float32Array(n);
|
|
223
223
|
if (t === "int32")
|
|
@@ -226,12 +226,12 @@ function Qt(n, t) {
|
|
|
226
226
|
return new Uint8Array(n);
|
|
227
227
|
throw new Error(`Unknown data type ${t}`);
|
|
228
228
|
}
|
|
229
|
-
function
|
|
229
|
+
function Dt(n) {
|
|
230
230
|
n.forEach((t) => {
|
|
231
231
|
y(Number.isInteger(t) && t >= 0, () => `Tensor must have a shape comprised of positive integers but got shape [${n}].`);
|
|
232
232
|
});
|
|
233
233
|
}
|
|
234
|
-
function
|
|
234
|
+
function Ct(n) {
|
|
235
235
|
return n && n.then && typeof n.then == "function";
|
|
236
236
|
}
|
|
237
237
|
/**
|
|
@@ -250,11 +250,11 @@ function Dt(n) {
|
|
|
250
250
|
* limitations under the License.
|
|
251
251
|
* =============================================================================
|
|
252
252
|
*/
|
|
253
|
-
const
|
|
254
|
-
class
|
|
253
|
+
const Ut = "tfjsflags";
|
|
254
|
+
class Ce {
|
|
255
255
|
// tslint:disable-next-line: no-any
|
|
256
256
|
constructor(t) {
|
|
257
|
-
this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams =
|
|
257
|
+
this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams = _e, this.populateURLFlags();
|
|
258
258
|
}
|
|
259
259
|
setPlatform(t, e) {
|
|
260
260
|
this.platform != null && (S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(`Platform ${this.platformName} has already been set. Overwriting the platform with ${t}.`)), this.platformName = t, this.platform = e;
|
|
@@ -272,7 +272,7 @@ class De {
|
|
|
272
272
|
if (t in this.flags)
|
|
273
273
|
return this.flags[t];
|
|
274
274
|
const e = this.evaluateFlag(t);
|
|
275
|
-
if (
|
|
275
|
+
if (Ct(e))
|
|
276
276
|
throw new Error(`Flag ${t} cannot be synchronously evaluated. Please use getAsync() instead.`);
|
|
277
277
|
return this.flags[t] = e, this.flags[t];
|
|
278
278
|
}
|
|
@@ -312,29 +312,29 @@ class De {
|
|
|
312
312
|
if (typeof this.global > "u" || typeof this.global.location > "u" || typeof this.global.location.search > "u")
|
|
313
313
|
return;
|
|
314
314
|
const t = this.getQueryParams(this.global.location.search);
|
|
315
|
-
|
|
315
|
+
Ut in t && t[Ut].split(",").forEach((s) => {
|
|
316
316
|
const [r, i] = s.split(":");
|
|
317
|
-
this.urlFlags[r] =
|
|
317
|
+
this.urlFlags[r] = Oe(r, i);
|
|
318
318
|
});
|
|
319
319
|
}
|
|
320
320
|
}
|
|
321
|
-
function
|
|
321
|
+
function _e(n) {
|
|
322
322
|
const t = {};
|
|
323
|
-
return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (
|
|
323
|
+
return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (Pe(t, s[0], s[1]), s.join("="))), t;
|
|
324
324
|
}
|
|
325
|
-
function
|
|
325
|
+
function Pe(n, t, e) {
|
|
326
326
|
n[decodeURIComponent(t)] = decodeURIComponent(e || "");
|
|
327
327
|
}
|
|
328
|
-
function
|
|
328
|
+
function Oe(n, t) {
|
|
329
329
|
const e = t.toLowerCase();
|
|
330
330
|
return e === "true" || e === "false" ? e === "true" : `${+e}` === e ? +e : t;
|
|
331
331
|
}
|
|
332
332
|
function S() {
|
|
333
|
-
return
|
|
333
|
+
return te;
|
|
334
334
|
}
|
|
335
|
-
let
|
|
336
|
-
function
|
|
337
|
-
|
|
335
|
+
let te = null;
|
|
336
|
+
function Le(n) {
|
|
337
|
+
te = n;
|
|
338
338
|
}
|
|
339
339
|
/**
|
|
340
340
|
* @license
|
|
@@ -353,13 +353,13 @@ function Oe(n) {
|
|
|
353
353
|
* =============================================================================
|
|
354
354
|
*/
|
|
355
355
|
let pt;
|
|
356
|
-
function
|
|
356
|
+
function ee() {
|
|
357
357
|
if (pt == null) {
|
|
358
358
|
let n;
|
|
359
359
|
if (typeof window < "u")
|
|
360
360
|
n = window;
|
|
361
|
-
else if (typeof
|
|
362
|
-
n =
|
|
361
|
+
else if (typeof Ot < "u")
|
|
362
|
+
n = Ot;
|
|
363
363
|
else if (typeof Q < "u")
|
|
364
364
|
n = Q;
|
|
365
365
|
else if (typeof self < "u")
|
|
@@ -370,12 +370,12 @@ function te() {
|
|
|
370
370
|
}
|
|
371
371
|
return pt;
|
|
372
372
|
}
|
|
373
|
-
function
|
|
374
|
-
const n =
|
|
373
|
+
function Ue() {
|
|
374
|
+
const n = ee();
|
|
375
375
|
return n._tfGlobals == null && (n._tfGlobals = /* @__PURE__ */ new Map()), n._tfGlobals;
|
|
376
376
|
}
|
|
377
|
-
function
|
|
378
|
-
const e =
|
|
377
|
+
function _t(n, t) {
|
|
378
|
+
const e = Ue();
|
|
379
379
|
if (e.has(n))
|
|
380
380
|
return e.get(n);
|
|
381
381
|
{
|
|
@@ -383,7 +383,7 @@ function Ct(n, t) {
|
|
|
383
383
|
return e.set(n, s), e.get(n);
|
|
384
384
|
}
|
|
385
385
|
}
|
|
386
|
-
const
|
|
386
|
+
const Ge = "Abs", ne = "Add", Es = "BatchMatMul", se = "Cast", As = "Complex", ze = "ComplexAbs", We = "RealDiv", Bs = "Elu", vs = "Exp", je = "Fill", Ke = "FloorDiv", Ms = "GatherNd", re = "Identity", Fs = "Imag", $s = "LeakyRelu", Rs = "Log", xs = "Max", Ve = "Maximum", qe = "Multiply", Ns = "Neg", Ds = "Pack", He = "Pow", Cs = "Prelu", _s = "Range", Ps = "Real", Os = "Relu", Ls = "Reshape", Us = "Relu6", Gs = "ScatterNd", zs = "Sigmoid", Je = "Sqrt", Ws = "Sum", js = "Softmax", Xe = "Sub", Ks = "Transpose", Ye = "ZerosLike", Vs = "Step", qs = "_FusedMatMul";
|
|
387
387
|
/**
|
|
388
388
|
* @license
|
|
389
389
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -400,7 +400,7 @@ const Ue = "Abs", ee = "Add", Es = "BatchMatMul", ne = "Cast", As = "Complex", G
|
|
|
400
400
|
* limitations under the License.
|
|
401
401
|
* =============================================================================
|
|
402
402
|
*/
|
|
403
|
-
function
|
|
403
|
+
function O(...n) {
|
|
404
404
|
S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(...n);
|
|
405
405
|
}
|
|
406
406
|
/**
|
|
@@ -419,15 +419,15 @@ function J(...n) {
|
|
|
419
419
|
* limitations under the License.
|
|
420
420
|
* =============================================================================
|
|
421
421
|
*/
|
|
422
|
-
const ht =
|
|
423
|
-
function
|
|
424
|
-
const e =
|
|
422
|
+
const ht = _t("kernelRegistry", () => /* @__PURE__ */ new Map()), It = _t("gradRegistry", () => /* @__PURE__ */ new Map());
|
|
423
|
+
function Gt(n, t) {
|
|
424
|
+
const e = ie(n, t);
|
|
425
425
|
return ht.get(e);
|
|
426
426
|
}
|
|
427
|
-
function Gt(n) {
|
|
428
|
-
return Ye.get(n);
|
|
429
|
-
}
|
|
430
427
|
function zt(n) {
|
|
428
|
+
return It.get(n);
|
|
429
|
+
}
|
|
430
|
+
function Wt(n) {
|
|
431
431
|
const t = ht.entries(), e = [];
|
|
432
432
|
for (; ; ) {
|
|
433
433
|
const { done: s, value: r } = t.next();
|
|
@@ -439,10 +439,14 @@ function zt(n) {
|
|
|
439
439
|
return e;
|
|
440
440
|
}
|
|
441
441
|
function Hs(n) {
|
|
442
|
-
const { kernelName: t, backendName: e } = n, s =
|
|
443
|
-
ht.has(s) &&
|
|
442
|
+
const { kernelName: t, backendName: e } = n, s = ie(t, e);
|
|
443
|
+
ht.has(s) && O(`The kernel '${t}' for backend '${e}' is already registered`), ht.set(s, n);
|
|
444
444
|
}
|
|
445
|
-
function
|
|
445
|
+
function Js(n) {
|
|
446
|
+
const { kernelName: t } = n;
|
|
447
|
+
It.has(t) && S().getBool("DEBUG") && O(`Overriding the gradient for '${t}'`), It.set(t, n);
|
|
448
|
+
}
|
|
449
|
+
function ie(n, t) {
|
|
446
450
|
return `${t}_${n}`;
|
|
447
451
|
}
|
|
448
452
|
/**
|
|
@@ -461,7 +465,7 @@ function re(n, t) {
|
|
|
461
465
|
* limitations under the License.
|
|
462
466
|
* =============================================================================
|
|
463
467
|
*/
|
|
464
|
-
function
|
|
468
|
+
function oe(n) {
|
|
465
469
|
return n instanceof Float32Array || n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray;
|
|
466
470
|
}
|
|
467
471
|
/**
|
|
@@ -483,10 +487,10 @@ function ie(n) {
|
|
|
483
487
|
function Qe(n, t) {
|
|
484
488
|
return n instanceof Float32Array && t === "float32" || n instanceof Int32Array && t === "int32" || n instanceof Uint8Array && t === "bool";
|
|
485
489
|
}
|
|
486
|
-
function
|
|
490
|
+
function ae(n, t) {
|
|
487
491
|
if (t === "string")
|
|
488
492
|
throw new Error("Cannot convert a string[] to a TypedArray");
|
|
489
|
-
if (Array.isArray(n) && (n = at(n)), S().getBool("DEBUG") &&
|
|
493
|
+
if (Array.isArray(n) && (n = at(n)), S().getBool("DEBUG") && Fe(n, t), Qe(n, t))
|
|
490
494
|
return n;
|
|
491
495
|
if (t == null || t === "float32" || t === "complex64")
|
|
492
496
|
return new Float32Array(n);
|
|
@@ -506,14 +510,14 @@ function ft() {
|
|
|
506
510
|
function Ze(n, t = "utf-8") {
|
|
507
511
|
return t = t || "utf-8", S().platform.encode(n, t);
|
|
508
512
|
}
|
|
509
|
-
function
|
|
513
|
+
function jt(n, t = "utf-8") {
|
|
510
514
|
return t = t || "utf-8", S().platform.decode(n, t);
|
|
511
515
|
}
|
|
512
516
|
function $(n) {
|
|
513
|
-
return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) :
|
|
517
|
+
return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) : oe(n);
|
|
514
518
|
}
|
|
515
519
|
function at(n, t = [], e = !1) {
|
|
516
|
-
if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" ||
|
|
520
|
+
if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" || Ct(n) || n == null || $(n) && e)
|
|
517
521
|
t.push(n);
|
|
518
522
|
else if (Array.isArray(n) || $(n))
|
|
519
523
|
for (let s = 0; s < n.length; ++s)
|
|
@@ -687,7 +691,7 @@ function rn(n, t, e, s) {
|
|
|
687
691
|
if (l.dtype !== "float32")
|
|
688
692
|
throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input ${c} must have 'float32' dtype, but has '${l.dtype}'`);
|
|
689
693
|
const u = i.inputs[c];
|
|
690
|
-
if (
|
|
694
|
+
if (!Rt(l.shape, u.shape))
|
|
691
695
|
throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input '${c}' has shape '${l.shape}', which does not match the shape of the input '${u.shape}'`);
|
|
692
696
|
if (n[u.id] == null)
|
|
693
697
|
n[u.id] = l;
|
|
@@ -714,15 +718,15 @@ function rn(n, t, e, s) {
|
|
|
714
718
|
* limitations under the License.
|
|
715
719
|
* =============================================================================
|
|
716
720
|
*/
|
|
717
|
-
const
|
|
721
|
+
const Kt = 20, rt = 3, yt = 7;
|
|
718
722
|
function on(n, t, e, s) {
|
|
719
|
-
const r =
|
|
723
|
+
const r = Nt(t), i = an(n, t, e, r), o = t.length, a = ut(n, t, e, r, i), c = ["Tensor"];
|
|
720
724
|
return s && (c.push(` dtype: ${e}`), c.push(` rank: ${o}`), c.push(` shape: [${t}]`), c.push(" values:")), c.push(a.map((l) => " " + l).join(`
|
|
721
725
|
`)), c.join(`
|
|
722
726
|
`);
|
|
723
727
|
}
|
|
724
728
|
function an(n, t, e, s) {
|
|
725
|
-
const r =
|
|
729
|
+
const r = G(t), i = s[s.length - 1], o = new Array(i).fill(0), a = t.length, c = e === "complex64" ? ot(n) : n;
|
|
726
730
|
if (a > 1)
|
|
727
731
|
for (let l = 0; l < r / i; l++) {
|
|
728
732
|
const u = l * i;
|
|
@@ -733,9 +737,9 @@ function an(n, t, e, s) {
|
|
|
733
737
|
}
|
|
734
738
|
function it(n, t, e) {
|
|
735
739
|
let s;
|
|
736
|
-
return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(yt))} + ${parseFloat(n[1].toFixed(yt))}j` :
|
|
740
|
+
return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(yt))} + ${parseFloat(n[1].toFixed(yt))}j` : xt(n) ? s = `'${n}'` : e === "bool" ? s = le(n) : s = parseFloat(n.toFixed(yt)).toString(), ct(s, t);
|
|
737
741
|
}
|
|
738
|
-
function
|
|
742
|
+
function le(n) {
|
|
739
743
|
return n === 0 ? "false" : "true";
|
|
740
744
|
}
|
|
741
745
|
function ut(n, t, e, s, r, i = !0) {
|
|
@@ -745,14 +749,14 @@ function ut(n, t, e, s, r, i = !0) {
|
|
|
745
749
|
const d = ot(n);
|
|
746
750
|
return [it(d[0], 0, e)];
|
|
747
751
|
}
|
|
748
|
-
return e === "bool" ? [
|
|
752
|
+
return e === "bool" ? [le(n[0])] : [n[0].toString()];
|
|
749
753
|
}
|
|
750
754
|
if (c === 1) {
|
|
751
|
-
if (a >
|
|
755
|
+
if (a > Kt) {
|
|
752
756
|
const k = rt * o;
|
|
753
757
|
let T = Array.from(n.slice(0, k)), nt = Array.from(n.slice((a - rt) * o, a * o));
|
|
754
758
|
return e === "complex64" && (T = ot(T), nt = ot(nt)), [
|
|
755
|
-
"[" + T.map((
|
|
759
|
+
"[" + T.map((H, J) => it(H, r[J], e)).join(", ") + ", ..., " + nt.map((H, J) => it(H, r[a - rt + J], e)).join(", ") + "]"
|
|
756
760
|
];
|
|
757
761
|
}
|
|
758
762
|
return [
|
|
@@ -760,7 +764,7 @@ function ut(n, t, e, s, r, i = !0) {
|
|
|
760
764
|
];
|
|
761
765
|
}
|
|
762
766
|
const l = t.slice(1), u = s.slice(1), h = s[0] * o, f = [];
|
|
763
|
-
if (a >
|
|
767
|
+
if (a > Kt) {
|
|
764
768
|
for (let d = 0; d < rt; d++) {
|
|
765
769
|
const k = d * h, T = k + h;
|
|
766
770
|
f.push(...ut(
|
|
@@ -834,13 +838,13 @@ function ot(n) {
|
|
|
834
838
|
*/
|
|
835
839
|
class ln {
|
|
836
840
|
constructor(t, e, s) {
|
|
837
|
-
if (this.dtype = e, this.shape = t.slice(), this.size =
|
|
841
|
+
if (this.dtype = e, this.shape = t.slice(), this.size = G(t), s != null) {
|
|
838
842
|
const r = s.length;
|
|
839
843
|
y(r === this.size, () => `Length of values '${r}' does not match the size inferred by the shape '${this.size}'.`);
|
|
840
844
|
}
|
|
841
845
|
if (e === "complex64")
|
|
842
846
|
throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).");
|
|
843
|
-
this.values = s ||
|
|
847
|
+
this.values = s || Me(e, this.size), this.strides = Nt(t);
|
|
844
848
|
}
|
|
845
849
|
/**
|
|
846
850
|
* Sets a value in the buffer at a given location.
|
|
@@ -918,7 +922,7 @@ function un(n) {
|
|
|
918
922
|
}
|
|
919
923
|
class x {
|
|
920
924
|
constructor(t, e, s, r) {
|
|
921
|
-
this.kept = !1, this.isDisposedInternal = !1, this.shape = t.slice(), this.dtype = e || "float32", this.size =
|
|
925
|
+
this.kept = !1, this.isDisposedInternal = !1, this.shape = t.slice(), this.dtype = e || "float32", this.size = G(t), this.strides = Nt(t), this.dataId = s, this.id = r, this.rankType = this.rank < 5 ? this.rank.toString() : "higher";
|
|
922
926
|
}
|
|
923
927
|
get rank() {
|
|
924
928
|
return this.shape.length;
|
|
@@ -947,7 +951,7 @@ class x {
|
|
|
947
951
|
*/
|
|
948
952
|
async array() {
|
|
949
953
|
const t = await this.data();
|
|
950
|
-
return
|
|
954
|
+
return Lt(this.shape, t, this.dtype === "complex64");
|
|
951
955
|
}
|
|
952
956
|
/**
|
|
953
957
|
* Returns the tensor data as a nested array. The transfer of data is done
|
|
@@ -956,7 +960,7 @@ class x {
|
|
|
956
960
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
957
961
|
*/
|
|
958
962
|
arraySync() {
|
|
959
|
-
return
|
|
963
|
+
return Lt(this.shape, this.dataSync(), this.dtype === "complex64");
|
|
960
964
|
}
|
|
961
965
|
/**
|
|
962
966
|
* Asynchronously downloads the values from the `tf.Tensor`. Returns a
|
|
@@ -970,7 +974,7 @@ class x {
|
|
|
970
974
|
if (this.dtype === "string") {
|
|
971
975
|
const e = await t;
|
|
972
976
|
try {
|
|
973
|
-
return e.map((s) =>
|
|
977
|
+
return e.map((s) => jt(s));
|
|
974
978
|
} catch {
|
|
975
979
|
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
|
|
976
980
|
}
|
|
@@ -1025,7 +1029,7 @@ class x {
|
|
|
1025
1029
|
const t = R().readSync(this.dataId);
|
|
1026
1030
|
if (this.dtype === "string")
|
|
1027
1031
|
try {
|
|
1028
|
-
return t.map((e) =>
|
|
1032
|
+
return t.map((e) => jt(e));
|
|
1029
1033
|
} catch {
|
|
1030
1034
|
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
|
|
1031
1035
|
}
|
|
@@ -1089,10 +1093,10 @@ class x {
|
|
|
1089
1093
|
Object.defineProperty(x, Symbol.hasInstance, {
|
|
1090
1094
|
value: (n) => !!n && n.data != null && n.dataSync != null && n.throwIfDisposed != null
|
|
1091
1095
|
});
|
|
1092
|
-
function
|
|
1093
|
-
return
|
|
1096
|
+
function ce() {
|
|
1097
|
+
return _t("Tensor", () => x);
|
|
1094
1098
|
}
|
|
1095
|
-
|
|
1099
|
+
ce();
|
|
1096
1100
|
class dt extends x {
|
|
1097
1101
|
constructor(t, e, s, r) {
|
|
1098
1102
|
super(t.shape, t.dtype, t.dataId, r), this.trainable = e, this.name = s;
|
|
@@ -1108,7 +1112,7 @@ class dt extends x {
|
|
|
1108
1112
|
assign(t) {
|
|
1109
1113
|
if (t.dtype !== this.dtype)
|
|
1110
1114
|
throw new Error(`dtype of the new value (${t.dtype}) and previous value (${this.dtype}) must match`);
|
|
1111
|
-
if (
|
|
1115
|
+
if (!Rt(t.shape, this.shape))
|
|
1112
1116
|
throw new Error(`shape of the new value (${t.shape}) and previous value (${this.shape}) must match`);
|
|
1113
1117
|
R().disposeTensor(this), this.dataId = t.dataId, R().incRef(
|
|
1114
1118
|
this,
|
|
@@ -1139,31 +1143,31 @@ Object.defineProperty(dt, Symbol.hasInstance, {
|
|
|
1139
1143
|
* limitations under the License.
|
|
1140
1144
|
* =============================================================================
|
|
1141
1145
|
*/
|
|
1142
|
-
var
|
|
1146
|
+
var Vt;
|
|
1143
1147
|
(function(n) {
|
|
1144
1148
|
n.R0 = "R0", n.R1 = "R1", n.R2 = "R2", n.R3 = "R3", n.R4 = "R4", n.R5 = "R5", n.R6 = "R6";
|
|
1145
|
-
})(
|
|
1146
|
-
var It;
|
|
1147
|
-
(function(n) {
|
|
1148
|
-
n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64";
|
|
1149
|
-
})(It || (It = {}));
|
|
1149
|
+
})(Vt || (Vt = {}));
|
|
1150
1150
|
var Tt;
|
|
1151
1151
|
(function(n) {
|
|
1152
|
-
n.float32 = "float32", n.int32 = "int32", n.bool = "
|
|
1152
|
+
n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64";
|
|
1153
1153
|
})(Tt || (Tt = {}));
|
|
1154
1154
|
var Et;
|
|
1155
1155
|
(function(n) {
|
|
1156
|
-
n.float32 = "float32", n.int32 = "
|
|
1156
|
+
n.float32 = "float32", n.int32 = "int32", n.bool = "bool", n.complex64 = "complex64";
|
|
1157
1157
|
})(Et || (Et = {}));
|
|
1158
1158
|
var At;
|
|
1159
1159
|
(function(n) {
|
|
1160
|
-
n.float32 = "
|
|
1160
|
+
n.float32 = "float32", n.int32 = "float32", n.bool = "float32", n.complex64 = "complex64";
|
|
1161
1161
|
})(At || (At = {}));
|
|
1162
|
+
var Bt;
|
|
1163
|
+
(function(n) {
|
|
1164
|
+
n.float32 = "complex64", n.int32 = "complex64", n.bool = "complex64", n.complex64 = "complex64";
|
|
1165
|
+
})(Bt || (Bt = {}));
|
|
1162
1166
|
const hn = {
|
|
1163
|
-
float32:
|
|
1164
|
-
int32:
|
|
1165
|
-
bool:
|
|
1166
|
-
complex64:
|
|
1167
|
+
float32: At,
|
|
1168
|
+
int32: Tt,
|
|
1169
|
+
bool: Et,
|
|
1170
|
+
complex64: Bt
|
|
1167
1171
|
};
|
|
1168
1172
|
function fn(n, t) {
|
|
1169
1173
|
if (n === "string" || t === "string") {
|
|
@@ -1173,10 +1177,10 @@ function fn(n, t) {
|
|
|
1173
1177
|
}
|
|
1174
1178
|
return hn[n][t];
|
|
1175
1179
|
}
|
|
1176
|
-
function
|
|
1180
|
+
function ue(n) {
|
|
1177
1181
|
return n != null && typeof n == "object" && "texture" in n && n.texture instanceof WebGLTexture;
|
|
1178
1182
|
}
|
|
1179
|
-
function
|
|
1183
|
+
function he(n) {
|
|
1180
1184
|
return typeof GPUBuffer < "u" && n != null && typeof n == "object" && "buffer" in n && n.buffer instanceof GPUBuffer;
|
|
1181
1185
|
}
|
|
1182
1186
|
/**
|
|
@@ -1195,17 +1199,17 @@ function ue(n) {
|
|
|
1195
1199
|
* limitations under the License.
|
|
1196
1200
|
* =============================================================================
|
|
1197
1201
|
*/
|
|
1198
|
-
function
|
|
1202
|
+
function V(n, t) {
|
|
1199
1203
|
if (n.dtype === t.dtype)
|
|
1200
1204
|
return [n, t];
|
|
1201
1205
|
const e = fn(n.dtype, t.dtype);
|
|
1202
1206
|
return [n.cast(e), t.cast(e)];
|
|
1203
1207
|
}
|
|
1204
|
-
function
|
|
1208
|
+
function fe(n) {
|
|
1205
1209
|
const t = [];
|
|
1206
|
-
return
|
|
1210
|
+
return de(n, t, /* @__PURE__ */ new Set()), t;
|
|
1207
1211
|
}
|
|
1208
|
-
function
|
|
1212
|
+
function de(n, t, e) {
|
|
1209
1213
|
if (n == null)
|
|
1210
1214
|
return;
|
|
1211
1215
|
if (n instanceof x) {
|
|
@@ -1217,7 +1221,7 @@ function fe(n, t, e) {
|
|
|
1217
1221
|
const s = n;
|
|
1218
1222
|
for (const r in s) {
|
|
1219
1223
|
const i = s[r];
|
|
1220
|
-
e.has(i) || (e.add(i),
|
|
1224
|
+
e.has(i) || (e.add(i), de(i, t, e));
|
|
1221
1225
|
}
|
|
1222
1226
|
}
|
|
1223
1227
|
function dn(n) {
|
|
@@ -1242,7 +1246,7 @@ function dn(n) {
|
|
|
1242
1246
|
function bt(n) {
|
|
1243
1247
|
return n.kernelName != null;
|
|
1244
1248
|
}
|
|
1245
|
-
class
|
|
1249
|
+
class qt {
|
|
1246
1250
|
constructor() {
|
|
1247
1251
|
this.registeredVariables = {}, this.nextTapeNodeId = 0, this.numBytes = 0, this.numTensors = 0, this.numStringTensors = 0, this.numDataBuffers = 0, this.gradientDepth = 0, this.kernelDepth = 0, this.scopeStack = [], this.numDataMovesStack = [], this.nextScopeId = 0, this.tensorInfo = /* @__PURE__ */ new WeakMap(), this.profiling = !1, this.activeProfile = {
|
|
1248
1252
|
newBytes: 0,
|
|
@@ -1262,7 +1266,7 @@ class Vt {
|
|
|
1262
1266
|
}
|
|
1263
1267
|
class tt {
|
|
1264
1268
|
constructor(t) {
|
|
1265
|
-
this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new
|
|
1269
|
+
this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new qt();
|
|
1266
1270
|
}
|
|
1267
1271
|
async ready() {
|
|
1268
1272
|
if (this.pendingBackendInit != null)
|
|
@@ -1308,7 +1312,7 @@ class tt {
|
|
|
1308
1312
|
return t in this.registryFactory ? this.registryFactory[t].factory : null;
|
|
1309
1313
|
}
|
|
1310
1314
|
registerBackend(t, e, s = 1) {
|
|
1311
|
-
return t in this.registryFactory ? (
|
|
1315
|
+
return t in this.registryFactory ? (O(`${t} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[t] = { factory: e, priority: s }, !0);
|
|
1312
1316
|
}
|
|
1313
1317
|
async setBackend(t) {
|
|
1314
1318
|
if (this.registryFactory[t] == null)
|
|
@@ -1322,12 +1326,12 @@ class tt {
|
|
|
1322
1326
|
return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new tn(this.backendInstance), !0;
|
|
1323
1327
|
}
|
|
1324
1328
|
setupRegisteredKernels() {
|
|
1325
|
-
|
|
1329
|
+
Wt(this.backendName).forEach((e) => {
|
|
1326
1330
|
e.setupFunc != null && e.setupFunc(this.backendInstance);
|
|
1327
1331
|
});
|
|
1328
1332
|
}
|
|
1329
1333
|
disposeRegisteredKernels(t) {
|
|
1330
|
-
|
|
1334
|
+
Wt(t).forEach((s) => {
|
|
1331
1335
|
s.disposeFunc != null && s.disposeFunc(this.registry[t]);
|
|
1332
1336
|
});
|
|
1333
1337
|
}
|
|
@@ -1343,13 +1347,13 @@ class tt {
|
|
|
1343
1347
|
throw new Error(`Cannot initialize backend ${t}, no registration found.`);
|
|
1344
1348
|
try {
|
|
1345
1349
|
const s = e.factory();
|
|
1346
|
-
if (s && !(s instanceof
|
|
1347
|
-
const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null,
|
|
1350
|
+
if (s && !(s instanceof Be) && typeof s.then == "function") {
|
|
1351
|
+
const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null, O(`Initialization of backend ${t} failed`), O(o.stack || o.message)), !1));
|
|
1348
1352
|
return this.pendingBackendInit = i, { success: i, asyncInit: !0 };
|
|
1349
1353
|
} else
|
|
1350
1354
|
return this.registry[t] = s, { success: !0, asyncInit: !1 };
|
|
1351
1355
|
} catch (s) {
|
|
1352
|
-
return
|
|
1356
|
+
return O(`Initialization of backend ${t} failed`), O(s.stack || s.message), { success: !1, asyncInit: !1 };
|
|
1353
1357
|
}
|
|
1354
1358
|
}
|
|
1355
1359
|
removeBackend(t) {
|
|
@@ -1413,11 +1417,11 @@ class tt {
|
|
|
1413
1417
|
* execution.
|
|
1414
1418
|
*/
|
|
1415
1419
|
clone(t) {
|
|
1416
|
-
const e = g.runKernel(
|
|
1420
|
+
const e = g.runKernel(re, { x: t }), s = { x: t }, r = (o) => ({
|
|
1417
1421
|
x: () => {
|
|
1418
1422
|
const a = "float32", c = { x: o }, l = { dtype: a };
|
|
1419
1423
|
return g.runKernel(
|
|
1420
|
-
|
|
1424
|
+
se,
|
|
1421
1425
|
c,
|
|
1422
1426
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
1423
1427
|
l
|
|
@@ -1440,7 +1444,7 @@ class tt {
|
|
|
1440
1444
|
* tensors are not visible to the user.
|
|
1441
1445
|
*/
|
|
1442
1446
|
runKernel(t, e, s) {
|
|
1443
|
-
if (this.backendName == null && this.backend, !(
|
|
1447
|
+
if (this.backendName == null && this.backend, !(Gt(t, this.backendName) != null))
|
|
1444
1448
|
throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);
|
|
1445
1449
|
return this.runKernelFunc({ kernelName: t, inputs: e, attrs: s });
|
|
1446
1450
|
}
|
|
@@ -1473,18 +1477,18 @@ class tt {
|
|
|
1473
1477
|
if (bt(t)) {
|
|
1474
1478
|
const { kernelName: b, inputs: d, attrs: k } = t;
|
|
1475
1479
|
this.backendName == null && this.backend;
|
|
1476
|
-
const T =
|
|
1480
|
+
const T = Gt(b, this.backendName);
|
|
1477
1481
|
y(T != null, () => `Cannot find registered kernel '${b}' for backend '${this.backendName}'`), a = () => {
|
|
1478
1482
|
const nt = this.backend.numDataIds();
|
|
1479
1483
|
c = T.kernelFunc({ inputs: d, attrs: k, backend: this.backend });
|
|
1480
|
-
const
|
|
1481
|
-
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(b, nt,
|
|
1482
|
-
const
|
|
1484
|
+
const H = Array.isArray(c) ? c : [c];
|
|
1485
|
+
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(b, nt, H);
|
|
1486
|
+
const J = H.map((st) => st.rank != null ? st : this.makeTensorFromTensorInfo(st));
|
|
1483
1487
|
if (r) {
|
|
1484
|
-
const st = this.getTensorsForGradient(b, d,
|
|
1488
|
+
const st = this.getTensorsForGradient(b, d, J);
|
|
1485
1489
|
s = this.saveTensorsForBackwardMode(st);
|
|
1486
1490
|
}
|
|
1487
|
-
return
|
|
1491
|
+
return J;
|
|
1488
1492
|
};
|
|
1489
1493
|
} else {
|
|
1490
1494
|
const { forwardFunc: b } = t, d = (k) => {
|
|
@@ -1534,7 +1538,7 @@ class tt {
|
|
|
1534
1538
|
* @param outputs an array of output tensors from forward mode of kernel.
|
|
1535
1539
|
*/
|
|
1536
1540
|
getTensorsForGradient(t, e, s) {
|
|
1537
|
-
const r =
|
|
1541
|
+
const r = zt(t);
|
|
1538
1542
|
if (r != null) {
|
|
1539
1543
|
const i = r.inputsToSave || [], o = r.outputsToSave || [];
|
|
1540
1544
|
let a;
|
|
@@ -1554,10 +1558,10 @@ class tt {
|
|
|
1554
1558
|
throw new Error("Values passed to engine.makeTensor() are null");
|
|
1555
1559
|
s = s || "float32", r = r || this.backend;
|
|
1556
1560
|
let i = t;
|
|
1557
|
-
s === "string" &&
|
|
1561
|
+
s === "string" && xt(t[0]) && (i = t.map((c) => Ze(c)));
|
|
1558
1562
|
const o = r.write(i, e, s), a = new x(e, s, o, this.nextTensorId());
|
|
1559
1563
|
if (this.trackTensor(a, r), s === "string") {
|
|
1560
|
-
const c = this.state.tensorInfo.get(o), l =
|
|
1564
|
+
const c = this.state.tensorInfo.get(o), l = Re(i);
|
|
1561
1565
|
this.state.numBytes += l - c.bytes, c.bytes = l;
|
|
1562
1566
|
}
|
|
1563
1567
|
return a;
|
|
@@ -1645,10 +1649,10 @@ class tt {
|
|
|
1645
1649
|
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
|
1646
1650
|
}
|
|
1647
1651
|
addTapeNode(t, e, s, r, i, o) {
|
|
1648
|
-
const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c =
|
|
1652
|
+
const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c = zt(t);
|
|
1649
1653
|
c != null && (r = c.gradFunc), r != null && (a.gradient = (l) => (l = l.map((u, h) => {
|
|
1650
1654
|
if (u == null) {
|
|
1651
|
-
const f = s[h], m =
|
|
1655
|
+
const f = s[h], m = Zt(f.size, f.dtype);
|
|
1652
1656
|
return this.makeTensor(m, f.shape, f.dtype);
|
|
1653
1657
|
}
|
|
1654
1658
|
return u;
|
|
@@ -1680,7 +1684,7 @@ class tt {
|
|
|
1680
1684
|
* as scope() without the need for a function closure.
|
|
1681
1685
|
*/
|
|
1682
1686
|
endScope(t) {
|
|
1683
|
-
const e =
|
|
1687
|
+
const e = fe(t), s = new Set(e.map((i) => i.id));
|
|
1684
1688
|
for (let i = 0; i < this.state.activeScope.track.length; i++) {
|
|
1685
1689
|
const o = this.state.activeScope.track[i];
|
|
1686
1690
|
!o.kept && !s.has(o.id) && o.dispose();
|
|
@@ -1774,7 +1778,7 @@ class tt {
|
|
|
1774
1778
|
* registered backend factories.
|
|
1775
1779
|
*/
|
|
1776
1780
|
reset() {
|
|
1777
|
-
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new
|
|
1781
|
+
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new qt();
|
|
1778
1782
|
for (const t in this.registry)
|
|
1779
1783
|
this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t];
|
|
1780
1784
|
this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
|
|
@@ -1783,21 +1787,21 @@ class tt {
|
|
|
1783
1787
|
tt.nextTensorId = 0;
|
|
1784
1788
|
tt.nextVariableId = 0;
|
|
1785
1789
|
function gn(n) {
|
|
1786
|
-
const t =
|
|
1790
|
+
const t = De(G(n), "float32");
|
|
1787
1791
|
return g.makeTensor(t, n, "float32");
|
|
1788
1792
|
}
|
|
1789
|
-
function
|
|
1790
|
-
const n =
|
|
1793
|
+
function ge() {
|
|
1794
|
+
const n = ee();
|
|
1791
1795
|
if (n._tfengine == null) {
|
|
1792
|
-
const t = new
|
|
1796
|
+
const t = new Ce(n);
|
|
1793
1797
|
n._tfengine = new tt(t);
|
|
1794
1798
|
}
|
|
1795
|
-
return
|
|
1799
|
+
return Le(n._tfengine.ENV), cn(() => n._tfengine), n._tfengine;
|
|
1796
1800
|
}
|
|
1797
|
-
const g =
|
|
1801
|
+
const g = ge();
|
|
1798
1802
|
function mn(n, t) {
|
|
1799
1803
|
const e = { a: n, b: t };
|
|
1800
|
-
return g.runKernel(
|
|
1804
|
+
return g.runKernel(ne, e);
|
|
1801
1805
|
}
|
|
1802
1806
|
/**
|
|
1803
1807
|
* @license
|
|
@@ -1855,19 +1859,19 @@ function yn(n, t) {
|
|
|
1855
1859
|
let e = n;
|
|
1856
1860
|
if ($(n))
|
|
1857
1861
|
return t === "string" ? [] : [n.length];
|
|
1858
|
-
if (
|
|
1862
|
+
if (ue(n)) {
|
|
1859
1863
|
const r = n.channels || "RGBA";
|
|
1860
1864
|
return [n.height, n.width * r.length];
|
|
1861
|
-
} else if (
|
|
1865
|
+
} else if (he(n))
|
|
1862
1866
|
return [n.buffer.size / (t == null ? 4 : St(t))];
|
|
1863
1867
|
if (!Array.isArray(n))
|
|
1864
1868
|
return [];
|
|
1865
1869
|
const s = [];
|
|
1866
1870
|
for (; Array.isArray(e) || $(e) && t !== "string"; )
|
|
1867
1871
|
s.push(e.length), e = e[0];
|
|
1868
|
-
return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") &&
|
|
1872
|
+
return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") && me(n, s, []), s;
|
|
1869
1873
|
}
|
|
1870
|
-
function
|
|
1874
|
+
function me(n, t, e) {
|
|
1871
1875
|
if (e = e || [], !Array.isArray(n) && !$(n)) {
|
|
1872
1876
|
y(t.length === 0, () => `Element arr[${e.join("][")}] is a primitive, but should be an array/TypedArray of ${t[0]} elements`);
|
|
1873
1877
|
return;
|
|
@@ -1875,9 +1879,9 @@ function ge(n, t, e) {
|
|
|
1875
1879
|
y(t.length > 0, () => `Element arr[${e.join("][")}] should be a primitive, but is an array of ${n.length} elements`), y(n.length === t[0], () => `Element arr[${e.join("][")}] should have ${t[0]} elements, but has ${n.length} elements`);
|
|
1876
1880
|
const s = t.slice(1);
|
|
1877
1881
|
for (let r = 0; r < n.length; ++r)
|
|
1878
|
-
|
|
1882
|
+
me(n[r], s, e.concat(r));
|
|
1879
1883
|
}
|
|
1880
|
-
function
|
|
1884
|
+
function Ht(n, t, e, s) {
|
|
1881
1885
|
if (n !== "string_or_numeric") {
|
|
1882
1886
|
if (n == null)
|
|
1883
1887
|
throw new Error("Expected dtype cannot be null.");
|
|
@@ -1886,19 +1890,19 @@ function qt(n, t, e, s) {
|
|
|
1886
1890
|
}
|
|
1887
1891
|
}
|
|
1888
1892
|
function I(n, t, e, s = "numeric") {
|
|
1889
|
-
if (n instanceof
|
|
1890
|
-
return
|
|
1893
|
+
if (n instanceof ce())
|
|
1894
|
+
return Ht(s, n.dtype, t, e), n;
|
|
1891
1895
|
let r = mt(n);
|
|
1892
|
-
if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s),
|
|
1896
|
+
if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s), Ht(s, r, t, e), n == null || !$(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string") {
|
|
1893
1897
|
const c = n == null ? "null" : n.constructor.name;
|
|
1894
1898
|
throw new Error(`Argument '${t}' passed to '${e}' must be a Tensor or TensorLike, but got '${c}'`);
|
|
1895
1899
|
}
|
|
1896
1900
|
const i = yn(n, r);
|
|
1897
1901
|
!$(n) && !Array.isArray(n) && (n = [n]);
|
|
1898
|
-
const a = r !== "string" ?
|
|
1902
|
+
const a = r !== "string" ? ae(n, r) : at(n, [], !0);
|
|
1899
1903
|
return g.makeTensor(a, i, r);
|
|
1900
1904
|
}
|
|
1901
|
-
function
|
|
1905
|
+
function Xs(n, t, e, s = "numeric") {
|
|
1902
1906
|
if (!Array.isArray(n))
|
|
1903
1907
|
throw new Error(`Argument ${t} passed to ${e} must be a \`Tensor[]\` or \`TensorLike[]\``);
|
|
1904
1908
|
return n.map((i, o) => I(i, `${t}[${o}]`, e, s));
|
|
@@ -1931,7 +1935,7 @@ function F(n) {
|
|
|
1931
1935
|
g.startScope(e);
|
|
1932
1936
|
try {
|
|
1933
1937
|
const o = s(...i);
|
|
1934
|
-
return
|
|
1938
|
+
return Ct(o) && console.error("Cannot return a Promise inside of tidy."), g.endScope(o), o;
|
|
1935
1939
|
} catch (o) {
|
|
1936
1940
|
throw g.endScope(null), o;
|
|
1937
1941
|
}
|
|
@@ -1959,7 +1963,7 @@ function wn(n, t, e, s) {
|
|
|
1959
1963
|
s = mt(n);
|
|
1960
1964
|
else if (s === "complex64")
|
|
1961
1965
|
throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");
|
|
1962
|
-
if (
|
|
1966
|
+
if (he(n) || ue(n)) {
|
|
1963
1967
|
if (s !== "float32" && s !== "int32")
|
|
1964
1968
|
throw new Error(`Creating tensor from GPU data only supports 'float32'|'int32' dtype, while the dtype is ${s}.`);
|
|
1965
1969
|
return g.backend.createTensorFromGPUData(n, t || e, s);
|
|
@@ -1967,15 +1971,15 @@ function wn(n, t, e, s) {
|
|
|
1967
1971
|
if (!$(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string")
|
|
1968
1972
|
throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");
|
|
1969
1973
|
if (t != null) {
|
|
1970
|
-
|
|
1971
|
-
const r =
|
|
1974
|
+
Dt(t);
|
|
1975
|
+
const r = G(t), i = G(e);
|
|
1972
1976
|
y(r === i, () => `Based on the provided shape, [${t}], the tensor should have ${r} values but has ${i}`);
|
|
1973
1977
|
for (let o = 0; o < e.length; ++o) {
|
|
1974
|
-
const a = e[o], c = o === e.length - 1 ? a !==
|
|
1978
|
+
const a = e[o], c = o === e.length - 1 ? a !== G(t.slice(o)) : !0;
|
|
1975
1979
|
y(e[o] === t[o] || !c, () => `Error creating a new Tensor. Inferred shape (${e}) does not match the provided shape (${t}). `);
|
|
1976
1980
|
}
|
|
1977
1981
|
}
|
|
1978
|
-
return !$(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = s !== "string" ?
|
|
1982
|
+
return !$(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = s !== "string" ? ae(n, s) : at(n, [], !0), g.makeTensor(n, t, s);
|
|
1979
1983
|
}
|
|
1980
1984
|
class lt {
|
|
1981
1985
|
/**
|
|
@@ -2061,24 +2065,24 @@ function Sn(n, t) {
|
|
|
2061
2065
|
* limitations under the License.
|
|
2062
2066
|
* =============================================================================
|
|
2063
2067
|
*/
|
|
2064
|
-
function
|
|
2068
|
+
function Ys() {
|
|
2065
2069
|
return g;
|
|
2066
2070
|
}
|
|
2067
2071
|
function E(n, t) {
|
|
2068
2072
|
return g.tidy(n, t);
|
|
2069
2073
|
}
|
|
2070
2074
|
function M(n) {
|
|
2071
|
-
|
|
2075
|
+
fe(n).forEach((e) => e.dispose());
|
|
2072
2076
|
}
|
|
2073
2077
|
function kn(n) {
|
|
2074
2078
|
return g.keep(n);
|
|
2075
2079
|
}
|
|
2076
|
-
const
|
|
2077
|
-
function
|
|
2078
|
-
return
|
|
2080
|
+
const Pt = typeof gt < "u" && (typeof Blob > "u" || typeof atob > "u" || typeof btoa > "u");
|
|
2081
|
+
function Jt(n) {
|
|
2082
|
+
return Pt ? gt.byteLength(n, "utf8") : new Blob([n]).size;
|
|
2079
2083
|
}
|
|
2080
2084
|
function In(n) {
|
|
2081
|
-
if (
|
|
2085
|
+
if (Pt)
|
|
2082
2086
|
return gt.from(n).toString("base64");
|
|
2083
2087
|
const t = new Uint8Array(n);
|
|
2084
2088
|
let e = "";
|
|
@@ -2087,7 +2091,7 @@ function In(n) {
|
|
|
2087
2091
|
return btoa(e);
|
|
2088
2092
|
}
|
|
2089
2093
|
function Tn(n) {
|
|
2090
|
-
if (
|
|
2094
|
+
if (Pt) {
|
|
2091
2095
|
const s = gt.from(n, "base64");
|
|
2092
2096
|
return s.buffer.slice(s.byteOffset, s.byteOffset + s.byteLength);
|
|
2093
2097
|
}
|
|
@@ -2096,14 +2100,14 @@ function Tn(n) {
|
|
|
2096
2100
|
e.set([t.charCodeAt(s)], s);
|
|
2097
2101
|
return e.buffer;
|
|
2098
2102
|
}
|
|
2099
|
-
function
|
|
2103
|
+
function pe(n) {
|
|
2100
2104
|
if (n.modelTopology instanceof ArrayBuffer)
|
|
2101
2105
|
throw new Error("Expected JSON model topology, received ArrayBuffer.");
|
|
2102
2106
|
return {
|
|
2103
2107
|
dateSaved: /* @__PURE__ */ new Date(),
|
|
2104
2108
|
modelTopologyType: "JSON",
|
|
2105
|
-
modelTopologyBytes: n.modelTopology == null ? 0 :
|
|
2106
|
-
weightSpecsBytes: n.weightSpecs == null ? 0 :
|
|
2109
|
+
modelTopologyBytes: n.modelTopology == null ? 0 : Jt(JSON.stringify(n.modelTopology)),
|
|
2110
|
+
weightSpecsBytes: n.weightSpecs == null ? 0 : Jt(JSON.stringify(n.weightSpecs)),
|
|
2107
2111
|
weightDataBytes: n.weightData == null ? 0 : new lt(n.weightData).byteLength
|
|
2108
2112
|
};
|
|
2109
2113
|
}
|
|
@@ -2194,8 +2198,8 @@ class A {
|
|
|
2194
2198
|
* limitations under the License.
|
|
2195
2199
|
* =============================================================================
|
|
2196
2200
|
*/
|
|
2197
|
-
const
|
|
2198
|
-
function
|
|
2201
|
+
const vt = "tensorflowjs", Mt = 1, U = "models_store", P = "model_info_store";
|
|
2202
|
+
function ye() {
|
|
2199
2203
|
if (!S().getBool("IS_BROWSER"))
|
|
2200
2204
|
throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");
|
|
2201
2205
|
const n = typeof window > "u" ? self : window, t = n.indexedDB || n.mozIndexedDB || n.webkitIndexedDB || n.msIndexedDB || n.shimIndexedDB;
|
|
@@ -2203,13 +2207,13 @@ function pe() {
|
|
|
2203
2207
|
throw new Error("The current browser does not appear to support IndexedDB.");
|
|
2204
2208
|
return t;
|
|
2205
2209
|
}
|
|
2206
|
-
function
|
|
2210
|
+
function Ft(n) {
|
|
2207
2211
|
const t = n.result;
|
|
2208
|
-
t.createObjectStore(
|
|
2212
|
+
t.createObjectStore(U, { keyPath: "modelPath" }), t.createObjectStore(P, { keyPath: "modelPath" });
|
|
2209
2213
|
}
|
|
2210
|
-
class
|
|
2214
|
+
class W {
|
|
2211
2215
|
constructor(t) {
|
|
2212
|
-
if (this.indexedDB =
|
|
2216
|
+
if (this.indexedDB = ye(), t == null || !t)
|
|
2213
2217
|
throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
|
|
2214
2218
|
this.modelPath = t;
|
|
2215
2219
|
}
|
|
@@ -2237,11 +2241,11 @@ class z {
|
|
|
2237
2241
|
*/
|
|
2238
2242
|
databaseAction(t, e) {
|
|
2239
2243
|
return new Promise((s, r) => {
|
|
2240
|
-
const i = this.indexedDB.open(
|
|
2241
|
-
i.onupgradeneeded = () =>
|
|
2244
|
+
const i = this.indexedDB.open(vt, Mt);
|
|
2245
|
+
i.onupgradeneeded = () => Ft(i), i.onsuccess = () => {
|
|
2242
2246
|
const o = i.result;
|
|
2243
2247
|
if (e == null) {
|
|
2244
|
-
const a = o.transaction(
|
|
2248
|
+
const a = o.transaction(U, "readonly"), l = a.objectStore(U).get(this.modelPath);
|
|
2245
2249
|
l.onsuccess = () => {
|
|
2246
2250
|
if (l.result == null)
|
|
2247
2251
|
return o.close(), r(new Error(`Cannot find model with path '${this.modelPath}' in IndexedDB.`));
|
|
@@ -2249,7 +2253,7 @@ class z {
|
|
|
2249
2253
|
}, l.onerror = (u) => (o.close(), r(l.error)), a.oncomplete = () => o.close();
|
|
2250
2254
|
} else {
|
|
2251
2255
|
e.weightData = lt.join(e.weightData);
|
|
2252
|
-
const a =
|
|
2256
|
+
const a = pe(e), c = o.transaction(P, "readwrite");
|
|
2253
2257
|
let l = c.objectStore(P), u;
|
|
2254
2258
|
try {
|
|
2255
2259
|
u = l.put({ modelPath: this.modelPath, modelArtifactsInfo: a });
|
|
@@ -2258,8 +2262,8 @@ class z {
|
|
|
2258
2262
|
}
|
|
2259
2263
|
let h;
|
|
2260
2264
|
u.onsuccess = () => {
|
|
2261
|
-
h = o.transaction(
|
|
2262
|
-
const f = h.objectStore(
|
|
2265
|
+
h = o.transaction(U, "readwrite");
|
|
2266
|
+
const f = h.objectStore(U);
|
|
2263
2267
|
let m;
|
|
2264
2268
|
try {
|
|
2265
2269
|
m = f.put({
|
|
@@ -2283,24 +2287,24 @@ class z {
|
|
|
2283
2287
|
});
|
|
2284
2288
|
}
|
|
2285
2289
|
}
|
|
2286
|
-
|
|
2287
|
-
const
|
|
2288
|
-
A.registerSaveRouter(
|
|
2289
|
-
A.registerLoadRouter(
|
|
2290
|
+
W.URL_SCHEME = "indexeddb://";
|
|
2291
|
+
const be = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(W.URL_SCHEME) ? En(n.slice(W.URL_SCHEME.length)) : null;
|
|
2292
|
+
A.registerSaveRouter(be);
|
|
2293
|
+
A.registerLoadRouter(be);
|
|
2290
2294
|
function En(n) {
|
|
2291
|
-
return new
|
|
2295
|
+
return new W(n);
|
|
2292
2296
|
}
|
|
2293
2297
|
function An(n) {
|
|
2294
|
-
return n.startsWith(
|
|
2298
|
+
return n.startsWith(W.URL_SCHEME) ? n.slice(W.URL_SCHEME.length) : n;
|
|
2295
2299
|
}
|
|
2296
2300
|
class Bn {
|
|
2297
2301
|
constructor() {
|
|
2298
|
-
this.indexedDB =
|
|
2302
|
+
this.indexedDB = ye();
|
|
2299
2303
|
}
|
|
2300
2304
|
async listModels() {
|
|
2301
2305
|
return new Promise((t, e) => {
|
|
2302
|
-
const s = this.indexedDB.open(
|
|
2303
|
-
s.onupgradeneeded = () =>
|
|
2306
|
+
const s = this.indexedDB.open(vt, Mt);
|
|
2307
|
+
s.onupgradeneeded = () => Ft(s), s.onsuccess = () => {
|
|
2304
2308
|
const r = s.result, i = r.transaction(P, "readonly"), a = i.objectStore(P).getAll();
|
|
2305
2309
|
a.onsuccess = () => {
|
|
2306
2310
|
const c = {};
|
|
@@ -2313,8 +2317,8 @@ class Bn {
|
|
|
2313
2317
|
}
|
|
2314
2318
|
async removeModel(t) {
|
|
2315
2319
|
return t = An(t), new Promise((e, s) => {
|
|
2316
|
-
const r = this.indexedDB.open(
|
|
2317
|
-
r.onupgradeneeded = () =>
|
|
2320
|
+
const r = this.indexedDB.open(vt, Mt);
|
|
2321
|
+
r.onupgradeneeded = () => Ft(r), r.onsuccess = () => {
|
|
2318
2322
|
const i = r.result, o = i.transaction(P, "readwrite"), a = o.objectStore(P), c = a.get(t);
|
|
2319
2323
|
let l;
|
|
2320
2324
|
c.onsuccess = () => {
|
|
@@ -2322,8 +2326,8 @@ class Bn {
|
|
|
2322
2326
|
return i.close(), s(new Error(`Cannot find model with path '${t}' in IndexedDB.`));
|
|
2323
2327
|
{
|
|
2324
2328
|
const u = a.delete(t), h = () => {
|
|
2325
|
-
l = i.transaction(
|
|
2326
|
-
const m = l.objectStore(
|
|
2329
|
+
l = i.transaction(U, "readwrite");
|
|
2330
|
+
const m = l.objectStore(U).delete(t);
|
|
2327
2331
|
m.onsuccess = () => e(c.result.modelArtifactsInfo), m.onerror = (b) => s(c.error);
|
|
2328
2332
|
};
|
|
2329
2333
|
u.onsuccess = h, u.onerror = (f) => (h(), i.close(), s(c.error));
|
|
@@ -2351,17 +2355,17 @@ class Bn {
|
|
|
2351
2355
|
* limitations under the License.
|
|
2352
2356
|
* =============================================================================
|
|
2353
2357
|
*/
|
|
2354
|
-
const _ = "/", Y = "tensorflowjs_models",
|
|
2355
|
-
function
|
|
2358
|
+
const _ = "/", Y = "tensorflowjs_models", we = "info", vn = "model_topology", Mn = "weight_specs", Fn = "weight_data", $n = "model_metadata";
|
|
2359
|
+
function Se(n) {
|
|
2356
2360
|
return {
|
|
2357
|
-
info: [Y, n,
|
|
2361
|
+
info: [Y, n, we].join(_),
|
|
2358
2362
|
topology: [Y, n, vn].join(_),
|
|
2359
2363
|
weightSpecs: [Y, n, Mn].join(_),
|
|
2360
2364
|
weightData: [Y, n, Fn].join(_),
|
|
2361
2365
|
modelMetadata: [Y, n, $n].join(_)
|
|
2362
2366
|
};
|
|
2363
2367
|
}
|
|
2364
|
-
function
|
|
2368
|
+
function ke(n) {
|
|
2365
2369
|
for (const t of Object.values(n))
|
|
2366
2370
|
window.localStorage.removeItem(t);
|
|
2367
2371
|
}
|
|
@@ -2372,15 +2376,15 @@ function Rn(n) {
|
|
|
2372
2376
|
return t.slice(1, t.length - 1).join(_);
|
|
2373
2377
|
}
|
|
2374
2378
|
function xn(n) {
|
|
2375
|
-
return n.startsWith(
|
|
2379
|
+
return n.startsWith(j.URL_SCHEME) ? n.slice(j.URL_SCHEME.length) : n;
|
|
2376
2380
|
}
|
|
2377
|
-
class
|
|
2381
|
+
class j {
|
|
2378
2382
|
constructor(t) {
|
|
2379
2383
|
if (!S().getBool("IS_BROWSER") || typeof window > "u" || typeof window.localStorage > "u")
|
|
2380
2384
|
throw new Error("The current environment does not support local storage.");
|
|
2381
2385
|
if (this.LS = window.localStorage, t == null || !t)
|
|
2382
2386
|
throw new Error("For local storage, modelPath must not be null, undefined or empty.");
|
|
2383
|
-
this.modelPath = t, this.keys =
|
|
2387
|
+
this.modelPath = t, this.keys = Se(this.modelPath);
|
|
2384
2388
|
}
|
|
2385
2389
|
/**
|
|
2386
2390
|
* Save model artifacts to browser local storage.
|
|
@@ -2395,7 +2399,7 @@ class W {
|
|
|
2395
2399
|
if (t.modelTopology instanceof ArrayBuffer)
|
|
2396
2400
|
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
|
|
2397
2401
|
{
|
|
2398
|
-
const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r =
|
|
2402
|
+
const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r = pe(t), i = lt.join(t.weightData);
|
|
2399
2403
|
try {
|
|
2400
2404
|
this.LS.setItem(this.keys.info, JSON.stringify(r)), this.LS.setItem(this.keys.topology, e), this.LS.setItem(this.keys.weightSpecs, s), this.LS.setItem(this.keys.weightData, In(i));
|
|
2401
2405
|
const o = {
|
|
@@ -2410,7 +2414,7 @@ class W {
|
|
|
2410
2414
|
};
|
|
2411
2415
|
return this.LS.setItem(this.keys.modelMetadata, JSON.stringify(o)), { modelArtifactsInfo: r };
|
|
2412
2416
|
} catch {
|
|
2413
|
-
throw
|
|
2417
|
+
throw ke(this.keys), new Error(`Failed to save model '${this.modelPath}' to local storage: size quota being exceeded is a possible cause of this failure: modelTopologyBytes=${r.modelTopologyBytes}, weightSpecsBytes=${r.weightSpecsBytes}, weightDataBytes=${r.weightDataBytes}.`);
|
|
2414
2418
|
}
|
|
2415
2419
|
}
|
|
2416
2420
|
}
|
|
@@ -2447,19 +2451,19 @@ class W {
|
|
|
2447
2451
|
return e.weightData = Tn(o), e;
|
|
2448
2452
|
}
|
|
2449
2453
|
}
|
|
2450
|
-
|
|
2451
|
-
const
|
|
2452
|
-
A.registerSaveRouter(
|
|
2453
|
-
A.registerLoadRouter(
|
|
2454
|
+
j.URL_SCHEME = "localstorage://";
|
|
2455
|
+
const Ie = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(j.URL_SCHEME) ? Nn(n.slice(j.URL_SCHEME.length)) : null;
|
|
2456
|
+
A.registerSaveRouter(Ie);
|
|
2457
|
+
A.registerLoadRouter(Ie);
|
|
2454
2458
|
function Nn(n) {
|
|
2455
|
-
return new
|
|
2459
|
+
return new j(n);
|
|
2456
2460
|
}
|
|
2457
2461
|
class Dn {
|
|
2458
2462
|
constructor() {
|
|
2459
2463
|
y(S().getBool("IS_BROWSER"), () => "Current environment is not a web browser"), y(typeof window > "u" || typeof window.localStorage < "u", () => "Current browser does not appear to support localStorage"), this.LS = window.localStorage;
|
|
2460
2464
|
}
|
|
2461
2465
|
async listModels() {
|
|
2462
|
-
const t = {}, e = Y + _, s = _ +
|
|
2466
|
+
const t = {}, e = Y + _, s = _ + we;
|
|
2463
2467
|
for (let r = 0; r < this.LS.length; ++r) {
|
|
2464
2468
|
const i = this.LS.key(r);
|
|
2465
2469
|
if (i.startsWith(e) && i.endsWith(s)) {
|
|
@@ -2471,11 +2475,11 @@ class Dn {
|
|
|
2471
2475
|
}
|
|
2472
2476
|
async removeModel(t) {
|
|
2473
2477
|
t = xn(t);
|
|
2474
|
-
const e =
|
|
2478
|
+
const e = Se(t);
|
|
2475
2479
|
if (this.LS.getItem(e.info) == null)
|
|
2476
2480
|
throw new Error(`Cannot find model at path '${t}'`);
|
|
2477
2481
|
const s = JSON.parse(this.LS.getItem(e.info));
|
|
2478
|
-
return
|
|
2482
|
+
return ke(e), s;
|
|
2479
2483
|
}
|
|
2480
2484
|
}
|
|
2481
2485
|
/**
|
|
@@ -2494,7 +2498,7 @@ class Dn {
|
|
|
2494
2498
|
* limitations under the License.
|
|
2495
2499
|
* =============================================================================
|
|
2496
2500
|
*/
|
|
2497
|
-
const
|
|
2501
|
+
const Xt = "://";
|
|
2498
2502
|
class N {
|
|
2499
2503
|
constructor() {
|
|
2500
2504
|
this.managers = {};
|
|
@@ -2509,7 +2513,7 @@ class N {
|
|
|
2509
2513
|
* of `IOHandler` with the `save` method defined or `null`.
|
|
2510
2514
|
*/
|
|
2511
2515
|
static registerManager(t, e) {
|
|
2512
|
-
y(t != null, () => "scheme must not be undefined or null."), t.endsWith(
|
|
2516
|
+
y(t != null, () => "scheme must not be undefined or null."), t.endsWith(Xt) && (t = t.slice(0, t.indexOf(Xt))), y(t.length > 0, () => "scheme must not be an empty string.");
|
|
2513
2517
|
const s = N.getInstance();
|
|
2514
2518
|
y(s.managers[t] == null, () => `A model store manager is already registered for scheme '${t}'.`), s.managers[t] = e;
|
|
2515
2519
|
}
|
|
@@ -2577,17 +2581,17 @@ class Cn {
|
|
|
2577
2581
|
}, !0));
|
|
2578
2582
|
}
|
|
2579
2583
|
isTypedArray(t) {
|
|
2580
|
-
return
|
|
2584
|
+
return oe(t);
|
|
2581
2585
|
}
|
|
2582
2586
|
}
|
|
2583
2587
|
if (S().get("IS_BROWSER")) {
|
|
2584
2588
|
S().setPlatform("browser", new Cn());
|
|
2585
2589
|
try {
|
|
2586
|
-
N.registerManager(
|
|
2590
|
+
N.registerManager(j.URL_SCHEME, new Dn());
|
|
2587
2591
|
} catch {
|
|
2588
2592
|
}
|
|
2589
2593
|
try {
|
|
2590
|
-
N.registerManager(
|
|
2594
|
+
N.registerManager(W.URL_SCHEME, new Bn());
|
|
2591
2595
|
} catch {
|
|
2592
2596
|
}
|
|
2593
2597
|
}
|
|
@@ -2637,7 +2641,7 @@ S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new Pn()
|
|
|
2637
2641
|
* =============================================================================
|
|
2638
2642
|
*/
|
|
2639
2643
|
function On(n, t = "float32", e) {
|
|
2640
|
-
return t = t || "float32",
|
|
2644
|
+
return t = t || "float32", Dt(n), new ln(n, t, e);
|
|
2641
2645
|
}
|
|
2642
2646
|
/**
|
|
2643
2647
|
* @license
|
|
@@ -2657,14 +2661,14 @@ function On(n, t = "float32", e) {
|
|
|
2657
2661
|
*/
|
|
2658
2662
|
function Ln(n, t) {
|
|
2659
2663
|
const e = I(n, "x", "cast");
|
|
2660
|
-
if (
|
|
2664
|
+
if (!$e(t))
|
|
2661
2665
|
throw new Error(`Failed to cast to unknown dtype ${t}`);
|
|
2662
2666
|
if (t === "string" && e.dtype !== "string" || t !== "string" && e.dtype === "string")
|
|
2663
2667
|
throw new Error("Only strings can be casted to strings");
|
|
2664
2668
|
const s = { x: e }, r = { dtype: t };
|
|
2665
|
-
return g.runKernel(
|
|
2669
|
+
return g.runKernel(se, s, r);
|
|
2666
2670
|
}
|
|
2667
|
-
const
|
|
2671
|
+
const $t = /* @__PURE__ */ F({ cast_: Ln });
|
|
2668
2672
|
/**
|
|
2669
2673
|
* @license
|
|
2670
2674
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2683,7 +2687,7 @@ const Ft = /* @__PURE__ */ F({ cast_: Ln });
|
|
|
2683
2687
|
*/
|
|
2684
2688
|
function Un(n) {
|
|
2685
2689
|
const e = { x: I(n, "x", "clone", "string_or_numeric") };
|
|
2686
|
-
return g.runKernel(
|
|
2690
|
+
return g.runKernel(re, e);
|
|
2687
2691
|
}
|
|
2688
2692
|
const Gn = /* @__PURE__ */ F({ clone_: Un });
|
|
2689
2693
|
/**
|
|
@@ -2721,10 +2725,10 @@ function zn(n, t = !1) {
|
|
|
2721
2725
|
* limitations under the License.
|
|
2722
2726
|
* =============================================================================
|
|
2723
2727
|
*/
|
|
2724
|
-
|
|
2728
|
+
ge();
|
|
2725
2729
|
const Wn = {
|
|
2726
2730
|
buffer: On,
|
|
2727
|
-
cast:
|
|
2731
|
+
cast: $t,
|
|
2728
2732
|
clone: Gn,
|
|
2729
2733
|
print: zn
|
|
2730
2734
|
};
|
|
@@ -2747,9 +2751,9 @@ un(Wn);
|
|
|
2747
2751
|
*/
|
|
2748
2752
|
function jn(n, t) {
|
|
2749
2753
|
let e = I(n, "a", "add"), s = I(t, "b", "add");
|
|
2750
|
-
[e, s] =
|
|
2754
|
+
[e, s] = V(e, s);
|
|
2751
2755
|
const r = { a: e, b: s };
|
|
2752
|
-
return g.runKernel(
|
|
2756
|
+
return g.runKernel(ne, r);
|
|
2753
2757
|
}
|
|
2754
2758
|
const w = /* @__PURE__ */ F({ add_: jn });
|
|
2755
2759
|
/**
|
|
@@ -2770,9 +2774,9 @@ const w = /* @__PURE__ */ F({ add_: jn });
|
|
|
2770
2774
|
*/
|
|
2771
2775
|
function Kn(n, t) {
|
|
2772
2776
|
let e = I(n, "a", "floorDiv"), s = I(t, "b", "floorDiv");
|
|
2773
|
-
[e, s] =
|
|
2777
|
+
[e, s] = V(e, s);
|
|
2774
2778
|
const r = { a: e, b: s };
|
|
2775
|
-
return g.runKernel(
|
|
2779
|
+
return g.runKernel(Ke, r);
|
|
2776
2780
|
}
|
|
2777
2781
|
const Vn = /* @__PURE__ */ F({ floorDiv_: Kn });
|
|
2778
2782
|
/**
|
|
@@ -2793,10 +2797,10 @@ const Vn = /* @__PURE__ */ F({ floorDiv_: Kn });
|
|
|
2793
2797
|
*/
|
|
2794
2798
|
function qn(n, t) {
|
|
2795
2799
|
let e = I(n, "a", "div"), s = I(t, "b", "div");
|
|
2796
|
-
if ([e, s] =
|
|
2800
|
+
if ([e, s] = V(e, s), e.dtype === "int32" && s.dtype === "int32")
|
|
2797
2801
|
return Vn(e, s);
|
|
2798
2802
|
const r = { a: e, b: s }, i = {};
|
|
2799
|
-
return g.runKernel(
|
|
2803
|
+
return g.runKernel(We, r, i);
|
|
2800
2804
|
}
|
|
2801
2805
|
const D = /* @__PURE__ */ F({ div_: qn });
|
|
2802
2806
|
/**
|
|
@@ -2817,9 +2821,9 @@ const D = /* @__PURE__ */ F({ div_: qn });
|
|
|
2817
2821
|
*/
|
|
2818
2822
|
function Hn(n, t) {
|
|
2819
2823
|
let e = I(n, "a", "mul"), s = I(t, "b", "mul");
|
|
2820
|
-
[e, s] =
|
|
2824
|
+
[e, s] = V(e, s);
|
|
2821
2825
|
const r = { a: e, b: s };
|
|
2822
|
-
return g.runKernel(
|
|
2826
|
+
return g.runKernel(qe, r);
|
|
2823
2827
|
}
|
|
2824
2828
|
const p = /* @__PURE__ */ F({ mul_: Hn });
|
|
2825
2829
|
/**
|
|
@@ -2842,10 +2846,10 @@ function Jn(n) {
|
|
|
2842
2846
|
const t = I(n, "x", "abs");
|
|
2843
2847
|
if (t.dtype === "complex64") {
|
|
2844
2848
|
const e = { x: t };
|
|
2845
|
-
return g.runKernel(
|
|
2849
|
+
return g.runKernel(ze, e);
|
|
2846
2850
|
} else {
|
|
2847
2851
|
const e = { x: t };
|
|
2848
|
-
return g.runKernel(
|
|
2852
|
+
return g.runKernel(Ge, e);
|
|
2849
2853
|
}
|
|
2850
2854
|
}
|
|
2851
2855
|
const Xn = /* @__PURE__ */ F({ abs_: Jn });
|
|
@@ -2866,9 +2870,9 @@ const Xn = /* @__PURE__ */ F({ abs_: Jn });
|
|
|
2866
2870
|
* =============================================================================
|
|
2867
2871
|
*/
|
|
2868
2872
|
function Yn(n, t, e) {
|
|
2869
|
-
|
|
2873
|
+
Dt(n), e = e || mt(t);
|
|
2870
2874
|
const s = { shape: n, value: t, dtype: e };
|
|
2871
|
-
return g.runKernel(
|
|
2875
|
+
return g.runKernel(je, {}, s);
|
|
2872
2876
|
}
|
|
2873
2877
|
/**
|
|
2874
2878
|
* @license
|
|
@@ -2886,7 +2890,7 @@ function Yn(n, t, e) {
|
|
|
2886
2890
|
* limitations under the License.
|
|
2887
2891
|
* =============================================================================
|
|
2888
2892
|
*/
|
|
2889
|
-
function
|
|
2893
|
+
function Qs(n, t) {
|
|
2890
2894
|
const e = [];
|
|
2891
2895
|
for (let s = 0; s < t.length; s++) {
|
|
2892
2896
|
const r = n[n.length - s - 1], i = t.length - s - 1, o = t[i];
|
|
@@ -2930,7 +2934,7 @@ function Qn(n, t) {
|
|
|
2930
2934
|
*/
|
|
2931
2935
|
function Zn(n) {
|
|
2932
2936
|
const e = { x: I(n, "x", "zerosLike") };
|
|
2933
|
-
return g.runKernel(
|
|
2937
|
+
return g.runKernel(Ye, e);
|
|
2934
2938
|
}
|
|
2935
2939
|
const C = /* @__PURE__ */ F({ zerosLike_: Zn });
|
|
2936
2940
|
/**
|
|
@@ -2951,11 +2955,11 @@ const C = /* @__PURE__ */ F({ zerosLike_: Zn });
|
|
|
2951
2955
|
*/
|
|
2952
2956
|
function ts(n, t) {
|
|
2953
2957
|
let e = I(n, "base", "pow"), s = I(t, "exp", "pow");
|
|
2954
|
-
[e, s] =
|
|
2958
|
+
[e, s] = V(e, s);
|
|
2955
2959
|
const r = { a: e, b: s };
|
|
2956
|
-
return g.runKernel(
|
|
2960
|
+
return g.runKernel(He, r);
|
|
2957
2961
|
}
|
|
2958
|
-
const
|
|
2962
|
+
const Yt = /* @__PURE__ */ F({ pow_: ts });
|
|
2959
2963
|
/**
|
|
2960
2964
|
* @license
|
|
2961
2965
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -2972,7 +2976,7 @@ const Xt = /* @__PURE__ */ F({ pow_: ts });
|
|
|
2972
2976
|
* limitations under the License.
|
|
2973
2977
|
* =============================================================================
|
|
2974
2978
|
*/
|
|
2975
|
-
function
|
|
2979
|
+
function K(n, t) {
|
|
2976
2980
|
if (($(n) && t !== "string" || Array.isArray(n)) && t !== "complex64")
|
|
2977
2981
|
throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");
|
|
2978
2982
|
if (t === "string" && $(n) && !(n instanceof Uint8Array))
|
|
@@ -2997,7 +3001,7 @@ function j(n, t) {
|
|
|
2997
3001
|
*/
|
|
2998
3002
|
function es(n) {
|
|
2999
3003
|
const e = { x: I(n, "x", "sqrt", "float32") };
|
|
3000
|
-
return g.runKernel(
|
|
3004
|
+
return g.runKernel(Je, e);
|
|
3001
3005
|
}
|
|
3002
3006
|
const et = /* @__PURE__ */ F({ sqrt_: es });
|
|
3003
3007
|
/**
|
|
@@ -3020,7 +3024,7 @@ function ns(n) {
|
|
|
3020
3024
|
const t = I(n, "x", "square"), e = {};
|
|
3021
3025
|
return g.runKernel("Square", { x: t }, e);
|
|
3022
3026
|
}
|
|
3023
|
-
const
|
|
3027
|
+
const z = /* @__PURE__ */ F({ square_: ns });
|
|
3024
3028
|
/**
|
|
3025
3029
|
* @license
|
|
3026
3030
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -3054,7 +3058,7 @@ function ss(n, t) {
|
|
|
3054
3058
|
a[u] != null && (c[l.name] = a[u]);
|
|
3055
3059
|
}), s?.forEach((l) => c[l.name] = null), { value: o, grads: c };
|
|
3056
3060
|
}
|
|
3057
|
-
function
|
|
3061
|
+
function Zs(n) {
|
|
3058
3062
|
return g.customGrad(n);
|
|
3059
3063
|
}
|
|
3060
3064
|
/**
|
|
@@ -3075,9 +3079,9 @@ function Qs(n) {
|
|
|
3075
3079
|
*/
|
|
3076
3080
|
function rs(n, t) {
|
|
3077
3081
|
let e = I(n, "a", "sub"), s = I(t, "b", "sub");
|
|
3078
|
-
[e, s] =
|
|
3082
|
+
[e, s] = V(e, s);
|
|
3079
3083
|
const r = { a: e, b: s };
|
|
3080
|
-
return g.runKernel(
|
|
3084
|
+
return g.runKernel(Xe, r);
|
|
3081
3085
|
}
|
|
3082
3086
|
const Z = /* @__PURE__ */ F({ sub_: rs });
|
|
3083
3087
|
/**
|
|
@@ -3098,9 +3102,9 @@ const Z = /* @__PURE__ */ F({ sub_: rs });
|
|
|
3098
3102
|
*/
|
|
3099
3103
|
function is(n, t) {
|
|
3100
3104
|
let e = I(n, "a", "maximum"), s = I(t, "b", "maximum");
|
|
3101
|
-
[e, s] =
|
|
3105
|
+
[e, s] = V(e, s), e.dtype === "bool" && (e = $t(e, "int32"), s = $t(s, "int32")), Qn(e.shape, s.shape);
|
|
3102
3106
|
const r = { a: e, b: s };
|
|
3103
|
-
return g.runKernel(
|
|
3107
|
+
return g.runKernel(Ve, r);
|
|
3104
3108
|
}
|
|
3105
3109
|
const os = /* @__PURE__ */ F({ maximum_: is });
|
|
3106
3110
|
/**
|
|
@@ -3148,7 +3152,7 @@ class cs {
|
|
|
3148
3152
|
return new t(e);
|
|
3149
3153
|
}
|
|
3150
3154
|
}
|
|
3151
|
-
class
|
|
3155
|
+
class L {
|
|
3152
3156
|
constructor() {
|
|
3153
3157
|
this.classNameMap = {};
|
|
3154
3158
|
}
|
|
@@ -3156,19 +3160,19 @@ class O {
|
|
|
3156
3160
|
* Returns the singleton instance of the map.
|
|
3157
3161
|
*/
|
|
3158
3162
|
static getMap() {
|
|
3159
|
-
return
|
|
3163
|
+
return L.instance == null && (L.instance = new L()), L.instance;
|
|
3160
3164
|
}
|
|
3161
3165
|
/**
|
|
3162
3166
|
* Registers the class as serializable.
|
|
3163
3167
|
*/
|
|
3164
3168
|
static register(t) {
|
|
3165
|
-
|
|
3169
|
+
L.getMap().classNameMap[t.className] = [t, t.fromConfig];
|
|
3166
3170
|
}
|
|
3167
3171
|
}
|
|
3168
3172
|
function us(n, t, e) {
|
|
3169
3173
|
y(n.className != null, () => "Class being registered does not have the static className property defined."), y(typeof n.className == "string", () => "className is required to be a string, but got type " + typeof n.className), y(n.className.length > 0, () => "Class being registered has an empty-string as its className, which is disallowed."), typeof t > "u" && (t = "Custom"), typeof e > "u" && (e = n.className);
|
|
3170
3174
|
const s = e, r = t + ">" + s;
|
|
3171
|
-
return
|
|
3175
|
+
return L.register(n), as.set(r, n), ls.set(n, r), n;
|
|
3172
3176
|
}
|
|
3173
3177
|
/**
|
|
3174
3178
|
* @license
|
|
@@ -3186,7 +3190,7 @@ function us(n, t, e) {
|
|
|
3186
3190
|
* limitations under the License.
|
|
3187
3191
|
* =============================================================================
|
|
3188
3192
|
*/
|
|
3189
|
-
class
|
|
3193
|
+
class q extends cs {
|
|
3190
3194
|
/**
|
|
3191
3195
|
* Executes `f()` and minimizes the scalar output of `f()` by computing
|
|
3192
3196
|
* gradients of y with respect to the list of trainable variables provided by
|
|
@@ -3245,7 +3249,7 @@ class V extends cs {
|
|
|
3245
3249
|
return this.iterations_ == null && (this.iterations_ = 0), {
|
|
3246
3250
|
name: "iter",
|
|
3247
3251
|
// TODO(cais): Use 'int64' type when available.
|
|
3248
|
-
tensor:
|
|
3252
|
+
tensor: K(this.iterations_, "int32")
|
|
3249
3253
|
};
|
|
3250
3254
|
}
|
|
3251
3255
|
async getWeights() {
|
|
@@ -3265,7 +3269,7 @@ class V extends cs {
|
|
|
3265
3269
|
return this.iterations_ = (await t[0].tensor.data())[0], t.slice(1);
|
|
3266
3270
|
}
|
|
3267
3271
|
}
|
|
3268
|
-
Object.defineProperty(
|
|
3272
|
+
Object.defineProperty(q, Symbol.hasInstance, {
|
|
3269
3273
|
value: (n) => n.minimize != null && n.computeGradients != null && n.applyGradients != null
|
|
3270
3274
|
});
|
|
3271
3275
|
/**
|
|
@@ -3284,7 +3288,7 @@ Object.defineProperty(V, Symbol.hasInstance, {
|
|
|
3284
3288
|
* limitations under the License.
|
|
3285
3289
|
* =============================================================================
|
|
3286
3290
|
*/
|
|
3287
|
-
class hs extends
|
|
3291
|
+
class hs extends q {
|
|
3288
3292
|
/** @nocollapse */
|
|
3289
3293
|
static get className() {
|
|
3290
3294
|
return "Adadelta";
|
|
@@ -3307,7 +3311,7 @@ class hs extends V {
|
|
|
3307
3311
|
return;
|
|
3308
3312
|
const c = this.accumulatedGrads[r].variable, l = this.accumulatedUpdates[r].variable;
|
|
3309
3313
|
E(() => {
|
|
3310
|
-
const u = w(p(c, this.rho), p(
|
|
3314
|
+
const u = w(p(c, this.rho), p(z(a), 1 - this.rho)), h = p(D(et(w(l, this.epsilon)), et(w(c, this.epsilon))), a), f = w(p(l, this.rho), p(z(h), 1 - this.rho));
|
|
3311
3315
|
c.assign(u), l.assign(f);
|
|
3312
3316
|
const m = w(p(h, -this.learningRate), i);
|
|
3313
3317
|
i.assign(m);
|
|
@@ -3360,7 +3364,7 @@ class hs extends V {
|
|
|
3360
3364
|
* limitations under the License.
|
|
3361
3365
|
* =============================================================================
|
|
3362
3366
|
*/
|
|
3363
|
-
class fs extends
|
|
3367
|
+
class fs extends q {
|
|
3364
3368
|
/** @nocollapse */
|
|
3365
3369
|
static get className() {
|
|
3366
3370
|
return "Adagrad";
|
|
@@ -3380,7 +3384,7 @@ class fs extends V {
|
|
|
3380
3384
|
return;
|
|
3381
3385
|
const a = this.accumulatedGrads[r].variable;
|
|
3382
3386
|
E(() => {
|
|
3383
|
-
const c = w(a,
|
|
3387
|
+
const c = w(a, z(o));
|
|
3384
3388
|
a.assign(c);
|
|
3385
3389
|
const l = w(p(D(o, et(w(c, g.backend.epsilon()))), -this.learningRate), i);
|
|
3386
3390
|
i.assign(l);
|
|
@@ -3425,14 +3429,14 @@ class fs extends V {
|
|
|
3425
3429
|
* limitations under the License.
|
|
3426
3430
|
* =============================================================================
|
|
3427
3431
|
*/
|
|
3428
|
-
class ds extends
|
|
3432
|
+
class ds extends q {
|
|
3429
3433
|
/** @nocollapse */
|
|
3430
3434
|
static get className() {
|
|
3431
3435
|
return "Adam";
|
|
3432
3436
|
}
|
|
3433
3437
|
constructor(t, e, s, r = null) {
|
|
3434
3438
|
super(), this.learningRate = t, this.beta1 = e, this.beta2 = s, this.epsilon = r, this.accumulatedFirstMoment = [], this.accumulatedSecondMoment = [], E(() => {
|
|
3435
|
-
this.accBeta1 =
|
|
3439
|
+
this.accBeta1 = K(e).variable(), this.accBeta2 = K(s).variable();
|
|
3436
3440
|
}), r == null && (this.epsilon = g.backend.epsilon());
|
|
3437
3441
|
}
|
|
3438
3442
|
applyGradients(t) {
|
|
@@ -3451,7 +3455,7 @@ class ds extends V {
|
|
|
3451
3455
|
const l = Array.isArray(t) ? t[o].tensor : t[i];
|
|
3452
3456
|
if (l == null)
|
|
3453
3457
|
return;
|
|
3454
|
-
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedSecondMoment[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = w(p(h, this.beta2), p(
|
|
3458
|
+
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedSecondMoment[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = w(p(h, this.beta2), p(z(l), 1 - this.beta2)), b = D(f, s), d = D(m, r);
|
|
3455
3459
|
u.assign(f), h.assign(m);
|
|
3456
3460
|
const k = w(p(D(b, w(et(d), this.epsilon)), -this.learningRate), a);
|
|
3457
3461
|
a.assign(k);
|
|
@@ -3467,7 +3471,7 @@ class ds extends V {
|
|
|
3467
3471
|
}
|
|
3468
3472
|
async setWeights(t) {
|
|
3469
3473
|
t = await this.extractIterations(t), E(() => {
|
|
3470
|
-
this.accBeta1.assign(
|
|
3474
|
+
this.accBeta1.assign(Yt(this.beta1, this.iterations_ + 1)), this.accBeta2.assign(Yt(this.beta2, this.iterations_ + 1));
|
|
3471
3475
|
});
|
|
3472
3476
|
const e = t.length / 2, s = !1;
|
|
3473
3477
|
this.accumulatedFirstMoment = t.slice(0, e).map((r) => ({
|
|
@@ -3507,14 +3511,14 @@ class ds extends V {
|
|
|
3507
3511
|
* limitations under the License.
|
|
3508
3512
|
* =============================================================================
|
|
3509
3513
|
*/
|
|
3510
|
-
class gs extends
|
|
3514
|
+
class gs extends q {
|
|
3511
3515
|
/** @nocollapse */
|
|
3512
3516
|
static get className() {
|
|
3513
3517
|
return "Adamax";
|
|
3514
3518
|
}
|
|
3515
3519
|
constructor(t, e, s, r = null, i = 0) {
|
|
3516
3520
|
super(), this.learningRate = t, this.beta1 = e, this.beta2 = s, this.epsilon = r, this.decay = i, this.accumulatedFirstMoment = [], this.accumulatedWeightedInfNorm = [], E(() => {
|
|
3517
|
-
this.iteration =
|
|
3521
|
+
this.iteration = K(0).variable(), this.accBeta1 = K(e).variable();
|
|
3518
3522
|
}), r == null && (this.epsilon = g.backend.epsilon());
|
|
3519
3523
|
}
|
|
3520
3524
|
applyGradients(t) {
|
|
@@ -3579,7 +3583,7 @@ class gs extends V {
|
|
|
3579
3583
|
* limitations under the License.
|
|
3580
3584
|
* =============================================================================
|
|
3581
3585
|
*/
|
|
3582
|
-
class
|
|
3586
|
+
class Te extends q {
|
|
3583
3587
|
/** @nocollapse */
|
|
3584
3588
|
static get className() {
|
|
3585
3589
|
return "SGD";
|
|
@@ -3603,7 +3607,7 @@ class Ie extends V {
|
|
|
3603
3607
|
* Sets the learning rate of the optimizer.
|
|
3604
3608
|
*/
|
|
3605
3609
|
setLearningRate(t) {
|
|
3606
|
-
this.learningRate = t, this.c != null && this.c.dispose(), this.c = kn(
|
|
3610
|
+
this.learningRate = t, this.c != null && this.c.dispose(), this.c = kn(K(-t));
|
|
3607
3611
|
}
|
|
3608
3612
|
dispose() {
|
|
3609
3613
|
this.c.dispose();
|
|
@@ -3639,14 +3643,14 @@ class Ie extends V {
|
|
|
3639
3643
|
* limitations under the License.
|
|
3640
3644
|
* =============================================================================
|
|
3641
3645
|
*/
|
|
3642
|
-
class ms extends
|
|
3646
|
+
class ms extends Te {
|
|
3643
3647
|
/** @nocollapse */
|
|
3644
3648
|
// Name matters for Python compatibility.
|
|
3645
3649
|
static get className() {
|
|
3646
3650
|
return "Momentum";
|
|
3647
3651
|
}
|
|
3648
3652
|
constructor(t, e, s = !1) {
|
|
3649
|
-
super(t), this.learningRate = t, this.momentum = e, this.useNesterov = s, this.accumulations = [], this.m =
|
|
3653
|
+
super(t), this.learningRate = t, this.momentum = e, this.useNesterov = s, this.accumulations = [], this.m = K(this.momentum);
|
|
3650
3654
|
}
|
|
3651
3655
|
applyGradients(t) {
|
|
3652
3656
|
(Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t)).forEach((s, r) => {
|
|
@@ -3710,7 +3714,7 @@ class ms extends Ie {
|
|
|
3710
3714
|
* limitations under the License.
|
|
3711
3715
|
* =============================================================================
|
|
3712
3716
|
*/
|
|
3713
|
-
class ps extends
|
|
3717
|
+
class ps extends q {
|
|
3714
3718
|
/** @nocollapse */
|
|
3715
3719
|
static get className() {
|
|
3716
3720
|
return "RMSProp";
|
|
@@ -3737,14 +3741,14 @@ class ps extends V {
|
|
|
3737
3741
|
return;
|
|
3738
3742
|
const c = this.accumulatedMeanSquares[r].variable, l = this.accumulatedMoments[r].variable;
|
|
3739
3743
|
E(() => {
|
|
3740
|
-
const u = w(p(c, this.decay), p(
|
|
3744
|
+
const u = w(p(c, this.decay), p(z(a), 1 - this.decay));
|
|
3741
3745
|
if (this.centered) {
|
|
3742
|
-
const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate), et(Z(u, w(
|
|
3746
|
+
const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate), et(Z(u, w(z(f), this.epsilon)))), b = w(p(l, this.momentum), m);
|
|
3743
3747
|
c.assign(u), h.assign(f), l.assign(b);
|
|
3744
3748
|
const d = Z(i, b);
|
|
3745
3749
|
i.assign(d);
|
|
3746
3750
|
} else {
|
|
3747
|
-
const h = w(p(c, this.decay), p(
|
|
3751
|
+
const h = w(p(c, this.decay), p(z(a), 1 - this.decay)), f = w(p(l, this.momentum), D(p(a, this.learningRate), et(w(h, this.epsilon))));
|
|
3748
3752
|
c.assign(h), l.assign(f);
|
|
3749
3753
|
const m = Z(i, f);
|
|
3750
3754
|
i.assign(m);
|
|
@@ -3810,7 +3814,7 @@ const ys = [
|
|
|
3810
3814
|
gs,
|
|
3811
3815
|
ms,
|
|
3812
3816
|
ps,
|
|
3813
|
-
|
|
3817
|
+
Te
|
|
3814
3818
|
];
|
|
3815
3819
|
function bs() {
|
|
3816
3820
|
for (const n of ys)
|
|
@@ -3837,50 +3841,51 @@ export {
|
|
|
3837
3841
|
ds as A,
|
|
3838
3842
|
Es as B,
|
|
3839
3843
|
As as C,
|
|
3840
|
-
|
|
3844
|
+
zs as D,
|
|
3841
3845
|
g as E,
|
|
3842
|
-
|
|
3846
|
+
Bs as F,
|
|
3843
3847
|
Ms as G,
|
|
3844
|
-
|
|
3848
|
+
$s as H,
|
|
3845
3849
|
Fs as I,
|
|
3846
|
-
|
|
3847
|
-
|
|
3850
|
+
Cs as J,
|
|
3851
|
+
Ps as K,
|
|
3848
3852
|
Rs as L,
|
|
3849
3853
|
xs as M,
|
|
3850
3854
|
Ns as N,
|
|
3851
|
-
|
|
3855
|
+
Os as O,
|
|
3852
3856
|
Ds as P,
|
|
3853
|
-
|
|
3857
|
+
Us as Q,
|
|
3854
3858
|
_s as R,
|
|
3855
3859
|
Ws as S,
|
|
3856
|
-
|
|
3857
|
-
|
|
3858
|
-
|
|
3860
|
+
Vs as T,
|
|
3861
|
+
Ks as U,
|
|
3862
|
+
Qs as V,
|
|
3863
|
+
Qn as W,
|
|
3859
3864
|
qs as _,
|
|
3860
3865
|
Z as a,
|
|
3861
|
-
|
|
3866
|
+
Js as b,
|
|
3862
3867
|
I as c,
|
|
3863
|
-
|
|
3864
|
-
|
|
3865
|
-
|
|
3866
|
-
|
|
3867
|
-
|
|
3868
|
-
|
|
3869
|
-
|
|
3870
|
-
|
|
3871
|
-
|
|
3868
|
+
V as d,
|
|
3869
|
+
Ys as e,
|
|
3870
|
+
Is as f,
|
|
3871
|
+
Xs as g,
|
|
3872
|
+
y as h,
|
|
3873
|
+
Ls as i,
|
|
3874
|
+
$t as j,
|
|
3875
|
+
Dt as k,
|
|
3876
|
+
Zt as l,
|
|
3872
3877
|
p as m,
|
|
3873
|
-
|
|
3878
|
+
G as n,
|
|
3874
3879
|
F as o,
|
|
3875
|
-
|
|
3876
|
-
|
|
3880
|
+
De as p,
|
|
3881
|
+
Gs as q,
|
|
3877
3882
|
Hs as r,
|
|
3878
|
-
|
|
3879
|
-
|
|
3880
|
-
|
|
3881
|
-
|
|
3882
|
-
|
|
3883
|
-
|
|
3884
|
-
|
|
3883
|
+
K as s,
|
|
3884
|
+
vs as t,
|
|
3885
|
+
Ts as u,
|
|
3886
|
+
w as v,
|
|
3887
|
+
js as w,
|
|
3888
|
+
Zs as x,
|
|
3889
|
+
E as y,
|
|
3885
3890
|
C as z
|
|
3886
3891
|
};
|