@genai-fi/nanogpt 0.2.4 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/NanoGPTModel.js +43 -44
- package/dist/TeachableLLM.js +3 -0
- package/dist/complex-D6Bq1XDf.js +27 -0
- package/dist/data/docx.d.ts +1 -0
- package/dist/data/docx.js +15 -0
- package/dist/data/pdf.d.ts +1 -1
- package/dist/data/pdf.js +10 -8
- package/dist/data/textLoader.js +29 -24
- package/dist/{index-DcaSvB38.js → index-D1SlunD-.js} +553 -522
- package/dist/{jszip.min-CAxN99oA.js → jszip.min-CjP2V1VV.js} +1 -1
- package/dist/layers/TiedEmbedding.js +113 -178
- package/dist/main.d.ts +2 -0
- package/dist/main.js +18 -10
- package/dist/ops/gatherSub.d.ts +2 -0
- package/dist/ops/gatherSub.js +66 -0
- package/dist/ops/node/sparseCrossEntropy.d.ts +1 -0
- package/dist/ops/node/sparseCrossEntropy.js +11 -0
- package/dist/ops/scatterSub.d.ts +2 -0
- package/dist/ops/scatterSub.js +150 -0
- package/dist/stack-DB2YLlAs.js +50 -0
- package/dist/sum-02UQ5Eaq.js +49 -0
- package/dist/tokeniser/CharTokeniser.d.ts +1 -0
- package/dist/tokeniser/CharTokeniser.js +48 -39
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +3 -2
- package/dist/training/Trainer.js +3 -3
- package/dist/training/sparseCrossEntropy.d.ts +11 -0
- package/dist/training/sparseCrossEntropy.js +177 -0
- package/dist/utilities/load.js +5 -5
- package/dist/utilities/parameters.d.ts +10 -0
- package/dist/utilities/parameters.js +52 -0
- package/dist/utilities/save.js +1 -1
- package/package.json +3 -2
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { g as
|
|
2
|
-
import { p as
|
|
3
|
-
import { B as
|
|
1
|
+
import { g as Pt } from "./index-D5v913EJ.js";
|
|
2
|
+
import { p as Q } from "./index-xuotMAFm.js";
|
|
3
|
+
import { B as gt } from "./index-Tf7vU29b.js";
|
|
4
4
|
/**
|
|
5
5
|
* @license
|
|
6
6
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -17,8 +17,8 @@ import { B as dt } from "./index-Tf7vU29b.js";
|
|
|
17
17
|
* limitations under the License.
|
|
18
18
|
* =============================================================================
|
|
19
19
|
*/
|
|
20
|
-
const
|
|
21
|
-
class
|
|
20
|
+
const Te = 1e-7, Ee = 1e-4;
|
|
21
|
+
class Ae {
|
|
22
22
|
refCount(t) {
|
|
23
23
|
return v("refCount");
|
|
24
24
|
}
|
|
@@ -64,7 +64,7 @@ class Ee {
|
|
|
64
64
|
}
|
|
65
65
|
/** Returns the smallest representable number. */
|
|
66
66
|
epsilon() {
|
|
67
|
-
return this.floatPrecision() === 32 ?
|
|
67
|
+
return this.floatPrecision() === 32 ? Te : Ee;
|
|
68
68
|
}
|
|
69
69
|
dispose() {
|
|
70
70
|
return v("dispose");
|
|
@@ -89,12 +89,12 @@ function v(n) {
|
|
|
89
89
|
* limitations under the License.
|
|
90
90
|
* =============================================================================
|
|
91
91
|
*/
|
|
92
|
-
function
|
|
92
|
+
function y(n, t) {
|
|
93
93
|
if (!n)
|
|
94
94
|
throw new Error(typeof t == "string" ? t : t());
|
|
95
95
|
}
|
|
96
96
|
function Is(n, t, e = "") {
|
|
97
|
-
|
|
97
|
+
y($t(n, t), () => e + ` Shapes ${n} and ${t} must match`);
|
|
98
98
|
}
|
|
99
99
|
function U(n) {
|
|
100
100
|
if (n.length === 0)
|
|
@@ -104,7 +104,7 @@ function U(n) {
|
|
|
104
104
|
t *= n[e];
|
|
105
105
|
return t;
|
|
106
106
|
}
|
|
107
|
-
function
|
|
107
|
+
function $t(n, t) {
|
|
108
108
|
if (n === t)
|
|
109
109
|
return !0;
|
|
110
110
|
if (n == null || t == null || n.length !== t.length)
|
|
@@ -114,10 +114,17 @@ function Ft(n, t) {
|
|
|
114
114
|
return !1;
|
|
115
115
|
return !0;
|
|
116
116
|
}
|
|
117
|
+
function Be(n) {
|
|
118
|
+
return n % 1 === 0;
|
|
119
|
+
}
|
|
117
120
|
function ct(n, t) {
|
|
118
121
|
return t <= n.length ? n : n + " ".repeat(t - n.length);
|
|
119
122
|
}
|
|
120
|
-
function
|
|
123
|
+
function Ts(n, t) {
|
|
124
|
+
const e = t.length;
|
|
125
|
+
return n = n == null ? t.map((s, r) => r) : [].concat(n), y(n.every((s) => s >= -e && s < e), () => `All values in axis param must be in range [-${e}, ${e}) but got axis ${n}`), y(n.every((s) => Be(s)), () => `All values in axis param must be integers but got axis ${n}`), n.map((s) => s < 0 ? e + s : s);
|
|
126
|
+
}
|
|
127
|
+
function ve(n, t) {
|
|
121
128
|
let e = null;
|
|
122
129
|
if (n == null || n === "float32")
|
|
123
130
|
e = new Float32Array(t);
|
|
@@ -131,17 +138,17 @@ function Be(n, t) {
|
|
|
131
138
|
throw new Error(`Unknown data type ${n}`);
|
|
132
139
|
return e;
|
|
133
140
|
}
|
|
134
|
-
function
|
|
141
|
+
function Me(n, t) {
|
|
135
142
|
for (let e = 0; e < n.length; e++) {
|
|
136
143
|
const s = n[e];
|
|
137
144
|
if (isNaN(s) || !isFinite(s))
|
|
138
145
|
throw Error(`A tensor of type ${t} being uploaded contains ${s}.`);
|
|
139
146
|
}
|
|
140
147
|
}
|
|
141
|
-
function
|
|
148
|
+
function Fe(n) {
|
|
142
149
|
return n === "bool" || n === "complex64" || n === "float32" || n === "int32" || n === "string";
|
|
143
150
|
}
|
|
144
|
-
function
|
|
151
|
+
function St(n) {
|
|
145
152
|
if (n === "float32" || n === "int32")
|
|
146
153
|
return 4;
|
|
147
154
|
if (n === "complex64")
|
|
@@ -150,7 +157,7 @@ function wt(n) {
|
|
|
150
157
|
return 1;
|
|
151
158
|
throw new Error(`Unknown dtype ${n}`);
|
|
152
159
|
}
|
|
153
|
-
function
|
|
160
|
+
function $e(n) {
|
|
154
161
|
if (n == null)
|
|
155
162
|
return 0;
|
|
156
163
|
let t = 0;
|
|
@@ -159,16 +166,16 @@ function Me(n) {
|
|
|
159
166
|
function Rt(n) {
|
|
160
167
|
return typeof n == "string" || n instanceof String;
|
|
161
168
|
}
|
|
162
|
-
function
|
|
169
|
+
function Re(n) {
|
|
163
170
|
return typeof n == "boolean";
|
|
164
171
|
}
|
|
165
|
-
function
|
|
172
|
+
function xe(n) {
|
|
166
173
|
return typeof n == "number";
|
|
167
174
|
}
|
|
168
|
-
function
|
|
169
|
-
return Array.isArray(n) ?
|
|
175
|
+
function mt(n) {
|
|
176
|
+
return Array.isArray(n) ? mt(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray ? "int32" : xe(n) ? "float32" : Rt(n) ? "string" : Re(n) ? "bool" : "float32";
|
|
170
177
|
}
|
|
171
|
-
function
|
|
178
|
+
function kt(n) {
|
|
172
179
|
return !!(n && n.constructor && n.call && n.apply);
|
|
173
180
|
}
|
|
174
181
|
function xt(n) {
|
|
@@ -181,7 +188,7 @@ function xt(n) {
|
|
|
181
188
|
e[s] = e[s + 1] * n[s + 1];
|
|
182
189
|
return e;
|
|
183
190
|
}
|
|
184
|
-
function
|
|
191
|
+
function Yt(n, t, e, s = !1) {
|
|
185
192
|
const r = new Array();
|
|
186
193
|
if (t.length === 1) {
|
|
187
194
|
const i = t[0] * (s ? 2 : 1);
|
|
@@ -190,11 +197,11 @@ function Xt(n, t, e, s = !1) {
|
|
|
190
197
|
} else {
|
|
191
198
|
const i = t[0], o = t.slice(1), a = o.reduce((c, l) => c * l) * (s ? 2 : 1);
|
|
192
199
|
for (let c = 0; c < i; c++)
|
|
193
|
-
r[c] =
|
|
200
|
+
r[c] = Yt(n + c * a, o, e, s);
|
|
194
201
|
}
|
|
195
202
|
return r;
|
|
196
203
|
}
|
|
197
|
-
function
|
|
204
|
+
function Ot(n, t, e = !1) {
|
|
198
205
|
if (n.length === 0)
|
|
199
206
|
return t[0];
|
|
200
207
|
const s = n.reduce((r, i) => r * i) * (e ? 2 : 1);
|
|
@@ -202,15 +209,15 @@ function Pt(n, t, e = !1) {
|
|
|
202
209
|
return [];
|
|
203
210
|
if (s !== t.length)
|
|
204
211
|
throw new Error(`[${n}] does not match the input size ${t.length}${e ? " for a complex tensor" : ""}.`);
|
|
205
|
-
return
|
|
212
|
+
return Yt(0, n, t, e);
|
|
206
213
|
}
|
|
207
|
-
function
|
|
208
|
-
const e =
|
|
214
|
+
function Ne(n, t) {
|
|
215
|
+
const e = Qt(n, t);
|
|
209
216
|
for (let s = 0; s < e.length; s++)
|
|
210
217
|
e[s] = 1;
|
|
211
218
|
return e;
|
|
212
219
|
}
|
|
213
|
-
function
|
|
220
|
+
function Qt(n, t) {
|
|
214
221
|
if (t == null || t === "float32" || t === "complex64")
|
|
215
222
|
return new Float32Array(n);
|
|
216
223
|
if (t === "int32")
|
|
@@ -219,12 +226,12 @@ function Yt(n, t) {
|
|
|
219
226
|
return new Uint8Array(n);
|
|
220
227
|
throw new Error(`Unknown data type ${t}`);
|
|
221
228
|
}
|
|
222
|
-
function
|
|
229
|
+
function Nt(n) {
|
|
223
230
|
n.forEach((t) => {
|
|
224
|
-
|
|
231
|
+
y(Number.isInteger(t) && t >= 0, () => `Tensor must have a shape comprised of positive integers but got shape [${n}].`);
|
|
225
232
|
});
|
|
226
233
|
}
|
|
227
|
-
function
|
|
234
|
+
function Dt(n) {
|
|
228
235
|
return n && n.then && typeof n.then == "function";
|
|
229
236
|
}
|
|
230
237
|
/**
|
|
@@ -243,11 +250,11 @@ function Nt(n) {
|
|
|
243
250
|
* limitations under the License.
|
|
244
251
|
* =============================================================================
|
|
245
252
|
*/
|
|
246
|
-
const
|
|
247
|
-
class
|
|
253
|
+
const Lt = "tfjsflags";
|
|
254
|
+
class De {
|
|
248
255
|
// tslint:disable-next-line: no-any
|
|
249
256
|
constructor(t) {
|
|
250
|
-
this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams =
|
|
257
|
+
this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams = Ce, this.populateURLFlags();
|
|
251
258
|
}
|
|
252
259
|
setPlatform(t, e) {
|
|
253
260
|
this.platform != null && (S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(`Platform ${this.platformName} has already been set. Overwriting the platform with ${t}.`)), this.platformName = t, this.platform = e;
|
|
@@ -265,7 +272,7 @@ class $e {
|
|
|
265
272
|
if (t in this.flags)
|
|
266
273
|
return this.flags[t];
|
|
267
274
|
const e = this.evaluateFlag(t);
|
|
268
|
-
if (
|
|
275
|
+
if (Dt(e))
|
|
269
276
|
throw new Error(`Flag ${t} cannot be synchronously evaluated. Please use getAsync() instead.`);
|
|
270
277
|
return this.flags[t] = e, this.flags[t];
|
|
271
278
|
}
|
|
@@ -305,29 +312,29 @@ class $e {
|
|
|
305
312
|
if (typeof this.global > "u" || typeof this.global.location > "u" || typeof this.global.location.search > "u")
|
|
306
313
|
return;
|
|
307
314
|
const t = this.getQueryParams(this.global.location.search);
|
|
308
|
-
|
|
315
|
+
Lt in t && t[Lt].split(",").forEach((s) => {
|
|
309
316
|
const [r, i] = s.split(":");
|
|
310
|
-
this.urlFlags[r] =
|
|
317
|
+
this.urlFlags[r] = Pe(r, i);
|
|
311
318
|
});
|
|
312
319
|
}
|
|
313
320
|
}
|
|
314
|
-
function
|
|
321
|
+
function Ce(n) {
|
|
315
322
|
const t = {};
|
|
316
|
-
return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (
|
|
323
|
+
return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (_e(t, s[0], s[1]), s.join("="))), t;
|
|
317
324
|
}
|
|
318
|
-
function
|
|
325
|
+
function _e(n, t, e) {
|
|
319
326
|
n[decodeURIComponent(t)] = decodeURIComponent(e || "");
|
|
320
327
|
}
|
|
321
|
-
function
|
|
328
|
+
function Pe(n, t) {
|
|
322
329
|
const e = t.toLowerCase();
|
|
323
330
|
return e === "true" || e === "false" ? e === "true" : `${+e}` === e ? +e : t;
|
|
324
331
|
}
|
|
325
332
|
function S() {
|
|
326
|
-
return
|
|
333
|
+
return Zt;
|
|
327
334
|
}
|
|
328
|
-
let
|
|
329
|
-
function
|
|
330
|
-
|
|
335
|
+
let Zt = null;
|
|
336
|
+
function Oe(n) {
|
|
337
|
+
Zt = n;
|
|
331
338
|
}
|
|
332
339
|
/**
|
|
333
340
|
* @license
|
|
@@ -345,30 +352,30 @@ function _e(n) {
|
|
|
345
352
|
* limitations under the License.
|
|
346
353
|
* =============================================================================
|
|
347
354
|
*/
|
|
348
|
-
let
|
|
349
|
-
function
|
|
350
|
-
if (
|
|
355
|
+
let pt;
|
|
356
|
+
function te() {
|
|
357
|
+
if (pt == null) {
|
|
351
358
|
let n;
|
|
352
359
|
if (typeof window < "u")
|
|
353
360
|
n = window;
|
|
354
|
-
else if (typeof
|
|
355
|
-
n =
|
|
356
|
-
else if (typeof
|
|
357
|
-
n =
|
|
361
|
+
else if (typeof Pt < "u")
|
|
362
|
+
n = Pt;
|
|
363
|
+
else if (typeof Q < "u")
|
|
364
|
+
n = Q;
|
|
358
365
|
else if (typeof self < "u")
|
|
359
366
|
n = self;
|
|
360
367
|
else
|
|
361
368
|
throw new Error("Could not find a global object");
|
|
362
|
-
|
|
369
|
+
pt = n;
|
|
363
370
|
}
|
|
364
|
-
return
|
|
371
|
+
return pt;
|
|
365
372
|
}
|
|
366
|
-
function
|
|
367
|
-
const n =
|
|
373
|
+
function Le() {
|
|
374
|
+
const n = te();
|
|
368
375
|
return n._tfGlobals == null && (n._tfGlobals = /* @__PURE__ */ new Map()), n._tfGlobals;
|
|
369
376
|
}
|
|
370
|
-
function
|
|
371
|
-
const e =
|
|
377
|
+
function Ct(n, t) {
|
|
378
|
+
const e = Le();
|
|
372
379
|
if (e.has(n))
|
|
373
380
|
return e.get(n);
|
|
374
381
|
{
|
|
@@ -376,7 +383,7 @@ function Dt(n, t) {
|
|
|
376
383
|
return e.set(n, s), e.get(n);
|
|
377
384
|
}
|
|
378
385
|
}
|
|
379
|
-
const
|
|
386
|
+
const Ue = "Abs", ee = "Add", Es = "BatchMatMul", ne = "Cast", As = "Complex", Ge = "ComplexAbs", ze = "RealDiv", Bs = "Elu", vs = "Exp", We = "Fill", je = "FloorDiv", Ms = "GatherNd", se = "Identity", Fs = "Imag", $s = "LeakyRelu", Rs = "Log", xs = "Max", Ke = "Maximum", Ve = "Multiply", Ns = "Neg", Ds = "Pack", qe = "Pow", Cs = "Prelu", _s = "Range", Ps = "Real", Os = "Relu", Ls = "Reshape", Us = "Relu6", Gs = "ScatterNd", zs = "Sigmoid", He = "Sqrt", Ws = "Sum", js = "Softmax", Je = "Sub", Ks = "Transpose", Xe = "ZerosLike", Vs = "Step", qs = "_FusedMatMul";
|
|
380
387
|
/**
|
|
381
388
|
* @license
|
|
382
389
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -393,7 +400,7 @@ const Oe = "Abs", te = "Add", ks = "BatchMatMul", ee = "Cast", Ts = "Complex", L
|
|
|
393
400
|
* limitations under the License.
|
|
394
401
|
* =============================================================================
|
|
395
402
|
*/
|
|
396
|
-
function
|
|
403
|
+
function J(...n) {
|
|
397
404
|
S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(...n);
|
|
398
405
|
}
|
|
399
406
|
/**
|
|
@@ -412,16 +419,16 @@ function st(...n) {
|
|
|
412
419
|
* limitations under the License.
|
|
413
420
|
* =============================================================================
|
|
414
421
|
*/
|
|
415
|
-
const
|
|
416
|
-
function
|
|
417
|
-
const e =
|
|
418
|
-
return
|
|
419
|
-
}
|
|
420
|
-
function Ut(n) {
|
|
421
|
-
return Je.get(n);
|
|
422
|
+
const ht = Ct("kernelRegistry", () => /* @__PURE__ */ new Map()), Ye = Ct("gradRegistry", () => /* @__PURE__ */ new Map());
|
|
423
|
+
function Ut(n, t) {
|
|
424
|
+
const e = re(n, t);
|
|
425
|
+
return ht.get(e);
|
|
422
426
|
}
|
|
423
427
|
function Gt(n) {
|
|
424
|
-
|
|
428
|
+
return Ye.get(n);
|
|
429
|
+
}
|
|
430
|
+
function zt(n) {
|
|
431
|
+
const t = ht.entries(), e = [];
|
|
425
432
|
for (; ; ) {
|
|
426
433
|
const { done: s, value: r } = t.next();
|
|
427
434
|
if (s)
|
|
@@ -431,7 +438,11 @@ function Gt(n) {
|
|
|
431
438
|
}
|
|
432
439
|
return e;
|
|
433
440
|
}
|
|
434
|
-
function
|
|
441
|
+
function Hs(n) {
|
|
442
|
+
const { kernelName: t, backendName: e } = n, s = re(t, e);
|
|
443
|
+
ht.has(s) && J(`The kernel '${t}' for backend '${e}' is already registered`), ht.set(s, n);
|
|
444
|
+
}
|
|
445
|
+
function re(n, t) {
|
|
435
446
|
return `${t}_${n}`;
|
|
436
447
|
}
|
|
437
448
|
/**
|
|
@@ -450,7 +461,7 @@ function Xe(n, t) {
|
|
|
450
461
|
* limitations under the License.
|
|
451
462
|
* =============================================================================
|
|
452
463
|
*/
|
|
453
|
-
function
|
|
464
|
+
function ie(n) {
|
|
454
465
|
return n instanceof Float32Array || n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray;
|
|
455
466
|
}
|
|
456
467
|
/**
|
|
@@ -469,13 +480,13 @@ function re(n) {
|
|
|
469
480
|
* limitations under the License.
|
|
470
481
|
* =============================================================================
|
|
471
482
|
*/
|
|
472
|
-
function
|
|
483
|
+
function Qe(n, t) {
|
|
473
484
|
return n instanceof Float32Array && t === "float32" || n instanceof Int32Array && t === "int32" || n instanceof Uint8Array && t === "bool";
|
|
474
485
|
}
|
|
475
|
-
function
|
|
486
|
+
function oe(n, t) {
|
|
476
487
|
if (t === "string")
|
|
477
488
|
throw new Error("Cannot convert a string[] to a TypedArray");
|
|
478
|
-
if (Array.isArray(n) && (n = at(n)), S().getBool("DEBUG") &&
|
|
489
|
+
if (Array.isArray(n) && (n = at(n)), S().getBool("DEBUG") && Me(n, t), Qe(n, t))
|
|
479
490
|
return n;
|
|
480
491
|
if (t == null || t === "float32" || t === "complex64")
|
|
481
492
|
return new Float32Array(n);
|
|
@@ -489,22 +500,22 @@ function ie(n, t) {
|
|
|
489
500
|
} else
|
|
490
501
|
throw new Error(`Unknown data type ${t}`);
|
|
491
502
|
}
|
|
492
|
-
function
|
|
503
|
+
function ft() {
|
|
493
504
|
return S().platform.now();
|
|
494
505
|
}
|
|
495
|
-
function
|
|
506
|
+
function Ze(n, t = "utf-8") {
|
|
496
507
|
return t = t || "utf-8", S().platform.encode(n, t);
|
|
497
508
|
}
|
|
498
|
-
function
|
|
509
|
+
function Wt(n, t = "utf-8") {
|
|
499
510
|
return t = t || "utf-8", S().platform.decode(n, t);
|
|
500
511
|
}
|
|
501
|
-
function
|
|
502
|
-
return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) :
|
|
512
|
+
function $(n) {
|
|
513
|
+
return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) : ie(n);
|
|
503
514
|
}
|
|
504
515
|
function at(n, t = [], e = !1) {
|
|
505
|
-
if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" ||
|
|
516
|
+
if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" || Dt(n) || n == null || $(n) && e)
|
|
506
517
|
t.push(n);
|
|
507
|
-
else if (Array.isArray(n) ||
|
|
518
|
+
else if (Array.isArray(n) || $(n))
|
|
508
519
|
for (let s = 0; s < n.length; ++s)
|
|
509
520
|
at(n[s], t, e);
|
|
510
521
|
else {
|
|
@@ -532,9 +543,9 @@ function at(n, t = [], e = !1) {
|
|
|
532
543
|
* limitations under the License.
|
|
533
544
|
* =============================================================================
|
|
534
545
|
*/
|
|
535
|
-
class
|
|
546
|
+
class tn {
|
|
536
547
|
constructor(t, e) {
|
|
537
|
-
this.backendTimer = t, this.logger = e, e == null && (this.logger = new
|
|
548
|
+
this.backendTimer = t, this.logger = e, e == null && (this.logger = new nn());
|
|
538
549
|
}
|
|
539
550
|
profileKernel(t, e, s) {
|
|
540
551
|
let r;
|
|
@@ -542,20 +553,20 @@ class Ze {
|
|
|
542
553
|
r = s();
|
|
543
554
|
};
|
|
544
555
|
let o;
|
|
545
|
-
const a =
|
|
556
|
+
const a = ft();
|
|
546
557
|
if (this.backendTimer.timerAvailable())
|
|
547
558
|
o = this.backendTimer.time(i);
|
|
548
559
|
else {
|
|
549
560
|
i();
|
|
550
561
|
for (const l of r)
|
|
551
562
|
l.dataSync();
|
|
552
|
-
o = Promise.resolve({ kernelMs:
|
|
563
|
+
o = Promise.resolve({ kernelMs: ft() - a });
|
|
553
564
|
}
|
|
554
565
|
if (S().getBool("CHECK_COMPUTATION_FOR_ERRORS"))
|
|
555
566
|
for (let l = 0; l < r.length; l++) {
|
|
556
567
|
const u = r[l];
|
|
557
568
|
u.data().then((h) => {
|
|
558
|
-
|
|
569
|
+
en(h, u.dtype, t);
|
|
559
570
|
});
|
|
560
571
|
}
|
|
561
572
|
return {
|
|
@@ -575,7 +586,7 @@ class Ze {
|
|
|
575
586
|
});
|
|
576
587
|
}
|
|
577
588
|
}
|
|
578
|
-
function
|
|
589
|
+
function en(n, t, e) {
|
|
579
590
|
if (t !== "float32")
|
|
580
591
|
return !1;
|
|
581
592
|
for (let s = 0; s < n.length; s++) {
|
|
@@ -585,15 +596,15 @@ function tn(n, t, e) {
|
|
|
585
596
|
}
|
|
586
597
|
return !1;
|
|
587
598
|
}
|
|
588
|
-
class
|
|
599
|
+
class nn {
|
|
589
600
|
logKernelProfile(t, e, s, r, i, o) {
|
|
590
601
|
const a = typeof r == "number" ? ct(`${r}ms`, 9) : r.error, c = ct(t, 25), l = e.rank, u = e.size, h = ct(e.shape.toString(), 14);
|
|
591
602
|
let f = "";
|
|
592
603
|
for (const m in i) {
|
|
593
|
-
const
|
|
594
|
-
if (
|
|
595
|
-
const d =
|
|
596
|
-
f += `${m}: ${
|
|
604
|
+
const b = i[m];
|
|
605
|
+
if (b != null) {
|
|
606
|
+
const d = b.shape || e.shape, k = d.length;
|
|
607
|
+
f += `${m}: ${k}D ${k > 0 ? d : ""} `;
|
|
597
608
|
}
|
|
598
609
|
}
|
|
599
610
|
console.log(`%c${c} %c${a} %c${l}D ${h} %c${u} %c${f} %c${o}`, "font-weight:bold", "color:red", "color:blue", "color: orange", "color: green", "color: steelblue");
|
|
@@ -615,7 +626,7 @@ class en {
|
|
|
615
626
|
* limitations under the License.
|
|
616
627
|
* =============================================================================
|
|
617
628
|
*/
|
|
618
|
-
function
|
|
629
|
+
function sn(n, t, e) {
|
|
619
630
|
const s = {}, r = {};
|
|
620
631
|
for (let c = 0; c < t.length; c++)
|
|
621
632
|
s[t[c].id] = !0;
|
|
@@ -624,7 +635,7 @@ function nn(n, t, e) {
|
|
|
624
635
|
for (const h in u) {
|
|
625
636
|
const f = u[h];
|
|
626
637
|
let m = !1;
|
|
627
|
-
for (let
|
|
638
|
+
for (let b = 0; b < t.length; b++)
|
|
628
639
|
if (s[f.id]) {
|
|
629
640
|
l.outputs.forEach((d) => s[d.id] = !0), m = !0, r[l.id] = !0;
|
|
630
641
|
break;
|
|
@@ -660,7 +671,7 @@ function nn(n, t, e) {
|
|
|
660
671
|
}
|
|
661
672
|
return a;
|
|
662
673
|
}
|
|
663
|
-
function
|
|
674
|
+
function rn(n, t, e, s) {
|
|
664
675
|
for (let r = t.length - 1; r >= 0; r--) {
|
|
665
676
|
const i = t[r], o = [];
|
|
666
677
|
if (i.outputs.forEach((c) => {
|
|
@@ -676,7 +687,7 @@ function sn(n, t, e, s) {
|
|
|
676
687
|
if (l.dtype !== "float32")
|
|
677
688
|
throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input ${c} must have 'float32' dtype, but has '${l.dtype}'`);
|
|
678
689
|
const u = i.inputs[c];
|
|
679
|
-
if (
|
|
690
|
+
if (!$t(l.shape, u.shape))
|
|
680
691
|
throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input '${c}' has shape '${l.shape}', which does not match the shape of the input '${u.shape}'`);
|
|
681
692
|
if (n[u.id] == null)
|
|
682
693
|
n[u.id] = l;
|
|
@@ -703,14 +714,14 @@ function sn(n, t, e, s) {
|
|
|
703
714
|
* limitations under the License.
|
|
704
715
|
* =============================================================================
|
|
705
716
|
*/
|
|
706
|
-
const
|
|
707
|
-
function
|
|
708
|
-
const r = xt(t), i =
|
|
717
|
+
const jt = 20, rt = 3, yt = 7;
|
|
718
|
+
function on(n, t, e, s) {
|
|
719
|
+
const r = xt(t), i = an(n, t, e, r), o = t.length, a = ut(n, t, e, r, i), c = ["Tensor"];
|
|
709
720
|
return s && (c.push(` dtype: ${e}`), c.push(` rank: ${o}`), c.push(` shape: [${t}]`), c.push(" values:")), c.push(a.map((l) => " " + l).join(`
|
|
710
721
|
`)), c.join(`
|
|
711
722
|
`);
|
|
712
723
|
}
|
|
713
|
-
function
|
|
724
|
+
function an(n, t, e, s) {
|
|
714
725
|
const r = U(t), i = s[s.length - 1], o = new Array(i).fill(0), a = t.length, c = e === "complex64" ? ot(n) : n;
|
|
715
726
|
if (a > 1)
|
|
716
727
|
for (let l = 0; l < r / i; l++) {
|
|
@@ -722,9 +733,9 @@ function on(n, t, e, s) {
|
|
|
722
733
|
}
|
|
723
734
|
function it(n, t, e) {
|
|
724
735
|
let s;
|
|
725
|
-
return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(
|
|
736
|
+
return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(yt))} + ${parseFloat(n[1].toFixed(yt))}j` : Rt(n) ? s = `'${n}'` : e === "bool" ? s = ae(n) : s = parseFloat(n.toFixed(yt)).toString(), ct(s, t);
|
|
726
737
|
}
|
|
727
|
-
function
|
|
738
|
+
function ae(n) {
|
|
728
739
|
return n === 0 ? "false" : "true";
|
|
729
740
|
}
|
|
730
741
|
function ut(n, t, e, s, r, i = !0) {
|
|
@@ -734,26 +745,26 @@ function ut(n, t, e, s, r, i = !0) {
|
|
|
734
745
|
const d = ot(n);
|
|
735
746
|
return [it(d[0], 0, e)];
|
|
736
747
|
}
|
|
737
|
-
return e === "bool" ? [
|
|
748
|
+
return e === "bool" ? [ae(n[0])] : [n[0].toString()];
|
|
738
749
|
}
|
|
739
750
|
if (c === 1) {
|
|
740
|
-
if (a >
|
|
741
|
-
const
|
|
742
|
-
let T = Array.from(n.slice(0,
|
|
743
|
-
return e === "complex64" && (T = ot(T),
|
|
744
|
-
"[" + T.map((q, H) => it(q, r[H], e)).join(", ") + ", ..., " +
|
|
751
|
+
if (a > jt) {
|
|
752
|
+
const k = rt * o;
|
|
753
|
+
let T = Array.from(n.slice(0, k)), nt = Array.from(n.slice((a - rt) * o, a * o));
|
|
754
|
+
return e === "complex64" && (T = ot(T), nt = ot(nt)), [
|
|
755
|
+
"[" + T.map((q, H) => it(q, r[H], e)).join(", ") + ", ..., " + nt.map((q, H) => it(q, r[a - rt + H], e)).join(", ") + "]"
|
|
745
756
|
];
|
|
746
757
|
}
|
|
747
758
|
return [
|
|
748
|
-
"[" + (e === "complex64" ? ot(n) : Array.from(n)).map((
|
|
759
|
+
"[" + (e === "complex64" ? ot(n) : Array.from(n)).map((k, T) => it(k, r[T], e)).join(", ") + "]"
|
|
749
760
|
];
|
|
750
761
|
}
|
|
751
762
|
const l = t.slice(1), u = s.slice(1), h = s[0] * o, f = [];
|
|
752
|
-
if (a >
|
|
763
|
+
if (a > jt) {
|
|
753
764
|
for (let d = 0; d < rt; d++) {
|
|
754
|
-
const
|
|
765
|
+
const k = d * h, T = k + h;
|
|
755
766
|
f.push(...ut(
|
|
756
|
-
n.slice(
|
|
767
|
+
n.slice(k, T),
|
|
757
768
|
l,
|
|
758
769
|
e,
|
|
759
770
|
u,
|
|
@@ -764,9 +775,9 @@ function ut(n, t, e, s, r, i = !0) {
|
|
|
764
775
|
}
|
|
765
776
|
f.push("...");
|
|
766
777
|
for (let d = a - rt; d < a; d++) {
|
|
767
|
-
const
|
|
778
|
+
const k = d * h, T = k + h;
|
|
768
779
|
f.push(...ut(
|
|
769
|
-
n.slice(
|
|
780
|
+
n.slice(k, T),
|
|
770
781
|
l,
|
|
771
782
|
e,
|
|
772
783
|
u,
|
|
@@ -777,9 +788,9 @@ function ut(n, t, e, s, r, i = !0) {
|
|
|
777
788
|
}
|
|
778
789
|
} else
|
|
779
790
|
for (let d = 0; d < a; d++) {
|
|
780
|
-
const
|
|
791
|
+
const k = d * h, T = k + h;
|
|
781
792
|
f.push(...ut(
|
|
782
|
-
n.slice(
|
|
793
|
+
n.slice(k, T),
|
|
783
794
|
l,
|
|
784
795
|
e,
|
|
785
796
|
u,
|
|
@@ -792,12 +803,12 @@ function ut(n, t, e, s, r, i = !0) {
|
|
|
792
803
|
f[0] = "[" + (a > 0 ? f[0] + m : "");
|
|
793
804
|
for (let d = 1; d < f.length - 1; d++)
|
|
794
805
|
f[d] = " " + f[d] + m;
|
|
795
|
-
let
|
|
806
|
+
let b = `,
|
|
796
807
|
`;
|
|
797
808
|
for (let d = 2; d < c; d++)
|
|
798
|
-
|
|
809
|
+
b += `
|
|
799
810
|
`;
|
|
800
|
-
return f[f.length - 1] = " " + f[f.length - 1] + "]" + (i ? "" :
|
|
811
|
+
return f[f.length - 1] = " " + f[f.length - 1] + "]" + (i ? "" : b), f;
|
|
801
812
|
}
|
|
802
813
|
function ot(n) {
|
|
803
814
|
const t = [];
|
|
@@ -821,15 +832,15 @@ function ot(n) {
|
|
|
821
832
|
* limitations under the License.
|
|
822
833
|
* =============================================================================
|
|
823
834
|
*/
|
|
824
|
-
class
|
|
835
|
+
class ln {
|
|
825
836
|
constructor(t, e, s) {
|
|
826
837
|
if (this.dtype = e, this.shape = t.slice(), this.size = U(t), s != null) {
|
|
827
838
|
const r = s.length;
|
|
828
|
-
|
|
839
|
+
y(r === this.size, () => `Length of values '${r}' does not match the size inferred by the shape '${this.size}'.`);
|
|
829
840
|
}
|
|
830
841
|
if (e === "complex64")
|
|
831
842
|
throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).");
|
|
832
|
-
this.values = s ||
|
|
843
|
+
this.values = s || ve(e, this.size), this.strides = xt(t);
|
|
833
844
|
}
|
|
834
845
|
/**
|
|
835
846
|
* Sets a value in the buffer at a given location.
|
|
@@ -840,7 +851,7 @@ class an {
|
|
|
840
851
|
* @doc {heading: 'Tensors', subheading: 'Creation'}
|
|
841
852
|
*/
|
|
842
853
|
set(t, ...e) {
|
|
843
|
-
e.length === 0 && (e = [0]),
|
|
854
|
+
e.length === 0 && (e = [0]), y(e.length === this.rank, () => `The number of provided coordinates (${e.length}) must match the rank (${this.rank})`);
|
|
844
855
|
const s = this.locToIndex(e);
|
|
845
856
|
this.values[s] = t;
|
|
846
857
|
}
|
|
@@ -895,17 +906,17 @@ class an {
|
|
|
895
906
|
* @doc {heading: 'Tensors', subheading: 'Creation'}
|
|
896
907
|
*/
|
|
897
908
|
toTensor() {
|
|
898
|
-
return
|
|
909
|
+
return R().makeTensor(this.values, this.shape, this.dtype);
|
|
899
910
|
}
|
|
900
911
|
}
|
|
901
|
-
let
|
|
902
|
-
function ln(n) {
|
|
903
|
-
x = n;
|
|
904
|
-
}
|
|
912
|
+
let R = null, X = null;
|
|
905
913
|
function cn(n) {
|
|
906
|
-
|
|
914
|
+
R = n;
|
|
907
915
|
}
|
|
908
|
-
|
|
916
|
+
function un(n) {
|
|
917
|
+
X = n;
|
|
918
|
+
}
|
|
919
|
+
class x {
|
|
909
920
|
constructor(t, e, s, r) {
|
|
910
921
|
this.kept = !1, this.isDisposedInternal = !1, this.shape = t.slice(), this.dtype = e || "float32", this.size = U(t), this.strides = xt(t), this.dataId = s, this.id = r, this.rankType = this.rank < 5 ? this.rank.toString() : "higher";
|
|
911
922
|
}
|
|
@@ -919,14 +930,14 @@ class $ {
|
|
|
919
930
|
*/
|
|
920
931
|
async buffer() {
|
|
921
932
|
const t = await this.data();
|
|
922
|
-
return
|
|
933
|
+
return X.buffer(this.shape, this.dtype, t);
|
|
923
934
|
}
|
|
924
935
|
/**
|
|
925
936
|
* Returns a `tf.TensorBuffer` that holds the underlying data.
|
|
926
937
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
927
938
|
*/
|
|
928
939
|
bufferSync() {
|
|
929
|
-
return
|
|
940
|
+
return X.buffer(this.shape, this.dtype, this.dataSync());
|
|
930
941
|
}
|
|
931
942
|
/**
|
|
932
943
|
* Returns the tensor data as a nested array. The transfer of data is done
|
|
@@ -936,7 +947,7 @@ class $ {
|
|
|
936
947
|
*/
|
|
937
948
|
async array() {
|
|
938
949
|
const t = await this.data();
|
|
939
|
-
return
|
|
950
|
+
return Ot(this.shape, t, this.dtype === "complex64");
|
|
940
951
|
}
|
|
941
952
|
/**
|
|
942
953
|
* Returns the tensor data as a nested array. The transfer of data is done
|
|
@@ -945,7 +956,7 @@ class $ {
|
|
|
945
956
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
946
957
|
*/
|
|
947
958
|
arraySync() {
|
|
948
|
-
return
|
|
959
|
+
return Ot(this.shape, this.dataSync(), this.dtype === "complex64");
|
|
949
960
|
}
|
|
950
961
|
/**
|
|
951
962
|
* Asynchronously downloads the values from the `tf.Tensor`. Returns a
|
|
@@ -955,11 +966,11 @@ class $ {
|
|
|
955
966
|
*/
|
|
956
967
|
async data() {
|
|
957
968
|
this.throwIfDisposed();
|
|
958
|
-
const t =
|
|
969
|
+
const t = R().read(this.dataId);
|
|
959
970
|
if (this.dtype === "string") {
|
|
960
971
|
const e = await t;
|
|
961
972
|
try {
|
|
962
|
-
return e.map((s) =>
|
|
973
|
+
return e.map((s) => Wt(s));
|
|
963
974
|
} catch {
|
|
964
975
|
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
|
|
965
976
|
}
|
|
@@ -1001,7 +1012,7 @@ class $ {
|
|
|
1001
1012
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
1002
1013
|
*/
|
|
1003
1014
|
dataToGPU(t) {
|
|
1004
|
-
return this.throwIfDisposed(),
|
|
1015
|
+
return this.throwIfDisposed(), R().readToGPU(this.dataId, t);
|
|
1005
1016
|
}
|
|
1006
1017
|
/**
|
|
1007
1018
|
* Synchronously downloads the values from the `tf.Tensor`. This blocks the
|
|
@@ -1011,10 +1022,10 @@ class $ {
|
|
|
1011
1022
|
*/
|
|
1012
1023
|
dataSync() {
|
|
1013
1024
|
this.throwIfDisposed();
|
|
1014
|
-
const t =
|
|
1025
|
+
const t = R().readSync(this.dataId);
|
|
1015
1026
|
if (this.dtype === "string")
|
|
1016
1027
|
try {
|
|
1017
|
-
return t.map((e) =>
|
|
1028
|
+
return t.map((e) => Wt(e));
|
|
1018
1029
|
} catch {
|
|
1019
1030
|
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
|
|
1020
1031
|
}
|
|
@@ -1023,7 +1034,7 @@ class $ {
|
|
|
1023
1034
|
/** Returns the underlying bytes of the tensor's data. */
|
|
1024
1035
|
async bytes() {
|
|
1025
1036
|
this.throwIfDisposed();
|
|
1026
|
-
const t = await
|
|
1037
|
+
const t = await R().read(this.dataId);
|
|
1027
1038
|
return this.dtype === "string" ? t : new Uint8Array(t.buffer);
|
|
1028
1039
|
}
|
|
1029
1040
|
/**
|
|
@@ -1032,7 +1043,7 @@ class $ {
|
|
|
1032
1043
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
1033
1044
|
*/
|
|
1034
1045
|
dispose() {
|
|
1035
|
-
this.isDisposed || (this.kerasMask && this.kerasMask.dispose(),
|
|
1046
|
+
this.isDisposed || (this.kerasMask && this.kerasMask.dispose(), R().disposeTensor(this), this.isDisposedInternal = !0);
|
|
1036
1047
|
}
|
|
1037
1048
|
get isDisposed() {
|
|
1038
1049
|
return this.isDisposedInternal;
|
|
@@ -1050,14 +1061,14 @@ class $ {
|
|
|
1050
1061
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
1051
1062
|
*/
|
|
1052
1063
|
print(t = !1) {
|
|
1053
|
-
return
|
|
1064
|
+
return X.print(this, t);
|
|
1054
1065
|
}
|
|
1055
1066
|
/**
|
|
1056
1067
|
* Returns a copy of the tensor. See `tf.clone` for details.
|
|
1057
1068
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
1058
1069
|
*/
|
|
1059
1070
|
clone() {
|
|
1060
|
-
return this.throwIfDisposed(),
|
|
1071
|
+
return this.throwIfDisposed(), X.clone(this);
|
|
1061
1072
|
}
|
|
1062
1073
|
/**
|
|
1063
1074
|
* Returns a human-readable description of the tensor. Useful for logging.
|
|
@@ -1066,23 +1077,23 @@ class $ {
|
|
|
1066
1077
|
*/
|
|
1067
1078
|
toString(t = !1) {
|
|
1068
1079
|
const e = this.dataSync();
|
|
1069
|
-
return
|
|
1080
|
+
return on(e, this.shape, this.dtype, t);
|
|
1070
1081
|
}
|
|
1071
1082
|
cast(t) {
|
|
1072
|
-
return this.throwIfDisposed(),
|
|
1083
|
+
return this.throwIfDisposed(), X.cast(this, t);
|
|
1073
1084
|
}
|
|
1074
1085
|
variable(t = !0, e, s) {
|
|
1075
|
-
return this.throwIfDisposed(),
|
|
1086
|
+
return this.throwIfDisposed(), R().makeVariable(this, t, e, s);
|
|
1076
1087
|
}
|
|
1077
1088
|
}
|
|
1078
|
-
Object.defineProperty(
|
|
1089
|
+
Object.defineProperty(x, Symbol.hasInstance, {
|
|
1079
1090
|
value: (n) => !!n && n.data != null && n.dataSync != null && n.throwIfDisposed != null
|
|
1080
1091
|
});
|
|
1081
|
-
function
|
|
1082
|
-
return
|
|
1092
|
+
function le() {
|
|
1093
|
+
return Ct("Tensor", () => x);
|
|
1083
1094
|
}
|
|
1084
|
-
|
|
1085
|
-
class
|
|
1095
|
+
le();
|
|
1096
|
+
class dt extends x {
|
|
1086
1097
|
constructor(t, e, s, r) {
|
|
1087
1098
|
super(t.shape, t.dtype, t.dataId, r), this.trainable = e, this.name = s;
|
|
1088
1099
|
}
|
|
@@ -1097,20 +1108,20 @@ class ft extends $ {
|
|
|
1097
1108
|
assign(t) {
|
|
1098
1109
|
if (t.dtype !== this.dtype)
|
|
1099
1110
|
throw new Error(`dtype of the new value (${t.dtype}) and previous value (${this.dtype}) must match`);
|
|
1100
|
-
if (
|
|
1111
|
+
if (!$t(t.shape, this.shape))
|
|
1101
1112
|
throw new Error(`shape of the new value (${t.shape}) and previous value (${this.shape}) must match`);
|
|
1102
|
-
|
|
1113
|
+
R().disposeTensor(this), this.dataId = t.dataId, R().incRef(
|
|
1103
1114
|
this,
|
|
1104
1115
|
null
|
|
1105
1116
|
/* backend */
|
|
1106
1117
|
);
|
|
1107
1118
|
}
|
|
1108
1119
|
dispose() {
|
|
1109
|
-
|
|
1120
|
+
R().disposeVariable(this), this.isDisposedInternal = !0;
|
|
1110
1121
|
}
|
|
1111
1122
|
}
|
|
1112
|
-
Object.defineProperty(
|
|
1113
|
-
value: (n) => n instanceof
|
|
1123
|
+
Object.defineProperty(dt, Symbol.hasInstance, {
|
|
1124
|
+
value: (n) => n instanceof x && n.assign != null && n.assign instanceof Function
|
|
1114
1125
|
});
|
|
1115
1126
|
/**
|
|
1116
1127
|
* @license
|
|
@@ -1128,44 +1139,44 @@ Object.defineProperty(ft, Symbol.hasInstance, {
|
|
|
1128
1139
|
* limitations under the License.
|
|
1129
1140
|
* =============================================================================
|
|
1130
1141
|
*/
|
|
1131
|
-
var
|
|
1142
|
+
var Kt;
|
|
1132
1143
|
(function(n) {
|
|
1133
1144
|
n.R0 = "R0", n.R1 = "R1", n.R2 = "R2", n.R3 = "R3", n.R4 = "R4", n.R5 = "R5", n.R6 = "R6";
|
|
1134
|
-
})(
|
|
1145
|
+
})(Kt || (Kt = {}));
|
|
1135
1146
|
var It;
|
|
1136
1147
|
(function(n) {
|
|
1137
1148
|
n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64";
|
|
1138
1149
|
})(It || (It = {}));
|
|
1139
|
-
var kt;
|
|
1140
|
-
(function(n) {
|
|
1141
|
-
n.float32 = "float32", n.int32 = "int32", n.bool = "bool", n.complex64 = "complex64";
|
|
1142
|
-
})(kt || (kt = {}));
|
|
1143
1150
|
var Tt;
|
|
1144
1151
|
(function(n) {
|
|
1145
|
-
n.float32 = "float32", n.int32 = "
|
|
1152
|
+
n.float32 = "float32", n.int32 = "int32", n.bool = "bool", n.complex64 = "complex64";
|
|
1146
1153
|
})(Tt || (Tt = {}));
|
|
1147
1154
|
var Et;
|
|
1148
1155
|
(function(n) {
|
|
1149
|
-
n.float32 = "
|
|
1156
|
+
n.float32 = "float32", n.int32 = "float32", n.bool = "float32", n.complex64 = "complex64";
|
|
1150
1157
|
})(Et || (Et = {}));
|
|
1151
|
-
|
|
1152
|
-
|
|
1158
|
+
var At;
|
|
1159
|
+
(function(n) {
|
|
1160
|
+
n.float32 = "complex64", n.int32 = "complex64", n.bool = "complex64", n.complex64 = "complex64";
|
|
1161
|
+
})(At || (At = {}));
|
|
1162
|
+
const hn = {
|
|
1163
|
+
float32: Et,
|
|
1153
1164
|
int32: It,
|
|
1154
|
-
bool:
|
|
1155
|
-
complex64:
|
|
1165
|
+
bool: Tt,
|
|
1166
|
+
complex64: At
|
|
1156
1167
|
};
|
|
1157
|
-
function
|
|
1168
|
+
function fn(n, t) {
|
|
1158
1169
|
if (n === "string" || t === "string") {
|
|
1159
1170
|
if (n === "string" && t === "string")
|
|
1160
1171
|
return "string";
|
|
1161
1172
|
throw new Error(`Can not upcast ${n} with ${t}`);
|
|
1162
1173
|
}
|
|
1163
|
-
return
|
|
1174
|
+
return hn[n][t];
|
|
1164
1175
|
}
|
|
1165
|
-
function
|
|
1176
|
+
function ce(n) {
|
|
1166
1177
|
return n != null && typeof n == "object" && "texture" in n && n.texture instanceof WebGLTexture;
|
|
1167
1178
|
}
|
|
1168
|
-
function
|
|
1179
|
+
function ue(n) {
|
|
1169
1180
|
return typeof GPUBuffer < "u" && n != null && typeof n == "object" && "buffer" in n && n.buffer instanceof GPUBuffer;
|
|
1170
1181
|
}
|
|
1171
1182
|
/**
|
|
@@ -1187,29 +1198,29 @@ function ce(n) {
|
|
|
1187
1198
|
function K(n, t) {
|
|
1188
1199
|
if (n.dtype === t.dtype)
|
|
1189
1200
|
return [n, t];
|
|
1190
|
-
const e =
|
|
1201
|
+
const e = fn(n.dtype, t.dtype);
|
|
1191
1202
|
return [n.cast(e), t.cast(e)];
|
|
1192
1203
|
}
|
|
1193
|
-
function
|
|
1204
|
+
function he(n) {
|
|
1194
1205
|
const t = [];
|
|
1195
|
-
return
|
|
1206
|
+
return fe(n, t, /* @__PURE__ */ new Set()), t;
|
|
1196
1207
|
}
|
|
1197
|
-
function
|
|
1208
|
+
function fe(n, t, e) {
|
|
1198
1209
|
if (n == null)
|
|
1199
1210
|
return;
|
|
1200
|
-
if (n instanceof
|
|
1211
|
+
if (n instanceof x) {
|
|
1201
1212
|
t.push(n);
|
|
1202
1213
|
return;
|
|
1203
1214
|
}
|
|
1204
|
-
if (!
|
|
1215
|
+
if (!dn(n))
|
|
1205
1216
|
return;
|
|
1206
1217
|
const s = n;
|
|
1207
1218
|
for (const r in s) {
|
|
1208
1219
|
const i = s[r];
|
|
1209
|
-
e.has(i) || (e.add(i),
|
|
1220
|
+
e.has(i) || (e.add(i), fe(i, t, e));
|
|
1210
1221
|
}
|
|
1211
1222
|
}
|
|
1212
|
-
function
|
|
1223
|
+
function dn(n) {
|
|
1213
1224
|
return Array.isArray(n) || typeof n == "object";
|
|
1214
1225
|
}
|
|
1215
1226
|
/**
|
|
@@ -1228,10 +1239,10 @@ function fn(n) {
|
|
|
1228
1239
|
* limitations under the License.
|
|
1229
1240
|
* =============================================================================
|
|
1230
1241
|
*/
|
|
1231
|
-
function
|
|
1242
|
+
function bt(n) {
|
|
1232
1243
|
return n.kernelName != null;
|
|
1233
1244
|
}
|
|
1234
|
-
class
|
|
1245
|
+
class Vt {
|
|
1235
1246
|
constructor() {
|
|
1236
1247
|
this.registeredVariables = {}, this.nextTapeNodeId = 0, this.numBytes = 0, this.numTensors = 0, this.numStringTensors = 0, this.numDataBuffers = 0, this.gradientDepth = 0, this.kernelDepth = 0, this.scopeStack = [], this.numDataMovesStack = [], this.nextScopeId = 0, this.tensorInfo = /* @__PURE__ */ new WeakMap(), this.profiling = !1, this.activeProfile = {
|
|
1237
1248
|
newBytes: 0,
|
|
@@ -1249,9 +1260,9 @@ class Kt {
|
|
|
1249
1260
|
this.registeredVariables[t].dispose();
|
|
1250
1261
|
}
|
|
1251
1262
|
}
|
|
1252
|
-
class
|
|
1263
|
+
class tt {
|
|
1253
1264
|
constructor(t) {
|
|
1254
|
-
this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new
|
|
1265
|
+
this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new Vt();
|
|
1255
1266
|
}
|
|
1256
1267
|
async ready() {
|
|
1257
1268
|
if (this.pendingBackendInit != null)
|
|
@@ -1297,7 +1308,7 @@ class Z {
|
|
|
1297
1308
|
return t in this.registryFactory ? this.registryFactory[t].factory : null;
|
|
1298
1309
|
}
|
|
1299
1310
|
registerBackend(t, e, s = 1) {
|
|
1300
|
-
return t in this.registryFactory ? (
|
|
1311
|
+
return t in this.registryFactory ? (J(`${t} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[t] = { factory: e, priority: s }, !0);
|
|
1301
1312
|
}
|
|
1302
1313
|
async setBackend(t) {
|
|
1303
1314
|
if (this.registryFactory[t] == null)
|
|
@@ -1308,15 +1319,15 @@ class Z {
|
|
|
1308
1319
|
if (!(s ? await e : e))
|
|
1309
1320
|
return !1;
|
|
1310
1321
|
}
|
|
1311
|
-
return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new
|
|
1322
|
+
return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new tn(this.backendInstance), !0;
|
|
1312
1323
|
}
|
|
1313
1324
|
setupRegisteredKernels() {
|
|
1314
|
-
|
|
1325
|
+
zt(this.backendName).forEach((e) => {
|
|
1315
1326
|
e.setupFunc != null && e.setupFunc(this.backendInstance);
|
|
1316
1327
|
});
|
|
1317
1328
|
}
|
|
1318
1329
|
disposeRegisteredKernels(t) {
|
|
1319
|
-
|
|
1330
|
+
zt(t).forEach((s) => {
|
|
1320
1331
|
s.disposeFunc != null && s.disposeFunc(this.registry[t]);
|
|
1321
1332
|
});
|
|
1322
1333
|
}
|
|
@@ -1332,13 +1343,13 @@ class Z {
|
|
|
1332
1343
|
throw new Error(`Cannot initialize backend ${t}, no registration found.`);
|
|
1333
1344
|
try {
|
|
1334
1345
|
const s = e.factory();
|
|
1335
|
-
if (s && !(s instanceof
|
|
1336
|
-
const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null,
|
|
1346
|
+
if (s && !(s instanceof Ae) && typeof s.then == "function") {
|
|
1347
|
+
const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null, J(`Initialization of backend ${t} failed`), J(o.stack || o.message)), !1));
|
|
1337
1348
|
return this.pendingBackendInit = i, { success: i, asyncInit: !0 };
|
|
1338
1349
|
} else
|
|
1339
1350
|
return this.registry[t] = s, { success: !0, asyncInit: !1 };
|
|
1340
1351
|
} catch (s) {
|
|
1341
|
-
return
|
|
1352
|
+
return J(`Initialization of backend ${t} failed`), J(s.stack || s.message), { success: !1, asyncInit: !1 };
|
|
1342
1353
|
}
|
|
1343
1354
|
}
|
|
1344
1355
|
removeBackend(t) {
|
|
@@ -1390,10 +1401,10 @@ class Z {
|
|
|
1390
1401
|
}
|
|
1391
1402
|
}
|
|
1392
1403
|
nextTensorId() {
|
|
1393
|
-
return
|
|
1404
|
+
return tt.nextTensorId++;
|
|
1394
1405
|
}
|
|
1395
1406
|
nextVariableId() {
|
|
1396
|
-
return
|
|
1407
|
+
return tt.nextVariableId++;
|
|
1397
1408
|
}
|
|
1398
1409
|
/**
|
|
1399
1410
|
* This method is called instead of the public-facing tensor.clone() when
|
|
@@ -1402,11 +1413,11 @@ class Z {
|
|
|
1402
1413
|
* execution.
|
|
1403
1414
|
*/
|
|
1404
1415
|
clone(t) {
|
|
1405
|
-
const e = g.runKernel(
|
|
1416
|
+
const e = g.runKernel(se, { x: t }), s = { x: t }, r = (o) => ({
|
|
1406
1417
|
x: () => {
|
|
1407
1418
|
const a = "float32", c = { x: o }, l = { dtype: a };
|
|
1408
1419
|
return g.runKernel(
|
|
1409
|
-
|
|
1420
|
+
ne,
|
|
1410
1421
|
c,
|
|
1411
1422
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
1412
1423
|
l
|
|
@@ -1429,7 +1440,7 @@ class Z {
|
|
|
1429
1440
|
* tensors are not visible to the user.
|
|
1430
1441
|
*/
|
|
1431
1442
|
runKernel(t, e, s) {
|
|
1432
|
-
if (this.backendName == null && this.backend, !(
|
|
1443
|
+
if (this.backendName == null && this.backend, !(Ut(t, this.backendName) != null))
|
|
1433
1444
|
throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);
|
|
1434
1445
|
return this.runKernelFunc({ kernelName: t, inputs: e, attrs: s });
|
|
1435
1446
|
}
|
|
@@ -1458,35 +1469,35 @@ class Z {
|
|
|
1458
1469
|
let a;
|
|
1459
1470
|
this.backendName == null && this.backend;
|
|
1460
1471
|
let c;
|
|
1461
|
-
const l =
|
|
1462
|
-
if (
|
|
1463
|
-
const { kernelName:
|
|
1472
|
+
const l = bt(t) ? t.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
|
|
1473
|
+
if (bt(t)) {
|
|
1474
|
+
const { kernelName: b, inputs: d, attrs: k } = t;
|
|
1464
1475
|
this.backendName == null && this.backend;
|
|
1465
|
-
const T =
|
|
1466
|
-
|
|
1467
|
-
const
|
|
1468
|
-
c = T.kernelFunc({ inputs: d, attrs:
|
|
1476
|
+
const T = Ut(b, this.backendName);
|
|
1477
|
+
y(T != null, () => `Cannot find registered kernel '${b}' for backend '${this.backendName}'`), a = () => {
|
|
1478
|
+
const nt = this.backend.numDataIds();
|
|
1479
|
+
c = T.kernelFunc({ inputs: d, attrs: k, backend: this.backend });
|
|
1469
1480
|
const q = Array.isArray(c) ? c : [c];
|
|
1470
|
-
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(
|
|
1471
|
-
const H = q.map((
|
|
1481
|
+
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(b, nt, q);
|
|
1482
|
+
const H = q.map((st) => st.rank != null ? st : this.makeTensorFromTensorInfo(st));
|
|
1472
1483
|
if (r) {
|
|
1473
|
-
const
|
|
1474
|
-
s = this.saveTensorsForBackwardMode(
|
|
1484
|
+
const st = this.getTensorsForGradient(b, d, H);
|
|
1485
|
+
s = this.saveTensorsForBackwardMode(st);
|
|
1475
1486
|
}
|
|
1476
1487
|
return H;
|
|
1477
1488
|
};
|
|
1478
1489
|
} else {
|
|
1479
|
-
const { forwardFunc:
|
|
1480
|
-
r && (s =
|
|
1490
|
+
const { forwardFunc: b } = t, d = (k) => {
|
|
1491
|
+
r && (s = k.map((T) => this.keep(this.clone(T))));
|
|
1481
1492
|
};
|
|
1482
1493
|
a = () => {
|
|
1483
|
-
const
|
|
1484
|
-
c = this.tidy(() =>
|
|
1494
|
+
const k = this.backend.numDataIds();
|
|
1495
|
+
c = this.tidy(() => b(this.backend, d));
|
|
1485
1496
|
const T = Array.isArray(c) ? c : [c];
|
|
1486
|
-
return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(l,
|
|
1497
|
+
return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(l, k, T), T;
|
|
1487
1498
|
};
|
|
1488
1499
|
}
|
|
1489
|
-
const { inputs: u, attrs: h } = t, f =
|
|
1500
|
+
const { inputs: u, attrs: h } = t, f = bt(t) ? null : t.backwardsFunc;
|
|
1490
1501
|
let m;
|
|
1491
1502
|
return this.scopedRun(
|
|
1492
1503
|
// Stop recording to a tape when running a kernel.
|
|
@@ -1501,8 +1512,8 @@ class Z {
|
|
|
1501
1512
|
totalBytesSnapshot: this.state.numBytes,
|
|
1502
1513
|
tensorsAdded: this.state.numTensors - o,
|
|
1503
1514
|
totalTensorsSnapshot: this.state.numTensors,
|
|
1504
|
-
inputShapes: Object.keys(u).map((
|
|
1505
|
-
outputShapes: e.map((
|
|
1515
|
+
inputShapes: Object.keys(u).map((b) => u[b] != null ? u[b].shape : null),
|
|
1516
|
+
outputShapes: e.map((b) => b.shape),
|
|
1506
1517
|
kernelTimeMs: m.timeMs,
|
|
1507
1518
|
extraInfo: m.extraInfo
|
|
1508
1519
|
}), Array.isArray(c) ? e : e[0];
|
|
@@ -1523,11 +1534,11 @@ class Z {
|
|
|
1523
1534
|
* @param outputs an array of output tensors from forward mode of kernel.
|
|
1524
1535
|
*/
|
|
1525
1536
|
getTensorsForGradient(t, e, s) {
|
|
1526
|
-
const r =
|
|
1537
|
+
const r = Gt(t);
|
|
1527
1538
|
if (r != null) {
|
|
1528
1539
|
const i = r.inputsToSave || [], o = r.outputsToSave || [];
|
|
1529
1540
|
let a;
|
|
1530
|
-
r.saveAllInputs ? (
|
|
1541
|
+
r.saveAllInputs ? (y(Array.isArray(e), () => "saveAllInputs is true, expected inputs to be an array."), a = Object.keys(e).map((l) => e[l])) : a = i.map((l) => e[l]);
|
|
1531
1542
|
const c = s.filter((l, u) => o[u]);
|
|
1532
1543
|
return a.concat(c);
|
|
1533
1544
|
}
|
|
@@ -1543,10 +1554,10 @@ class Z {
|
|
|
1543
1554
|
throw new Error("Values passed to engine.makeTensor() are null");
|
|
1544
1555
|
s = s || "float32", r = r || this.backend;
|
|
1545
1556
|
let i = t;
|
|
1546
|
-
s === "string" && Rt(t[0]) && (i = t.map((c) =>
|
|
1547
|
-
const o = r.write(i, e, s), a = new
|
|
1557
|
+
s === "string" && Rt(t[0]) && (i = t.map((c) => Ze(c)));
|
|
1558
|
+
const o = r.write(i, e, s), a = new x(e, s, o, this.nextTensorId());
|
|
1548
1559
|
if (this.trackTensor(a, r), s === "string") {
|
|
1549
|
-
const c = this.state.tensorInfo.get(o), l =
|
|
1560
|
+
const c = this.state.tensorInfo.get(o), l = $e(i);
|
|
1550
1561
|
this.state.numBytes += l - c.bytes, c.bytes = l;
|
|
1551
1562
|
}
|
|
1552
1563
|
return a;
|
|
@@ -1568,12 +1579,12 @@ class Z {
|
|
|
1568
1579
|
* only increments the ref count used in memory tracking.
|
|
1569
1580
|
*/
|
|
1570
1581
|
makeTensorFromTensorInfo(t, e) {
|
|
1571
|
-
const { dataId: s, shape: r, dtype: i } = t, o = new
|
|
1582
|
+
const { dataId: s, shape: r, dtype: i } = t, o = new x(r, i, s, this.nextTensorId());
|
|
1572
1583
|
return this.trackTensor(o, e), o;
|
|
1573
1584
|
}
|
|
1574
1585
|
makeVariable(t, e = !0, s, r) {
|
|
1575
1586
|
s = s || this.nextVariableId().toString(), r != null && r !== t.dtype && (t = t.cast(r));
|
|
1576
|
-
const i = new
|
|
1587
|
+
const i = new dt(t, e, s, this.nextTensorId());
|
|
1577
1588
|
if (this.state.registeredVariables[i.name] != null)
|
|
1578
1589
|
throw new Error(`Variable with name ${i.name} was already registered`);
|
|
1579
1590
|
return this.state.registeredVariables[i.name] = i, this.incRef(i, this.backend), i;
|
|
@@ -1581,12 +1592,12 @@ class Z {
|
|
|
1581
1592
|
trackTensor(t, e) {
|
|
1582
1593
|
this.state.numTensors++, t.dtype === "string" && this.state.numStringTensors++;
|
|
1583
1594
|
let s = 0;
|
|
1584
|
-
t.dtype !== "complex64" && t.dtype !== "string" && (s = t.size *
|
|
1595
|
+
t.dtype !== "complex64" && t.dtype !== "string" && (s = t.size * St(t.dtype)), this.state.numBytes += s, this.state.tensorInfo.has(t.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(t.dataId, {
|
|
1585
1596
|
backend: e || this.backend,
|
|
1586
1597
|
dtype: t.dtype,
|
|
1587
1598
|
shape: t.shape,
|
|
1588
1599
|
bytes: s
|
|
1589
|
-
})), t instanceof
|
|
1600
|
+
})), t instanceof dt || this.track(t);
|
|
1590
1601
|
}
|
|
1591
1602
|
// Track the tensor by dataId and increase the refCount for the dataId in the
|
|
1592
1603
|
// backend.
|
|
@@ -1604,7 +1615,7 @@ class Z {
|
|
|
1604
1615
|
return;
|
|
1605
1616
|
const e = this.state.tensorInfo.get(t.dataId);
|
|
1606
1617
|
if (this.state.numTensors--, t.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= e.bytes), t.dtype !== "complex64" && t.dtype !== "string") {
|
|
1607
|
-
const s = t.size *
|
|
1618
|
+
const s = t.size * St(t.dtype);
|
|
1608
1619
|
this.state.numBytes -= s;
|
|
1609
1620
|
}
|
|
1610
1621
|
e.backend.disposeData(t.dataId) && this.removeDataId(t.dataId, e.backend);
|
|
@@ -1634,10 +1645,10 @@ class Z {
|
|
|
1634
1645
|
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
|
1635
1646
|
}
|
|
1636
1647
|
addTapeNode(t, e, s, r, i, o) {
|
|
1637
|
-
const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c =
|
|
1648
|
+
const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c = Gt(t);
|
|
1638
1649
|
c != null && (r = c.gradFunc), r != null && (a.gradient = (l) => (l = l.map((u, h) => {
|
|
1639
1650
|
if (u == null) {
|
|
1640
|
-
const f = s[h], m =
|
|
1651
|
+
const f = s[h], m = Qt(f.size, f.dtype);
|
|
1641
1652
|
return this.makeTensor(m, f.shape, f.dtype);
|
|
1642
1653
|
}
|
|
1643
1654
|
return u;
|
|
@@ -1669,7 +1680,7 @@ class Z {
|
|
|
1669
1680
|
* as scope() without the need for a function closure.
|
|
1670
1681
|
*/
|
|
1671
1682
|
endScope(t) {
|
|
1672
|
-
const e =
|
|
1683
|
+
const e = he(t), s = new Set(e.map((i) => i.id));
|
|
1673
1684
|
for (let i = 0; i < this.state.activeScope.track.length; i++) {
|
|
1674
1685
|
const o = this.state.activeScope.track[i];
|
|
1675
1686
|
!o.kept && !s.has(o.id) && o.dispose();
|
|
@@ -1686,22 +1697,22 @@ class Z {
|
|
|
1686
1697
|
* gradient, which defaults to `1`.
|
|
1687
1698
|
*/
|
|
1688
1699
|
gradients(t, e, s, r = !1) {
|
|
1689
|
-
if (
|
|
1700
|
+
if (y(e.length > 0, () => "gradients() received an empty list of xs."), s != null && s.dtype !== "float32")
|
|
1690
1701
|
throw new Error(`dy must have 'float32' dtype, but has '${s.dtype}'`);
|
|
1691
1702
|
const i = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy("forward", t));
|
|
1692
|
-
|
|
1693
|
-
const o =
|
|
1703
|
+
y(i instanceof x, () => "The result y returned by f() must be a tensor.");
|
|
1704
|
+
const o = sn(this.state.activeTape, e, i);
|
|
1694
1705
|
if (!r && o.length === 0 && e.length > 0)
|
|
1695
1706
|
throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.");
|
|
1696
1707
|
return this.tidy("backward", () => {
|
|
1697
1708
|
const a = {};
|
|
1698
|
-
a[i.id] = s ??
|
|
1709
|
+
a[i.id] = s ?? gn(i.shape), rn(
|
|
1699
1710
|
a,
|
|
1700
1711
|
o,
|
|
1701
1712
|
// Pass the tidy function to avoid circular dep with `tape.ts`.
|
|
1702
1713
|
(l) => this.tidy(l),
|
|
1703
1714
|
// Pass an add function to avoide a circular dep with `tape.ts`.
|
|
1704
|
-
|
|
1715
|
+
mn
|
|
1705
1716
|
);
|
|
1706
1717
|
const c = e.map((l) => a[l.id]);
|
|
1707
1718
|
return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((l) => {
|
|
@@ -1711,16 +1722,16 @@ class Z {
|
|
|
1711
1722
|
});
|
|
1712
1723
|
}
|
|
1713
1724
|
customGrad(t) {
|
|
1714
|
-
return
|
|
1715
|
-
|
|
1725
|
+
return y(kt(t), () => "The f passed in customGrad(f) must be a function."), (...e) => {
|
|
1726
|
+
y(e.every((a) => a instanceof x), () => "The args passed in customGrad(f)(x1, x2,...) must all be tensors");
|
|
1716
1727
|
let s;
|
|
1717
1728
|
const r = {};
|
|
1718
1729
|
e.forEach((a, c) => {
|
|
1719
1730
|
r[c] = a;
|
|
1720
1731
|
});
|
|
1721
|
-
const i = (a, c) => (s = t(...e, c),
|
|
1732
|
+
const i = (a, c) => (s = t(...e, c), y(s.value instanceof x, () => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"), y(kt(s.gradFunc), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."), s.value), o = (a, c) => {
|
|
1722
1733
|
const l = s.gradFunc(a, c), u = Array.isArray(l) ? l : [l];
|
|
1723
|
-
|
|
1734
|
+
y(u.length === e.length, () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."), y(u.every((f) => f instanceof x), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors.");
|
|
1724
1735
|
const h = {};
|
|
1725
1736
|
return u.forEach((f, m) => {
|
|
1726
1737
|
h[m] = () => f;
|
|
@@ -1743,8 +1754,8 @@ class Z {
|
|
|
1743
1754
|
return this.state.tensorInfo.get(t).backend.readToGPU(t, e);
|
|
1744
1755
|
}
|
|
1745
1756
|
async time(t) {
|
|
1746
|
-
const e =
|
|
1747
|
-
return s.wallMs =
|
|
1757
|
+
const e = ft(), s = await this.backend.time(t);
|
|
1758
|
+
return s.wallMs = ft() - e, s;
|
|
1748
1759
|
}
|
|
1749
1760
|
/**
|
|
1750
1761
|
* Tracks a Tensor in the current scope to be automatically cleaned up
|
|
@@ -1763,30 +1774,30 @@ class Z {
|
|
|
1763
1774
|
* registered backend factories.
|
|
1764
1775
|
*/
|
|
1765
1776
|
reset() {
|
|
1766
|
-
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new
|
|
1777
|
+
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new Vt();
|
|
1767
1778
|
for (const t in this.registry)
|
|
1768
1779
|
this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t];
|
|
1769
1780
|
this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
|
|
1770
1781
|
}
|
|
1771
1782
|
}
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
function
|
|
1775
|
-
const t =
|
|
1783
|
+
tt.nextTensorId = 0;
|
|
1784
|
+
tt.nextVariableId = 0;
|
|
1785
|
+
function gn(n) {
|
|
1786
|
+
const t = Ne(U(n), "float32");
|
|
1776
1787
|
return g.makeTensor(t, n, "float32");
|
|
1777
1788
|
}
|
|
1778
|
-
function
|
|
1779
|
-
const n =
|
|
1789
|
+
function de() {
|
|
1790
|
+
const n = te();
|
|
1780
1791
|
if (n._tfengine == null) {
|
|
1781
|
-
const t = new
|
|
1782
|
-
n._tfengine = new
|
|
1792
|
+
const t = new De(n);
|
|
1793
|
+
n._tfengine = new tt(t);
|
|
1783
1794
|
}
|
|
1784
|
-
return
|
|
1795
|
+
return Oe(n._tfengine.ENV), cn(() => n._tfengine), n._tfengine;
|
|
1785
1796
|
}
|
|
1786
|
-
const g =
|
|
1787
|
-
function
|
|
1797
|
+
const g = de();
|
|
1798
|
+
function mn(n, t) {
|
|
1788
1799
|
const e = { a: n, b: t };
|
|
1789
|
-
return g.runKernel(
|
|
1800
|
+
return g.runKernel(ee, e);
|
|
1790
1801
|
}
|
|
1791
1802
|
/**
|
|
1792
1803
|
* @license
|
|
@@ -1804,26 +1815,26 @@ function gn(n, t) {
|
|
|
1804
1815
|
* limitations under the License.
|
|
1805
1816
|
* =============================================================================
|
|
1806
1817
|
*/
|
|
1807
|
-
function
|
|
1818
|
+
function pn() {
|
|
1808
1819
|
return typeof window < "u" && window.document != null || //@ts-ignore
|
|
1809
1820
|
typeof WorkerGlobalScope < "u";
|
|
1810
1821
|
}
|
|
1811
|
-
const
|
|
1812
|
-
|
|
1822
|
+
const B = S();
|
|
1823
|
+
B.registerFlag("DEBUG", () => !1, (n) => {
|
|
1813
1824
|
n && console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.");
|
|
1814
1825
|
});
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1826
|
+
B.registerFlag("IS_BROWSER", () => pn());
|
|
1827
|
+
B.registerFlag("IS_NODE", () => typeof Q < "u" && typeof Q.versions < "u" && typeof Q.versions.node < "u");
|
|
1828
|
+
B.registerFlag("IS_CHROME", () => typeof navigator < "u" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor));
|
|
1829
|
+
B.registerFlag("IS_SAFARI", () => typeof navigator < "u" && navigator != null && navigator.userAgent != null && /Safari/.test(navigator.userAgent) && /Apple/.test(navigator.vendor));
|
|
1830
|
+
B.registerFlag("PROD", () => !1);
|
|
1831
|
+
B.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", () => B.getBool("DEBUG"));
|
|
1832
|
+
B.registerFlag("DEPRECATION_WARNINGS_ENABLED", () => !0);
|
|
1833
|
+
B.registerFlag("IS_TEST", () => !1);
|
|
1834
|
+
B.registerFlag("CHECK_COMPUTATION_FOR_ERRORS", () => B.getBool("DEBUG"));
|
|
1835
|
+
B.registerFlag("WRAP_TO_IMAGEBITMAP", () => !1);
|
|
1836
|
+
B.registerFlag("CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU", () => !1);
|
|
1837
|
+
B.registerFlag("USE_SETTIMEOUTCUSTOM", () => !1);
|
|
1827
1838
|
/**
|
|
1828
1839
|
* @license
|
|
1829
1840
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -1840,33 +1851,33 @@ A.registerFlag("USE_SETTIMEOUTCUSTOM", () => !1);
|
|
|
1840
1851
|
* limitations under the License.
|
|
1841
1852
|
* =============================================================================
|
|
1842
1853
|
*/
|
|
1843
|
-
function
|
|
1854
|
+
function yn(n, t) {
|
|
1844
1855
|
let e = n;
|
|
1845
|
-
if (
|
|
1856
|
+
if ($(n))
|
|
1846
1857
|
return t === "string" ? [] : [n.length];
|
|
1847
|
-
if (
|
|
1858
|
+
if (ce(n)) {
|
|
1848
1859
|
const r = n.channels || "RGBA";
|
|
1849
1860
|
return [n.height, n.width * r.length];
|
|
1850
|
-
} else if (
|
|
1851
|
-
return [n.buffer.size / (t == null ? 4 :
|
|
1861
|
+
} else if (ue(n))
|
|
1862
|
+
return [n.buffer.size / (t == null ? 4 : St(t))];
|
|
1852
1863
|
if (!Array.isArray(n))
|
|
1853
1864
|
return [];
|
|
1854
1865
|
const s = [];
|
|
1855
|
-
for (; Array.isArray(e) ||
|
|
1866
|
+
for (; Array.isArray(e) || $(e) && t !== "string"; )
|
|
1856
1867
|
s.push(e.length), e = e[0];
|
|
1857
|
-
return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") &&
|
|
1868
|
+
return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") && ge(n, s, []), s;
|
|
1858
1869
|
}
|
|
1859
|
-
function
|
|
1860
|
-
if (e = e || [], !Array.isArray(n) &&
|
|
1861
|
-
|
|
1870
|
+
function ge(n, t, e) {
|
|
1871
|
+
if (e = e || [], !Array.isArray(n) && !$(n)) {
|
|
1872
|
+
y(t.length === 0, () => `Element arr[${e.join("][")}] is a primitive, but should be an array/TypedArray of ${t[0]} elements`);
|
|
1862
1873
|
return;
|
|
1863
1874
|
}
|
|
1864
|
-
|
|
1875
|
+
y(t.length > 0, () => `Element arr[${e.join("][")}] should be a primitive, but is an array of ${n.length} elements`), y(n.length === t[0], () => `Element arr[${e.join("][")}] should have ${t[0]} elements, but has ${n.length} elements`);
|
|
1865
1876
|
const s = t.slice(1);
|
|
1866
1877
|
for (let r = 0; r < n.length; ++r)
|
|
1867
|
-
|
|
1878
|
+
ge(n[r], s, e.concat(r));
|
|
1868
1879
|
}
|
|
1869
|
-
function
|
|
1880
|
+
function qt(n, t, e, s) {
|
|
1870
1881
|
if (n !== "string_or_numeric") {
|
|
1871
1882
|
if (n == null)
|
|
1872
1883
|
throw new Error("Expected dtype cannot be null.");
|
|
@@ -1874,19 +1885,24 @@ function Vt(n, t, e, s) {
|
|
|
1874
1885
|
throw new Error(`Argument '${e}' passed to '${s}' must be ${n} tensor, but got ${t} tensor`);
|
|
1875
1886
|
}
|
|
1876
1887
|
}
|
|
1877
|
-
function
|
|
1878
|
-
if (n instanceof
|
|
1879
|
-
return
|
|
1880
|
-
let r =
|
|
1881
|
-
if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s),
|
|
1888
|
+
function I(n, t, e, s = "numeric") {
|
|
1889
|
+
if (n instanceof le())
|
|
1890
|
+
return qt(s, n.dtype, t, e), n;
|
|
1891
|
+
let r = mt(n);
|
|
1892
|
+
if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s), qt(s, r, t, e), n == null || !$(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string") {
|
|
1882
1893
|
const c = n == null ? "null" : n.constructor.name;
|
|
1883
1894
|
throw new Error(`Argument '${t}' passed to '${e}' must be a Tensor or TensorLike, but got '${c}'`);
|
|
1884
1895
|
}
|
|
1885
|
-
const i =
|
|
1886
|
-
|
|
1887
|
-
const a = r !== "string" ?
|
|
1896
|
+
const i = yn(n, r);
|
|
1897
|
+
!$(n) && !Array.isArray(n) && (n = [n]);
|
|
1898
|
+
const a = r !== "string" ? oe(n, r) : at(n, [], !0);
|
|
1888
1899
|
return g.makeTensor(a, i, r);
|
|
1889
1900
|
}
|
|
1901
|
+
function Js(n, t, e, s = "numeric") {
|
|
1902
|
+
if (!Array.isArray(n))
|
|
1903
|
+
throw new Error(`Argument ${t} passed to ${e} must be a \`Tensor[]\` or \`TensorLike[]\``);
|
|
1904
|
+
return n.map((i, o) => I(i, `${t}[${o}]`, e, s));
|
|
1905
|
+
}
|
|
1890
1906
|
/**
|
|
1891
1907
|
* @license
|
|
1892
1908
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -1903,19 +1919,19 @@ function k(n, t, e, s = "numeric") {
|
|
|
1903
1919
|
* limitations under the License.
|
|
1904
1920
|
* =============================================================================
|
|
1905
1921
|
*/
|
|
1906
|
-
const
|
|
1922
|
+
const bn = "__op";
|
|
1907
1923
|
function F(n) {
|
|
1908
1924
|
const t = Object.keys(n);
|
|
1909
1925
|
if (t.length !== 1)
|
|
1910
1926
|
throw new Error(`Please provide an object with a single key (operation name) mapping to a function. Got an object with ${t.length} keys.`);
|
|
1911
1927
|
let e = t[0];
|
|
1912
1928
|
const s = n[e];
|
|
1913
|
-
e.endsWith("_") && (e = e.substring(0, e.length - 1)), e = e +
|
|
1929
|
+
e.endsWith("_") && (e = e.substring(0, e.length - 1)), e = e + bn;
|
|
1914
1930
|
const r = (...i) => {
|
|
1915
1931
|
g.startScope(e);
|
|
1916
1932
|
try {
|
|
1917
1933
|
const o = s(...i);
|
|
1918
|
-
return
|
|
1934
|
+
return Dt(o) && console.error("Cannot return a Promise inside of tidy."), g.endScope(o), o;
|
|
1919
1935
|
} catch (o) {
|
|
1920
1936
|
throw g.endScope(null), o;
|
|
1921
1937
|
}
|
|
@@ -1938,28 +1954,28 @@ function F(n) {
|
|
|
1938
1954
|
* limitations under the License.
|
|
1939
1955
|
* =============================================================================
|
|
1940
1956
|
*/
|
|
1941
|
-
function
|
|
1957
|
+
function wn(n, t, e, s) {
|
|
1942
1958
|
if (s == null)
|
|
1943
|
-
s =
|
|
1959
|
+
s = mt(n);
|
|
1944
1960
|
else if (s === "complex64")
|
|
1945
1961
|
throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");
|
|
1946
|
-
if (
|
|
1962
|
+
if (ue(n) || ce(n)) {
|
|
1947
1963
|
if (s !== "float32" && s !== "int32")
|
|
1948
1964
|
throw new Error(`Creating tensor from GPU data only supports 'float32'|'int32' dtype, while the dtype is ${s}.`);
|
|
1949
1965
|
return g.backend.createTensorFromGPUData(n, t || e, s);
|
|
1950
1966
|
}
|
|
1951
|
-
if (
|
|
1967
|
+
if (!$(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string")
|
|
1952
1968
|
throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");
|
|
1953
1969
|
if (t != null) {
|
|
1954
|
-
|
|
1970
|
+
Nt(t);
|
|
1955
1971
|
const r = U(t), i = U(e);
|
|
1956
|
-
|
|
1972
|
+
y(r === i, () => `Based on the provided shape, [${t}], the tensor should have ${r} values but has ${i}`);
|
|
1957
1973
|
for (let o = 0; o < e.length; ++o) {
|
|
1958
1974
|
const a = e[o], c = o === e.length - 1 ? a !== U(t.slice(o)) : !0;
|
|
1959
|
-
|
|
1975
|
+
y(e[o] === t[o] || !c, () => `Error creating a new Tensor. Inferred shape (${e}) does not match the provided shape (${t}). `);
|
|
1960
1976
|
}
|
|
1961
1977
|
}
|
|
1962
|
-
return
|
|
1978
|
+
return !$(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = s !== "string" ? oe(n, s) : at(n, [], !0), g.makeTensor(n, t, s);
|
|
1963
1979
|
}
|
|
1964
1980
|
class lt {
|
|
1965
1981
|
/**
|
|
@@ -1973,7 +1989,7 @@ class lt {
|
|
|
1973
1989
|
return new lt(t).slice();
|
|
1974
1990
|
}
|
|
1975
1991
|
constructor(t) {
|
|
1976
|
-
if (this.shards = [], this.previousShardIndex = 0, t == null || (t instanceof Array || (t = [t]), t = t.map((s) =>
|
|
1992
|
+
if (this.shards = [], this.previousShardIndex = 0, t == null || (t instanceof Array || (t = [t]), t = t.map((s) => $(s) ? s.buffer : s), t.length === 0))
|
|
1977
1993
|
return;
|
|
1978
1994
|
this.bufferUniformSize = t[0].byteLength;
|
|
1979
1995
|
let e = 0;
|
|
@@ -1996,7 +2012,7 @@ class lt {
|
|
|
1996
2012
|
const r = e - t, i = new ArrayBuffer(r), o = new Uint8Array(i);
|
|
1997
2013
|
let a = 0;
|
|
1998
2014
|
for (let c = s; c < this.shards.length; c++) {
|
|
1999
|
-
const l = this.shards[c], h = t + a - l.start, f = a,
|
|
2015
|
+
const l = this.shards[c], h = t + a - l.start, f = a, b = Math.min(e, l.end) - l.start, d = new Uint8Array(l.buffer, h, b - h);
|
|
2000
2016
|
if (o.set(d, f), a += d.length, e < l.end)
|
|
2001
2017
|
break;
|
|
2002
2018
|
}
|
|
@@ -2015,11 +2031,11 @@ class lt {
|
|
|
2015
2031
|
}
|
|
2016
2032
|
if (e(this.shards[this.previousShardIndex]) === 0)
|
|
2017
2033
|
return this.previousShardIndex;
|
|
2018
|
-
const s =
|
|
2034
|
+
const s = Sn(this.shards, e);
|
|
2019
2035
|
return s === -1 ? -1 : (this.previousShardIndex = s, this.previousShardIndex);
|
|
2020
2036
|
}
|
|
2021
2037
|
}
|
|
2022
|
-
function
|
|
2038
|
+
function Sn(n, t) {
|
|
2023
2039
|
let e = 0, s = n.length;
|
|
2024
2040
|
for (; e <= s; ) {
|
|
2025
2041
|
const r = Math.floor((s - e) / 2) + e, i = t(n[r]);
|
|
@@ -2045,34 +2061,34 @@ function wn(n, t) {
|
|
|
2045
2061
|
* limitations under the License.
|
|
2046
2062
|
* =============================================================================
|
|
2047
2063
|
*/
|
|
2048
|
-
function
|
|
2064
|
+
function Xs() {
|
|
2049
2065
|
return g;
|
|
2050
2066
|
}
|
|
2051
2067
|
function E(n, t) {
|
|
2052
2068
|
return g.tidy(n, t);
|
|
2053
2069
|
}
|
|
2054
2070
|
function M(n) {
|
|
2055
|
-
|
|
2071
|
+
he(n).forEach((e) => e.dispose());
|
|
2056
2072
|
}
|
|
2057
|
-
function
|
|
2073
|
+
function kn(n) {
|
|
2058
2074
|
return g.keep(n);
|
|
2059
2075
|
}
|
|
2060
|
-
const
|
|
2061
|
-
function
|
|
2062
|
-
return
|
|
2076
|
+
const _t = typeof gt < "u" && (typeof Blob > "u" || typeof atob > "u" || typeof btoa > "u");
|
|
2077
|
+
function Ht(n) {
|
|
2078
|
+
return _t ? gt.byteLength(n, "utf8") : new Blob([n]).size;
|
|
2063
2079
|
}
|
|
2064
2080
|
function In(n) {
|
|
2065
|
-
if (
|
|
2066
|
-
return
|
|
2081
|
+
if (_t)
|
|
2082
|
+
return gt.from(n).toString("base64");
|
|
2067
2083
|
const t = new Uint8Array(n);
|
|
2068
2084
|
let e = "";
|
|
2069
2085
|
for (let s = 0, r = t.length; s < r; s++)
|
|
2070
2086
|
e += String.fromCharCode(t[s]);
|
|
2071
2087
|
return btoa(e);
|
|
2072
2088
|
}
|
|
2073
|
-
function
|
|
2074
|
-
if (
|
|
2075
|
-
const s =
|
|
2089
|
+
function Tn(n) {
|
|
2090
|
+
if (_t) {
|
|
2091
|
+
const s = gt.from(n, "base64");
|
|
2076
2092
|
return s.buffer.slice(s.byteOffset, s.byteOffset + s.byteLength);
|
|
2077
2093
|
}
|
|
2078
2094
|
const t = atob(n), e = new Uint8Array(t.length);
|
|
@@ -2080,14 +2096,14 @@ function kn(n) {
|
|
|
2080
2096
|
e.set([t.charCodeAt(s)], s);
|
|
2081
2097
|
return e.buffer;
|
|
2082
2098
|
}
|
|
2083
|
-
function
|
|
2099
|
+
function me(n) {
|
|
2084
2100
|
if (n.modelTopology instanceof ArrayBuffer)
|
|
2085
2101
|
throw new Error("Expected JSON model topology, received ArrayBuffer.");
|
|
2086
2102
|
return {
|
|
2087
2103
|
dateSaved: /* @__PURE__ */ new Date(),
|
|
2088
2104
|
modelTopologyType: "JSON",
|
|
2089
|
-
modelTopologyBytes: n.modelTopology == null ? 0 :
|
|
2090
|
-
weightSpecsBytes: n.weightSpecs == null ? 0 :
|
|
2105
|
+
modelTopologyBytes: n.modelTopology == null ? 0 : Ht(JSON.stringify(n.modelTopology)),
|
|
2106
|
+
weightSpecsBytes: n.weightSpecs == null ? 0 : Ht(JSON.stringify(n.weightSpecs)),
|
|
2091
2107
|
weightDataBytes: n.weightData == null ? 0 : new lt(n.weightData).byteLength
|
|
2092
2108
|
};
|
|
2093
2109
|
}
|
|
@@ -2107,12 +2123,12 @@ function ge(n) {
|
|
|
2107
2123
|
* limitations under the License.
|
|
2108
2124
|
* =============================================================================
|
|
2109
2125
|
*/
|
|
2110
|
-
class
|
|
2126
|
+
class A {
|
|
2111
2127
|
constructor() {
|
|
2112
2128
|
this.saveRouters = [], this.loadRouters = [];
|
|
2113
2129
|
}
|
|
2114
2130
|
static getInstance() {
|
|
2115
|
-
return
|
|
2131
|
+
return A.instance == null && (A.instance = new A()), A.instance;
|
|
2116
2132
|
}
|
|
2117
2133
|
/**
|
|
2118
2134
|
* Register a save-handler router.
|
|
@@ -2121,7 +2137,7 @@ class B {
|
|
|
2121
2137
|
* of `IOHandler` with the `save` method defined or `null`.
|
|
2122
2138
|
*/
|
|
2123
2139
|
static registerSaveRouter(t) {
|
|
2124
|
-
|
|
2140
|
+
A.getInstance().saveRouters.push(t);
|
|
2125
2141
|
}
|
|
2126
2142
|
/**
|
|
2127
2143
|
* Register a load-handler router.
|
|
@@ -2130,7 +2146,7 @@ class B {
|
|
|
2130
2146
|
* of `IOHandler` with the `load` method defined or `null`.
|
|
2131
2147
|
*/
|
|
2132
2148
|
static registerLoadRouter(t) {
|
|
2133
|
-
|
|
2149
|
+
A.getInstance().loadRouters.push(t);
|
|
2134
2150
|
}
|
|
2135
2151
|
/**
|
|
2136
2152
|
* Look up IOHandler for saving, given a URL-like string.
|
|
@@ -2141,7 +2157,7 @@ class B {
|
|
|
2141
2157
|
* @throws Error, if more than one match is found.
|
|
2142
2158
|
*/
|
|
2143
2159
|
static getSaveHandlers(t) {
|
|
2144
|
-
return
|
|
2160
|
+
return A.getHandlers(t, "save");
|
|
2145
2161
|
}
|
|
2146
2162
|
/**
|
|
2147
2163
|
* Look up IOHandler for loading, given a URL-like string.
|
|
@@ -2152,11 +2168,11 @@ class B {
|
|
|
2152
2168
|
* handler routers.
|
|
2153
2169
|
*/
|
|
2154
2170
|
static getLoadHandlers(t, e) {
|
|
2155
|
-
return
|
|
2171
|
+
return A.getHandlers(t, "load", e);
|
|
2156
2172
|
}
|
|
2157
2173
|
static getHandlers(t, e, s) {
|
|
2158
2174
|
const r = [];
|
|
2159
|
-
return (e === "load" ?
|
|
2175
|
+
return (e === "load" ? A.getInstance().loadRouters : A.getInstance().saveRouters).forEach((o) => {
|
|
2160
2176
|
const a = o(t, s);
|
|
2161
2177
|
a !== null && r.push(a);
|
|
2162
2178
|
}), r;
|
|
@@ -2178,8 +2194,8 @@ class B {
|
|
|
2178
2194
|
* limitations under the License.
|
|
2179
2195
|
* =============================================================================
|
|
2180
2196
|
*/
|
|
2181
|
-
const Bt = "tensorflowjs",
|
|
2182
|
-
function
|
|
2197
|
+
const Bt = "tensorflowjs", vt = 1, L = "models_store", P = "model_info_store";
|
|
2198
|
+
function pe() {
|
|
2183
2199
|
if (!S().getBool("IS_BROWSER"))
|
|
2184
2200
|
throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");
|
|
2185
2201
|
const n = typeof window > "u" ? self : window, t = n.indexedDB || n.mozIndexedDB || n.webkitIndexedDB || n.msIndexedDB || n.shimIndexedDB;
|
|
@@ -2187,13 +2203,13 @@ function me() {
|
|
|
2187
2203
|
throw new Error("The current browser does not appear to support IndexedDB.");
|
|
2188
2204
|
return t;
|
|
2189
2205
|
}
|
|
2190
|
-
function
|
|
2206
|
+
function Mt(n) {
|
|
2191
2207
|
const t = n.result;
|
|
2192
2208
|
t.createObjectStore(L, { keyPath: "modelPath" }), t.createObjectStore(P, { keyPath: "modelPath" });
|
|
2193
2209
|
}
|
|
2194
2210
|
class z {
|
|
2195
2211
|
constructor(t) {
|
|
2196
|
-
if (this.indexedDB =
|
|
2212
|
+
if (this.indexedDB = pe(), t == null || !t)
|
|
2197
2213
|
throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
|
|
2198
2214
|
this.modelPath = t;
|
|
2199
2215
|
}
|
|
@@ -2221,8 +2237,8 @@ class z {
|
|
|
2221
2237
|
*/
|
|
2222
2238
|
databaseAction(t, e) {
|
|
2223
2239
|
return new Promise((s, r) => {
|
|
2224
|
-
const i = this.indexedDB.open(Bt,
|
|
2225
|
-
i.onupgradeneeded = () =>
|
|
2240
|
+
const i = this.indexedDB.open(Bt, vt);
|
|
2241
|
+
i.onupgradeneeded = () => Mt(i), i.onsuccess = () => {
|
|
2226
2242
|
const o = i.result;
|
|
2227
2243
|
if (e == null) {
|
|
2228
2244
|
const a = o.transaction(L, "readonly"), l = a.objectStore(L).get(this.modelPath);
|
|
@@ -2233,7 +2249,7 @@ class z {
|
|
|
2233
2249
|
}, l.onerror = (u) => (o.close(), r(l.error)), a.oncomplete = () => o.close();
|
|
2234
2250
|
} else {
|
|
2235
2251
|
e.weightData = lt.join(e.weightData);
|
|
2236
|
-
const a =
|
|
2252
|
+
const a = me(e), c = o.transaction(P, "readwrite");
|
|
2237
2253
|
let l = c.objectStore(P), u;
|
|
2238
2254
|
try {
|
|
2239
2255
|
u = l.put({ modelPath: this.modelPath, modelArtifactsInfo: a });
|
|
@@ -2251,13 +2267,13 @@ class z {
|
|
|
2251
2267
|
modelArtifacts: e,
|
|
2252
2268
|
modelArtifactsInfo: a
|
|
2253
2269
|
});
|
|
2254
|
-
} catch (
|
|
2255
|
-
return r(
|
|
2270
|
+
} catch (b) {
|
|
2271
|
+
return r(b);
|
|
2256
2272
|
}
|
|
2257
|
-
m.onsuccess = () => s({ modelArtifactsInfo: a }), m.onerror = (
|
|
2273
|
+
m.onsuccess = () => s({ modelArtifactsInfo: a }), m.onerror = (b) => {
|
|
2258
2274
|
l = c.objectStore(P);
|
|
2259
2275
|
const d = l.delete(this.modelPath);
|
|
2260
|
-
d.onsuccess = () => (o.close(), r(m.error)), d.onerror = (
|
|
2276
|
+
d.onsuccess = () => (o.close(), r(m.error)), d.onerror = (k) => (o.close(), r(m.error));
|
|
2261
2277
|
};
|
|
2262
2278
|
}, u.onerror = (f) => (o.close(), r(u.error)), c.oncomplete = () => {
|
|
2263
2279
|
h == null ? o.close() : h.oncomplete = () => o.close();
|
|
@@ -2268,23 +2284,23 @@ class z {
|
|
|
2268
2284
|
}
|
|
2269
2285
|
}
|
|
2270
2286
|
z.URL_SCHEME = "indexeddb://";
|
|
2271
|
-
const
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
function
|
|
2287
|
+
const ye = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(z.URL_SCHEME) ? En(n.slice(z.URL_SCHEME.length)) : null;
|
|
2288
|
+
A.registerSaveRouter(ye);
|
|
2289
|
+
A.registerLoadRouter(ye);
|
|
2290
|
+
function En(n) {
|
|
2275
2291
|
return new z(n);
|
|
2276
2292
|
}
|
|
2277
|
-
function
|
|
2293
|
+
function An(n) {
|
|
2278
2294
|
return n.startsWith(z.URL_SCHEME) ? n.slice(z.URL_SCHEME.length) : n;
|
|
2279
2295
|
}
|
|
2280
2296
|
class Bn {
|
|
2281
2297
|
constructor() {
|
|
2282
|
-
this.indexedDB =
|
|
2298
|
+
this.indexedDB = pe();
|
|
2283
2299
|
}
|
|
2284
2300
|
async listModels() {
|
|
2285
2301
|
return new Promise((t, e) => {
|
|
2286
|
-
const s = this.indexedDB.open(Bt,
|
|
2287
|
-
s.onupgradeneeded = () =>
|
|
2302
|
+
const s = this.indexedDB.open(Bt, vt);
|
|
2303
|
+
s.onupgradeneeded = () => Mt(s), s.onsuccess = () => {
|
|
2288
2304
|
const r = s.result, i = r.transaction(P, "readonly"), a = i.objectStore(P).getAll();
|
|
2289
2305
|
a.onsuccess = () => {
|
|
2290
2306
|
const c = {};
|
|
@@ -2296,9 +2312,9 @@ class Bn {
|
|
|
2296
2312
|
});
|
|
2297
2313
|
}
|
|
2298
2314
|
async removeModel(t) {
|
|
2299
|
-
return t =
|
|
2300
|
-
const r = this.indexedDB.open(Bt,
|
|
2301
|
-
r.onupgradeneeded = () =>
|
|
2315
|
+
return t = An(t), new Promise((e, s) => {
|
|
2316
|
+
const r = this.indexedDB.open(Bt, vt);
|
|
2317
|
+
r.onupgradeneeded = () => Mt(r), r.onsuccess = () => {
|
|
2302
2318
|
const i = r.result, o = i.transaction(P, "readwrite"), a = o.objectStore(P), c = a.get(t);
|
|
2303
2319
|
let l;
|
|
2304
2320
|
c.onsuccess = () => {
|
|
@@ -2308,7 +2324,7 @@ class Bn {
|
|
|
2308
2324
|
const u = a.delete(t), h = () => {
|
|
2309
2325
|
l = i.transaction(L, "readwrite");
|
|
2310
2326
|
const m = l.objectStore(L).delete(t);
|
|
2311
|
-
m.onsuccess = () => e(c.result.modelArtifactsInfo), m.onerror = (
|
|
2327
|
+
m.onsuccess = () => e(c.result.modelArtifactsInfo), m.onerror = (b) => s(c.error);
|
|
2312
2328
|
};
|
|
2313
2329
|
u.onsuccess = h, u.onerror = (f) => (h(), i.close(), s(c.error));
|
|
2314
2330
|
}
|
|
@@ -2335,17 +2351,17 @@ class Bn {
|
|
|
2335
2351
|
* limitations under the License.
|
|
2336
2352
|
* =============================================================================
|
|
2337
2353
|
*/
|
|
2338
|
-
const _ = "/",
|
|
2339
|
-
function
|
|
2354
|
+
const _ = "/", Y = "tensorflowjs_models", be = "info", vn = "model_topology", Mn = "weight_specs", Fn = "weight_data", $n = "model_metadata";
|
|
2355
|
+
function we(n) {
|
|
2340
2356
|
return {
|
|
2341
|
-
info: [
|
|
2342
|
-
topology: [
|
|
2343
|
-
weightSpecs: [
|
|
2344
|
-
weightData: [
|
|
2345
|
-
modelMetadata: [
|
|
2357
|
+
info: [Y, n, be].join(_),
|
|
2358
|
+
topology: [Y, n, vn].join(_),
|
|
2359
|
+
weightSpecs: [Y, n, Mn].join(_),
|
|
2360
|
+
weightData: [Y, n, Fn].join(_),
|
|
2361
|
+
modelMetadata: [Y, n, $n].join(_)
|
|
2346
2362
|
};
|
|
2347
2363
|
}
|
|
2348
|
-
function
|
|
2364
|
+
function Se(n) {
|
|
2349
2365
|
for (const t of Object.values(n))
|
|
2350
2366
|
window.localStorage.removeItem(t);
|
|
2351
2367
|
}
|
|
@@ -2364,7 +2380,7 @@ class W {
|
|
|
2364
2380
|
throw new Error("The current environment does not support local storage.");
|
|
2365
2381
|
if (this.LS = window.localStorage, t == null || !t)
|
|
2366
2382
|
throw new Error("For local storage, modelPath must not be null, undefined or empty.");
|
|
2367
|
-
this.modelPath = t, this.keys =
|
|
2383
|
+
this.modelPath = t, this.keys = we(this.modelPath);
|
|
2368
2384
|
}
|
|
2369
2385
|
/**
|
|
2370
2386
|
* Save model artifacts to browser local storage.
|
|
@@ -2379,7 +2395,7 @@ class W {
|
|
|
2379
2395
|
if (t.modelTopology instanceof ArrayBuffer)
|
|
2380
2396
|
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
|
|
2381
2397
|
{
|
|
2382
|
-
const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r =
|
|
2398
|
+
const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r = me(t), i = lt.join(t.weightData);
|
|
2383
2399
|
try {
|
|
2384
2400
|
this.LS.setItem(this.keys.info, JSON.stringify(r)), this.LS.setItem(this.keys.topology, e), this.LS.setItem(this.keys.weightSpecs, s), this.LS.setItem(this.keys.weightData, In(i));
|
|
2385
2401
|
const o = {
|
|
@@ -2394,7 +2410,7 @@ class W {
|
|
|
2394
2410
|
};
|
|
2395
2411
|
return this.LS.setItem(this.keys.modelMetadata, JSON.stringify(o)), { modelArtifactsInfo: r };
|
|
2396
2412
|
} catch {
|
|
2397
|
-
throw
|
|
2413
|
+
throw Se(this.keys), new Error(`Failed to save model '${this.modelPath}' to local storage: size quota being exceeded is a possible cause of this failure: modelTopologyBytes=${r.modelTopologyBytes}, weightSpecsBytes=${r.weightSpecsBytes}, weightDataBytes=${r.weightDataBytes}.`);
|
|
2398
2414
|
}
|
|
2399
2415
|
}
|
|
2400
2416
|
}
|
|
@@ -2428,22 +2444,22 @@ class W {
|
|
|
2428
2444
|
const o = this.LS.getItem(this.keys.weightData);
|
|
2429
2445
|
if (o == null)
|
|
2430
2446
|
throw new Error(`In local storage, the binary weight values of model '${this.modelPath}' are missing.`);
|
|
2431
|
-
return e.weightData =
|
|
2447
|
+
return e.weightData = Tn(o), e;
|
|
2432
2448
|
}
|
|
2433
2449
|
}
|
|
2434
2450
|
W.URL_SCHEME = "localstorage://";
|
|
2435
|
-
const
|
|
2436
|
-
|
|
2437
|
-
|
|
2438
|
-
function
|
|
2451
|
+
const ke = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(W.URL_SCHEME) ? Nn(n.slice(W.URL_SCHEME.length)) : null;
|
|
2452
|
+
A.registerSaveRouter(ke);
|
|
2453
|
+
A.registerLoadRouter(ke);
|
|
2454
|
+
function Nn(n) {
|
|
2439
2455
|
return new W(n);
|
|
2440
2456
|
}
|
|
2441
|
-
class
|
|
2457
|
+
class Dn {
|
|
2442
2458
|
constructor() {
|
|
2443
|
-
|
|
2459
|
+
y(S().getBool("IS_BROWSER"), () => "Current environment is not a web browser"), y(typeof window > "u" || typeof window.localStorage < "u", () => "Current browser does not appear to support localStorage"), this.LS = window.localStorage;
|
|
2444
2460
|
}
|
|
2445
2461
|
async listModels() {
|
|
2446
|
-
const t = {}, e =
|
|
2462
|
+
const t = {}, e = Y + _, s = _ + be;
|
|
2447
2463
|
for (let r = 0; r < this.LS.length; ++r) {
|
|
2448
2464
|
const i = this.LS.key(r);
|
|
2449
2465
|
if (i.startsWith(e) && i.endsWith(s)) {
|
|
@@ -2455,11 +2471,11 @@ class Nn {
|
|
|
2455
2471
|
}
|
|
2456
2472
|
async removeModel(t) {
|
|
2457
2473
|
t = xn(t);
|
|
2458
|
-
const e =
|
|
2474
|
+
const e = we(t);
|
|
2459
2475
|
if (this.LS.getItem(e.info) == null)
|
|
2460
2476
|
throw new Error(`Cannot find model at path '${t}'`);
|
|
2461
2477
|
const s = JSON.parse(this.LS.getItem(e.info));
|
|
2462
|
-
return
|
|
2478
|
+
return Se(e), s;
|
|
2463
2479
|
}
|
|
2464
2480
|
}
|
|
2465
2481
|
/**
|
|
@@ -2478,7 +2494,7 @@ class Nn {
|
|
|
2478
2494
|
* limitations under the License.
|
|
2479
2495
|
* =============================================================================
|
|
2480
2496
|
*/
|
|
2481
|
-
const
|
|
2497
|
+
const Jt = "://";
|
|
2482
2498
|
class N {
|
|
2483
2499
|
constructor() {
|
|
2484
2500
|
this.managers = {};
|
|
@@ -2493,9 +2509,9 @@ class N {
|
|
|
2493
2509
|
* of `IOHandler` with the `save` method defined or `null`.
|
|
2494
2510
|
*/
|
|
2495
2511
|
static registerManager(t, e) {
|
|
2496
|
-
|
|
2512
|
+
y(t != null, () => "scheme must not be undefined or null."), t.endsWith(Jt) && (t = t.slice(0, t.indexOf(Jt))), y(t.length > 0, () => "scheme must not be an empty string.");
|
|
2497
2513
|
const s = N.getInstance();
|
|
2498
|
-
|
|
2514
|
+
y(s.managers[t] == null, () => `A model store manager is already registered for scheme '${t}'.`), s.managers[t] = e;
|
|
2499
2515
|
}
|
|
2500
2516
|
static getManager(t) {
|
|
2501
2517
|
const e = N.getInstance().managers[t];
|
|
@@ -2523,7 +2539,7 @@ class N {
|
|
|
2523
2539
|
* limitations under the License.
|
|
2524
2540
|
* =============================================================================
|
|
2525
2541
|
*/
|
|
2526
|
-
class
|
|
2542
|
+
class Cn {
|
|
2527
2543
|
constructor() {
|
|
2528
2544
|
this.messageName = "setTimeoutCustom", this.functionRefs = [], this.handledMessageCount = 0, this.hasEventListener = !1;
|
|
2529
2545
|
}
|
|
@@ -2561,13 +2577,13 @@ class Dn {
|
|
|
2561
2577
|
}, !0));
|
|
2562
2578
|
}
|
|
2563
2579
|
isTypedArray(t) {
|
|
2564
|
-
return
|
|
2580
|
+
return ie(t);
|
|
2565
2581
|
}
|
|
2566
2582
|
}
|
|
2567
2583
|
if (S().get("IS_BROWSER")) {
|
|
2568
|
-
S().setPlatform("browser", new
|
|
2584
|
+
S().setPlatform("browser", new Cn());
|
|
2569
2585
|
try {
|
|
2570
|
-
N.registerManager(W.URL_SCHEME, new
|
|
2586
|
+
N.registerManager(W.URL_SCHEME, new Dn());
|
|
2571
2587
|
} catch {
|
|
2572
2588
|
}
|
|
2573
2589
|
try {
|
|
@@ -2575,20 +2591,20 @@ if (S().get("IS_BROWSER")) {
|
|
|
2575
2591
|
} catch {
|
|
2576
2592
|
}
|
|
2577
2593
|
}
|
|
2578
|
-
const
|
|
2594
|
+
const _n = {
|
|
2579
2595
|
// tslint:disable-next-line:no-require-imports
|
|
2580
2596
|
importFetch: () => require("node-fetch")
|
|
2581
2597
|
};
|
|
2582
|
-
let
|
|
2583
|
-
class
|
|
2598
|
+
let wt;
|
|
2599
|
+
class Pn {
|
|
2584
2600
|
constructor() {
|
|
2585
2601
|
this.util = require("util"), this.textEncoder = new this.util.TextEncoder();
|
|
2586
2602
|
}
|
|
2587
2603
|
fetch(t, e) {
|
|
2588
|
-
return S().global.fetch != null ? S().global.fetch(t, e) : (
|
|
2604
|
+
return S().global.fetch != null ? S().global.fetch(t, e) : (wt == null && (wt = _n.importFetch()), wt(t, e));
|
|
2589
2605
|
}
|
|
2590
2606
|
now() {
|
|
2591
|
-
const t =
|
|
2607
|
+
const t = Q.hrtime();
|
|
2592
2608
|
return t[0] * 1e3 + t[1] / 1e6;
|
|
2593
2609
|
}
|
|
2594
2610
|
encode(t, e) {
|
|
@@ -2603,7 +2619,7 @@ class _n {
|
|
|
2603
2619
|
return this.util.types.isFloat32Array(t) || this.util.types.isInt32Array(t) || this.util.types.isUint8Array(t) || this.util.types.isUint8ClampedArray(t);
|
|
2604
2620
|
}
|
|
2605
2621
|
}
|
|
2606
|
-
S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new
|
|
2622
|
+
S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new Pn());
|
|
2607
2623
|
/**
|
|
2608
2624
|
* @license
|
|
2609
2625
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
|
@@ -2620,8 +2636,8 @@ S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new _n()
|
|
|
2620
2636
|
* limitations under the License.
|
|
2621
2637
|
* =============================================================================
|
|
2622
2638
|
*/
|
|
2623
|
-
function
|
|
2624
|
-
return t = t || "float32",
|
|
2639
|
+
function On(n, t = "float32", e) {
|
|
2640
|
+
return t = t || "float32", Nt(n), new ln(n, t, e);
|
|
2625
2641
|
}
|
|
2626
2642
|
/**
|
|
2627
2643
|
* @license
|
|
@@ -2639,16 +2655,16 @@ function Pn(n, t = "float32", e) {
|
|
|
2639
2655
|
* limitations under the License.
|
|
2640
2656
|
* =============================================================================
|
|
2641
2657
|
*/
|
|
2642
|
-
function
|
|
2643
|
-
const e =
|
|
2644
|
-
if (!
|
|
2658
|
+
function Ln(n, t) {
|
|
2659
|
+
const e = I(n, "x", "cast");
|
|
2660
|
+
if (!Fe(t))
|
|
2645
2661
|
throw new Error(`Failed to cast to unknown dtype ${t}`);
|
|
2646
2662
|
if (t === "string" && e.dtype !== "string" || t !== "string" && e.dtype === "string")
|
|
2647
2663
|
throw new Error("Only strings can be casted to strings");
|
|
2648
2664
|
const s = { x: e }, r = { dtype: t };
|
|
2649
|
-
return g.runKernel(
|
|
2665
|
+
return g.runKernel(ne, s, r);
|
|
2650
2666
|
}
|
|
2651
|
-
const
|
|
2667
|
+
const Ft = /* @__PURE__ */ F({ cast_: Ln });
|
|
2652
2668
|
/**
|
|
2653
2669
|
* @license
|
|
2654
2670
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2665,11 +2681,11 @@ const Mt = /* @__PURE__ */ F({ cast_: On });
|
|
|
2665
2681
|
* limitations under the License.
|
|
2666
2682
|
* =============================================================================
|
|
2667
2683
|
*/
|
|
2668
|
-
function
|
|
2669
|
-
const e = { x:
|
|
2670
|
-
return g.runKernel(
|
|
2684
|
+
function Un(n) {
|
|
2685
|
+
const e = { x: I(n, "x", "clone", "string_or_numeric") };
|
|
2686
|
+
return g.runKernel(se, e);
|
|
2671
2687
|
}
|
|
2672
|
-
const
|
|
2688
|
+
const Gn = /* @__PURE__ */ F({ clone_: Un });
|
|
2673
2689
|
/**
|
|
2674
2690
|
* @license
|
|
2675
2691
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
|
@@ -2686,7 +2702,7 @@ const Un = /* @__PURE__ */ F({ clone_: Ln });
|
|
|
2686
2702
|
* limitations under the License.
|
|
2687
2703
|
* =============================================================================
|
|
2688
2704
|
*/
|
|
2689
|
-
function
|
|
2705
|
+
function zn(n, t = !1) {
|
|
2690
2706
|
console.log(n.toString(t));
|
|
2691
2707
|
}
|
|
2692
2708
|
/**
|
|
@@ -2705,14 +2721,14 @@ function Gn(n, t = !1) {
|
|
|
2705
2721
|
* limitations under the License.
|
|
2706
2722
|
* =============================================================================
|
|
2707
2723
|
*/
|
|
2708
|
-
|
|
2709
|
-
const
|
|
2710
|
-
buffer:
|
|
2711
|
-
cast:
|
|
2712
|
-
clone:
|
|
2713
|
-
print:
|
|
2724
|
+
de();
|
|
2725
|
+
const Wn = {
|
|
2726
|
+
buffer: On,
|
|
2727
|
+
cast: Ft,
|
|
2728
|
+
clone: Gn,
|
|
2729
|
+
print: zn
|
|
2714
2730
|
};
|
|
2715
|
-
|
|
2731
|
+
un(Wn);
|
|
2716
2732
|
/**
|
|
2717
2733
|
* @license
|
|
2718
2734
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2729,13 +2745,13 @@ cn(zn);
|
|
|
2729
2745
|
* limitations under the License.
|
|
2730
2746
|
* =============================================================================
|
|
2731
2747
|
*/
|
|
2732
|
-
function
|
|
2733
|
-
let e =
|
|
2748
|
+
function jn(n, t) {
|
|
2749
|
+
let e = I(n, "a", "add"), s = I(t, "b", "add");
|
|
2734
2750
|
[e, s] = K(e, s);
|
|
2735
2751
|
const r = { a: e, b: s };
|
|
2736
|
-
return g.runKernel(
|
|
2752
|
+
return g.runKernel(ee, r);
|
|
2737
2753
|
}
|
|
2738
|
-
const w = /* @__PURE__ */ F({ add_:
|
|
2754
|
+
const w = /* @__PURE__ */ F({ add_: jn });
|
|
2739
2755
|
/**
|
|
2740
2756
|
* @license
|
|
2741
2757
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2752,13 +2768,13 @@ const w = /* @__PURE__ */ F({ add_: Wn });
|
|
|
2752
2768
|
* limitations under the License.
|
|
2753
2769
|
* =============================================================================
|
|
2754
2770
|
*/
|
|
2755
|
-
function
|
|
2756
|
-
let e =
|
|
2771
|
+
function Kn(n, t) {
|
|
2772
|
+
let e = I(n, "a", "floorDiv"), s = I(t, "b", "floorDiv");
|
|
2757
2773
|
[e, s] = K(e, s);
|
|
2758
2774
|
const r = { a: e, b: s };
|
|
2759
|
-
return g.runKernel(
|
|
2775
|
+
return g.runKernel(je, r);
|
|
2760
2776
|
}
|
|
2761
|
-
const
|
|
2777
|
+
const Vn = /* @__PURE__ */ F({ floorDiv_: Kn });
|
|
2762
2778
|
/**
|
|
2763
2779
|
* @license
|
|
2764
2780
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2775,14 +2791,14 @@ const Kn = /* @__PURE__ */ F({ floorDiv_: jn });
|
|
|
2775
2791
|
* limitations under the License.
|
|
2776
2792
|
* =============================================================================
|
|
2777
2793
|
*/
|
|
2778
|
-
function
|
|
2779
|
-
let e =
|
|
2794
|
+
function qn(n, t) {
|
|
2795
|
+
let e = I(n, "a", "div"), s = I(t, "b", "div");
|
|
2780
2796
|
if ([e, s] = K(e, s), e.dtype === "int32" && s.dtype === "int32")
|
|
2781
|
-
return
|
|
2797
|
+
return Vn(e, s);
|
|
2782
2798
|
const r = { a: e, b: s }, i = {};
|
|
2783
|
-
return g.runKernel(
|
|
2799
|
+
return g.runKernel(ze, r, i);
|
|
2784
2800
|
}
|
|
2785
|
-
const D = /* @__PURE__ */ F({ div_:
|
|
2801
|
+
const D = /* @__PURE__ */ F({ div_: qn });
|
|
2786
2802
|
/**
|
|
2787
2803
|
* @license
|
|
2788
2804
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2799,13 +2815,13 @@ const D = /* @__PURE__ */ F({ div_: Vn });
|
|
|
2799
2815
|
* limitations under the License.
|
|
2800
2816
|
* =============================================================================
|
|
2801
2817
|
*/
|
|
2802
|
-
function
|
|
2803
|
-
let e =
|
|
2818
|
+
function Hn(n, t) {
|
|
2819
|
+
let e = I(n, "a", "mul"), s = I(t, "b", "mul");
|
|
2804
2820
|
[e, s] = K(e, s);
|
|
2805
2821
|
const r = { a: e, b: s };
|
|
2806
|
-
return g.runKernel(
|
|
2822
|
+
return g.runKernel(Ve, r);
|
|
2807
2823
|
}
|
|
2808
|
-
const p = /* @__PURE__ */ F({ mul_:
|
|
2824
|
+
const p = /* @__PURE__ */ F({ mul_: Hn });
|
|
2809
2825
|
/**
|
|
2810
2826
|
* @license
|
|
2811
2827
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -2822,17 +2838,17 @@ const p = /* @__PURE__ */ F({ mul_: qn });
|
|
|
2822
2838
|
* limitations under the License.
|
|
2823
2839
|
* =============================================================================
|
|
2824
2840
|
*/
|
|
2825
|
-
function
|
|
2826
|
-
const t =
|
|
2841
|
+
function Jn(n) {
|
|
2842
|
+
const t = I(n, "x", "abs");
|
|
2827
2843
|
if (t.dtype === "complex64") {
|
|
2828
2844
|
const e = { x: t };
|
|
2829
|
-
return g.runKernel(
|
|
2845
|
+
return g.runKernel(Ge, e);
|
|
2830
2846
|
} else {
|
|
2831
2847
|
const e = { x: t };
|
|
2832
|
-
return g.runKernel(
|
|
2848
|
+
return g.runKernel(Ue, e);
|
|
2833
2849
|
}
|
|
2834
2850
|
}
|
|
2835
|
-
const
|
|
2851
|
+
const Xn = /* @__PURE__ */ F({ abs_: Jn });
|
|
2836
2852
|
/**
|
|
2837
2853
|
* @license
|
|
2838
2854
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2849,10 +2865,10 @@ const Jn = /* @__PURE__ */ F({ abs_: Hn });
|
|
|
2849
2865
|
* limitations under the License.
|
|
2850
2866
|
* =============================================================================
|
|
2851
2867
|
*/
|
|
2852
|
-
function
|
|
2853
|
-
|
|
2868
|
+
function Yn(n, t, e) {
|
|
2869
|
+
Nt(n), e = e || mt(t);
|
|
2854
2870
|
const s = { shape: n, value: t, dtype: e };
|
|
2855
|
-
return g.runKernel(
|
|
2871
|
+
return g.runKernel(We, {}, s);
|
|
2856
2872
|
}
|
|
2857
2873
|
/**
|
|
2858
2874
|
* @license
|
|
@@ -2870,7 +2886,7 @@ function Xn(n, t, e) {
|
|
|
2870
2886
|
* limitations under the License.
|
|
2871
2887
|
* =============================================================================
|
|
2872
2888
|
*/
|
|
2873
|
-
function
|
|
2889
|
+
function Ys(n, t) {
|
|
2874
2890
|
const e = [];
|
|
2875
2891
|
for (let s = 0; s < t.length; s++) {
|
|
2876
2892
|
const r = n[n.length - s - 1], i = t.length - s - 1, o = t[i];
|
|
@@ -2878,7 +2894,7 @@ function Ls(n, t) {
|
|
|
2878
2894
|
}
|
|
2879
2895
|
return e;
|
|
2880
2896
|
}
|
|
2881
|
-
function
|
|
2897
|
+
function Qn(n, t) {
|
|
2882
2898
|
const e = Math.max(n.length, t.length), s = new Array(e);
|
|
2883
2899
|
for (let r = 0; r < e; r++) {
|
|
2884
2900
|
let i = n[n.length - r - 1];
|
|
@@ -2912,11 +2928,11 @@ function Yn(n, t) {
|
|
|
2912
2928
|
* limitations under the License.
|
|
2913
2929
|
* =============================================================================
|
|
2914
2930
|
*/
|
|
2915
|
-
function
|
|
2916
|
-
const e = { x:
|
|
2917
|
-
return g.runKernel(
|
|
2931
|
+
function Zn(n) {
|
|
2932
|
+
const e = { x: I(n, "x", "zerosLike") };
|
|
2933
|
+
return g.runKernel(Xe, e);
|
|
2918
2934
|
}
|
|
2919
|
-
const C = /* @__PURE__ */ F({ zerosLike_:
|
|
2935
|
+
const C = /* @__PURE__ */ F({ zerosLike_: Zn });
|
|
2920
2936
|
/**
|
|
2921
2937
|
* @license
|
|
2922
2938
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2933,13 +2949,13 @@ const C = /* @__PURE__ */ F({ zerosLike_: Qn });
|
|
|
2933
2949
|
* limitations under the License.
|
|
2934
2950
|
* =============================================================================
|
|
2935
2951
|
*/
|
|
2936
|
-
function
|
|
2937
|
-
let e =
|
|
2952
|
+
function ts(n, t) {
|
|
2953
|
+
let e = I(n, "base", "pow"), s = I(t, "exp", "pow");
|
|
2938
2954
|
[e, s] = K(e, s);
|
|
2939
2955
|
const r = { a: e, b: s };
|
|
2940
|
-
return g.runKernel(
|
|
2956
|
+
return g.runKernel(qe, r);
|
|
2941
2957
|
}
|
|
2942
|
-
const
|
|
2958
|
+
const Xt = /* @__PURE__ */ F({ pow_: ts });
|
|
2943
2959
|
/**
|
|
2944
2960
|
* @license
|
|
2945
2961
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -2957,11 +2973,11 @@ const Jt = /* @__PURE__ */ F({ pow_: Zn });
|
|
|
2957
2973
|
* =============================================================================
|
|
2958
2974
|
*/
|
|
2959
2975
|
function j(n, t) {
|
|
2960
|
-
if ((
|
|
2976
|
+
if (($(n) && t !== "string" || Array.isArray(n)) && t !== "complex64")
|
|
2961
2977
|
throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");
|
|
2962
|
-
if (t === "string" &&
|
|
2978
|
+
if (t === "string" && $(n) && !(n instanceof Uint8Array))
|
|
2963
2979
|
throw new Error("When making a scalar from encoded string, the value must be `Uint8Array`.");
|
|
2964
|
-
return
|
|
2980
|
+
return wn(n, [], [], t);
|
|
2965
2981
|
}
|
|
2966
2982
|
/**
|
|
2967
2983
|
* @license
|
|
@@ -2979,11 +2995,11 @@ function j(n, t) {
|
|
|
2979
2995
|
* limitations under the License.
|
|
2980
2996
|
* =============================================================================
|
|
2981
2997
|
*/
|
|
2982
|
-
function
|
|
2983
|
-
const e = { x:
|
|
2984
|
-
return g.runKernel(
|
|
2998
|
+
function es(n) {
|
|
2999
|
+
const e = { x: I(n, "x", "sqrt", "float32") };
|
|
3000
|
+
return g.runKernel(He, e);
|
|
2985
3001
|
}
|
|
2986
|
-
const
|
|
3002
|
+
const et = /* @__PURE__ */ F({ sqrt_: es });
|
|
2987
3003
|
/**
|
|
2988
3004
|
* @license
|
|
2989
3005
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
@@ -3000,11 +3016,11 @@ const tt = /* @__PURE__ */ F({ sqrt_: ts });
|
|
|
3000
3016
|
* limitations under the License.
|
|
3001
3017
|
* =============================================================================
|
|
3002
3018
|
*/
|
|
3003
|
-
function
|
|
3004
|
-
const t =
|
|
3019
|
+
function ns(n) {
|
|
3020
|
+
const t = I(n, "x", "square"), e = {};
|
|
3005
3021
|
return g.runKernel("Square", { x: t }, e);
|
|
3006
3022
|
}
|
|
3007
|
-
const G = /* @__PURE__ */ F({ square_:
|
|
3023
|
+
const G = /* @__PURE__ */ F({ square_: ns });
|
|
3008
3024
|
/**
|
|
3009
3025
|
* @license
|
|
3010
3026
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -3021,8 +3037,8 @@ const G = /* @__PURE__ */ F({ square_: es });
|
|
|
3021
3037
|
* limitations under the License.
|
|
3022
3038
|
* =============================================================================
|
|
3023
3039
|
*/
|
|
3024
|
-
function
|
|
3025
|
-
|
|
3040
|
+
function ss(n, t) {
|
|
3041
|
+
y(kt(n), () => "The f passed in variableGrads(f) must be a function"), y(t == null || Array.isArray(t) && t.every((l) => l instanceof dt), () => "The varList passed in variableGrads(f, varList) must be an array of variables");
|
|
3026
3042
|
const e = t != null;
|
|
3027
3043
|
if (!e) {
|
|
3028
3044
|
t = [];
|
|
@@ -3030,15 +3046,15 @@ function ns(n, t) {
|
|
|
3030
3046
|
t.push(g.registeredVariables[l]);
|
|
3031
3047
|
}
|
|
3032
3048
|
const s = e ? t.filter((l) => !l.trainable) : null, r = t.length;
|
|
3033
|
-
t = t.filter((l) => l.trainable),
|
|
3049
|
+
t = t.filter((l) => l.trainable), y(t.length > 0, () => `variableGrads() expects at least one of the input variables to be trainable, but none of the ${r} variables is trainable.`);
|
|
3034
3050
|
const i = !0, { value: o, grads: a } = g.gradients(n, t, null, i);
|
|
3035
|
-
|
|
3051
|
+
y(a.some((l) => l != null), () => "Cannot find a connection between any variable and the result of the loss function y=f(x). Please make sure the operations that use variables are inside the function f passed to minimize()."), y(o.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it returned a rank-${o.rank} tensor`);
|
|
3036
3052
|
const c = {};
|
|
3037
3053
|
return t.forEach((l, u) => {
|
|
3038
3054
|
a[u] != null && (c[l.name] = a[u]);
|
|
3039
3055
|
}), s?.forEach((l) => c[l.name] = null), { value: o, grads: c };
|
|
3040
3056
|
}
|
|
3041
|
-
function
|
|
3057
|
+
function Qs(n) {
|
|
3042
3058
|
return g.customGrad(n);
|
|
3043
3059
|
}
|
|
3044
3060
|
/**
|
|
@@ -3057,13 +3073,13 @@ function Us(n) {
|
|
|
3057
3073
|
* limitations under the License.
|
|
3058
3074
|
* =============================================================================
|
|
3059
3075
|
*/
|
|
3060
|
-
function
|
|
3061
|
-
let e =
|
|
3076
|
+
function rs(n, t) {
|
|
3077
|
+
let e = I(n, "a", "sub"), s = I(t, "b", "sub");
|
|
3062
3078
|
[e, s] = K(e, s);
|
|
3063
3079
|
const r = { a: e, b: s };
|
|
3064
|
-
return g.runKernel(
|
|
3080
|
+
return g.runKernel(Je, r);
|
|
3065
3081
|
}
|
|
3066
|
-
const
|
|
3082
|
+
const Z = /* @__PURE__ */ F({ sub_: rs });
|
|
3067
3083
|
/**
|
|
3068
3084
|
* @license
|
|
3069
3085
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3080,13 +3096,13 @@ const Q = /* @__PURE__ */ F({ sub_: ss });
|
|
|
3080
3096
|
* limitations under the License.
|
|
3081
3097
|
* =============================================================================
|
|
3082
3098
|
*/
|
|
3083
|
-
function
|
|
3084
|
-
let e =
|
|
3085
|
-
[e, s] = K(e, s), e.dtype === "bool" && (e =
|
|
3099
|
+
function is(n, t) {
|
|
3100
|
+
let e = I(n, "a", "maximum"), s = I(t, "b", "maximum");
|
|
3101
|
+
[e, s] = K(e, s), e.dtype === "bool" && (e = Ft(e, "int32"), s = Ft(s, "int32")), Qn(e.shape, s.shape);
|
|
3086
3102
|
const r = { a: e, b: s };
|
|
3087
|
-
return g.runKernel(
|
|
3103
|
+
return g.runKernel(Ke, r);
|
|
3088
3104
|
}
|
|
3089
|
-
const
|
|
3105
|
+
const os = /* @__PURE__ */ F({ maximum_: is });
|
|
3090
3106
|
/**
|
|
3091
3107
|
* @license
|
|
3092
3108
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -3103,8 +3119,8 @@ const is = /* @__PURE__ */ F({ maximum_: rs });
|
|
|
3103
3119
|
* limitations under the License.
|
|
3104
3120
|
* =============================================================================
|
|
3105
3121
|
*/
|
|
3106
|
-
const
|
|
3107
|
-
class
|
|
3122
|
+
const as = /* @__PURE__ */ new Map(), ls = /* @__PURE__ */ new Map();
|
|
3123
|
+
class cs {
|
|
3108
3124
|
/**
|
|
3109
3125
|
* Return the class name for this class to use in serialization contexts.
|
|
3110
3126
|
*
|
|
@@ -3149,10 +3165,10 @@ class O {
|
|
|
3149
3165
|
O.getMap().classNameMap[t.className] = [t, t.fromConfig];
|
|
3150
3166
|
}
|
|
3151
3167
|
}
|
|
3152
|
-
function
|
|
3153
|
-
|
|
3168
|
+
function us(n, t, e) {
|
|
3169
|
+
y(n.className != null, () => "Class being registered does not have the static className property defined."), y(typeof n.className == "string", () => "className is required to be a string, but got type " + typeof n.className), y(n.className.length > 0, () => "Class being registered has an empty-string as its className, which is disallowed."), typeof t > "u" && (t = "Custom"), typeof e > "u" && (e = n.className);
|
|
3154
3170
|
const s = e, r = t + ">" + s;
|
|
3155
|
-
return O.register(n),
|
|
3171
|
+
return O.register(n), as.set(r, n), ls.set(n, r), n;
|
|
3156
3172
|
}
|
|
3157
3173
|
/**
|
|
3158
3174
|
* @license
|
|
@@ -3170,7 +3186,7 @@ function cs(n, t, e) {
|
|
|
3170
3186
|
* limitations under the License.
|
|
3171
3187
|
* =============================================================================
|
|
3172
3188
|
*/
|
|
3173
|
-
class V extends
|
|
3189
|
+
class V extends cs {
|
|
3174
3190
|
/**
|
|
3175
3191
|
* Executes `f()` and minimizes the scalar output of `f()` by computing
|
|
3176
3192
|
* gradients of y with respect to the list of trainable variables provided by
|
|
@@ -3217,7 +3233,7 @@ class V extends ls {
|
|
|
3217
3233
|
* @doc {heading: 'Training', subheading: 'Optimizers'}
|
|
3218
3234
|
*/
|
|
3219
3235
|
computeGradients(t, e) {
|
|
3220
|
-
return
|
|
3236
|
+
return ss(t, e);
|
|
3221
3237
|
}
|
|
3222
3238
|
/**
|
|
3223
3239
|
* Dispose the variables (if any) owned by this optimizer instance.
|
|
@@ -3268,7 +3284,7 @@ Object.defineProperty(V, Symbol.hasInstance, {
|
|
|
3268
3284
|
* limitations under the License.
|
|
3269
3285
|
* =============================================================================
|
|
3270
3286
|
*/
|
|
3271
|
-
class
|
|
3287
|
+
class hs extends V {
|
|
3272
3288
|
/** @nocollapse */
|
|
3273
3289
|
static get className() {
|
|
3274
3290
|
return "Adadelta";
|
|
@@ -3291,7 +3307,7 @@ class us extends V {
|
|
|
3291
3307
|
return;
|
|
3292
3308
|
const c = this.accumulatedGrads[r].variable, l = this.accumulatedUpdates[r].variable;
|
|
3293
3309
|
E(() => {
|
|
3294
|
-
const u = w(p(c, this.rho), p(G(a), 1 - this.rho)), h = p(D(
|
|
3310
|
+
const u = w(p(c, this.rho), p(G(a), 1 - this.rho)), h = p(D(et(w(l, this.epsilon)), et(w(c, this.epsilon))), a), f = w(p(l, this.rho), p(G(h), 1 - this.rho));
|
|
3295
3311
|
c.assign(u), l.assign(f);
|
|
3296
3312
|
const m = w(p(h, -this.learningRate), i);
|
|
3297
3313
|
i.assign(m);
|
|
@@ -3344,7 +3360,7 @@ class us extends V {
|
|
|
3344
3360
|
* limitations under the License.
|
|
3345
3361
|
* =============================================================================
|
|
3346
3362
|
*/
|
|
3347
|
-
class
|
|
3363
|
+
class fs extends V {
|
|
3348
3364
|
/** @nocollapse */
|
|
3349
3365
|
static get className() {
|
|
3350
3366
|
return "Adagrad";
|
|
@@ -3357,7 +3373,7 @@ class hs extends V {
|
|
|
3357
3373
|
const i = g.registeredVariables[s];
|
|
3358
3374
|
this.accumulatedGrads[r] == null && (this.accumulatedGrads[r] = {
|
|
3359
3375
|
originalName: `${s}/accumulator`,
|
|
3360
|
-
variable: E(() =>
|
|
3376
|
+
variable: E(() => Yn(i.shape, this.initialAccumulatorValue).variable(!1))
|
|
3361
3377
|
});
|
|
3362
3378
|
const o = Array.isArray(t) ? t[r].tensor : t[s];
|
|
3363
3379
|
if (o == null)
|
|
@@ -3366,7 +3382,7 @@ class hs extends V {
|
|
|
3366
3382
|
E(() => {
|
|
3367
3383
|
const c = w(a, G(o));
|
|
3368
3384
|
a.assign(c);
|
|
3369
|
-
const l = w(p(D(o,
|
|
3385
|
+
const l = w(p(D(o, et(w(c, g.backend.epsilon()))), -this.learningRate), i);
|
|
3370
3386
|
i.assign(l);
|
|
3371
3387
|
});
|
|
3372
3388
|
}), this.incrementIterations();
|
|
@@ -3409,7 +3425,7 @@ class hs extends V {
|
|
|
3409
3425
|
* limitations under the License.
|
|
3410
3426
|
* =============================================================================
|
|
3411
3427
|
*/
|
|
3412
|
-
class
|
|
3428
|
+
class ds extends V {
|
|
3413
3429
|
/** @nocollapse */
|
|
3414
3430
|
static get className() {
|
|
3415
3431
|
return "Adam";
|
|
@@ -3422,7 +3438,7 @@ class fs extends V {
|
|
|
3422
3438
|
applyGradients(t) {
|
|
3423
3439
|
const e = Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t);
|
|
3424
3440
|
E(() => {
|
|
3425
|
-
const s =
|
|
3441
|
+
const s = Z(1, this.accBeta1), r = Z(1, this.accBeta2);
|
|
3426
3442
|
e.forEach((i, o) => {
|
|
3427
3443
|
const a = g.registeredVariables[i], c = !1;
|
|
3428
3444
|
this.accumulatedFirstMoment[o] == null && (this.accumulatedFirstMoment[o] = {
|
|
@@ -3435,10 +3451,10 @@ class fs extends V {
|
|
|
3435
3451
|
const l = Array.isArray(t) ? t[o].tensor : t[i];
|
|
3436
3452
|
if (l == null)
|
|
3437
3453
|
return;
|
|
3438
|
-
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedSecondMoment[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = w(p(h, this.beta2), p(G(l), 1 - this.beta2)),
|
|
3454
|
+
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedSecondMoment[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = w(p(h, this.beta2), p(G(l), 1 - this.beta2)), b = D(f, s), d = D(m, r);
|
|
3439
3455
|
u.assign(f), h.assign(m);
|
|
3440
|
-
const
|
|
3441
|
-
a.assign(
|
|
3456
|
+
const k = w(p(D(b, w(et(d), this.epsilon)), -this.learningRate), a);
|
|
3457
|
+
a.assign(k);
|
|
3442
3458
|
}), this.accBeta1.assign(p(this.accBeta1, this.beta1)), this.accBeta2.assign(p(this.accBeta2, this.beta2));
|
|
3443
3459
|
}), this.incrementIterations();
|
|
3444
3460
|
}
|
|
@@ -3451,7 +3467,7 @@ class fs extends V {
|
|
|
3451
3467
|
}
|
|
3452
3468
|
async setWeights(t) {
|
|
3453
3469
|
t = await this.extractIterations(t), E(() => {
|
|
3454
|
-
this.accBeta1.assign(
|
|
3470
|
+
this.accBeta1.assign(Xt(this.beta1, this.iterations_ + 1)), this.accBeta2.assign(Xt(this.beta2, this.iterations_ + 1));
|
|
3455
3471
|
});
|
|
3456
3472
|
const e = t.length / 2, s = !1;
|
|
3457
3473
|
this.accumulatedFirstMoment = t.slice(0, e).map((r) => ({
|
|
@@ -3491,7 +3507,7 @@ class fs extends V {
|
|
|
3491
3507
|
* limitations under the License.
|
|
3492
3508
|
* =============================================================================
|
|
3493
3509
|
*/
|
|
3494
|
-
class
|
|
3510
|
+
class gs extends V {
|
|
3495
3511
|
/** @nocollapse */
|
|
3496
3512
|
static get className() {
|
|
3497
3513
|
return "Adamax";
|
|
@@ -3504,7 +3520,7 @@ class ds extends V {
|
|
|
3504
3520
|
applyGradients(t) {
|
|
3505
3521
|
const e = Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t);
|
|
3506
3522
|
E(() => {
|
|
3507
|
-
const s =
|
|
3523
|
+
const s = Z(1, this.accBeta1), r = D(-this.learningRate, w(p(this.iteration, this.decay), 1));
|
|
3508
3524
|
e.forEach((i, o) => {
|
|
3509
3525
|
const a = g.registeredVariables[i], c = !1;
|
|
3510
3526
|
this.accumulatedFirstMoment[o] == null && (this.accumulatedFirstMoment[o] = {
|
|
@@ -3517,10 +3533,10 @@ class ds extends V {
|
|
|
3517
3533
|
const l = Array.isArray(t) ? t[o].tensor : t[i];
|
|
3518
3534
|
if (l == null)
|
|
3519
3535
|
return;
|
|
3520
|
-
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedWeightedInfNorm[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = p(h, this.beta2),
|
|
3536
|
+
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedWeightedInfNorm[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = p(h, this.beta2), b = Xn(l), d = os(m, b);
|
|
3521
3537
|
u.assign(f), h.assign(d);
|
|
3522
|
-
const
|
|
3523
|
-
a.assign(
|
|
3538
|
+
const k = w(p(D(r, s), D(f, w(d, this.epsilon))), a);
|
|
3539
|
+
a.assign(k);
|
|
3524
3540
|
}), this.iteration.assign(w(this.iteration, 1)), this.accBeta1.assign(p(this.accBeta1, this.beta1));
|
|
3525
3541
|
}), this.incrementIterations();
|
|
3526
3542
|
}
|
|
@@ -3587,7 +3603,7 @@ class Ie extends V {
|
|
|
3587
3603
|
* Sets the learning rate of the optimizer.
|
|
3588
3604
|
*/
|
|
3589
3605
|
setLearningRate(t) {
|
|
3590
|
-
this.learningRate = t, this.c != null && this.c.dispose(), this.c =
|
|
3606
|
+
this.learningRate = t, this.c != null && this.c.dispose(), this.c = kn(j(-t));
|
|
3591
3607
|
}
|
|
3592
3608
|
dispose() {
|
|
3593
3609
|
this.c.dispose();
|
|
@@ -3623,7 +3639,7 @@ class Ie extends V {
|
|
|
3623
3639
|
* limitations under the License.
|
|
3624
3640
|
* =============================================================================
|
|
3625
3641
|
*/
|
|
3626
|
-
class
|
|
3642
|
+
class ms extends Ie {
|
|
3627
3643
|
/** @nocollapse */
|
|
3628
3644
|
// Name matters for Python compatibility.
|
|
3629
3645
|
static get className() {
|
|
@@ -3694,7 +3710,7 @@ class gs extends Ie {
|
|
|
3694
3710
|
* limitations under the License.
|
|
3695
3711
|
* =============================================================================
|
|
3696
3712
|
*/
|
|
3697
|
-
class
|
|
3713
|
+
class ps extends V {
|
|
3698
3714
|
/** @nocollapse */
|
|
3699
3715
|
static get className() {
|
|
3700
3716
|
return "RMSProp";
|
|
@@ -3723,14 +3739,14 @@ class ms extends V {
|
|
|
3723
3739
|
E(() => {
|
|
3724
3740
|
const u = w(p(c, this.decay), p(G(a), 1 - this.decay));
|
|
3725
3741
|
if (this.centered) {
|
|
3726
|
-
const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate),
|
|
3727
|
-
c.assign(u), h.assign(f), l.assign(
|
|
3728
|
-
const d =
|
|
3742
|
+
const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate), et(Z(u, w(G(f), this.epsilon)))), b = w(p(l, this.momentum), m);
|
|
3743
|
+
c.assign(u), h.assign(f), l.assign(b);
|
|
3744
|
+
const d = Z(i, b);
|
|
3729
3745
|
i.assign(d);
|
|
3730
3746
|
} else {
|
|
3731
|
-
const h = w(p(c, this.decay), p(G(a), 1 - this.decay)), f = w(p(l, this.momentum), D(p(a, this.learningRate),
|
|
3747
|
+
const h = w(p(c, this.decay), p(G(a), 1 - this.decay)), f = w(p(l, this.momentum), D(p(a, this.learningRate), et(w(h, this.epsilon))));
|
|
3732
3748
|
c.assign(h), l.assign(f);
|
|
3733
|
-
const m =
|
|
3749
|
+
const m = Z(i, f);
|
|
3734
3750
|
i.assign(m);
|
|
3735
3751
|
}
|
|
3736
3752
|
});
|
|
@@ -3787,18 +3803,18 @@ class ms extends V {
|
|
|
3787
3803
|
* limitations under the License.
|
|
3788
3804
|
* =============================================================================
|
|
3789
3805
|
*/
|
|
3790
|
-
const
|
|
3791
|
-
us,
|
|
3806
|
+
const ys = [
|
|
3792
3807
|
hs,
|
|
3793
3808
|
fs,
|
|
3794
3809
|
ds,
|
|
3795
3810
|
gs,
|
|
3796
3811
|
ms,
|
|
3812
|
+
ps,
|
|
3797
3813
|
Ie
|
|
3798
3814
|
];
|
|
3799
|
-
function
|
|
3800
|
-
for (const n of
|
|
3801
|
-
|
|
3815
|
+
function bs() {
|
|
3816
|
+
for (const n of ys)
|
|
3817
|
+
us(n);
|
|
3802
3818
|
}
|
|
3803
3819
|
/**
|
|
3804
3820
|
* @license
|
|
@@ -3816,40 +3832,55 @@ function ys() {
|
|
|
3816
3832
|
* limitations under the License.
|
|
3817
3833
|
* =============================================================================
|
|
3818
3834
|
*/
|
|
3819
|
-
|
|
3835
|
+
bs();
|
|
3820
3836
|
export {
|
|
3821
|
-
|
|
3822
|
-
|
|
3823
|
-
|
|
3837
|
+
ds as A,
|
|
3838
|
+
Es as B,
|
|
3839
|
+
As as C,
|
|
3840
|
+
Bs as D,
|
|
3824
3841
|
g as E,
|
|
3825
|
-
|
|
3826
|
-
|
|
3827
|
-
|
|
3828
|
-
|
|
3829
|
-
|
|
3830
|
-
|
|
3831
|
-
|
|
3832
|
-
|
|
3833
|
-
|
|
3842
|
+
$s as F,
|
|
3843
|
+
Ms as G,
|
|
3844
|
+
Cs as H,
|
|
3845
|
+
Fs as I,
|
|
3846
|
+
Ps as J,
|
|
3847
|
+
Os as K,
|
|
3848
|
+
Rs as L,
|
|
3849
|
+
xs as M,
|
|
3850
|
+
Ns as N,
|
|
3851
|
+
Us as O,
|
|
3852
|
+
Ds as P,
|
|
3853
|
+
Vs as Q,
|
|
3854
|
+
_s as R,
|
|
3855
|
+
Ws as S,
|
|
3856
|
+
Ks as T,
|
|
3857
|
+
Ys as U,
|
|
3858
|
+
Qn as V,
|
|
3859
|
+
qs as _,
|
|
3860
|
+
Z as a,
|
|
3834
3861
|
Is as b,
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3838
|
-
|
|
3839
|
-
|
|
3840
|
-
|
|
3841
|
-
|
|
3842
|
-
|
|
3843
|
-
|
|
3844
|
-
|
|
3862
|
+
I as c,
|
|
3863
|
+
Js as d,
|
|
3864
|
+
Xs as e,
|
|
3865
|
+
y as f,
|
|
3866
|
+
Ls as g,
|
|
3867
|
+
Ft as h,
|
|
3868
|
+
Nt as i,
|
|
3869
|
+
Qt as j,
|
|
3870
|
+
U as k,
|
|
3871
|
+
Ne as l,
|
|
3845
3872
|
p as m,
|
|
3846
|
-
|
|
3873
|
+
Gs as n,
|
|
3847
3874
|
F as o,
|
|
3848
|
-
|
|
3849
|
-
|
|
3850
|
-
|
|
3875
|
+
vs as p,
|
|
3876
|
+
Ts as q,
|
|
3877
|
+
Hs as r,
|
|
3851
3878
|
j as s,
|
|
3852
|
-
|
|
3853
|
-
|
|
3854
|
-
|
|
3879
|
+
w as t,
|
|
3880
|
+
js as u,
|
|
3881
|
+
Qs as v,
|
|
3882
|
+
E as w,
|
|
3883
|
+
K as x,
|
|
3884
|
+
zs as y,
|
|
3885
|
+
C as z
|
|
3855
3886
|
};
|