@genai-fi/nanogpt 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/TeachableLLM.js +6 -4
- package/dist/data/docx.d.ts +1 -0
- package/dist/data/docx.js +15 -0
- package/dist/data/parquet.js +7 -5
- package/dist/data/pdf.d.ts +1 -0
- package/dist/data/pdf.js +14 -0
- package/dist/data/textLoader.js +24 -14
- package/dist/index-D5v913EJ.js +4 -0
- package/dist/{index-B8nyc6IR.js → index-DcaSvB38.js} +462 -506
- package/dist/index-Tf7vU29b.js +1023 -0
- package/dist/index-xuotMAFm.js +118 -0
- package/dist/{jszip.min-pMIn3RZH.js → jszip.min-CjP2V1VV.js} +42 -52
- package/dist/layers/TiedEmbedding.js +1 -1
- package/dist/parquet-BRl5lE_I.js +44956 -0
- package/dist/pdf-kJD-f258.js +19481 -0
- package/dist/training/AdamExt.js +1 -1
- package/dist/utilities/load.js +5 -5
- package/dist/utilities/save.js +1 -1
- package/package.json +4 -2
- package/dist/parquet-DpcqBLb0.js +0 -39727
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import { g as _t } from "./index-D5v913EJ.js";
|
|
2
|
+
import { p as Y } from "./index-xuotMAFm.js";
|
|
3
|
+
import { B as dt } from "./index-Tf7vU29b.js";
|
|
1
4
|
/**
|
|
2
5
|
* @license
|
|
3
6
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -14,7 +17,8 @@
|
|
|
14
17
|
* limitations under the License.
|
|
15
18
|
* =============================================================================
|
|
16
19
|
*/
|
|
17
|
-
|
|
20
|
+
const ke = 1e-7, Te = 1e-4;
|
|
21
|
+
class Ee {
|
|
18
22
|
refCount(t) {
|
|
19
23
|
return v("refCount");
|
|
20
24
|
}
|
|
@@ -60,7 +64,7 @@ class we {
|
|
|
60
64
|
}
|
|
61
65
|
/** Returns the smallest representable number. */
|
|
62
66
|
epsilon() {
|
|
63
|
-
return this.floatPrecision() === 32 ?
|
|
67
|
+
return this.floatPrecision() === 32 ? ke : Te;
|
|
64
68
|
}
|
|
65
69
|
dispose() {
|
|
66
70
|
return v("dispose");
|
|
@@ -89,8 +93,8 @@ function b(n, t) {
|
|
|
89
93
|
if (!n)
|
|
90
94
|
throw new Error(typeof t == "string" ? t : t());
|
|
91
95
|
}
|
|
92
|
-
function
|
|
93
|
-
b(
|
|
96
|
+
function Is(n, t, e = "") {
|
|
97
|
+
b(Ft(n, t), () => e + ` Shapes ${n} and ${t} must match`);
|
|
94
98
|
}
|
|
95
99
|
function U(n) {
|
|
96
100
|
if (n.length === 0)
|
|
@@ -100,7 +104,7 @@ function U(n) {
|
|
|
100
104
|
t *= n[e];
|
|
101
105
|
return t;
|
|
102
106
|
}
|
|
103
|
-
function
|
|
107
|
+
function Ft(n, t) {
|
|
104
108
|
if (n === t)
|
|
105
109
|
return !0;
|
|
106
110
|
if (n == null || t == null || n.length !== t.length)
|
|
@@ -110,10 +114,10 @@ function vt(n, t) {
|
|
|
110
114
|
return !1;
|
|
111
115
|
return !0;
|
|
112
116
|
}
|
|
113
|
-
function
|
|
117
|
+
function ct(n, t) {
|
|
114
118
|
return t <= n.length ? n : n + " ".repeat(t - n.length);
|
|
115
119
|
}
|
|
116
|
-
function
|
|
120
|
+
function Be(n, t) {
|
|
117
121
|
let e = null;
|
|
118
122
|
if (n == null || n === "float32")
|
|
119
123
|
e = new Float32Array(t);
|
|
@@ -127,17 +131,17 @@ function Se(n, t) {
|
|
|
127
131
|
throw new Error(`Unknown data type ${n}`);
|
|
128
132
|
return e;
|
|
129
133
|
}
|
|
130
|
-
function
|
|
134
|
+
function Ae(n, t) {
|
|
131
135
|
for (let e = 0; e < n.length; e++) {
|
|
132
136
|
const s = n[e];
|
|
133
137
|
if (isNaN(s) || !isFinite(s))
|
|
134
138
|
throw Error(`A tensor of type ${t} being uploaded contains ${s}.`);
|
|
135
139
|
}
|
|
136
140
|
}
|
|
137
|
-
function
|
|
141
|
+
function ve(n) {
|
|
138
142
|
return n === "bool" || n === "complex64" || n === "float32" || n === "int32" || n === "string";
|
|
139
143
|
}
|
|
140
|
-
function
|
|
144
|
+
function wt(n) {
|
|
141
145
|
if (n === "float32" || n === "int32")
|
|
142
146
|
return 4;
|
|
143
147
|
if (n === "complex64")
|
|
@@ -146,28 +150,28 @@ function yt(n) {
|
|
|
146
150
|
return 1;
|
|
147
151
|
throw new Error(`Unknown dtype ${n}`);
|
|
148
152
|
}
|
|
149
|
-
function
|
|
153
|
+
function Me(n) {
|
|
150
154
|
if (n == null)
|
|
151
155
|
return 0;
|
|
152
156
|
let t = 0;
|
|
153
157
|
return n.forEach((e) => t += e.length), t;
|
|
154
158
|
}
|
|
155
|
-
function
|
|
159
|
+
function Rt(n) {
|
|
156
160
|
return typeof n == "string" || n instanceof String;
|
|
157
161
|
}
|
|
158
|
-
function
|
|
162
|
+
function Fe(n) {
|
|
159
163
|
return typeof n == "boolean";
|
|
160
164
|
}
|
|
161
|
-
function
|
|
165
|
+
function Re(n) {
|
|
162
166
|
return typeof n == "number";
|
|
163
167
|
}
|
|
164
|
-
function
|
|
165
|
-
return Array.isArray(n) ?
|
|
168
|
+
function gt(n) {
|
|
169
|
+
return Array.isArray(n) ? gt(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray ? "int32" : Re(n) ? "float32" : Rt(n) ? "string" : Fe(n) ? "bool" : "float32";
|
|
166
170
|
}
|
|
167
|
-
function
|
|
171
|
+
function St(n) {
|
|
168
172
|
return !!(n && n.constructor && n.call && n.apply);
|
|
169
173
|
}
|
|
170
|
-
function
|
|
174
|
+
function xt(n) {
|
|
171
175
|
const t = n.length;
|
|
172
176
|
if (t < 2)
|
|
173
177
|
return [];
|
|
@@ -177,7 +181,7 @@ function Ft(n) {
|
|
|
177
181
|
e[s] = e[s + 1] * n[s + 1];
|
|
178
182
|
return e;
|
|
179
183
|
}
|
|
180
|
-
function
|
|
184
|
+
function Xt(n, t, e, s = !1) {
|
|
181
185
|
const r = new Array();
|
|
182
186
|
if (t.length === 1) {
|
|
183
187
|
const i = t[0] * (s ? 2 : 1);
|
|
@@ -186,11 +190,11 @@ function qt(n, t, e, s = !1) {
|
|
|
186
190
|
} else {
|
|
187
191
|
const i = t[0], o = t.slice(1), a = o.reduce((c, l) => c * l) * (s ? 2 : 1);
|
|
188
192
|
for (let c = 0; c < i; c++)
|
|
189
|
-
r[c] =
|
|
193
|
+
r[c] = Xt(n + c * a, o, e, s);
|
|
190
194
|
}
|
|
191
195
|
return r;
|
|
192
196
|
}
|
|
193
|
-
function
|
|
197
|
+
function Pt(n, t, e = !1) {
|
|
194
198
|
if (n.length === 0)
|
|
195
199
|
return t[0];
|
|
196
200
|
const s = n.reduce((r, i) => r * i) * (e ? 2 : 1);
|
|
@@ -198,15 +202,15 @@ function Dt(n, t, e = !1) {
|
|
|
198
202
|
return [];
|
|
199
203
|
if (s !== t.length)
|
|
200
204
|
throw new Error(`[${n}] does not match the input size ${t.length}${e ? " for a complex tensor" : ""}.`);
|
|
201
|
-
return
|
|
205
|
+
return Xt(0, n, t, e);
|
|
202
206
|
}
|
|
203
|
-
function
|
|
204
|
-
const e =
|
|
207
|
+
function xe(n, t) {
|
|
208
|
+
const e = Yt(n, t);
|
|
205
209
|
for (let s = 0; s < e.length; s++)
|
|
206
210
|
e[s] = 1;
|
|
207
211
|
return e;
|
|
208
212
|
}
|
|
209
|
-
function
|
|
213
|
+
function Yt(n, t) {
|
|
210
214
|
if (t == null || t === "float32" || t === "complex64")
|
|
211
215
|
return new Float32Array(n);
|
|
212
216
|
if (t === "int32")
|
|
@@ -215,12 +219,12 @@ function Ht(n, t) {
|
|
|
215
219
|
return new Uint8Array(n);
|
|
216
220
|
throw new Error(`Unknown data type ${t}`);
|
|
217
221
|
}
|
|
218
|
-
function
|
|
222
|
+
function $t(n) {
|
|
219
223
|
n.forEach((t) => {
|
|
220
224
|
b(Number.isInteger(t) && t >= 0, () => `Tensor must have a shape comprised of positive integers but got shape [${n}].`);
|
|
221
225
|
});
|
|
222
226
|
}
|
|
223
|
-
function
|
|
227
|
+
function Nt(n) {
|
|
224
228
|
return n && n.then && typeof n.then == "function";
|
|
225
229
|
}
|
|
226
230
|
/**
|
|
@@ -239,11 +243,11 @@ function xt(n) {
|
|
|
239
243
|
* limitations under the License.
|
|
240
244
|
* =============================================================================
|
|
241
245
|
*/
|
|
242
|
-
const
|
|
243
|
-
class
|
|
246
|
+
const Ot = "tfjsflags";
|
|
247
|
+
class $e {
|
|
244
248
|
// tslint:disable-next-line: no-any
|
|
245
249
|
constructor(t) {
|
|
246
|
-
this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams =
|
|
250
|
+
this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams = Ne, this.populateURLFlags();
|
|
247
251
|
}
|
|
248
252
|
setPlatform(t, e) {
|
|
249
253
|
this.platform != null && (S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(`Platform ${this.platformName} has already been set. Overwriting the platform with ${t}.`)), this.platformName = t, this.platform = e;
|
|
@@ -261,7 +265,7 @@ class ve {
|
|
|
261
265
|
if (t in this.flags)
|
|
262
266
|
return this.flags[t];
|
|
263
267
|
const e = this.evaluateFlag(t);
|
|
264
|
-
if (
|
|
268
|
+
if (Nt(e))
|
|
265
269
|
throw new Error(`Flag ${t} cannot be synchronously evaluated. Please use getAsync() instead.`);
|
|
266
270
|
return this.flags[t] = e, this.flags[t];
|
|
267
271
|
}
|
|
@@ -301,29 +305,29 @@ class ve {
|
|
|
301
305
|
if (typeof this.global > "u" || typeof this.global.location > "u" || typeof this.global.location.search > "u")
|
|
302
306
|
return;
|
|
303
307
|
const t = this.getQueryParams(this.global.location.search);
|
|
304
|
-
|
|
308
|
+
Ot in t && t[Ot].split(",").forEach((s) => {
|
|
305
309
|
const [r, i] = s.split(":");
|
|
306
|
-
this.urlFlags[r] =
|
|
310
|
+
this.urlFlags[r] = Ce(r, i);
|
|
307
311
|
});
|
|
308
312
|
}
|
|
309
313
|
}
|
|
310
|
-
function
|
|
314
|
+
function Ne(n) {
|
|
311
315
|
const t = {};
|
|
312
|
-
return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (
|
|
316
|
+
return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (De(t, s[0], s[1]), s.join("="))), t;
|
|
313
317
|
}
|
|
314
|
-
function
|
|
318
|
+
function De(n, t, e) {
|
|
315
319
|
n[decodeURIComponent(t)] = decodeURIComponent(e || "");
|
|
316
320
|
}
|
|
317
|
-
function
|
|
321
|
+
function Ce(n, t) {
|
|
318
322
|
const e = t.toLowerCase();
|
|
319
323
|
return e === "true" || e === "false" ? e === "true" : `${+e}` === e ? +e : t;
|
|
320
324
|
}
|
|
321
325
|
function S() {
|
|
322
|
-
return
|
|
326
|
+
return Qt;
|
|
323
327
|
}
|
|
324
|
-
let
|
|
325
|
-
function
|
|
326
|
-
|
|
328
|
+
let Qt = null;
|
|
329
|
+
function _e(n) {
|
|
330
|
+
Qt = n;
|
|
327
331
|
}
|
|
328
332
|
/**
|
|
329
333
|
* @license
|
|
@@ -341,30 +345,30 @@ function xe(n) {
|
|
|
341
345
|
* limitations under the License.
|
|
342
346
|
* =============================================================================
|
|
343
347
|
*/
|
|
344
|
-
let
|
|
345
|
-
function
|
|
346
|
-
if (
|
|
348
|
+
let mt;
|
|
349
|
+
function Zt() {
|
|
350
|
+
if (mt == null) {
|
|
347
351
|
let n;
|
|
348
352
|
if (typeof window < "u")
|
|
349
353
|
n = window;
|
|
350
|
-
else if (typeof
|
|
351
|
-
n =
|
|
352
|
-
else if (typeof
|
|
353
|
-
n =
|
|
354
|
+
else if (typeof _t < "u")
|
|
355
|
+
n = _t;
|
|
356
|
+
else if (typeof Y < "u")
|
|
357
|
+
n = Y;
|
|
354
358
|
else if (typeof self < "u")
|
|
355
359
|
n = self;
|
|
356
360
|
else
|
|
357
361
|
throw new Error("Could not find a global object");
|
|
358
|
-
|
|
362
|
+
mt = n;
|
|
359
363
|
}
|
|
360
|
-
return
|
|
364
|
+
return mt;
|
|
361
365
|
}
|
|
362
|
-
function
|
|
363
|
-
const n =
|
|
366
|
+
function Pe() {
|
|
367
|
+
const n = Zt();
|
|
364
368
|
return n._tfGlobals == null && (n._tfGlobals = /* @__PURE__ */ new Map()), n._tfGlobals;
|
|
365
369
|
}
|
|
366
|
-
function
|
|
367
|
-
const e =
|
|
370
|
+
function Dt(n, t) {
|
|
371
|
+
const e = Pe();
|
|
368
372
|
if (e.has(n))
|
|
369
373
|
return e.get(n);
|
|
370
374
|
{
|
|
@@ -372,7 +376,7 @@ function Nt(n, t) {
|
|
|
372
376
|
return e.set(n, s), e.get(n);
|
|
373
377
|
}
|
|
374
378
|
}
|
|
375
|
-
const
|
|
379
|
+
const Oe = "Abs", te = "Add", ks = "BatchMatMul", ee = "Cast", Ts = "Complex", Le = "ComplexAbs", Ue = "RealDiv", Es = "Elu", Ge = "Fill", ze = "FloorDiv", ne = "Identity", Bs = "Imag", As = "LeakyRelu", We = "Maximum", je = "Multiply", vs = "Neg", Ke = "Pow", Ms = "Prelu", Fs = "Real", Rs = "Relu", xs = "Reshape", $s = "Relu6", Ns = "Sigmoid", Ve = "Sqrt", Ds = "Sum", qe = "Sub", Cs = "Transpose", He = "ZerosLike", _s = "Step", Ps = "_FusedMatMul";
|
|
376
380
|
/**
|
|
377
381
|
* @license
|
|
378
382
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -389,7 +393,7 @@ const $e = "Abs", Yt = "Add", gs = "BatchMatMul", Qt = "Cast", ms = "Complex", D
|
|
|
389
393
|
* limitations under the License.
|
|
390
394
|
* =============================================================================
|
|
391
395
|
*/
|
|
392
|
-
function
|
|
396
|
+
function st(...n) {
|
|
393
397
|
S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(...n);
|
|
394
398
|
}
|
|
395
399
|
/**
|
|
@@ -408,16 +412,16 @@ function nt(...n) {
|
|
|
408
412
|
* limitations under the License.
|
|
409
413
|
* =============================================================================
|
|
410
414
|
*/
|
|
411
|
-
const
|
|
412
|
-
function
|
|
413
|
-
const e =
|
|
414
|
-
return
|
|
415
|
+
const se = Dt("kernelRegistry", () => /* @__PURE__ */ new Map()), Je = Dt("gradRegistry", () => /* @__PURE__ */ new Map());
|
|
416
|
+
function Lt(n, t) {
|
|
417
|
+
const e = Xe(n, t);
|
|
418
|
+
return se.get(e);
|
|
415
419
|
}
|
|
416
|
-
function
|
|
417
|
-
return
|
|
420
|
+
function Ut(n) {
|
|
421
|
+
return Je.get(n);
|
|
418
422
|
}
|
|
419
|
-
function
|
|
420
|
-
const t =
|
|
423
|
+
function Gt(n) {
|
|
424
|
+
const t = se.entries(), e = [];
|
|
421
425
|
for (; ; ) {
|
|
422
426
|
const { done: s, value: r } = t.next();
|
|
423
427
|
if (s)
|
|
@@ -427,7 +431,7 @@ function Pt(n) {
|
|
|
427
431
|
}
|
|
428
432
|
return e;
|
|
429
433
|
}
|
|
430
|
-
function
|
|
434
|
+
function Xe(n, t) {
|
|
431
435
|
return `${t}_${n}`;
|
|
432
436
|
}
|
|
433
437
|
/**
|
|
@@ -446,7 +450,7 @@ function Ke(n, t) {
|
|
|
446
450
|
* limitations under the License.
|
|
447
451
|
* =============================================================================
|
|
448
452
|
*/
|
|
449
|
-
function
|
|
453
|
+
function re(n) {
|
|
450
454
|
return n instanceof Float32Array || n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray;
|
|
451
455
|
}
|
|
452
456
|
/**
|
|
@@ -465,13 +469,13 @@ function ee(n) {
|
|
|
465
469
|
* limitations under the License.
|
|
466
470
|
* =============================================================================
|
|
467
471
|
*/
|
|
468
|
-
function
|
|
472
|
+
function Ye(n, t) {
|
|
469
473
|
return n instanceof Float32Array && t === "float32" || n instanceof Int32Array && t === "int32" || n instanceof Uint8Array && t === "bool";
|
|
470
474
|
}
|
|
471
|
-
function
|
|
475
|
+
function ie(n, t) {
|
|
472
476
|
if (t === "string")
|
|
473
477
|
throw new Error("Cannot convert a string[] to a TypedArray");
|
|
474
|
-
if (Array.isArray(n) && (n =
|
|
478
|
+
if (Array.isArray(n) && (n = at(n)), S().getBool("DEBUG") && Ae(n, t), Ye(n, t))
|
|
475
479
|
return n;
|
|
476
480
|
if (t == null || t === "float32" || t === "complex64")
|
|
477
481
|
return new Float32Array(n);
|
|
@@ -485,30 +489,30 @@ function ne(n, t) {
|
|
|
485
489
|
} else
|
|
486
490
|
throw new Error(`Unknown data type ${t}`);
|
|
487
491
|
}
|
|
488
|
-
function
|
|
492
|
+
function ht() {
|
|
489
493
|
return S().platform.now();
|
|
490
494
|
}
|
|
491
|
-
function
|
|
495
|
+
function Qe(n, t = "utf-8") {
|
|
492
496
|
return t = t || "utf-8", S().platform.encode(n, t);
|
|
493
497
|
}
|
|
494
|
-
function
|
|
498
|
+
function zt(n, t = "utf-8") {
|
|
495
499
|
return t = t || "utf-8", S().platform.decode(n, t);
|
|
496
500
|
}
|
|
497
501
|
function R(n) {
|
|
498
|
-
return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) :
|
|
502
|
+
return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) : re(n);
|
|
499
503
|
}
|
|
500
|
-
function
|
|
501
|
-
if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" ||
|
|
504
|
+
function at(n, t = [], e = !1) {
|
|
505
|
+
if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" || Nt(n) || n == null || R(n) && e)
|
|
502
506
|
t.push(n);
|
|
503
507
|
else if (Array.isArray(n) || R(n))
|
|
504
508
|
for (let s = 0; s < n.length; ++s)
|
|
505
|
-
|
|
509
|
+
at(n[s], t, e);
|
|
506
510
|
else {
|
|
507
511
|
let s = -1;
|
|
508
512
|
for (const r of Object.keys(n))
|
|
509
513
|
/^([1-9]+[0-9]*|0)$/.test(r) && (s = Math.max(s, Number(r)));
|
|
510
514
|
for (let r = 0; r <= s; r++)
|
|
511
|
-
|
|
515
|
+
at(n[r], t, e);
|
|
512
516
|
}
|
|
513
517
|
return t;
|
|
514
518
|
}
|
|
@@ -528,9 +532,9 @@ function ot(n, t = [], e = !1) {
|
|
|
528
532
|
* limitations under the License.
|
|
529
533
|
* =============================================================================
|
|
530
534
|
*/
|
|
531
|
-
class
|
|
535
|
+
class Ze {
|
|
532
536
|
constructor(t, e) {
|
|
533
|
-
this.backendTimer = t, this.logger = e, e == null && (this.logger = new
|
|
537
|
+
this.backendTimer = t, this.logger = e, e == null && (this.logger = new en());
|
|
534
538
|
}
|
|
535
539
|
profileKernel(t, e, s) {
|
|
536
540
|
let r;
|
|
@@ -538,20 +542,20 @@ class He {
|
|
|
538
542
|
r = s();
|
|
539
543
|
};
|
|
540
544
|
let o;
|
|
541
|
-
const a =
|
|
545
|
+
const a = ht();
|
|
542
546
|
if (this.backendTimer.timerAvailable())
|
|
543
547
|
o = this.backendTimer.time(i);
|
|
544
548
|
else {
|
|
545
549
|
i();
|
|
546
550
|
for (const l of r)
|
|
547
551
|
l.dataSync();
|
|
548
|
-
o = Promise.resolve({ kernelMs:
|
|
552
|
+
o = Promise.resolve({ kernelMs: ht() - a });
|
|
549
553
|
}
|
|
550
554
|
if (S().getBool("CHECK_COMPUTATION_FOR_ERRORS"))
|
|
551
555
|
for (let l = 0; l < r.length; l++) {
|
|
552
556
|
const u = r[l];
|
|
553
557
|
u.data().then((h) => {
|
|
554
|
-
|
|
558
|
+
tn(h, u.dtype, t);
|
|
555
559
|
});
|
|
556
560
|
}
|
|
557
561
|
return {
|
|
@@ -571,7 +575,7 @@ class He {
|
|
|
571
575
|
});
|
|
572
576
|
}
|
|
573
577
|
}
|
|
574
|
-
function
|
|
578
|
+
function tn(n, t, e) {
|
|
575
579
|
if (t !== "float32")
|
|
576
580
|
return !1;
|
|
577
581
|
for (let s = 0; s < n.length; s++) {
|
|
@@ -581,9 +585,9 @@ function Je(n, t, e) {
|
|
|
581
585
|
}
|
|
582
586
|
return !1;
|
|
583
587
|
}
|
|
584
|
-
class
|
|
588
|
+
class en {
|
|
585
589
|
logKernelProfile(t, e, s, r, i, o) {
|
|
586
|
-
const a = typeof r == "number" ?
|
|
590
|
+
const a = typeof r == "number" ? ct(`${r}ms`, 9) : r.error, c = ct(t, 25), l = e.rank, u = e.size, h = ct(e.shape.toString(), 14);
|
|
587
591
|
let f = "";
|
|
588
592
|
for (const m in i) {
|
|
589
593
|
const y = i[m];
|
|
@@ -611,7 +615,7 @@ class Xe {
|
|
|
611
615
|
* limitations under the License.
|
|
612
616
|
* =============================================================================
|
|
613
617
|
*/
|
|
614
|
-
function
|
|
618
|
+
function nn(n, t, e) {
|
|
615
619
|
const s = {}, r = {};
|
|
616
620
|
for (let c = 0; c < t.length; c++)
|
|
617
621
|
s[t[c].id] = !0;
|
|
@@ -656,7 +660,7 @@ function Ye(n, t, e) {
|
|
|
656
660
|
}
|
|
657
661
|
return a;
|
|
658
662
|
}
|
|
659
|
-
function
|
|
663
|
+
function sn(n, t, e, s) {
|
|
660
664
|
for (let r = t.length - 1; r >= 0; r--) {
|
|
661
665
|
const i = t[r], o = [];
|
|
662
666
|
if (i.outputs.forEach((c) => {
|
|
@@ -672,7 +676,7 @@ function Qe(n, t, e, s) {
|
|
|
672
676
|
if (l.dtype !== "float32")
|
|
673
677
|
throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input ${c} must have 'float32' dtype, but has '${l.dtype}'`);
|
|
674
678
|
const u = i.inputs[c];
|
|
675
|
-
if (!
|
|
679
|
+
if (!Ft(l.shape, u.shape))
|
|
676
680
|
throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input '${c}' has shape '${l.shape}', which does not match the shape of the input '${u.shape}'`);
|
|
677
681
|
if (n[u.id] == null)
|
|
678
682
|
n[u.id] = l;
|
|
@@ -699,56 +703,56 @@ function Qe(n, t, e, s) {
|
|
|
699
703
|
* limitations under the License.
|
|
700
704
|
* =============================================================================
|
|
701
705
|
*/
|
|
702
|
-
const
|
|
703
|
-
function
|
|
704
|
-
const r =
|
|
706
|
+
const Wt = 20, rt = 3, pt = 7;
|
|
707
|
+
function rn(n, t, e, s) {
|
|
708
|
+
const r = xt(t), i = on(n, t, e, r), o = t.length, a = ut(n, t, e, r, i), c = ["Tensor"];
|
|
705
709
|
return s && (c.push(` dtype: ${e}`), c.push(` rank: ${o}`), c.push(` shape: [${t}]`), c.push(" values:")), c.push(a.map((l) => " " + l).join(`
|
|
706
710
|
`)), c.join(`
|
|
707
711
|
`);
|
|
708
712
|
}
|
|
709
|
-
function
|
|
710
|
-
const r = U(t), i = s[s.length - 1], o = new Array(i).fill(0), a = t.length, c = e === "complex64" ?
|
|
713
|
+
function on(n, t, e, s) {
|
|
714
|
+
const r = U(t), i = s[s.length - 1], o = new Array(i).fill(0), a = t.length, c = e === "complex64" ? ot(n) : n;
|
|
711
715
|
if (a > 1)
|
|
712
716
|
for (let l = 0; l < r / i; l++) {
|
|
713
717
|
const u = l * i;
|
|
714
718
|
for (let h = 0; h < i; h++)
|
|
715
|
-
o[h] = Math.max(o[h],
|
|
719
|
+
o[h] = Math.max(o[h], it(c[u + h], 0, e).length);
|
|
716
720
|
}
|
|
717
721
|
return o;
|
|
718
722
|
}
|
|
719
|
-
function
|
|
723
|
+
function it(n, t, e) {
|
|
720
724
|
let s;
|
|
721
|
-
return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(
|
|
725
|
+
return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(pt))} + ${parseFloat(n[1].toFixed(pt))}j` : Rt(n) ? s = `'${n}'` : e === "bool" ? s = oe(n) : s = parseFloat(n.toFixed(pt)).toString(), ct(s, t);
|
|
722
726
|
}
|
|
723
|
-
function
|
|
727
|
+
function oe(n) {
|
|
724
728
|
return n === 0 ? "false" : "true";
|
|
725
729
|
}
|
|
726
|
-
function
|
|
730
|
+
function ut(n, t, e, s, r, i = !0) {
|
|
727
731
|
const o = e === "complex64" ? 2 : 1, a = t[0], c = t.length;
|
|
728
732
|
if (c === 0) {
|
|
729
733
|
if (e === "complex64") {
|
|
730
|
-
const d =
|
|
731
|
-
return [
|
|
734
|
+
const d = ot(n);
|
|
735
|
+
return [it(d[0], 0, e)];
|
|
732
736
|
}
|
|
733
|
-
return e === "bool" ? [
|
|
737
|
+
return e === "bool" ? [oe(n[0])] : [n[0].toString()];
|
|
734
738
|
}
|
|
735
739
|
if (c === 1) {
|
|
736
|
-
if (a >
|
|
737
|
-
const I =
|
|
738
|
-
let T = Array.from(n.slice(0, I)),
|
|
739
|
-
return e === "complex64" && (T =
|
|
740
|
-
"[" + T.map((q, H) =>
|
|
740
|
+
if (a > Wt) {
|
|
741
|
+
const I = rt * o;
|
|
742
|
+
let T = Array.from(n.slice(0, I)), et = Array.from(n.slice((a - rt) * o, a * o));
|
|
743
|
+
return e === "complex64" && (T = ot(T), et = ot(et)), [
|
|
744
|
+
"[" + T.map((q, H) => it(q, r[H], e)).join(", ") + ", ..., " + et.map((q, H) => it(q, r[a - rt + H], e)).join(", ") + "]"
|
|
741
745
|
];
|
|
742
746
|
}
|
|
743
747
|
return [
|
|
744
|
-
"[" + (e === "complex64" ?
|
|
748
|
+
"[" + (e === "complex64" ? ot(n) : Array.from(n)).map((I, T) => it(I, r[T], e)).join(", ") + "]"
|
|
745
749
|
];
|
|
746
750
|
}
|
|
747
751
|
const l = t.slice(1), u = s.slice(1), h = s[0] * o, f = [];
|
|
748
|
-
if (a >
|
|
749
|
-
for (let d = 0; d <
|
|
752
|
+
if (a > Wt) {
|
|
753
|
+
for (let d = 0; d < rt; d++) {
|
|
750
754
|
const I = d * h, T = I + h;
|
|
751
|
-
f.push(...
|
|
755
|
+
f.push(...ut(
|
|
752
756
|
n.slice(I, T),
|
|
753
757
|
l,
|
|
754
758
|
e,
|
|
@@ -759,9 +763,9 @@ function ct(n, t, e, s, r, i = !0) {
|
|
|
759
763
|
));
|
|
760
764
|
}
|
|
761
765
|
f.push("...");
|
|
762
|
-
for (let d = a -
|
|
766
|
+
for (let d = a - rt; d < a; d++) {
|
|
763
767
|
const I = d * h, T = I + h;
|
|
764
|
-
f.push(...
|
|
768
|
+
f.push(...ut(
|
|
765
769
|
n.slice(I, T),
|
|
766
770
|
l,
|
|
767
771
|
e,
|
|
@@ -774,7 +778,7 @@ function ct(n, t, e, s, r, i = !0) {
|
|
|
774
778
|
} else
|
|
775
779
|
for (let d = 0; d < a; d++) {
|
|
776
780
|
const I = d * h, T = I + h;
|
|
777
|
-
f.push(...
|
|
781
|
+
f.push(...ut(
|
|
778
782
|
n.slice(I, T),
|
|
779
783
|
l,
|
|
780
784
|
e,
|
|
@@ -795,7 +799,7 @@ function ct(n, t, e, s, r, i = !0) {
|
|
|
795
799
|
`;
|
|
796
800
|
return f[f.length - 1] = " " + f[f.length - 1] + "]" + (i ? "" : y), f;
|
|
797
801
|
}
|
|
798
|
-
function
|
|
802
|
+
function ot(n) {
|
|
799
803
|
const t = [];
|
|
800
804
|
for (let e = 0; e < n.length; e += 2)
|
|
801
805
|
t.push([n[e], n[e + 1]]);
|
|
@@ -817,7 +821,7 @@ function it(n) {
|
|
|
817
821
|
* limitations under the License.
|
|
818
822
|
* =============================================================================
|
|
819
823
|
*/
|
|
820
|
-
class
|
|
824
|
+
class an {
|
|
821
825
|
constructor(t, e, s) {
|
|
822
826
|
if (this.dtype = e, this.shape = t.slice(), this.size = U(t), s != null) {
|
|
823
827
|
const r = s.length;
|
|
@@ -825,7 +829,7 @@ class en {
|
|
|
825
829
|
}
|
|
826
830
|
if (e === "complex64")
|
|
827
831
|
throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).");
|
|
828
|
-
this.values = s ||
|
|
832
|
+
this.values = s || Be(e, this.size), this.strides = xt(t);
|
|
829
833
|
}
|
|
830
834
|
/**
|
|
831
835
|
* Sets a value in the buffer at a given location.
|
|
@@ -895,15 +899,15 @@ class en {
|
|
|
895
899
|
}
|
|
896
900
|
}
|
|
897
901
|
let x = null, J = null;
|
|
898
|
-
function
|
|
902
|
+
function ln(n) {
|
|
899
903
|
x = n;
|
|
900
904
|
}
|
|
901
|
-
function
|
|
905
|
+
function cn(n) {
|
|
902
906
|
J = n;
|
|
903
907
|
}
|
|
904
|
-
class
|
|
908
|
+
class $ {
|
|
905
909
|
constructor(t, e, s, r) {
|
|
906
|
-
this.kept = !1, this.isDisposedInternal = !1, this.shape = t.slice(), this.dtype = e || "float32", this.size = U(t), this.strides =
|
|
910
|
+
this.kept = !1, this.isDisposedInternal = !1, this.shape = t.slice(), this.dtype = e || "float32", this.size = U(t), this.strides = xt(t), this.dataId = s, this.id = r, this.rankType = this.rank < 5 ? this.rank.toString() : "higher";
|
|
907
911
|
}
|
|
908
912
|
get rank() {
|
|
909
913
|
return this.shape.length;
|
|
@@ -932,7 +936,7 @@ class N {
|
|
|
932
936
|
*/
|
|
933
937
|
async array() {
|
|
934
938
|
const t = await this.data();
|
|
935
|
-
return
|
|
939
|
+
return Pt(this.shape, t, this.dtype === "complex64");
|
|
936
940
|
}
|
|
937
941
|
/**
|
|
938
942
|
* Returns the tensor data as a nested array. The transfer of data is done
|
|
@@ -941,7 +945,7 @@ class N {
|
|
|
941
945
|
* @doc {heading: 'Tensors', subheading: 'Classes'}
|
|
942
946
|
*/
|
|
943
947
|
arraySync() {
|
|
944
|
-
return
|
|
948
|
+
return Pt(this.shape, this.dataSync(), this.dtype === "complex64");
|
|
945
949
|
}
|
|
946
950
|
/**
|
|
947
951
|
* Asynchronously downloads the values from the `tf.Tensor`. Returns a
|
|
@@ -955,7 +959,7 @@ class N {
|
|
|
955
959
|
if (this.dtype === "string") {
|
|
956
960
|
const e = await t;
|
|
957
961
|
try {
|
|
958
|
-
return e.map((s) =>
|
|
962
|
+
return e.map((s) => zt(s));
|
|
959
963
|
} catch {
|
|
960
964
|
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
|
|
961
965
|
}
|
|
@@ -1010,7 +1014,7 @@ class N {
|
|
|
1010
1014
|
const t = x().readSync(this.dataId);
|
|
1011
1015
|
if (this.dtype === "string")
|
|
1012
1016
|
try {
|
|
1013
|
-
return t.map((e) =>
|
|
1017
|
+
return t.map((e) => zt(e));
|
|
1014
1018
|
} catch {
|
|
1015
1019
|
throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
|
|
1016
1020
|
}
|
|
@@ -1062,7 +1066,7 @@ class N {
|
|
|
1062
1066
|
*/
|
|
1063
1067
|
toString(t = !1) {
|
|
1064
1068
|
const e = this.dataSync();
|
|
1065
|
-
return
|
|
1069
|
+
return rn(e, this.shape, this.dtype, t);
|
|
1066
1070
|
}
|
|
1067
1071
|
cast(t) {
|
|
1068
1072
|
return this.throwIfDisposed(), J.cast(this, t);
|
|
@@ -1071,14 +1075,14 @@ class N {
|
|
|
1071
1075
|
return this.throwIfDisposed(), x().makeVariable(this, t, e, s);
|
|
1072
1076
|
}
|
|
1073
1077
|
}
|
|
1074
|
-
Object.defineProperty(
|
|
1078
|
+
Object.defineProperty($, Symbol.hasInstance, {
|
|
1075
1079
|
value: (n) => !!n && n.data != null && n.dataSync != null && n.throwIfDisposed != null
|
|
1076
1080
|
});
|
|
1077
|
-
function
|
|
1078
|
-
return
|
|
1081
|
+
function ae() {
|
|
1082
|
+
return Dt("Tensor", () => $);
|
|
1079
1083
|
}
|
|
1080
|
-
|
|
1081
|
-
class
|
|
1084
|
+
ae();
|
|
1085
|
+
class ft extends $ {
|
|
1082
1086
|
constructor(t, e, s, r) {
|
|
1083
1087
|
super(t.shape, t.dtype, t.dataId, r), this.trainable = e, this.name = s;
|
|
1084
1088
|
}
|
|
@@ -1093,7 +1097,7 @@ class ht extends N {
|
|
|
1093
1097
|
assign(t) {
|
|
1094
1098
|
if (t.dtype !== this.dtype)
|
|
1095
1099
|
throw new Error(`dtype of the new value (${t.dtype}) and previous value (${this.dtype}) must match`);
|
|
1096
|
-
if (!
|
|
1100
|
+
if (!Ft(t.shape, this.shape))
|
|
1097
1101
|
throw new Error(`shape of the new value (${t.shape}) and previous value (${this.shape}) must match`);
|
|
1098
1102
|
x().disposeTensor(this), this.dataId = t.dataId, x().incRef(
|
|
1099
1103
|
this,
|
|
@@ -1105,8 +1109,8 @@ class ht extends N {
|
|
|
1105
1109
|
x().disposeVariable(this), this.isDisposedInternal = !0;
|
|
1106
1110
|
}
|
|
1107
1111
|
}
|
|
1108
|
-
Object.defineProperty(
|
|
1109
|
-
value: (n) => n instanceof
|
|
1112
|
+
Object.defineProperty(ft, Symbol.hasInstance, {
|
|
1113
|
+
value: (n) => n instanceof $ && n.assign != null && n.assign instanceof Function
|
|
1110
1114
|
});
|
|
1111
1115
|
/**
|
|
1112
1116
|
* @license
|
|
@@ -1124,44 +1128,44 @@ Object.defineProperty(ht, Symbol.hasInstance, {
|
|
|
1124
1128
|
* limitations under the License.
|
|
1125
1129
|
* =============================================================================
|
|
1126
1130
|
*/
|
|
1127
|
-
var
|
|
1131
|
+
var jt;
|
|
1128
1132
|
(function(n) {
|
|
1129
1133
|
n.R0 = "R0", n.R1 = "R1", n.R2 = "R2", n.R3 = "R3", n.R4 = "R4", n.R5 = "R5", n.R6 = "R6";
|
|
1130
|
-
})(
|
|
1131
|
-
var
|
|
1134
|
+
})(jt || (jt = {}));
|
|
1135
|
+
var It;
|
|
1132
1136
|
(function(n) {
|
|
1133
1137
|
n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64";
|
|
1134
|
-
})(
|
|
1135
|
-
var
|
|
1138
|
+
})(It || (It = {}));
|
|
1139
|
+
var kt;
|
|
1136
1140
|
(function(n) {
|
|
1137
1141
|
n.float32 = "float32", n.int32 = "int32", n.bool = "bool", n.complex64 = "complex64";
|
|
1138
|
-
})(
|
|
1139
|
-
var
|
|
1142
|
+
})(kt || (kt = {}));
|
|
1143
|
+
var Tt;
|
|
1140
1144
|
(function(n) {
|
|
1141
1145
|
n.float32 = "float32", n.int32 = "float32", n.bool = "float32", n.complex64 = "complex64";
|
|
1142
|
-
})(
|
|
1143
|
-
var
|
|
1146
|
+
})(Tt || (Tt = {}));
|
|
1147
|
+
var Et;
|
|
1144
1148
|
(function(n) {
|
|
1145
1149
|
n.float32 = "complex64", n.int32 = "complex64", n.bool = "complex64", n.complex64 = "complex64";
|
|
1146
|
-
})(
|
|
1147
|
-
const
|
|
1148
|
-
float32:
|
|
1149
|
-
int32:
|
|
1150
|
-
bool:
|
|
1151
|
-
complex64:
|
|
1150
|
+
})(Et || (Et = {}));
|
|
1151
|
+
const un = {
|
|
1152
|
+
float32: Tt,
|
|
1153
|
+
int32: It,
|
|
1154
|
+
bool: kt,
|
|
1155
|
+
complex64: Et
|
|
1152
1156
|
};
|
|
1153
|
-
function
|
|
1157
|
+
function hn(n, t) {
|
|
1154
1158
|
if (n === "string" || t === "string") {
|
|
1155
1159
|
if (n === "string" && t === "string")
|
|
1156
1160
|
return "string";
|
|
1157
1161
|
throw new Error(`Can not upcast ${n} with ${t}`);
|
|
1158
1162
|
}
|
|
1159
|
-
return
|
|
1163
|
+
return un[n][t];
|
|
1160
1164
|
}
|
|
1161
|
-
function
|
|
1165
|
+
function le(n) {
|
|
1162
1166
|
return n != null && typeof n == "object" && "texture" in n && n.texture instanceof WebGLTexture;
|
|
1163
1167
|
}
|
|
1164
|
-
function
|
|
1168
|
+
function ce(n) {
|
|
1165
1169
|
return typeof GPUBuffer < "u" && n != null && typeof n == "object" && "buffer" in n && n.buffer instanceof GPUBuffer;
|
|
1166
1170
|
}
|
|
1167
1171
|
/**
|
|
@@ -1183,29 +1187,29 @@ function oe(n) {
|
|
|
1183
1187
|
function K(n, t) {
|
|
1184
1188
|
if (n.dtype === t.dtype)
|
|
1185
1189
|
return [n, t];
|
|
1186
|
-
const e =
|
|
1190
|
+
const e = hn(n.dtype, t.dtype);
|
|
1187
1191
|
return [n.cast(e), t.cast(e)];
|
|
1188
1192
|
}
|
|
1189
|
-
function
|
|
1193
|
+
function ue(n) {
|
|
1190
1194
|
const t = [];
|
|
1191
|
-
return
|
|
1195
|
+
return he(n, t, /* @__PURE__ */ new Set()), t;
|
|
1192
1196
|
}
|
|
1193
|
-
function
|
|
1197
|
+
function he(n, t, e) {
|
|
1194
1198
|
if (n == null)
|
|
1195
1199
|
return;
|
|
1196
|
-
if (n instanceof
|
|
1200
|
+
if (n instanceof $) {
|
|
1197
1201
|
t.push(n);
|
|
1198
1202
|
return;
|
|
1199
1203
|
}
|
|
1200
|
-
if (!
|
|
1204
|
+
if (!fn(n))
|
|
1201
1205
|
return;
|
|
1202
1206
|
const s = n;
|
|
1203
1207
|
for (const r in s) {
|
|
1204
1208
|
const i = s[r];
|
|
1205
|
-
e.has(i) || (e.add(i),
|
|
1209
|
+
e.has(i) || (e.add(i), he(i, t, e));
|
|
1206
1210
|
}
|
|
1207
1211
|
}
|
|
1208
|
-
function
|
|
1212
|
+
function fn(n) {
|
|
1209
1213
|
return Array.isArray(n) || typeof n == "object";
|
|
1210
1214
|
}
|
|
1211
1215
|
/**
|
|
@@ -1224,10 +1228,10 @@ function an(n) {
|
|
|
1224
1228
|
* limitations under the License.
|
|
1225
1229
|
* =============================================================================
|
|
1226
1230
|
*/
|
|
1227
|
-
function
|
|
1231
|
+
function yt(n) {
|
|
1228
1232
|
return n.kernelName != null;
|
|
1229
1233
|
}
|
|
1230
|
-
class
|
|
1234
|
+
class Kt {
|
|
1231
1235
|
constructor() {
|
|
1232
1236
|
this.registeredVariables = {}, this.nextTapeNodeId = 0, this.numBytes = 0, this.numTensors = 0, this.numStringTensors = 0, this.numDataBuffers = 0, this.gradientDepth = 0, this.kernelDepth = 0, this.scopeStack = [], this.numDataMovesStack = [], this.nextScopeId = 0, this.tensorInfo = /* @__PURE__ */ new WeakMap(), this.profiling = !1, this.activeProfile = {
|
|
1233
1237
|
newBytes: 0,
|
|
@@ -1245,9 +1249,9 @@ class zt {
|
|
|
1245
1249
|
this.registeredVariables[t].dispose();
|
|
1246
1250
|
}
|
|
1247
1251
|
}
|
|
1248
|
-
class
|
|
1252
|
+
class Z {
|
|
1249
1253
|
constructor(t) {
|
|
1250
|
-
this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new
|
|
1254
|
+
this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new Kt();
|
|
1251
1255
|
}
|
|
1252
1256
|
async ready() {
|
|
1253
1257
|
if (this.pendingBackendInit != null)
|
|
@@ -1293,7 +1297,7 @@ class Q {
|
|
|
1293
1297
|
return t in this.registryFactory ? this.registryFactory[t].factory : null;
|
|
1294
1298
|
}
|
|
1295
1299
|
registerBackend(t, e, s = 1) {
|
|
1296
|
-
return t in this.registryFactory ? (
|
|
1300
|
+
return t in this.registryFactory ? (st(`${t} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[t] = { factory: e, priority: s }, !0);
|
|
1297
1301
|
}
|
|
1298
1302
|
async setBackend(t) {
|
|
1299
1303
|
if (this.registryFactory[t] == null)
|
|
@@ -1304,15 +1308,15 @@ class Q {
|
|
|
1304
1308
|
if (!(s ? await e : e))
|
|
1305
1309
|
return !1;
|
|
1306
1310
|
}
|
|
1307
|
-
return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new
|
|
1311
|
+
return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new Ze(this.backendInstance), !0;
|
|
1308
1312
|
}
|
|
1309
1313
|
setupRegisteredKernels() {
|
|
1310
|
-
|
|
1314
|
+
Gt(this.backendName).forEach((e) => {
|
|
1311
1315
|
e.setupFunc != null && e.setupFunc(this.backendInstance);
|
|
1312
1316
|
});
|
|
1313
1317
|
}
|
|
1314
1318
|
disposeRegisteredKernels(t) {
|
|
1315
|
-
|
|
1319
|
+
Gt(t).forEach((s) => {
|
|
1316
1320
|
s.disposeFunc != null && s.disposeFunc(this.registry[t]);
|
|
1317
1321
|
});
|
|
1318
1322
|
}
|
|
@@ -1328,13 +1332,13 @@ class Q {
|
|
|
1328
1332
|
throw new Error(`Cannot initialize backend ${t}, no registration found.`);
|
|
1329
1333
|
try {
|
|
1330
1334
|
const s = e.factory();
|
|
1331
|
-
if (s && !(s instanceof
|
|
1332
|
-
const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null,
|
|
1335
|
+
if (s && !(s instanceof Ee) && typeof s.then == "function") {
|
|
1336
|
+
const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null, st(`Initialization of backend ${t} failed`), st(o.stack || o.message)), !1));
|
|
1333
1337
|
return this.pendingBackendInit = i, { success: i, asyncInit: !0 };
|
|
1334
1338
|
} else
|
|
1335
1339
|
return this.registry[t] = s, { success: !0, asyncInit: !1 };
|
|
1336
1340
|
} catch (s) {
|
|
1337
|
-
return
|
|
1341
|
+
return st(`Initialization of backend ${t} failed`), st(s.stack || s.message), { success: !1, asyncInit: !1 };
|
|
1338
1342
|
}
|
|
1339
1343
|
}
|
|
1340
1344
|
removeBackend(t) {
|
|
@@ -1386,10 +1390,10 @@ class Q {
|
|
|
1386
1390
|
}
|
|
1387
1391
|
}
|
|
1388
1392
|
nextTensorId() {
|
|
1389
|
-
return
|
|
1393
|
+
return Z.nextTensorId++;
|
|
1390
1394
|
}
|
|
1391
1395
|
nextVariableId() {
|
|
1392
|
-
return
|
|
1396
|
+
return Z.nextVariableId++;
|
|
1393
1397
|
}
|
|
1394
1398
|
/**
|
|
1395
1399
|
* This method is called instead of the public-facing tensor.clone() when
|
|
@@ -1398,11 +1402,11 @@ class Q {
|
|
|
1398
1402
|
* execution.
|
|
1399
1403
|
*/
|
|
1400
1404
|
clone(t) {
|
|
1401
|
-
const e = g.runKernel(
|
|
1405
|
+
const e = g.runKernel(ne, { x: t }), s = { x: t }, r = (o) => ({
|
|
1402
1406
|
x: () => {
|
|
1403
1407
|
const a = "float32", c = { x: o }, l = { dtype: a };
|
|
1404
1408
|
return g.runKernel(
|
|
1405
|
-
|
|
1409
|
+
ee,
|
|
1406
1410
|
c,
|
|
1407
1411
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
1408
1412
|
l
|
|
@@ -1425,7 +1429,7 @@ class Q {
|
|
|
1425
1429
|
* tensors are not visible to the user.
|
|
1426
1430
|
*/
|
|
1427
1431
|
runKernel(t, e, s) {
|
|
1428
|
-
if (this.backendName == null && this.backend, !(
|
|
1432
|
+
if (this.backendName == null && this.backend, !(Lt(t, this.backendName) != null))
|
|
1429
1433
|
throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);
|
|
1430
1434
|
return this.runKernelFunc({ kernelName: t, inputs: e, attrs: s });
|
|
1431
1435
|
}
|
|
@@ -1454,20 +1458,20 @@ class Q {
|
|
|
1454
1458
|
let a;
|
|
1455
1459
|
this.backendName == null && this.backend;
|
|
1456
1460
|
let c;
|
|
1457
|
-
const l =
|
|
1458
|
-
if (
|
|
1461
|
+
const l = yt(t) ? t.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
|
|
1462
|
+
if (yt(t)) {
|
|
1459
1463
|
const { kernelName: y, inputs: d, attrs: I } = t;
|
|
1460
1464
|
this.backendName == null && this.backend;
|
|
1461
|
-
const T =
|
|
1465
|
+
const T = Lt(y, this.backendName);
|
|
1462
1466
|
b(T != null, () => `Cannot find registered kernel '${y}' for backend '${this.backendName}'`), a = () => {
|
|
1463
|
-
const
|
|
1467
|
+
const et = this.backend.numDataIds();
|
|
1464
1468
|
c = T.kernelFunc({ inputs: d, attrs: I, backend: this.backend });
|
|
1465
1469
|
const q = Array.isArray(c) ? c : [c];
|
|
1466
|
-
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(y,
|
|
1467
|
-
const H = q.map((
|
|
1470
|
+
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(y, et, q);
|
|
1471
|
+
const H = q.map((nt) => nt.rank != null ? nt : this.makeTensorFromTensorInfo(nt));
|
|
1468
1472
|
if (r) {
|
|
1469
|
-
const
|
|
1470
|
-
s = this.saveTensorsForBackwardMode(
|
|
1473
|
+
const nt = this.getTensorsForGradient(y, d, H);
|
|
1474
|
+
s = this.saveTensorsForBackwardMode(nt);
|
|
1471
1475
|
}
|
|
1472
1476
|
return H;
|
|
1473
1477
|
};
|
|
@@ -1482,7 +1486,7 @@ class Q {
|
|
|
1482
1486
|
return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(l, I, T), T;
|
|
1483
1487
|
};
|
|
1484
1488
|
}
|
|
1485
|
-
const { inputs: u, attrs: h } = t, f =
|
|
1489
|
+
const { inputs: u, attrs: h } = t, f = yt(t) ? null : t.backwardsFunc;
|
|
1486
1490
|
let m;
|
|
1487
1491
|
return this.scopedRun(
|
|
1488
1492
|
// Stop recording to a tape when running a kernel.
|
|
@@ -1519,7 +1523,7 @@ class Q {
|
|
|
1519
1523
|
* @param outputs an array of output tensors from forward mode of kernel.
|
|
1520
1524
|
*/
|
|
1521
1525
|
getTensorsForGradient(t, e, s) {
|
|
1522
|
-
const r =
|
|
1526
|
+
const r = Ut(t);
|
|
1523
1527
|
if (r != null) {
|
|
1524
1528
|
const i = r.inputsToSave || [], o = r.outputsToSave || [];
|
|
1525
1529
|
let a;
|
|
@@ -1539,10 +1543,10 @@ class Q {
|
|
|
1539
1543
|
throw new Error("Values passed to engine.makeTensor() are null");
|
|
1540
1544
|
s = s || "float32", r = r || this.backend;
|
|
1541
1545
|
let i = t;
|
|
1542
|
-
s === "string" &&
|
|
1543
|
-
const o = r.write(i, e, s), a = new
|
|
1546
|
+
s === "string" && Rt(t[0]) && (i = t.map((c) => Qe(c)));
|
|
1547
|
+
const o = r.write(i, e, s), a = new $(e, s, o, this.nextTensorId());
|
|
1544
1548
|
if (this.trackTensor(a, r), s === "string") {
|
|
1545
|
-
const c = this.state.tensorInfo.get(o), l =
|
|
1549
|
+
const c = this.state.tensorInfo.get(o), l = Me(i);
|
|
1546
1550
|
this.state.numBytes += l - c.bytes, c.bytes = l;
|
|
1547
1551
|
}
|
|
1548
1552
|
return a;
|
|
@@ -1564,12 +1568,12 @@ class Q {
|
|
|
1564
1568
|
* only increments the ref count used in memory tracking.
|
|
1565
1569
|
*/
|
|
1566
1570
|
makeTensorFromTensorInfo(t, e) {
|
|
1567
|
-
const { dataId: s, shape: r, dtype: i } = t, o = new
|
|
1571
|
+
const { dataId: s, shape: r, dtype: i } = t, o = new $(r, i, s, this.nextTensorId());
|
|
1568
1572
|
return this.trackTensor(o, e), o;
|
|
1569
1573
|
}
|
|
1570
1574
|
makeVariable(t, e = !0, s, r) {
|
|
1571
1575
|
s = s || this.nextVariableId().toString(), r != null && r !== t.dtype && (t = t.cast(r));
|
|
1572
|
-
const i = new
|
|
1576
|
+
const i = new ft(t, e, s, this.nextTensorId());
|
|
1573
1577
|
if (this.state.registeredVariables[i.name] != null)
|
|
1574
1578
|
throw new Error(`Variable with name ${i.name} was already registered`);
|
|
1575
1579
|
return this.state.registeredVariables[i.name] = i, this.incRef(i, this.backend), i;
|
|
@@ -1577,12 +1581,12 @@ class Q {
|
|
|
1577
1581
|
trackTensor(t, e) {
|
|
1578
1582
|
this.state.numTensors++, t.dtype === "string" && this.state.numStringTensors++;
|
|
1579
1583
|
let s = 0;
|
|
1580
|
-
t.dtype !== "complex64" && t.dtype !== "string" && (s = t.size *
|
|
1584
|
+
t.dtype !== "complex64" && t.dtype !== "string" && (s = t.size * wt(t.dtype)), this.state.numBytes += s, this.state.tensorInfo.has(t.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(t.dataId, {
|
|
1581
1585
|
backend: e || this.backend,
|
|
1582
1586
|
dtype: t.dtype,
|
|
1583
1587
|
shape: t.shape,
|
|
1584
1588
|
bytes: s
|
|
1585
|
-
})), t instanceof
|
|
1589
|
+
})), t instanceof ft || this.track(t);
|
|
1586
1590
|
}
|
|
1587
1591
|
// Track the tensor by dataId and increase the refCount for the dataId in the
|
|
1588
1592
|
// backend.
|
|
@@ -1600,7 +1604,7 @@ class Q {
|
|
|
1600
1604
|
return;
|
|
1601
1605
|
const e = this.state.tensorInfo.get(t.dataId);
|
|
1602
1606
|
if (this.state.numTensors--, t.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= e.bytes), t.dtype !== "complex64" && t.dtype !== "string") {
|
|
1603
|
-
const s = t.size *
|
|
1607
|
+
const s = t.size * wt(t.dtype);
|
|
1604
1608
|
this.state.numBytes -= s;
|
|
1605
1609
|
}
|
|
1606
1610
|
e.backend.disposeData(t.dataId) && this.removeDataId(t.dataId, e.backend);
|
|
@@ -1630,10 +1634,10 @@ class Q {
|
|
|
1630
1634
|
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
|
1631
1635
|
}
|
|
1632
1636
|
addTapeNode(t, e, s, r, i, o) {
|
|
1633
|
-
const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c =
|
|
1637
|
+
const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c = Ut(t);
|
|
1634
1638
|
c != null && (r = c.gradFunc), r != null && (a.gradient = (l) => (l = l.map((u, h) => {
|
|
1635
1639
|
if (u == null) {
|
|
1636
|
-
const f = s[h], m =
|
|
1640
|
+
const f = s[h], m = Yt(f.size, f.dtype);
|
|
1637
1641
|
return this.makeTensor(m, f.shape, f.dtype);
|
|
1638
1642
|
}
|
|
1639
1643
|
return u;
|
|
@@ -1665,7 +1669,7 @@ class Q {
|
|
|
1665
1669
|
* as scope() without the need for a function closure.
|
|
1666
1670
|
*/
|
|
1667
1671
|
endScope(t) {
|
|
1668
|
-
const e =
|
|
1672
|
+
const e = ue(t), s = new Set(e.map((i) => i.id));
|
|
1669
1673
|
for (let i = 0; i < this.state.activeScope.track.length; i++) {
|
|
1670
1674
|
const o = this.state.activeScope.track[i];
|
|
1671
1675
|
!o.kept && !s.has(o.id) && o.dispose();
|
|
@@ -1685,19 +1689,19 @@ class Q {
|
|
|
1685
1689
|
if (b(e.length > 0, () => "gradients() received an empty list of xs."), s != null && s.dtype !== "float32")
|
|
1686
1690
|
throw new Error(`dy must have 'float32' dtype, but has '${s.dtype}'`);
|
|
1687
1691
|
const i = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy("forward", t));
|
|
1688
|
-
b(i instanceof
|
|
1689
|
-
const o =
|
|
1692
|
+
b(i instanceof $, () => "The result y returned by f() must be a tensor.");
|
|
1693
|
+
const o = nn(this.state.activeTape, e, i);
|
|
1690
1694
|
if (!r && o.length === 0 && e.length > 0)
|
|
1691
1695
|
throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.");
|
|
1692
1696
|
return this.tidy("backward", () => {
|
|
1693
1697
|
const a = {};
|
|
1694
|
-
a[i.id] = s ??
|
|
1698
|
+
a[i.id] = s ?? dn(i.shape), sn(
|
|
1695
1699
|
a,
|
|
1696
1700
|
o,
|
|
1697
1701
|
// Pass the tidy function to avoid circular dep with `tape.ts`.
|
|
1698
1702
|
(l) => this.tidy(l),
|
|
1699
1703
|
// Pass an add function to avoide a circular dep with `tape.ts`.
|
|
1700
|
-
|
|
1704
|
+
gn
|
|
1701
1705
|
);
|
|
1702
1706
|
const c = e.map((l) => a[l.id]);
|
|
1703
1707
|
return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((l) => {
|
|
@@ -1707,16 +1711,16 @@ class Q {
|
|
|
1707
1711
|
});
|
|
1708
1712
|
}
|
|
1709
1713
|
customGrad(t) {
|
|
1710
|
-
return b(
|
|
1711
|
-
b(e.every((a) => a instanceof
|
|
1714
|
+
return b(St(t), () => "The f passed in customGrad(f) must be a function."), (...e) => {
|
|
1715
|
+
b(e.every((a) => a instanceof $), () => "The args passed in customGrad(f)(x1, x2,...) must all be tensors");
|
|
1712
1716
|
let s;
|
|
1713
1717
|
const r = {};
|
|
1714
1718
|
e.forEach((a, c) => {
|
|
1715
1719
|
r[c] = a;
|
|
1716
1720
|
});
|
|
1717
|
-
const i = (a, c) => (s = t(...e, c), b(s.value instanceof
|
|
1721
|
+
const i = (a, c) => (s = t(...e, c), b(s.value instanceof $, () => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"), b(St(s.gradFunc), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."), s.value), o = (a, c) => {
|
|
1718
1722
|
const l = s.gradFunc(a, c), u = Array.isArray(l) ? l : [l];
|
|
1719
|
-
b(u.length === e.length, () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."), b(u.every((f) => f instanceof
|
|
1723
|
+
b(u.length === e.length, () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."), b(u.every((f) => f instanceof $), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors.");
|
|
1720
1724
|
const h = {};
|
|
1721
1725
|
return u.forEach((f, m) => {
|
|
1722
1726
|
h[m] = () => f;
|
|
@@ -1739,8 +1743,8 @@ class Q {
|
|
|
1739
1743
|
return this.state.tensorInfo.get(t).backend.readToGPU(t, e);
|
|
1740
1744
|
}
|
|
1741
1745
|
async time(t) {
|
|
1742
|
-
const e =
|
|
1743
|
-
return s.wallMs =
|
|
1746
|
+
const e = ht(), s = await this.backend.time(t);
|
|
1747
|
+
return s.wallMs = ht() - e, s;
|
|
1744
1748
|
}
|
|
1745
1749
|
/**
|
|
1746
1750
|
* Tracks a Tensor in the current scope to be automatically cleaned up
|
|
@@ -1759,30 +1763,30 @@ class Q {
|
|
|
1759
1763
|
* registered backend factories.
|
|
1760
1764
|
*/
|
|
1761
1765
|
reset() {
|
|
1762
|
-
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new
|
|
1766
|
+
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new Kt();
|
|
1763
1767
|
for (const t in this.registry)
|
|
1764
1768
|
this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t];
|
|
1765
1769
|
this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
|
|
1766
1770
|
}
|
|
1767
1771
|
}
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
function
|
|
1771
|
-
const t =
|
|
1772
|
+
Z.nextTensorId = 0;
|
|
1773
|
+
Z.nextVariableId = 0;
|
|
1774
|
+
function dn(n) {
|
|
1775
|
+
const t = xe(U(n), "float32");
|
|
1772
1776
|
return g.makeTensor(t, n, "float32");
|
|
1773
1777
|
}
|
|
1774
|
-
function
|
|
1775
|
-
const n =
|
|
1778
|
+
function fe() {
|
|
1779
|
+
const n = Zt();
|
|
1776
1780
|
if (n._tfengine == null) {
|
|
1777
|
-
const t = new
|
|
1778
|
-
n._tfengine = new
|
|
1781
|
+
const t = new $e(n);
|
|
1782
|
+
n._tfengine = new Z(t);
|
|
1779
1783
|
}
|
|
1780
|
-
return
|
|
1784
|
+
return _e(n._tfengine.ENV), ln(() => n._tfengine), n._tfengine;
|
|
1781
1785
|
}
|
|
1782
|
-
const g =
|
|
1783
|
-
function
|
|
1786
|
+
const g = fe();
|
|
1787
|
+
function gn(n, t) {
|
|
1784
1788
|
const e = { a: n, b: t };
|
|
1785
|
-
return g.runKernel(
|
|
1789
|
+
return g.runKernel(te, e);
|
|
1786
1790
|
}
|
|
1787
1791
|
/**
|
|
1788
1792
|
* @license
|
|
@@ -1800,32 +1804,16 @@ function cn(n, t) {
|
|
|
1800
1804
|
* limitations under the License.
|
|
1801
1805
|
* =============================================================================
|
|
1802
1806
|
*/
|
|
1803
|
-
function
|
|
1807
|
+
function mn() {
|
|
1804
1808
|
return typeof window < "u" && window.document != null || //@ts-ignore
|
|
1805
1809
|
typeof WorkerGlobalScope < "u";
|
|
1806
1810
|
}
|
|
1807
|
-
/**
|
|
1808
|
-
* @license
|
|
1809
|
-
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
1810
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
1811
|
-
* you may not use this file except in compliance with the License.
|
|
1812
|
-
* You may obtain a copy of the License at
|
|
1813
|
-
*
|
|
1814
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
1815
|
-
*
|
|
1816
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
1817
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
1818
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
1819
|
-
* See the License for the specific language governing permissions and
|
|
1820
|
-
* limitations under the License.
|
|
1821
|
-
* =============================================================================
|
|
1822
|
-
*/
|
|
1823
1811
|
const A = S();
|
|
1824
1812
|
A.registerFlag("DEBUG", () => !1, (n) => {
|
|
1825
1813
|
n && console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.");
|
|
1826
1814
|
});
|
|
1827
|
-
A.registerFlag("IS_BROWSER", () =>
|
|
1828
|
-
A.registerFlag("IS_NODE", () => typeof
|
|
1815
|
+
A.registerFlag("IS_BROWSER", () => mn());
|
|
1816
|
+
A.registerFlag("IS_NODE", () => typeof Y < "u" && typeof Y.versions < "u" && typeof Y.versions.node < "u");
|
|
1829
1817
|
A.registerFlag("IS_CHROME", () => typeof navigator < "u" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor));
|
|
1830
1818
|
A.registerFlag("IS_SAFARI", () => typeof navigator < "u" && navigator != null && navigator.userAgent != null && /Safari/.test(navigator.userAgent) && /Apple/.test(navigator.vendor));
|
|
1831
1819
|
A.registerFlag("PROD", () => !1);
|
|
@@ -1852,23 +1840,23 @@ A.registerFlag("USE_SETTIMEOUTCUSTOM", () => !1);
|
|
|
1852
1840
|
* limitations under the License.
|
|
1853
1841
|
* =============================================================================
|
|
1854
1842
|
*/
|
|
1855
|
-
function
|
|
1843
|
+
function pn(n, t) {
|
|
1856
1844
|
let e = n;
|
|
1857
1845
|
if (R(n))
|
|
1858
1846
|
return t === "string" ? [] : [n.length];
|
|
1859
|
-
if (
|
|
1847
|
+
if (le(n)) {
|
|
1860
1848
|
const r = n.channels || "RGBA";
|
|
1861
1849
|
return [n.height, n.width * r.length];
|
|
1862
|
-
} else if (
|
|
1863
|
-
return [n.buffer.size / (t == null ? 4 :
|
|
1850
|
+
} else if (ce(n))
|
|
1851
|
+
return [n.buffer.size / (t == null ? 4 : wt(t))];
|
|
1864
1852
|
if (!Array.isArray(n))
|
|
1865
1853
|
return [];
|
|
1866
1854
|
const s = [];
|
|
1867
1855
|
for (; Array.isArray(e) || R(e) && t !== "string"; )
|
|
1868
1856
|
s.push(e.length), e = e[0];
|
|
1869
|
-
return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") &&
|
|
1857
|
+
return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") && de(n, s, []), s;
|
|
1870
1858
|
}
|
|
1871
|
-
function
|
|
1859
|
+
function de(n, t, e) {
|
|
1872
1860
|
if (e = e || [], !Array.isArray(n) && !R(n)) {
|
|
1873
1861
|
b(t.length === 0, () => `Element arr[${e.join("][")}] is a primitive, but should be an array/TypedArray of ${t[0]} elements`);
|
|
1874
1862
|
return;
|
|
@@ -1876,9 +1864,9 @@ function ue(n, t, e) {
|
|
|
1876
1864
|
b(t.length > 0, () => `Element arr[${e.join("][")}] should be a primitive, but is an array of ${n.length} elements`), b(n.length === t[0], () => `Element arr[${e.join("][")}] should have ${t[0]} elements, but has ${n.length} elements`);
|
|
1877
1865
|
const s = t.slice(1);
|
|
1878
1866
|
for (let r = 0; r < n.length; ++r)
|
|
1879
|
-
|
|
1867
|
+
de(n[r], s, e.concat(r));
|
|
1880
1868
|
}
|
|
1881
|
-
function
|
|
1869
|
+
function Vt(n, t, e, s) {
|
|
1882
1870
|
if (n !== "string_or_numeric") {
|
|
1883
1871
|
if (n == null)
|
|
1884
1872
|
throw new Error("Expected dtype cannot be null.");
|
|
@@ -1887,16 +1875,16 @@ function Wt(n, t, e, s) {
|
|
|
1887
1875
|
}
|
|
1888
1876
|
}
|
|
1889
1877
|
function k(n, t, e, s = "numeric") {
|
|
1890
|
-
if (n instanceof
|
|
1891
|
-
return
|
|
1892
|
-
let r =
|
|
1893
|
-
if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s),
|
|
1878
|
+
if (n instanceof ae())
|
|
1879
|
+
return Vt(s, n.dtype, t, e), n;
|
|
1880
|
+
let r = gt(n);
|
|
1881
|
+
if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s), Vt(s, r, t, e), n == null || !R(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string") {
|
|
1894
1882
|
const c = n == null ? "null" : n.constructor.name;
|
|
1895
1883
|
throw new Error(`Argument '${t}' passed to '${e}' must be a Tensor or TensorLike, but got '${c}'`);
|
|
1896
1884
|
}
|
|
1897
|
-
const i =
|
|
1885
|
+
const i = pn(n, r);
|
|
1898
1886
|
!R(n) && !Array.isArray(n) && (n = [n]);
|
|
1899
|
-
const a = r !== "string" ?
|
|
1887
|
+
const a = r !== "string" ? ie(n, r) : at(n, [], !0);
|
|
1900
1888
|
return g.makeTensor(a, i, r);
|
|
1901
1889
|
}
|
|
1902
1890
|
/**
|
|
@@ -1915,19 +1903,19 @@ function k(n, t, e, s = "numeric") {
|
|
|
1915
1903
|
* limitations under the License.
|
|
1916
1904
|
* =============================================================================
|
|
1917
1905
|
*/
|
|
1918
|
-
const
|
|
1906
|
+
const yn = "__op";
|
|
1919
1907
|
function F(n) {
|
|
1920
1908
|
const t = Object.keys(n);
|
|
1921
1909
|
if (t.length !== 1)
|
|
1922
1910
|
throw new Error(`Please provide an object with a single key (operation name) mapping to a function. Got an object with ${t.length} keys.`);
|
|
1923
1911
|
let e = t[0];
|
|
1924
1912
|
const s = n[e];
|
|
1925
|
-
e.endsWith("_") && (e = e.substring(0, e.length - 1)), e = e +
|
|
1913
|
+
e.endsWith("_") && (e = e.substring(0, e.length - 1)), e = e + yn;
|
|
1926
1914
|
const r = (...i) => {
|
|
1927
1915
|
g.startScope(e);
|
|
1928
1916
|
try {
|
|
1929
1917
|
const o = s(...i);
|
|
1930
|
-
return
|
|
1918
|
+
return Nt(o) && console.error("Cannot return a Promise inside of tidy."), g.endScope(o), o;
|
|
1931
1919
|
} catch (o) {
|
|
1932
1920
|
throw g.endScope(null), o;
|
|
1933
1921
|
}
|
|
@@ -1950,12 +1938,12 @@ function F(n) {
|
|
|
1950
1938
|
* limitations under the License.
|
|
1951
1939
|
* =============================================================================
|
|
1952
1940
|
*/
|
|
1953
|
-
function
|
|
1941
|
+
function bn(n, t, e, s) {
|
|
1954
1942
|
if (s == null)
|
|
1955
|
-
s =
|
|
1943
|
+
s = gt(n);
|
|
1956
1944
|
else if (s === "complex64")
|
|
1957
1945
|
throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");
|
|
1958
|
-
if (
|
|
1946
|
+
if (ce(n) || le(n)) {
|
|
1959
1947
|
if (s !== "float32" && s !== "int32")
|
|
1960
1948
|
throw new Error(`Creating tensor from GPU data only supports 'float32'|'int32' dtype, while the dtype is ${s}.`);
|
|
1961
1949
|
return g.backend.createTensorFromGPUData(n, t || e, s);
|
|
@@ -1963,7 +1951,7 @@ function dn(n, t, e, s) {
|
|
|
1963
1951
|
if (!R(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string")
|
|
1964
1952
|
throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");
|
|
1965
1953
|
if (t != null) {
|
|
1966
|
-
|
|
1954
|
+
$t(t);
|
|
1967
1955
|
const r = U(t), i = U(e);
|
|
1968
1956
|
b(r === i, () => `Based on the provided shape, [${t}], the tensor should have ${r} values but has ${i}`);
|
|
1969
1957
|
for (let o = 0; o < e.length; ++o) {
|
|
@@ -1971,9 +1959,9 @@ function dn(n, t, e, s) {
|
|
|
1971
1959
|
b(e[o] === t[o] || !c, () => `Error creating a new Tensor. Inferred shape (${e}) does not match the provided shape (${t}). `);
|
|
1972
1960
|
}
|
|
1973
1961
|
}
|
|
1974
|
-
return !R(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = s !== "string" ?
|
|
1962
|
+
return !R(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = s !== "string" ? ie(n, s) : at(n, [], !0), g.makeTensor(n, t, s);
|
|
1975
1963
|
}
|
|
1976
|
-
class
|
|
1964
|
+
class lt {
|
|
1977
1965
|
/**
|
|
1978
1966
|
* Concatenate a number of ArrayBuffers into one.
|
|
1979
1967
|
*
|
|
@@ -1982,7 +1970,7 @@ class at {
|
|
|
1982
1970
|
* @returns Result of concatenating `buffers` in order.
|
|
1983
1971
|
*/
|
|
1984
1972
|
static join(t) {
|
|
1985
|
-
return new
|
|
1973
|
+
return new lt(t).slice();
|
|
1986
1974
|
}
|
|
1987
1975
|
constructor(t) {
|
|
1988
1976
|
if (this.shards = [], this.previousShardIndex = 0, t == null || (t instanceof Array || (t = [t]), t = t.map((s) => R(s) ? s.buffer : s), t.length === 0))
|
|
@@ -2027,11 +2015,11 @@ class at {
|
|
|
2027
2015
|
}
|
|
2028
2016
|
if (e(this.shards[this.previousShardIndex]) === 0)
|
|
2029
2017
|
return this.previousShardIndex;
|
|
2030
|
-
const s =
|
|
2018
|
+
const s = wn(this.shards, e);
|
|
2031
2019
|
return s === -1 ? -1 : (this.previousShardIndex = s, this.previousShardIndex);
|
|
2032
2020
|
}
|
|
2033
2021
|
}
|
|
2034
|
-
function
|
|
2022
|
+
function wn(n, t) {
|
|
2035
2023
|
let e = 0, s = n.length;
|
|
2036
2024
|
for (; e <= s; ) {
|
|
2037
2025
|
const r = Math.floor((s - e) / 2) + e, i = t(n[r]);
|
|
@@ -2057,50 +2045,34 @@ function gn(n, t) {
|
|
|
2057
2045
|
* limitations under the License.
|
|
2058
2046
|
* =============================================================================
|
|
2059
2047
|
*/
|
|
2060
|
-
function
|
|
2048
|
+
function Os() {
|
|
2061
2049
|
return g;
|
|
2062
2050
|
}
|
|
2063
2051
|
function E(n, t) {
|
|
2064
2052
|
return g.tidy(n, t);
|
|
2065
2053
|
}
|
|
2066
2054
|
function M(n) {
|
|
2067
|
-
|
|
2055
|
+
ue(n).forEach((e) => e.dispose());
|
|
2068
2056
|
}
|
|
2069
|
-
function
|
|
2057
|
+
function Sn(n) {
|
|
2070
2058
|
return g.keep(n);
|
|
2071
2059
|
}
|
|
2072
|
-
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
2076
|
-
* you may not use this file except in compliance with the License.
|
|
2077
|
-
* You may obtain a copy of the License at
|
|
2078
|
-
*
|
|
2079
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
2080
|
-
*
|
|
2081
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
2082
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
2083
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
2084
|
-
* See the License for the specific language governing permissions and
|
|
2085
|
-
* limitations under the License.
|
|
2086
|
-
* =============================================================================
|
|
2087
|
-
*/
|
|
2088
|
-
const $t = typeof Buffer < "u" && (typeof Blob > "u" || typeof atob > "u" || typeof btoa > "u");
|
|
2089
|
-
function jt(n) {
|
|
2090
|
-
return $t ? Buffer.byteLength(n, "utf8") : new Blob([n]).size;
|
|
2060
|
+
const Ct = typeof dt < "u" && (typeof Blob > "u" || typeof atob > "u" || typeof btoa > "u");
|
|
2061
|
+
function qt(n) {
|
|
2062
|
+
return Ct ? dt.byteLength(n, "utf8") : new Blob([n]).size;
|
|
2091
2063
|
}
|
|
2092
|
-
function
|
|
2093
|
-
if (
|
|
2094
|
-
return
|
|
2064
|
+
function In(n) {
|
|
2065
|
+
if (Ct)
|
|
2066
|
+
return dt.from(n).toString("base64");
|
|
2095
2067
|
const t = new Uint8Array(n);
|
|
2096
2068
|
let e = "";
|
|
2097
2069
|
for (let s = 0, r = t.length; s < r; s++)
|
|
2098
2070
|
e += String.fromCharCode(t[s]);
|
|
2099
2071
|
return btoa(e);
|
|
2100
2072
|
}
|
|
2101
|
-
function
|
|
2102
|
-
if (
|
|
2103
|
-
const s =
|
|
2073
|
+
function kn(n) {
|
|
2074
|
+
if (Ct) {
|
|
2075
|
+
const s = dt.from(n, "base64");
|
|
2104
2076
|
return s.buffer.slice(s.byteOffset, s.byteOffset + s.byteLength);
|
|
2105
2077
|
}
|
|
2106
2078
|
const t = atob(n), e = new Uint8Array(t.length);
|
|
@@ -2108,15 +2080,15 @@ function yn(n) {
|
|
|
2108
2080
|
e.set([t.charCodeAt(s)], s);
|
|
2109
2081
|
return e.buffer;
|
|
2110
2082
|
}
|
|
2111
|
-
function
|
|
2083
|
+
function ge(n) {
|
|
2112
2084
|
if (n.modelTopology instanceof ArrayBuffer)
|
|
2113
2085
|
throw new Error("Expected JSON model topology, received ArrayBuffer.");
|
|
2114
2086
|
return {
|
|
2115
2087
|
dateSaved: /* @__PURE__ */ new Date(),
|
|
2116
2088
|
modelTopologyType: "JSON",
|
|
2117
|
-
modelTopologyBytes: n.modelTopology == null ? 0 :
|
|
2118
|
-
weightSpecsBytes: n.weightSpecs == null ? 0 :
|
|
2119
|
-
weightDataBytes: n.weightData == null ? 0 : new
|
|
2089
|
+
modelTopologyBytes: n.modelTopology == null ? 0 : qt(JSON.stringify(n.modelTopology)),
|
|
2090
|
+
weightSpecsBytes: n.weightSpecs == null ? 0 : qt(JSON.stringify(n.weightSpecs)),
|
|
2091
|
+
weightDataBytes: n.weightData == null ? 0 : new lt(n.weightData).byteLength
|
|
2120
2092
|
};
|
|
2121
2093
|
}
|
|
2122
2094
|
/**
|
|
@@ -2206,8 +2178,8 @@ class B {
|
|
|
2206
2178
|
* limitations under the License.
|
|
2207
2179
|
* =============================================================================
|
|
2208
2180
|
*/
|
|
2209
|
-
const
|
|
2210
|
-
function
|
|
2181
|
+
const Bt = "tensorflowjs", At = 1, L = "models_store", P = "model_info_store";
|
|
2182
|
+
function me() {
|
|
2211
2183
|
if (!S().getBool("IS_BROWSER"))
|
|
2212
2184
|
throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");
|
|
2213
2185
|
const n = typeof window > "u" ? self : window, t = n.indexedDB || n.mozIndexedDB || n.webkitIndexedDB || n.msIndexedDB || n.shimIndexedDB;
|
|
@@ -2215,13 +2187,13 @@ function fe() {
|
|
|
2215
2187
|
throw new Error("The current browser does not appear to support IndexedDB.");
|
|
2216
2188
|
return t;
|
|
2217
2189
|
}
|
|
2218
|
-
function
|
|
2190
|
+
function vt(n) {
|
|
2219
2191
|
const t = n.result;
|
|
2220
|
-
t.createObjectStore(L, { keyPath: "modelPath" }), t.createObjectStore(
|
|
2192
|
+
t.createObjectStore(L, { keyPath: "modelPath" }), t.createObjectStore(P, { keyPath: "modelPath" });
|
|
2221
2193
|
}
|
|
2222
2194
|
class z {
|
|
2223
2195
|
constructor(t) {
|
|
2224
|
-
if (this.indexedDB =
|
|
2196
|
+
if (this.indexedDB = me(), t == null || !t)
|
|
2225
2197
|
throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
|
|
2226
2198
|
this.modelPath = t;
|
|
2227
2199
|
}
|
|
@@ -2249,8 +2221,8 @@ class z {
|
|
|
2249
2221
|
*/
|
|
2250
2222
|
databaseAction(t, e) {
|
|
2251
2223
|
return new Promise((s, r) => {
|
|
2252
|
-
const i = this.indexedDB.open(
|
|
2253
|
-
i.onupgradeneeded = () =>
|
|
2224
|
+
const i = this.indexedDB.open(Bt, At);
|
|
2225
|
+
i.onupgradeneeded = () => vt(i), i.onsuccess = () => {
|
|
2254
2226
|
const o = i.result;
|
|
2255
2227
|
if (e == null) {
|
|
2256
2228
|
const a = o.transaction(L, "readonly"), l = a.objectStore(L).get(this.modelPath);
|
|
@@ -2260,9 +2232,9 @@ class z {
|
|
|
2260
2232
|
s(l.result.modelArtifacts);
|
|
2261
2233
|
}, l.onerror = (u) => (o.close(), r(l.error)), a.oncomplete = () => o.close();
|
|
2262
2234
|
} else {
|
|
2263
|
-
e.weightData =
|
|
2264
|
-
const a =
|
|
2265
|
-
let l = c.objectStore(
|
|
2235
|
+
e.weightData = lt.join(e.weightData);
|
|
2236
|
+
const a = ge(e), c = o.transaction(P, "readwrite");
|
|
2237
|
+
let l = c.objectStore(P), u;
|
|
2266
2238
|
try {
|
|
2267
2239
|
u = l.put({ modelPath: this.modelPath, modelArtifactsInfo: a });
|
|
2268
2240
|
} catch (f) {
|
|
@@ -2283,7 +2255,7 @@ class z {
|
|
|
2283
2255
|
return r(y);
|
|
2284
2256
|
}
|
|
2285
2257
|
m.onsuccess = () => s({ modelArtifactsInfo: a }), m.onerror = (y) => {
|
|
2286
|
-
l = c.objectStore(
|
|
2258
|
+
l = c.objectStore(P);
|
|
2287
2259
|
const d = l.delete(this.modelPath);
|
|
2288
2260
|
d.onsuccess = () => (o.close(), r(m.error)), d.onerror = (I) => (o.close(), r(m.error));
|
|
2289
2261
|
};
|
|
@@ -2296,24 +2268,24 @@ class z {
|
|
|
2296
2268
|
}
|
|
2297
2269
|
}
|
|
2298
2270
|
z.URL_SCHEME = "indexeddb://";
|
|
2299
|
-
const
|
|
2300
|
-
B.registerSaveRouter(
|
|
2301
|
-
B.registerLoadRouter(
|
|
2302
|
-
function
|
|
2271
|
+
const pe = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(z.URL_SCHEME) ? Tn(n.slice(z.URL_SCHEME.length)) : null;
|
|
2272
|
+
B.registerSaveRouter(pe);
|
|
2273
|
+
B.registerLoadRouter(pe);
|
|
2274
|
+
function Tn(n) {
|
|
2303
2275
|
return new z(n);
|
|
2304
2276
|
}
|
|
2305
|
-
function
|
|
2277
|
+
function En(n) {
|
|
2306
2278
|
return n.startsWith(z.URL_SCHEME) ? n.slice(z.URL_SCHEME.length) : n;
|
|
2307
2279
|
}
|
|
2308
|
-
class
|
|
2280
|
+
class Bn {
|
|
2309
2281
|
constructor() {
|
|
2310
|
-
this.indexedDB =
|
|
2282
|
+
this.indexedDB = me();
|
|
2311
2283
|
}
|
|
2312
2284
|
async listModels() {
|
|
2313
2285
|
return new Promise((t, e) => {
|
|
2314
|
-
const s = this.indexedDB.open(
|
|
2315
|
-
s.onupgradeneeded = () =>
|
|
2316
|
-
const r = s.result, i = r.transaction(
|
|
2286
|
+
const s = this.indexedDB.open(Bt, At);
|
|
2287
|
+
s.onupgradeneeded = () => vt(s), s.onsuccess = () => {
|
|
2288
|
+
const r = s.result, i = r.transaction(P, "readonly"), a = i.objectStore(P).getAll();
|
|
2317
2289
|
a.onsuccess = () => {
|
|
2318
2290
|
const c = {};
|
|
2319
2291
|
for (const l of a.result)
|
|
@@ -2324,10 +2296,10 @@ class Sn {
|
|
|
2324
2296
|
});
|
|
2325
2297
|
}
|
|
2326
2298
|
async removeModel(t) {
|
|
2327
|
-
return t =
|
|
2328
|
-
const r = this.indexedDB.open(
|
|
2329
|
-
r.onupgradeneeded = () =>
|
|
2330
|
-
const i = r.result, o = i.transaction(
|
|
2299
|
+
return t = En(t), new Promise((e, s) => {
|
|
2300
|
+
const r = this.indexedDB.open(Bt, At);
|
|
2301
|
+
r.onupgradeneeded = () => vt(r), r.onsuccess = () => {
|
|
2302
|
+
const i = r.result, o = i.transaction(P, "readwrite"), a = o.objectStore(P), c = a.get(t);
|
|
2331
2303
|
let l;
|
|
2332
2304
|
c.onsuccess = () => {
|
|
2333
2305
|
if (c.result == null)
|
|
@@ -2363,27 +2335,27 @@ class Sn {
|
|
|
2363
2335
|
* limitations under the License.
|
|
2364
2336
|
* =============================================================================
|
|
2365
2337
|
*/
|
|
2366
|
-
const _ = "/", X = "tensorflowjs_models",
|
|
2367
|
-
function
|
|
2338
|
+
const _ = "/", X = "tensorflowjs_models", ye = "info", An = "model_topology", vn = "weight_specs", Mn = "weight_data", Fn = "model_metadata";
|
|
2339
|
+
function be(n) {
|
|
2368
2340
|
return {
|
|
2369
|
-
info: [X, n,
|
|
2370
|
-
topology: [X, n,
|
|
2371
|
-
weightSpecs: [X, n,
|
|
2372
|
-
weightData: [X, n,
|
|
2373
|
-
modelMetadata: [X, n,
|
|
2341
|
+
info: [X, n, ye].join(_),
|
|
2342
|
+
topology: [X, n, An].join(_),
|
|
2343
|
+
weightSpecs: [X, n, vn].join(_),
|
|
2344
|
+
weightData: [X, n, Mn].join(_),
|
|
2345
|
+
modelMetadata: [X, n, Fn].join(_)
|
|
2374
2346
|
};
|
|
2375
2347
|
}
|
|
2376
|
-
function
|
|
2348
|
+
function we(n) {
|
|
2377
2349
|
for (const t of Object.values(n))
|
|
2378
2350
|
window.localStorage.removeItem(t);
|
|
2379
2351
|
}
|
|
2380
|
-
function
|
|
2352
|
+
function Rn(n) {
|
|
2381
2353
|
const t = n.split(_);
|
|
2382
2354
|
if (t.length < 3)
|
|
2383
2355
|
throw new Error(`Invalid key format: ${n}`);
|
|
2384
2356
|
return t.slice(1, t.length - 1).join(_);
|
|
2385
2357
|
}
|
|
2386
|
-
function
|
|
2358
|
+
function xn(n) {
|
|
2387
2359
|
return n.startsWith(W.URL_SCHEME) ? n.slice(W.URL_SCHEME.length) : n;
|
|
2388
2360
|
}
|
|
2389
2361
|
class W {
|
|
@@ -2392,7 +2364,7 @@ class W {
|
|
|
2392
2364
|
throw new Error("The current environment does not support local storage.");
|
|
2393
2365
|
if (this.LS = window.localStorage, t == null || !t)
|
|
2394
2366
|
throw new Error("For local storage, modelPath must not be null, undefined or empty.");
|
|
2395
|
-
this.modelPath = t, this.keys =
|
|
2367
|
+
this.modelPath = t, this.keys = be(this.modelPath);
|
|
2396
2368
|
}
|
|
2397
2369
|
/**
|
|
2398
2370
|
* Save model artifacts to browser local storage.
|
|
@@ -2407,9 +2379,9 @@ class W {
|
|
|
2407
2379
|
if (t.modelTopology instanceof ArrayBuffer)
|
|
2408
2380
|
throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
|
|
2409
2381
|
{
|
|
2410
|
-
const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r =
|
|
2382
|
+
const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r = ge(t), i = lt.join(t.weightData);
|
|
2411
2383
|
try {
|
|
2412
|
-
this.LS.setItem(this.keys.info, JSON.stringify(r)), this.LS.setItem(this.keys.topology, e), this.LS.setItem(this.keys.weightSpecs, s), this.LS.setItem(this.keys.weightData,
|
|
2384
|
+
this.LS.setItem(this.keys.info, JSON.stringify(r)), this.LS.setItem(this.keys.topology, e), this.LS.setItem(this.keys.weightSpecs, s), this.LS.setItem(this.keys.weightData, In(i));
|
|
2413
2385
|
const o = {
|
|
2414
2386
|
format: t.format,
|
|
2415
2387
|
generatedBy: t.generatedBy,
|
|
@@ -2422,7 +2394,7 @@ class W {
|
|
|
2422
2394
|
};
|
|
2423
2395
|
return this.LS.setItem(this.keys.modelMetadata, JSON.stringify(o)), { modelArtifactsInfo: r };
|
|
2424
2396
|
} catch {
|
|
2425
|
-
throw
|
|
2397
|
+
throw we(this.keys), new Error(`Failed to save model '${this.modelPath}' to local storage: size quota being exceeded is a possible cause of this failure: modelTopologyBytes=${r.modelTopologyBytes}, weightSpecsBytes=${r.weightSpecsBytes}, weightDataBytes=${r.weightDataBytes}.`);
|
|
2426
2398
|
}
|
|
2427
2399
|
}
|
|
2428
2400
|
}
|
|
@@ -2456,38 +2428,38 @@ class W {
|
|
|
2456
2428
|
const o = this.LS.getItem(this.keys.weightData);
|
|
2457
2429
|
if (o == null)
|
|
2458
2430
|
throw new Error(`In local storage, the binary weight values of model '${this.modelPath}' are missing.`);
|
|
2459
|
-
return e.weightData =
|
|
2431
|
+
return e.weightData = kn(o), e;
|
|
2460
2432
|
}
|
|
2461
2433
|
}
|
|
2462
2434
|
W.URL_SCHEME = "localstorage://";
|
|
2463
|
-
const
|
|
2464
|
-
B.registerSaveRouter(
|
|
2465
|
-
B.registerLoadRouter(
|
|
2466
|
-
function
|
|
2435
|
+
const Se = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(W.URL_SCHEME) ? $n(n.slice(W.URL_SCHEME.length)) : null;
|
|
2436
|
+
B.registerSaveRouter(Se);
|
|
2437
|
+
B.registerLoadRouter(Se);
|
|
2438
|
+
function $n(n) {
|
|
2467
2439
|
return new W(n);
|
|
2468
2440
|
}
|
|
2469
|
-
class
|
|
2441
|
+
class Nn {
|
|
2470
2442
|
constructor() {
|
|
2471
2443
|
b(S().getBool("IS_BROWSER"), () => "Current environment is not a web browser"), b(typeof window > "u" || typeof window.localStorage < "u", () => "Current browser does not appear to support localStorage"), this.LS = window.localStorage;
|
|
2472
2444
|
}
|
|
2473
2445
|
async listModels() {
|
|
2474
|
-
const t = {}, e = X + _, s = _ +
|
|
2446
|
+
const t = {}, e = X + _, s = _ + ye;
|
|
2475
2447
|
for (let r = 0; r < this.LS.length; ++r) {
|
|
2476
2448
|
const i = this.LS.key(r);
|
|
2477
2449
|
if (i.startsWith(e) && i.endsWith(s)) {
|
|
2478
|
-
const o =
|
|
2450
|
+
const o = Rn(i);
|
|
2479
2451
|
t[o] = JSON.parse(this.LS.getItem(i));
|
|
2480
2452
|
}
|
|
2481
2453
|
}
|
|
2482
2454
|
return t;
|
|
2483
2455
|
}
|
|
2484
2456
|
async removeModel(t) {
|
|
2485
|
-
t =
|
|
2486
|
-
const e =
|
|
2457
|
+
t = xn(t);
|
|
2458
|
+
const e = be(t);
|
|
2487
2459
|
if (this.LS.getItem(e.info) == null)
|
|
2488
2460
|
throw new Error(`Cannot find model at path '${t}'`);
|
|
2489
2461
|
const s = JSON.parse(this.LS.getItem(e.info));
|
|
2490
|
-
return
|
|
2462
|
+
return we(e), s;
|
|
2491
2463
|
}
|
|
2492
2464
|
}
|
|
2493
2465
|
/**
|
|
@@ -2506,13 +2478,13 @@ class Mn {
|
|
|
2506
2478
|
* limitations under the License.
|
|
2507
2479
|
* =============================================================================
|
|
2508
2480
|
*/
|
|
2509
|
-
const
|
|
2510
|
-
class
|
|
2481
|
+
const Ht = "://";
|
|
2482
|
+
class N {
|
|
2511
2483
|
constructor() {
|
|
2512
2484
|
this.managers = {};
|
|
2513
2485
|
}
|
|
2514
2486
|
static getInstance() {
|
|
2515
|
-
return
|
|
2487
|
+
return N.instance == null && (N.instance = new N()), N.instance;
|
|
2516
2488
|
}
|
|
2517
2489
|
/**
|
|
2518
2490
|
* Register a save-handler router.
|
|
@@ -2521,18 +2493,18 @@ class $ {
|
|
|
2521
2493
|
* of `IOHandler` with the `save` method defined or `null`.
|
|
2522
2494
|
*/
|
|
2523
2495
|
static registerManager(t, e) {
|
|
2524
|
-
b(t != null, () => "scheme must not be undefined or null."), t.endsWith(
|
|
2525
|
-
const s =
|
|
2496
|
+
b(t != null, () => "scheme must not be undefined or null."), t.endsWith(Ht) && (t = t.slice(0, t.indexOf(Ht))), b(t.length > 0, () => "scheme must not be an empty string.");
|
|
2497
|
+
const s = N.getInstance();
|
|
2526
2498
|
b(s.managers[t] == null, () => `A model store manager is already registered for scheme '${t}'.`), s.managers[t] = e;
|
|
2527
2499
|
}
|
|
2528
2500
|
static getManager(t) {
|
|
2529
|
-
const e =
|
|
2501
|
+
const e = N.getInstance().managers[t];
|
|
2530
2502
|
if (e == null)
|
|
2531
2503
|
throw new Error(`Cannot find model manager for scheme '${t}'`);
|
|
2532
2504
|
return e;
|
|
2533
2505
|
}
|
|
2534
2506
|
static getSchemes() {
|
|
2535
|
-
return Object.keys(
|
|
2507
|
+
return Object.keys(N.getInstance().managers);
|
|
2536
2508
|
}
|
|
2537
2509
|
}
|
|
2538
2510
|
/**
|
|
@@ -2551,7 +2523,7 @@ class $ {
|
|
|
2551
2523
|
* limitations under the License.
|
|
2552
2524
|
* =============================================================================
|
|
2553
2525
|
*/
|
|
2554
|
-
class
|
|
2526
|
+
class Dn {
|
|
2555
2527
|
constructor() {
|
|
2556
2528
|
this.messageName = "setTimeoutCustom", this.functionRefs = [], this.handledMessageCount = 0, this.hasEventListener = !1;
|
|
2557
2529
|
}
|
|
@@ -2589,50 +2561,34 @@ class Fn {
|
|
|
2589
2561
|
}, !0));
|
|
2590
2562
|
}
|
|
2591
2563
|
isTypedArray(t) {
|
|
2592
|
-
return
|
|
2564
|
+
return re(t);
|
|
2593
2565
|
}
|
|
2594
2566
|
}
|
|
2595
2567
|
if (S().get("IS_BROWSER")) {
|
|
2596
|
-
S().setPlatform("browser", new
|
|
2568
|
+
S().setPlatform("browser", new Dn());
|
|
2597
2569
|
try {
|
|
2598
|
-
|
|
2570
|
+
N.registerManager(W.URL_SCHEME, new Nn());
|
|
2599
2571
|
} catch {
|
|
2600
2572
|
}
|
|
2601
2573
|
try {
|
|
2602
|
-
|
|
2574
|
+
N.registerManager(z.URL_SCHEME, new Bn());
|
|
2603
2575
|
} catch {
|
|
2604
2576
|
}
|
|
2605
2577
|
}
|
|
2606
|
-
|
|
2607
|
-
* @license
|
|
2608
|
-
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
2609
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
2610
|
-
* you may not use this file except in compliance with the License.
|
|
2611
|
-
* You may obtain a copy of the License at
|
|
2612
|
-
*
|
|
2613
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
2614
|
-
*
|
|
2615
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
2616
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
2617
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
2618
|
-
* See the License for the specific language governing permissions and
|
|
2619
|
-
* limitations under the License.
|
|
2620
|
-
* =============================================================================
|
|
2621
|
-
*/
|
|
2622
|
-
const Rn = {
|
|
2578
|
+
const Cn = {
|
|
2623
2579
|
// tslint:disable-next-line:no-require-imports
|
|
2624
2580
|
importFetch: () => require("node-fetch")
|
|
2625
2581
|
};
|
|
2626
|
-
let
|
|
2627
|
-
class
|
|
2582
|
+
let bt;
|
|
2583
|
+
class _n {
|
|
2628
2584
|
constructor() {
|
|
2629
2585
|
this.util = require("util"), this.textEncoder = new this.util.TextEncoder();
|
|
2630
2586
|
}
|
|
2631
2587
|
fetch(t, e) {
|
|
2632
|
-
return S().global.fetch != null ? S().global.fetch(t, e) : (
|
|
2588
|
+
return S().global.fetch != null ? S().global.fetch(t, e) : (bt == null && (bt = Cn.importFetch()), bt(t, e));
|
|
2633
2589
|
}
|
|
2634
2590
|
now() {
|
|
2635
|
-
const t =
|
|
2591
|
+
const t = Y.hrtime();
|
|
2636
2592
|
return t[0] * 1e3 + t[1] / 1e6;
|
|
2637
2593
|
}
|
|
2638
2594
|
encode(t, e) {
|
|
@@ -2647,7 +2603,7 @@ class xn {
|
|
|
2647
2603
|
return this.util.types.isFloat32Array(t) || this.util.types.isInt32Array(t) || this.util.types.isUint8Array(t) || this.util.types.isUint8ClampedArray(t);
|
|
2648
2604
|
}
|
|
2649
2605
|
}
|
|
2650
|
-
S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new
|
|
2606
|
+
S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new _n());
|
|
2651
2607
|
/**
|
|
2652
2608
|
* @license
|
|
2653
2609
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
|
@@ -2664,8 +2620,8 @@ S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new xn()
|
|
|
2664
2620
|
* limitations under the License.
|
|
2665
2621
|
* =============================================================================
|
|
2666
2622
|
*/
|
|
2667
|
-
function
|
|
2668
|
-
return t = t || "float32",
|
|
2623
|
+
function Pn(n, t = "float32", e) {
|
|
2624
|
+
return t = t || "float32", $t(n), new an(n, t, e);
|
|
2669
2625
|
}
|
|
2670
2626
|
/**
|
|
2671
2627
|
* @license
|
|
@@ -2683,16 +2639,16 @@ function Nn(n, t = "float32", e) {
|
|
|
2683
2639
|
* limitations under the License.
|
|
2684
2640
|
* =============================================================================
|
|
2685
2641
|
*/
|
|
2686
|
-
function
|
|
2642
|
+
function On(n, t) {
|
|
2687
2643
|
const e = k(n, "x", "cast");
|
|
2688
|
-
if (!
|
|
2644
|
+
if (!ve(t))
|
|
2689
2645
|
throw new Error(`Failed to cast to unknown dtype ${t}`);
|
|
2690
2646
|
if (t === "string" && e.dtype !== "string" || t !== "string" && e.dtype === "string")
|
|
2691
2647
|
throw new Error("Only strings can be casted to strings");
|
|
2692
2648
|
const s = { x: e }, r = { dtype: t };
|
|
2693
|
-
return g.runKernel(
|
|
2649
|
+
return g.runKernel(ee, s, r);
|
|
2694
2650
|
}
|
|
2695
|
-
const
|
|
2651
|
+
const Mt = /* @__PURE__ */ F({ cast_: On });
|
|
2696
2652
|
/**
|
|
2697
2653
|
* @license
|
|
2698
2654
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2709,11 +2665,11 @@ const At = /* @__PURE__ */ F({ cast_: $n });
|
|
|
2709
2665
|
* limitations under the License.
|
|
2710
2666
|
* =============================================================================
|
|
2711
2667
|
*/
|
|
2712
|
-
function
|
|
2668
|
+
function Ln(n) {
|
|
2713
2669
|
const e = { x: k(n, "x", "clone", "string_or_numeric") };
|
|
2714
|
-
return g.runKernel(
|
|
2670
|
+
return g.runKernel(ne, e);
|
|
2715
2671
|
}
|
|
2716
|
-
const
|
|
2672
|
+
const Un = /* @__PURE__ */ F({ clone_: Ln });
|
|
2717
2673
|
/**
|
|
2718
2674
|
* @license
|
|
2719
2675
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
|
@@ -2730,7 +2686,7 @@ const Cn = /* @__PURE__ */ F({ clone_: Dn });
|
|
|
2730
2686
|
* limitations under the License.
|
|
2731
2687
|
* =============================================================================
|
|
2732
2688
|
*/
|
|
2733
|
-
function
|
|
2689
|
+
function Gn(n, t = !1) {
|
|
2734
2690
|
console.log(n.toString(t));
|
|
2735
2691
|
}
|
|
2736
2692
|
/**
|
|
@@ -2749,14 +2705,14 @@ function _n(n, t = !1) {
|
|
|
2749
2705
|
* limitations under the License.
|
|
2750
2706
|
* =============================================================================
|
|
2751
2707
|
*/
|
|
2752
|
-
|
|
2753
|
-
const
|
|
2754
|
-
buffer:
|
|
2755
|
-
cast:
|
|
2756
|
-
clone:
|
|
2757
|
-
print:
|
|
2708
|
+
fe();
|
|
2709
|
+
const zn = {
|
|
2710
|
+
buffer: Pn,
|
|
2711
|
+
cast: Mt,
|
|
2712
|
+
clone: Un,
|
|
2713
|
+
print: Gn
|
|
2758
2714
|
};
|
|
2759
|
-
|
|
2715
|
+
cn(zn);
|
|
2760
2716
|
/**
|
|
2761
2717
|
* @license
|
|
2762
2718
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2773,13 +2729,13 @@ sn(On);
|
|
|
2773
2729
|
* limitations under the License.
|
|
2774
2730
|
* =============================================================================
|
|
2775
2731
|
*/
|
|
2776
|
-
function
|
|
2732
|
+
function Wn(n, t) {
|
|
2777
2733
|
let e = k(n, "a", "add"), s = k(t, "b", "add");
|
|
2778
2734
|
[e, s] = K(e, s);
|
|
2779
2735
|
const r = { a: e, b: s };
|
|
2780
|
-
return g.runKernel(
|
|
2736
|
+
return g.runKernel(te, r);
|
|
2781
2737
|
}
|
|
2782
|
-
const w = /* @__PURE__ */ F({ add_:
|
|
2738
|
+
const w = /* @__PURE__ */ F({ add_: Wn });
|
|
2783
2739
|
/**
|
|
2784
2740
|
* @license
|
|
2785
2741
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2796,13 +2752,13 @@ const w = /* @__PURE__ */ F({ add_: Pn });
|
|
|
2796
2752
|
* limitations under the License.
|
|
2797
2753
|
* =============================================================================
|
|
2798
2754
|
*/
|
|
2799
|
-
function
|
|
2755
|
+
function jn(n, t) {
|
|
2800
2756
|
let e = k(n, "a", "floorDiv"), s = k(t, "b", "floorDiv");
|
|
2801
2757
|
[e, s] = K(e, s);
|
|
2802
2758
|
const r = { a: e, b: s };
|
|
2803
|
-
return g.runKernel(
|
|
2759
|
+
return g.runKernel(ze, r);
|
|
2804
2760
|
}
|
|
2805
|
-
const
|
|
2761
|
+
const Kn = /* @__PURE__ */ F({ floorDiv_: jn });
|
|
2806
2762
|
/**
|
|
2807
2763
|
* @license
|
|
2808
2764
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2819,14 +2775,14 @@ const Un = /* @__PURE__ */ F({ floorDiv_: Ln });
|
|
|
2819
2775
|
* limitations under the License.
|
|
2820
2776
|
* =============================================================================
|
|
2821
2777
|
*/
|
|
2822
|
-
function
|
|
2778
|
+
function Vn(n, t) {
|
|
2823
2779
|
let e = k(n, "a", "div"), s = k(t, "b", "div");
|
|
2824
2780
|
if ([e, s] = K(e, s), e.dtype === "int32" && s.dtype === "int32")
|
|
2825
|
-
return
|
|
2781
|
+
return Kn(e, s);
|
|
2826
2782
|
const r = { a: e, b: s }, i = {};
|
|
2827
|
-
return g.runKernel(
|
|
2783
|
+
return g.runKernel(Ue, r, i);
|
|
2828
2784
|
}
|
|
2829
|
-
const D = /* @__PURE__ */ F({ div_:
|
|
2785
|
+
const D = /* @__PURE__ */ F({ div_: Vn });
|
|
2830
2786
|
/**
|
|
2831
2787
|
* @license
|
|
2832
2788
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2843,13 +2799,13 @@ const D = /* @__PURE__ */ F({ div_: Gn });
|
|
|
2843
2799
|
* limitations under the License.
|
|
2844
2800
|
* =============================================================================
|
|
2845
2801
|
*/
|
|
2846
|
-
function
|
|
2802
|
+
function qn(n, t) {
|
|
2847
2803
|
let e = k(n, "a", "mul"), s = k(t, "b", "mul");
|
|
2848
2804
|
[e, s] = K(e, s);
|
|
2849
2805
|
const r = { a: e, b: s };
|
|
2850
|
-
return g.runKernel(
|
|
2806
|
+
return g.runKernel(je, r);
|
|
2851
2807
|
}
|
|
2852
|
-
const p = /* @__PURE__ */ F({ mul_:
|
|
2808
|
+
const p = /* @__PURE__ */ F({ mul_: qn });
|
|
2853
2809
|
/**
|
|
2854
2810
|
* @license
|
|
2855
2811
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -2866,17 +2822,17 @@ const p = /* @__PURE__ */ F({ mul_: zn });
|
|
|
2866
2822
|
* limitations under the License.
|
|
2867
2823
|
* =============================================================================
|
|
2868
2824
|
*/
|
|
2869
|
-
function
|
|
2825
|
+
function Hn(n) {
|
|
2870
2826
|
const t = k(n, "x", "abs");
|
|
2871
2827
|
if (t.dtype === "complex64") {
|
|
2872
2828
|
const e = { x: t };
|
|
2873
|
-
return g.runKernel(
|
|
2829
|
+
return g.runKernel(Le, e);
|
|
2874
2830
|
} else {
|
|
2875
2831
|
const e = { x: t };
|
|
2876
|
-
return g.runKernel(
|
|
2832
|
+
return g.runKernel(Oe, e);
|
|
2877
2833
|
}
|
|
2878
2834
|
}
|
|
2879
|
-
const
|
|
2835
|
+
const Jn = /* @__PURE__ */ F({ abs_: Hn });
|
|
2880
2836
|
/**
|
|
2881
2837
|
* @license
|
|
2882
2838
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2893,10 +2849,10 @@ const jn = /* @__PURE__ */ F({ abs_: Wn });
|
|
|
2893
2849
|
* limitations under the License.
|
|
2894
2850
|
* =============================================================================
|
|
2895
2851
|
*/
|
|
2896
|
-
function
|
|
2897
|
-
|
|
2852
|
+
function Xn(n, t, e) {
|
|
2853
|
+
$t(n), e = e || gt(t);
|
|
2898
2854
|
const s = { shape: n, value: t, dtype: e };
|
|
2899
|
-
return g.runKernel(
|
|
2855
|
+
return g.runKernel(Ge, {}, s);
|
|
2900
2856
|
}
|
|
2901
2857
|
/**
|
|
2902
2858
|
* @license
|
|
@@ -2914,7 +2870,7 @@ function Kn(n, t, e) {
|
|
|
2914
2870
|
* limitations under the License.
|
|
2915
2871
|
* =============================================================================
|
|
2916
2872
|
*/
|
|
2917
|
-
function
|
|
2873
|
+
function Ls(n, t) {
|
|
2918
2874
|
const e = [];
|
|
2919
2875
|
for (let s = 0; s < t.length; s++) {
|
|
2920
2876
|
const r = n[n.length - s - 1], i = t.length - s - 1, o = t[i];
|
|
@@ -2922,7 +2878,7 @@ function xs(n, t) {
|
|
|
2922
2878
|
}
|
|
2923
2879
|
return e;
|
|
2924
2880
|
}
|
|
2925
|
-
function
|
|
2881
|
+
function Yn(n, t) {
|
|
2926
2882
|
const e = Math.max(n.length, t.length), s = new Array(e);
|
|
2927
2883
|
for (let r = 0; r < e; r++) {
|
|
2928
2884
|
let i = n[n.length - r - 1];
|
|
@@ -2956,11 +2912,11 @@ function Vn(n, t) {
|
|
|
2956
2912
|
* limitations under the License.
|
|
2957
2913
|
* =============================================================================
|
|
2958
2914
|
*/
|
|
2959
|
-
function
|
|
2915
|
+
function Qn(n) {
|
|
2960
2916
|
const e = { x: k(n, "x", "zerosLike") };
|
|
2961
|
-
return g.runKernel(
|
|
2917
|
+
return g.runKernel(He, e);
|
|
2962
2918
|
}
|
|
2963
|
-
const C = /* @__PURE__ */ F({ zerosLike_:
|
|
2919
|
+
const C = /* @__PURE__ */ F({ zerosLike_: Qn });
|
|
2964
2920
|
/**
|
|
2965
2921
|
* @license
|
|
2966
2922
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -2977,13 +2933,13 @@ const C = /* @__PURE__ */ F({ zerosLike_: qn });
|
|
|
2977
2933
|
* limitations under the License.
|
|
2978
2934
|
* =============================================================================
|
|
2979
2935
|
*/
|
|
2980
|
-
function
|
|
2936
|
+
function Zn(n, t) {
|
|
2981
2937
|
let e = k(n, "base", "pow"), s = k(t, "exp", "pow");
|
|
2982
2938
|
[e, s] = K(e, s);
|
|
2983
2939
|
const r = { a: e, b: s };
|
|
2984
|
-
return g.runKernel(
|
|
2940
|
+
return g.runKernel(Ke, r);
|
|
2985
2941
|
}
|
|
2986
|
-
const
|
|
2942
|
+
const Jt = /* @__PURE__ */ F({ pow_: Zn });
|
|
2987
2943
|
/**
|
|
2988
2944
|
* @license
|
|
2989
2945
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -3005,7 +2961,7 @@ function j(n, t) {
|
|
|
3005
2961
|
throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");
|
|
3006
2962
|
if (t === "string" && R(n) && !(n instanceof Uint8Array))
|
|
3007
2963
|
throw new Error("When making a scalar from encoded string, the value must be `Uint8Array`.");
|
|
3008
|
-
return
|
|
2964
|
+
return bn(n, [], [], t);
|
|
3009
2965
|
}
|
|
3010
2966
|
/**
|
|
3011
2967
|
* @license
|
|
@@ -3023,11 +2979,11 @@ function j(n, t) {
|
|
|
3023
2979
|
* limitations under the License.
|
|
3024
2980
|
* =============================================================================
|
|
3025
2981
|
*/
|
|
3026
|
-
function
|
|
2982
|
+
function ts(n) {
|
|
3027
2983
|
const e = { x: k(n, "x", "sqrt", "float32") };
|
|
3028
|
-
return g.runKernel(
|
|
2984
|
+
return g.runKernel(Ve, e);
|
|
3029
2985
|
}
|
|
3030
|
-
const
|
|
2986
|
+
const tt = /* @__PURE__ */ F({ sqrt_: ts });
|
|
3031
2987
|
/**
|
|
3032
2988
|
* @license
|
|
3033
2989
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
@@ -3044,11 +3000,11 @@ const Z = /* @__PURE__ */ F({ sqrt_: Jn });
|
|
|
3044
3000
|
* limitations under the License.
|
|
3045
3001
|
* =============================================================================
|
|
3046
3002
|
*/
|
|
3047
|
-
function
|
|
3003
|
+
function es(n) {
|
|
3048
3004
|
const t = k(n, "x", "square"), e = {};
|
|
3049
3005
|
return g.runKernel("Square", { x: t }, e);
|
|
3050
3006
|
}
|
|
3051
|
-
const G = /* @__PURE__ */ F({ square_:
|
|
3007
|
+
const G = /* @__PURE__ */ F({ square_: es });
|
|
3052
3008
|
/**
|
|
3053
3009
|
* @license
|
|
3054
3010
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -3065,8 +3021,8 @@ const G = /* @__PURE__ */ F({ square_: Xn });
|
|
|
3065
3021
|
* limitations under the License.
|
|
3066
3022
|
* =============================================================================
|
|
3067
3023
|
*/
|
|
3068
|
-
function
|
|
3069
|
-
b(
|
|
3024
|
+
function ns(n, t) {
|
|
3025
|
+
b(St(n), () => "The f passed in variableGrads(f) must be a function"), b(t == null || Array.isArray(t) && t.every((l) => l instanceof ft), () => "The varList passed in variableGrads(f, varList) must be an array of variables");
|
|
3070
3026
|
const e = t != null;
|
|
3071
3027
|
if (!e) {
|
|
3072
3028
|
t = [];
|
|
@@ -3082,7 +3038,7 @@ function Yn(n, t) {
|
|
|
3082
3038
|
a[u] != null && (c[l.name] = a[u]);
|
|
3083
3039
|
}), s?.forEach((l) => c[l.name] = null), { value: o, grads: c };
|
|
3084
3040
|
}
|
|
3085
|
-
function
|
|
3041
|
+
function Us(n) {
|
|
3086
3042
|
return g.customGrad(n);
|
|
3087
3043
|
}
|
|
3088
3044
|
/**
|
|
@@ -3101,13 +3057,13 @@ function Ns(n) {
|
|
|
3101
3057
|
* limitations under the License.
|
|
3102
3058
|
* =============================================================================
|
|
3103
3059
|
*/
|
|
3104
|
-
function
|
|
3060
|
+
function ss(n, t) {
|
|
3105
3061
|
let e = k(n, "a", "sub"), s = k(t, "b", "sub");
|
|
3106
3062
|
[e, s] = K(e, s);
|
|
3107
3063
|
const r = { a: e, b: s };
|
|
3108
|
-
return g.runKernel(
|
|
3064
|
+
return g.runKernel(qe, r);
|
|
3109
3065
|
}
|
|
3110
|
-
const
|
|
3066
|
+
const Q = /* @__PURE__ */ F({ sub_: ss });
|
|
3111
3067
|
/**
|
|
3112
3068
|
* @license
|
|
3113
3069
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -3124,13 +3080,13 @@ const Y = /* @__PURE__ */ F({ sub_: Qn });
|
|
|
3124
3080
|
* limitations under the License.
|
|
3125
3081
|
* =============================================================================
|
|
3126
3082
|
*/
|
|
3127
|
-
function
|
|
3083
|
+
function rs(n, t) {
|
|
3128
3084
|
let e = k(n, "a", "maximum"), s = k(t, "b", "maximum");
|
|
3129
|
-
[e, s] = K(e, s), e.dtype === "bool" && (e =
|
|
3085
|
+
[e, s] = K(e, s), e.dtype === "bool" && (e = Mt(e, "int32"), s = Mt(s, "int32")), Yn(e.shape, s.shape);
|
|
3130
3086
|
const r = { a: e, b: s };
|
|
3131
|
-
return g.runKernel(
|
|
3087
|
+
return g.runKernel(We, r);
|
|
3132
3088
|
}
|
|
3133
|
-
const
|
|
3089
|
+
const is = /* @__PURE__ */ F({ maximum_: rs });
|
|
3134
3090
|
/**
|
|
3135
3091
|
* @license
|
|
3136
3092
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -3147,8 +3103,8 @@ const ts = /* @__PURE__ */ F({ maximum_: Zn });
|
|
|
3147
3103
|
* limitations under the License.
|
|
3148
3104
|
* =============================================================================
|
|
3149
3105
|
*/
|
|
3150
|
-
const
|
|
3151
|
-
class
|
|
3106
|
+
const os = /* @__PURE__ */ new Map(), as = /* @__PURE__ */ new Map();
|
|
3107
|
+
class ls {
|
|
3152
3108
|
/**
|
|
3153
3109
|
* Return the class name for this class to use in serialization contexts.
|
|
3154
3110
|
*
|
|
@@ -3176,7 +3132,7 @@ class ss {
|
|
|
3176
3132
|
return new t(e);
|
|
3177
3133
|
}
|
|
3178
3134
|
}
|
|
3179
|
-
class
|
|
3135
|
+
class O {
|
|
3180
3136
|
constructor() {
|
|
3181
3137
|
this.classNameMap = {};
|
|
3182
3138
|
}
|
|
@@ -3184,19 +3140,19 @@ class P {
|
|
|
3184
3140
|
* Returns the singleton instance of the map.
|
|
3185
3141
|
*/
|
|
3186
3142
|
static getMap() {
|
|
3187
|
-
return
|
|
3143
|
+
return O.instance == null && (O.instance = new O()), O.instance;
|
|
3188
3144
|
}
|
|
3189
3145
|
/**
|
|
3190
3146
|
* Registers the class as serializable.
|
|
3191
3147
|
*/
|
|
3192
3148
|
static register(t) {
|
|
3193
|
-
|
|
3149
|
+
O.getMap().classNameMap[t.className] = [t, t.fromConfig];
|
|
3194
3150
|
}
|
|
3195
3151
|
}
|
|
3196
|
-
function
|
|
3152
|
+
function cs(n, t, e) {
|
|
3197
3153
|
b(n.className != null, () => "Class being registered does not have the static className property defined."), b(typeof n.className == "string", () => "className is required to be a string, but got type " + typeof n.className), b(n.className.length > 0, () => "Class being registered has an empty-string as its className, which is disallowed."), typeof t > "u" && (t = "Custom"), typeof e > "u" && (e = n.className);
|
|
3198
3154
|
const s = e, r = t + ">" + s;
|
|
3199
|
-
return
|
|
3155
|
+
return O.register(n), os.set(r, n), as.set(n, r), n;
|
|
3200
3156
|
}
|
|
3201
3157
|
/**
|
|
3202
3158
|
* @license
|
|
@@ -3214,7 +3170,7 @@ function rs(n, t, e) {
|
|
|
3214
3170
|
* limitations under the License.
|
|
3215
3171
|
* =============================================================================
|
|
3216
3172
|
*/
|
|
3217
|
-
class V extends
|
|
3173
|
+
class V extends ls {
|
|
3218
3174
|
/**
|
|
3219
3175
|
* Executes `f()` and minimizes the scalar output of `f()` by computing
|
|
3220
3176
|
* gradients of y with respect to the list of trainable variables provided by
|
|
@@ -3261,7 +3217,7 @@ class V extends ss {
|
|
|
3261
3217
|
* @doc {heading: 'Training', subheading: 'Optimizers'}
|
|
3262
3218
|
*/
|
|
3263
3219
|
computeGradients(t, e) {
|
|
3264
|
-
return
|
|
3220
|
+
return ns(t, e);
|
|
3265
3221
|
}
|
|
3266
3222
|
/**
|
|
3267
3223
|
* Dispose the variables (if any) owned by this optimizer instance.
|
|
@@ -3312,7 +3268,7 @@ Object.defineProperty(V, Symbol.hasInstance, {
|
|
|
3312
3268
|
* limitations under the License.
|
|
3313
3269
|
* =============================================================================
|
|
3314
3270
|
*/
|
|
3315
|
-
class
|
|
3271
|
+
class us extends V {
|
|
3316
3272
|
/** @nocollapse */
|
|
3317
3273
|
static get className() {
|
|
3318
3274
|
return "Adadelta";
|
|
@@ -3335,7 +3291,7 @@ class is extends V {
|
|
|
3335
3291
|
return;
|
|
3336
3292
|
const c = this.accumulatedGrads[r].variable, l = this.accumulatedUpdates[r].variable;
|
|
3337
3293
|
E(() => {
|
|
3338
|
-
const u = w(p(c, this.rho), p(G(a), 1 - this.rho)), h = p(D(
|
|
3294
|
+
const u = w(p(c, this.rho), p(G(a), 1 - this.rho)), h = p(D(tt(w(l, this.epsilon)), tt(w(c, this.epsilon))), a), f = w(p(l, this.rho), p(G(h), 1 - this.rho));
|
|
3339
3295
|
c.assign(u), l.assign(f);
|
|
3340
3296
|
const m = w(p(h, -this.learningRate), i);
|
|
3341
3297
|
i.assign(m);
|
|
@@ -3388,7 +3344,7 @@ class is extends V {
|
|
|
3388
3344
|
* limitations under the License.
|
|
3389
3345
|
* =============================================================================
|
|
3390
3346
|
*/
|
|
3391
|
-
class
|
|
3347
|
+
class hs extends V {
|
|
3392
3348
|
/** @nocollapse */
|
|
3393
3349
|
static get className() {
|
|
3394
3350
|
return "Adagrad";
|
|
@@ -3401,7 +3357,7 @@ class os extends V {
|
|
|
3401
3357
|
const i = g.registeredVariables[s];
|
|
3402
3358
|
this.accumulatedGrads[r] == null && (this.accumulatedGrads[r] = {
|
|
3403
3359
|
originalName: `${s}/accumulator`,
|
|
3404
|
-
variable: E(() =>
|
|
3360
|
+
variable: E(() => Xn(i.shape, this.initialAccumulatorValue).variable(!1))
|
|
3405
3361
|
});
|
|
3406
3362
|
const o = Array.isArray(t) ? t[r].tensor : t[s];
|
|
3407
3363
|
if (o == null)
|
|
@@ -3410,7 +3366,7 @@ class os extends V {
|
|
|
3410
3366
|
E(() => {
|
|
3411
3367
|
const c = w(a, G(o));
|
|
3412
3368
|
a.assign(c);
|
|
3413
|
-
const l = w(p(D(o,
|
|
3369
|
+
const l = w(p(D(o, tt(w(c, g.backend.epsilon()))), -this.learningRate), i);
|
|
3414
3370
|
i.assign(l);
|
|
3415
3371
|
});
|
|
3416
3372
|
}), this.incrementIterations();
|
|
@@ -3453,7 +3409,7 @@ class os extends V {
|
|
|
3453
3409
|
* limitations under the License.
|
|
3454
3410
|
* =============================================================================
|
|
3455
3411
|
*/
|
|
3456
|
-
class
|
|
3412
|
+
class fs extends V {
|
|
3457
3413
|
/** @nocollapse */
|
|
3458
3414
|
static get className() {
|
|
3459
3415
|
return "Adam";
|
|
@@ -3466,7 +3422,7 @@ class as extends V {
|
|
|
3466
3422
|
applyGradients(t) {
|
|
3467
3423
|
const e = Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t);
|
|
3468
3424
|
E(() => {
|
|
3469
|
-
const s =
|
|
3425
|
+
const s = Q(1, this.accBeta1), r = Q(1, this.accBeta2);
|
|
3470
3426
|
e.forEach((i, o) => {
|
|
3471
3427
|
const a = g.registeredVariables[i], c = !1;
|
|
3472
3428
|
this.accumulatedFirstMoment[o] == null && (this.accumulatedFirstMoment[o] = {
|
|
@@ -3481,7 +3437,7 @@ class as extends V {
|
|
|
3481
3437
|
return;
|
|
3482
3438
|
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedSecondMoment[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = w(p(h, this.beta2), p(G(l), 1 - this.beta2)), y = D(f, s), d = D(m, r);
|
|
3483
3439
|
u.assign(f), h.assign(m);
|
|
3484
|
-
const I = w(p(D(y, w(
|
|
3440
|
+
const I = w(p(D(y, w(tt(d), this.epsilon)), -this.learningRate), a);
|
|
3485
3441
|
a.assign(I);
|
|
3486
3442
|
}), this.accBeta1.assign(p(this.accBeta1, this.beta1)), this.accBeta2.assign(p(this.accBeta2, this.beta2));
|
|
3487
3443
|
}), this.incrementIterations();
|
|
@@ -3495,7 +3451,7 @@ class as extends V {
|
|
|
3495
3451
|
}
|
|
3496
3452
|
async setWeights(t) {
|
|
3497
3453
|
t = await this.extractIterations(t), E(() => {
|
|
3498
|
-
this.accBeta1.assign(
|
|
3454
|
+
this.accBeta1.assign(Jt(this.beta1, this.iterations_ + 1)), this.accBeta2.assign(Jt(this.beta2, this.iterations_ + 1));
|
|
3499
3455
|
});
|
|
3500
3456
|
const e = t.length / 2, s = !1;
|
|
3501
3457
|
this.accumulatedFirstMoment = t.slice(0, e).map((r) => ({
|
|
@@ -3535,7 +3491,7 @@ class as extends V {
|
|
|
3535
3491
|
* limitations under the License.
|
|
3536
3492
|
* =============================================================================
|
|
3537
3493
|
*/
|
|
3538
|
-
class
|
|
3494
|
+
class ds extends V {
|
|
3539
3495
|
/** @nocollapse */
|
|
3540
3496
|
static get className() {
|
|
3541
3497
|
return "Adamax";
|
|
@@ -3548,7 +3504,7 @@ class ls extends V {
|
|
|
3548
3504
|
applyGradients(t) {
|
|
3549
3505
|
const e = Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t);
|
|
3550
3506
|
E(() => {
|
|
3551
|
-
const s =
|
|
3507
|
+
const s = Q(1, this.accBeta1), r = D(-this.learningRate, w(p(this.iteration, this.decay), 1));
|
|
3552
3508
|
e.forEach((i, o) => {
|
|
3553
3509
|
const a = g.registeredVariables[i], c = !1;
|
|
3554
3510
|
this.accumulatedFirstMoment[o] == null && (this.accumulatedFirstMoment[o] = {
|
|
@@ -3561,7 +3517,7 @@ class ls extends V {
|
|
|
3561
3517
|
const l = Array.isArray(t) ? t[o].tensor : t[i];
|
|
3562
3518
|
if (l == null)
|
|
3563
3519
|
return;
|
|
3564
|
-
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedWeightedInfNorm[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = p(h, this.beta2), y =
|
|
3520
|
+
const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedWeightedInfNorm[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = p(h, this.beta2), y = Jn(l), d = is(m, y);
|
|
3565
3521
|
u.assign(f), h.assign(d);
|
|
3566
3522
|
const I = w(p(D(r, s), D(f, w(d, this.epsilon))), a);
|
|
3567
3523
|
a.assign(I);
|
|
@@ -3607,7 +3563,7 @@ class ls extends V {
|
|
|
3607
3563
|
* limitations under the License.
|
|
3608
3564
|
* =============================================================================
|
|
3609
3565
|
*/
|
|
3610
|
-
class
|
|
3566
|
+
class Ie extends V {
|
|
3611
3567
|
/** @nocollapse */
|
|
3612
3568
|
static get className() {
|
|
3613
3569
|
return "SGD";
|
|
@@ -3631,7 +3587,7 @@ class be extends V {
|
|
|
3631
3587
|
* Sets the learning rate of the optimizer.
|
|
3632
3588
|
*/
|
|
3633
3589
|
setLearningRate(t) {
|
|
3634
|
-
this.learningRate = t, this.c != null && this.c.dispose(), this.c =
|
|
3590
|
+
this.learningRate = t, this.c != null && this.c.dispose(), this.c = Sn(j(-t));
|
|
3635
3591
|
}
|
|
3636
3592
|
dispose() {
|
|
3637
3593
|
this.c.dispose();
|
|
@@ -3667,7 +3623,7 @@ class be extends V {
|
|
|
3667
3623
|
* limitations under the License.
|
|
3668
3624
|
* =============================================================================
|
|
3669
3625
|
*/
|
|
3670
|
-
class
|
|
3626
|
+
class gs extends Ie {
|
|
3671
3627
|
/** @nocollapse */
|
|
3672
3628
|
// Name matters for Python compatibility.
|
|
3673
3629
|
static get className() {
|
|
@@ -3738,7 +3694,7 @@ class cs extends be {
|
|
|
3738
3694
|
* limitations under the License.
|
|
3739
3695
|
* =============================================================================
|
|
3740
3696
|
*/
|
|
3741
|
-
class
|
|
3697
|
+
class ms extends V {
|
|
3742
3698
|
/** @nocollapse */
|
|
3743
3699
|
static get className() {
|
|
3744
3700
|
return "RMSProp";
|
|
@@ -3767,14 +3723,14 @@ class us extends V {
|
|
|
3767
3723
|
E(() => {
|
|
3768
3724
|
const u = w(p(c, this.decay), p(G(a), 1 - this.decay));
|
|
3769
3725
|
if (this.centered) {
|
|
3770
|
-
const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate),
|
|
3726
|
+
const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate), tt(Q(u, w(G(f), this.epsilon)))), y = w(p(l, this.momentum), m);
|
|
3771
3727
|
c.assign(u), h.assign(f), l.assign(y);
|
|
3772
|
-
const d =
|
|
3728
|
+
const d = Q(i, y);
|
|
3773
3729
|
i.assign(d);
|
|
3774
3730
|
} else {
|
|
3775
|
-
const h = w(p(c, this.decay), p(G(a), 1 - this.decay)), f = w(p(l, this.momentum), D(p(a, this.learningRate),
|
|
3731
|
+
const h = w(p(c, this.decay), p(G(a), 1 - this.decay)), f = w(p(l, this.momentum), D(p(a, this.learningRate), tt(w(h, this.epsilon))));
|
|
3776
3732
|
c.assign(h), l.assign(f);
|
|
3777
|
-
const m =
|
|
3733
|
+
const m = Q(i, f);
|
|
3778
3734
|
i.assign(m);
|
|
3779
3735
|
}
|
|
3780
3736
|
});
|
|
@@ -3831,18 +3787,18 @@ class us extends V {
|
|
|
3831
3787
|
* limitations under the License.
|
|
3832
3788
|
* =============================================================================
|
|
3833
3789
|
*/
|
|
3834
|
-
const
|
|
3835
|
-
is,
|
|
3836
|
-
os,
|
|
3837
|
-
as,
|
|
3838
|
-
ls,
|
|
3839
|
-
cs,
|
|
3790
|
+
const ps = [
|
|
3840
3791
|
us,
|
|
3841
|
-
|
|
3792
|
+
hs,
|
|
3793
|
+
fs,
|
|
3794
|
+
ds,
|
|
3795
|
+
gs,
|
|
3796
|
+
ms,
|
|
3797
|
+
Ie
|
|
3842
3798
|
];
|
|
3843
|
-
function
|
|
3844
|
-
for (const n of
|
|
3845
|
-
|
|
3799
|
+
function ys() {
|
|
3800
|
+
for (const n of ps)
|
|
3801
|
+
cs(n);
|
|
3846
3802
|
}
|
|
3847
3803
|
/**
|
|
3848
3804
|
* @license
|
|
@@ -3860,40 +3816,40 @@ function fs() {
|
|
|
3860
3816
|
* limitations under the License.
|
|
3861
3817
|
* =============================================================================
|
|
3862
3818
|
*/
|
|
3863
|
-
|
|
3819
|
+
ys();
|
|
3864
3820
|
export {
|
|
3865
|
-
|
|
3866
|
-
|
|
3867
|
-
|
|
3821
|
+
fs as A,
|
|
3822
|
+
ks as B,
|
|
3823
|
+
Ts as C,
|
|
3868
3824
|
g as E,
|
|
3869
|
-
|
|
3870
|
-
|
|
3871
|
-
|
|
3872
|
-
|
|
3873
|
-
|
|
3874
|
-
|
|
3875
|
-
|
|
3876
|
-
|
|
3877
|
-
|
|
3878
|
-
|
|
3825
|
+
Bs as I,
|
|
3826
|
+
As as L,
|
|
3827
|
+
vs as N,
|
|
3828
|
+
Ms as P,
|
|
3829
|
+
xs as R,
|
|
3830
|
+
Ns as S,
|
|
3831
|
+
Cs as T,
|
|
3832
|
+
Ps as _,
|
|
3833
|
+
Q as a,
|
|
3834
|
+
Is as b,
|
|
3879
3835
|
k as c,
|
|
3880
3836
|
K as d,
|
|
3881
|
-
|
|
3882
|
-
|
|
3883
|
-
|
|
3884
|
-
|
|
3885
|
-
|
|
3886
|
-
|
|
3887
|
-
|
|
3888
|
-
|
|
3837
|
+
Os as e,
|
|
3838
|
+
Es as f,
|
|
3839
|
+
Mt as g,
|
|
3840
|
+
Ds as h,
|
|
3841
|
+
Fs as i,
|
|
3842
|
+
Rs as j,
|
|
3843
|
+
$s as k,
|
|
3844
|
+
_s as l,
|
|
3889
3845
|
p as m,
|
|
3890
3846
|
b as n,
|
|
3891
3847
|
F as o,
|
|
3892
|
-
|
|
3848
|
+
Ls as p,
|
|
3893
3849
|
w as q,
|
|
3894
3850
|
U as r,
|
|
3895
3851
|
j as s,
|
|
3896
3852
|
E as t,
|
|
3897
|
-
|
|
3898
|
-
|
|
3853
|
+
Yn as u,
|
|
3854
|
+
Us as v
|
|
3899
3855
|
};
|