@genai-fi/nanogpt 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. package/LICENSE +7 -0
  2. package/README.md +20 -0
  3. package/dist/Generator.d.ts +14 -0
  4. package/dist/Generator.js +39 -0
  5. package/dist/NanoGPTModel.d.ts +35 -0
  6. package/dist/NanoGPTModel.js +129 -0
  7. package/dist/TeachableLLM.d.ts +21 -0
  8. package/dist/TeachableLLM.js +47 -0
  9. package/dist/Trainer.d.ts +19 -0
  10. package/dist/Trainer.js +34 -0
  11. package/dist/_commonjsHelpers-DaMA6jEr.js +8 -0
  12. package/dist/assets/worker-BYeSPNkq.js +1 -0
  13. package/dist/config.d.ts +11 -0
  14. package/dist/config.js +19 -0
  15. package/dist/index-B8nyc6IR.js +3899 -0
  16. package/dist/index-SOhdqzHq.js +113 -0
  17. package/dist/jszip.min-BLbRbbKt.js +2324 -0
  18. package/dist/layers/CausalSelfAttention.d.ts +22 -0
  19. package/dist/layers/CausalSelfAttention.js +75 -0
  20. package/dist/layers/LayerNorm.d.ts +12 -0
  21. package/dist/layers/LayerNorm.js +30 -0
  22. package/dist/layers/MLP.d.ts +17 -0
  23. package/dist/layers/MLP.js +57 -0
  24. package/dist/layers/TiedEmbedding.d.ts +22 -0
  25. package/dist/layers/TiedEmbedding.js +532 -0
  26. package/dist/layers/TransformerBlock.d.ts +19 -0
  27. package/dist/layers/TransformerBlock.js +47 -0
  28. package/dist/main.d.ts +6 -0
  29. package/dist/main.js +8 -0
  30. package/dist/tokeniser/CharTokeniser.d.ts +20 -0
  31. package/dist/tokeniser/CharTokeniser.js +52 -0
  32. package/dist/tokeniser/NodeTokeniser.d.ts +19 -0
  33. package/dist/tokeniser/NodeTokeniser.js +46 -0
  34. package/dist/tokeniser/WebTokeniser.d.ts +18 -0
  35. package/dist/tokeniser/WebTokeniser.js +96 -0
  36. package/dist/tokeniser/bpe.d.ts +14 -0
  37. package/dist/tokeniser/bpe.js +102 -0
  38. package/dist/tokeniser/messages.d.ts +61 -0
  39. package/dist/tokeniser/messages.js +1 -0
  40. package/dist/tokeniser/type.d.ts +14 -0
  41. package/dist/tokeniser/type.js +1 -0
  42. package/dist/tokeniser/worker.d.ts +1 -0
  43. package/dist/tokeniser/worker.js +53 -0
  44. package/dist/training/AdamExt.d.ts +23 -0
  45. package/dist/training/AdamExt.js +43 -0
  46. package/dist/training/DatasetBuilder.d.ts +12 -0
  47. package/dist/training/DatasetBuilder.js +27 -0
  48. package/dist/training/FullTrainer.d.ts +17 -0
  49. package/dist/training/FullTrainer.js +75 -0
  50. package/dist/training/LayerTrainer.d.ts +28 -0
  51. package/dist/training/LayerTrainer.js +108 -0
  52. package/dist/training/Trainer.d.ts +73 -0
  53. package/dist/training/Trainer.js +87 -0
  54. package/dist/training/lwSchedule.d.ts +7 -0
  55. package/dist/training/lwSchedule.js +162 -0
  56. package/dist/utilities/generate.d.ts +3 -0
  57. package/dist/utilities/generate.js +22 -0
  58. package/dist/utilities/load.d.ts +7 -0
  59. package/dist/utilities/load.js +47 -0
  60. package/dist/utilities/save.d.ts +3 -0
  61. package/dist/utilities/save.js +21 -0
  62. package/dist/utilities/textLoader.d.ts +1 -0
  63. package/dist/utilities/textLoader.js +438 -0
  64. package/dist/utilities/tokenParse.d.ts +1 -0
  65. package/dist/utilities/tokenParse.js +66 -0
  66. package/dist/utilities/weights.d.ts +12 -0
  67. package/dist/utilities/weights.js +43 -0
  68. package/package.json +59 -0
@@ -0,0 +1,3899 @@
1
+ /**
2
+ * @license
3
+ * Copyright 2020 Google LLC. All Rights Reserved.
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ * =============================================================================
16
+ */
17
+ class we {
18
+ refCount(t) {
19
+ return v("refCount");
20
+ }
21
+ incRef(t) {
22
+ return v("incRef");
23
+ }
24
+ timerAvailable() {
25
+ return !0;
26
+ }
27
+ time(t) {
28
+ return v("time");
29
+ }
30
+ read(t) {
31
+ return v("read");
32
+ }
33
+ readSync(t) {
34
+ return v("readSync");
35
+ }
36
+ readToGPU(t, e) {
37
+ return v("readToGPU");
38
+ }
39
+ numDataIds() {
40
+ return v("numDataIds");
41
+ }
42
+ disposeData(t, e) {
43
+ return v("disposeData");
44
+ }
45
+ write(t, e, s) {
46
+ return v("write");
47
+ }
48
+ move(t, e, s, r, i) {
49
+ return v("move");
50
+ }
51
+ createTensorFromGPUData(t, e, s) {
52
+ return v("createTensorFromGPUData");
53
+ }
54
+ memory() {
55
+ return v("memory");
56
+ }
57
+ /** Returns the highest precision for floats in bits (e.g. 16 or 32) */
58
+ floatPrecision() {
59
+ return v("floatPrecision");
60
+ }
61
+ /** Returns the smallest representable number. */
62
+ epsilon() {
63
+ return this.floatPrecision() === 32 ? 1e-7 : 1e-4;
64
+ }
65
+ dispose() {
66
+ return v("dispose");
67
+ }
68
+ }
69
+ function v(n) {
70
+ throw new Error(`'${n}' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen`);
71
+ }
72
+ /**
73
+ * @license
74
+ * Copyright 2020 Google LLC. All Rights Reserved.
75
+ * Licensed under the Apache License, Version 2.0 (the "License");
76
+ * you may not use this file except in compliance with the License.
77
+ * You may obtain a copy of the License at
78
+ *
79
+ * http://www.apache.org/licenses/LICENSE-2.0
80
+ *
81
+ * Unless required by applicable law or agreed to in writing, software
82
+ * distributed under the License is distributed on an "AS IS" BASIS,
83
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84
+ * See the License for the specific language governing permissions and
85
+ * limitations under the License.
86
+ * =============================================================================
87
+ */
88
+ function b(n, t) {
89
+ if (!n)
90
+ throw new Error(typeof t == "string" ? t : t());
91
+ }
92
+ function ds(n, t, e = "") {
93
+ b(vt(n, t), () => e + ` Shapes ${n} and ${t} must match`);
94
+ }
95
+ function U(n) {
96
+ if (n.length === 0)
97
+ return 1;
98
+ let t = n[0];
99
+ for (let e = 1; e < n.length; e++)
100
+ t *= n[e];
101
+ return t;
102
+ }
103
+ function vt(n, t) {
104
+ if (n === t)
105
+ return !0;
106
+ if (n == null || t == null || n.length !== t.length)
107
+ return !1;
108
+ for (let e = 0; e < n.length; e++)
109
+ if (n[e] !== t[e])
110
+ return !1;
111
+ return !0;
112
+ }
113
+ function lt(n, t) {
114
+ return t <= n.length ? n : n + " ".repeat(t - n.length);
115
+ }
116
+ function Se(n, t) {
117
+ let e = null;
118
+ if (n == null || n === "float32")
119
+ e = new Float32Array(t);
120
+ else if (n === "int32")
121
+ e = new Int32Array(t);
122
+ else if (n === "bool")
123
+ e = new Uint8Array(t);
124
+ else if (n === "string")
125
+ e = new Array(t);
126
+ else
127
+ throw new Error(`Unknown data type ${n}`);
128
+ return e;
129
+ }
130
+ function Ie(n, t) {
131
+ for (let e = 0; e < n.length; e++) {
132
+ const s = n[e];
133
+ if (isNaN(s) || !isFinite(s))
134
+ throw Error(`A tensor of type ${t} being uploaded contains ${s}.`);
135
+ }
136
+ }
137
+ function ke(n) {
138
+ return n === "bool" || n === "complex64" || n === "float32" || n === "int32" || n === "string";
139
+ }
140
+ function yt(n) {
141
+ if (n === "float32" || n === "int32")
142
+ return 4;
143
+ if (n === "complex64")
144
+ return 8;
145
+ if (n === "bool")
146
+ return 1;
147
+ throw new Error(`Unknown dtype ${n}`);
148
+ }
149
+ function Te(n) {
150
+ if (n == null)
151
+ return 0;
152
+ let t = 0;
153
+ return n.forEach((e) => t += e.length), t;
154
+ }
155
+ function Mt(n) {
156
+ return typeof n == "string" || n instanceof String;
157
+ }
158
+ function Ee(n) {
159
+ return typeof n == "boolean";
160
+ }
161
+ function Be(n) {
162
+ return typeof n == "number";
163
+ }
164
+ function ft(n) {
165
+ return Array.isArray(n) ? ft(n[0]) : n instanceof Float32Array ? "float32" : n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray ? "int32" : Be(n) ? "float32" : Mt(n) ? "string" : Ee(n) ? "bool" : "float32";
166
+ }
167
+ function bt(n) {
168
+ return !!(n && n.constructor && n.call && n.apply);
169
+ }
170
+ function Ft(n) {
171
+ const t = n.length;
172
+ if (t < 2)
173
+ return [];
174
+ const e = new Array(t - 1);
175
+ e[t - 2] = n[t - 1];
176
+ for (let s = t - 3; s >= 0; --s)
177
+ e[s] = e[s + 1] * n[s + 1];
178
+ return e;
179
+ }
180
+ function qt(n, t, e, s = !1) {
181
+ const r = new Array();
182
+ if (t.length === 1) {
183
+ const i = t[0] * (s ? 2 : 1);
184
+ for (let o = 0; o < i; o++)
185
+ r[o] = e[n + o];
186
+ } else {
187
+ const i = t[0], o = t.slice(1), a = o.reduce((c, l) => c * l) * (s ? 2 : 1);
188
+ for (let c = 0; c < i; c++)
189
+ r[c] = qt(n + c * a, o, e, s);
190
+ }
191
+ return r;
192
+ }
193
+ function Dt(n, t, e = !1) {
194
+ if (n.length === 0)
195
+ return t[0];
196
+ const s = n.reduce((r, i) => r * i) * (e ? 2 : 1);
197
+ if (s === 0)
198
+ return [];
199
+ if (s !== t.length)
200
+ throw new Error(`[${n}] does not match the input size ${t.length}${e ? " for a complex tensor" : ""}.`);
201
+ return qt(0, n, t, e);
202
+ }
203
+ function Ae(n, t) {
204
+ const e = Ht(n, t);
205
+ for (let s = 0; s < e.length; s++)
206
+ e[s] = 1;
207
+ return e;
208
+ }
209
+ function Ht(n, t) {
210
+ if (t == null || t === "float32" || t === "complex64")
211
+ return new Float32Array(n);
212
+ if (t === "int32")
213
+ return new Int32Array(n);
214
+ if (t === "bool")
215
+ return new Uint8Array(n);
216
+ throw new Error(`Unknown data type ${t}`);
217
+ }
218
+ function Rt(n) {
219
+ n.forEach((t) => {
220
+ b(Number.isInteger(t) && t >= 0, () => `Tensor must have a shape comprised of positive integers but got shape [${n}].`);
221
+ });
222
+ }
223
+ function xt(n) {
224
+ return n && n.then && typeof n.then == "function";
225
+ }
226
+ /**
227
+ * @license
228
+ * Copyright 2017 Google LLC. All Rights Reserved.
229
+ * Licensed under the Apache License, Version 2.0 (the "License");
230
+ * you may not use this file except in compliance with the License.
231
+ * You may obtain a copy of the License at
232
+ *
233
+ * http://www.apache.org/licenses/LICENSE-2.0
234
+ *
235
+ * Unless required by applicable law or agreed to in writing, software
236
+ * distributed under the License is distributed on an "AS IS" BASIS,
237
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
238
+ * See the License for the specific language governing permissions and
239
+ * limitations under the License.
240
+ * =============================================================================
241
+ */
242
+ const Ct = "tfjsflags";
243
+ class ve {
244
+ // tslint:disable-next-line: no-any
245
+ constructor(t) {
246
+ this.global = t, this.flags = {}, this.flagRegistry = {}, this.urlFlags = {}, this.getQueryParams = Me, this.populateURLFlags();
247
+ }
248
+ setPlatform(t, e) {
249
+ this.platform != null && (S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(`Platform ${this.platformName} has already been set. Overwriting the platform with ${t}.`)), this.platformName = t, this.platform = e;
250
+ }
251
+ registerFlag(t, e, s) {
252
+ if (this.flagRegistry[t] = { evaluationFn: e, setHook: s }, this.urlFlags[t] != null) {
253
+ const r = this.urlFlags[t];
254
+ S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(`Setting feature override from URL ${t}: ${r}.`), this.set(t, r);
255
+ }
256
+ }
257
+ async getAsync(t) {
258
+ return t in this.flags ? this.flags[t] : (this.flags[t] = await this.evaluateFlag(t), this.flags[t]);
259
+ }
260
+ get(t) {
261
+ if (t in this.flags)
262
+ return this.flags[t];
263
+ const e = this.evaluateFlag(t);
264
+ if (xt(e))
265
+ throw new Error(`Flag ${t} cannot be synchronously evaluated. Please use getAsync() instead.`);
266
+ return this.flags[t] = e, this.flags[t];
267
+ }
268
+ getNumber(t) {
269
+ return this.get(t);
270
+ }
271
+ getBool(t) {
272
+ return this.get(t);
273
+ }
274
+ getString(t) {
275
+ return this.get(t);
276
+ }
277
+ getFlags() {
278
+ return this.flags;
279
+ }
280
+ // For backwards compatibility.
281
+ get features() {
282
+ return this.flags;
283
+ }
284
+ set(t, e) {
285
+ if (this.flagRegistry[t] == null)
286
+ throw new Error(`Cannot set flag ${t} as it has not been registered.`);
287
+ this.flags[t] = e, this.flagRegistry[t].setHook != null && this.flagRegistry[t].setHook(e);
288
+ }
289
+ evaluateFlag(t) {
290
+ if (this.flagRegistry[t] == null)
291
+ throw new Error(`Cannot evaluate flag '${t}': no evaluation function found.`);
292
+ return this.flagRegistry[t].evaluationFn();
293
+ }
294
+ setFlags(t) {
295
+ this.flags = Object.assign({}, t);
296
+ }
297
+ reset() {
298
+ this.flags = {}, this.urlFlags = {}, this.populateURLFlags();
299
+ }
300
+ populateURLFlags() {
301
+ if (typeof this.global > "u" || typeof this.global.location > "u" || typeof this.global.location.search > "u")
302
+ return;
303
+ const t = this.getQueryParams(this.global.location.search);
304
+ Ct in t && t[Ct].split(",").forEach((s) => {
305
+ const [r, i] = s.split(":");
306
+ this.urlFlags[r] = Re(r, i);
307
+ });
308
+ }
309
+ }
310
+ function Me(n) {
311
+ const t = {};
312
+ return n.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (e, ...s) => (Fe(t, s[0], s[1]), s.join("="))), t;
313
+ }
314
+ function Fe(n, t, e) {
315
+ n[decodeURIComponent(t)] = decodeURIComponent(e || "");
316
+ }
317
+ function Re(n, t) {
318
+ const e = t.toLowerCase();
319
+ return e === "true" || e === "false" ? e === "true" : `${+e}` === e ? +e : t;
320
+ }
321
+ function S() {
322
+ return Jt;
323
+ }
324
+ let Jt = null;
325
+ function xe(n) {
326
+ Jt = n;
327
+ }
328
+ /**
329
+ * @license
330
+ * Copyright 2020 Google LLC. All Rights Reserved.
331
+ * Licensed under the Apache License, Version 2.0 (the "License");
332
+ * you may not use this file except in compliance with the License.
333
+ * You may obtain a copy of the License at
334
+ *
335
+ * http://www.apache.org/licenses/LICENSE-2.0
336
+ *
337
+ * Unless required by applicable law or agreed to in writing, software
338
+ * distributed under the License is distributed on an "AS IS" BASIS,
339
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
340
+ * See the License for the specific language governing permissions and
341
+ * limitations under the License.
342
+ * =============================================================================
343
+ */
344
+ let dt;
345
+ function Xt() {
346
+ if (dt == null) {
347
+ let n;
348
+ if (typeof window < "u")
349
+ n = window;
350
+ else if (typeof global < "u")
351
+ n = global;
352
+ else if (typeof process < "u")
353
+ n = process;
354
+ else if (typeof self < "u")
355
+ n = self;
356
+ else
357
+ throw new Error("Could not find a global object");
358
+ dt = n;
359
+ }
360
+ return dt;
361
+ }
362
+ function Ne() {
363
+ const n = Xt();
364
+ return n._tfGlobals == null && (n._tfGlobals = /* @__PURE__ */ new Map()), n._tfGlobals;
365
+ }
366
+ function Nt(n, t) {
367
+ const e = Ne();
368
+ if (e.has(n))
369
+ return e.get(n);
370
+ {
371
+ const s = t();
372
+ return e.set(n, s), e.get(n);
373
+ }
374
+ }
375
+ const $e = "Abs", Yt = "Add", gs = "BatchMatMul", Qt = "Cast", ms = "Complex", De = "ComplexAbs", Ce = "RealDiv", ps = "Elu", _e = "Fill", Oe = "FloorDiv", Zt = "Identity", ys = "Imag", bs = "LeakyRelu", Pe = "Maximum", Le = "Multiply", ws = "Neg", Ue = "Pow", Ss = "Prelu", Is = "Real", ks = "Relu", Ts = "Reshape", Es = "Relu6", Bs = "Sigmoid", Ge = "Sqrt", As = "Sum", ze = "Sub", vs = "Transpose", We = "ZerosLike", Ms = "Step", Fs = "_FusedMatMul";
376
+ /**
377
+ * @license
378
+ * Copyright 2018 Google LLC. All Rights Reserved.
379
+ * Licensed under the Apache License, Version 2.0 (the "License");
380
+ * you may not use this file except in compliance with the License.
381
+ * You may obtain a copy of the License at
382
+ *
383
+ * http://www.apache.org/licenses/LICENSE-2.0
384
+ *
385
+ * Unless required by applicable law or agreed to in writing, software
386
+ * distributed under the License is distributed on an "AS IS" BASIS,
387
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
388
+ * See the License for the specific language governing permissions and
389
+ * limitations under the License.
390
+ * =============================================================================
391
+ */
392
+ function nt(...n) {
393
+ S().getBool("IS_TEST") || S().getBool("PROD") || console.warn(...n);
394
+ }
395
+ /**
396
+ * @license
397
+ * Copyright 2019 Google LLC. All Rights Reserved.
398
+ * Licensed under the Apache License, Version 2.0 (the "License");
399
+ * you may not use this file except in compliance with the License.
400
+ * You may obtain a copy of the License at
401
+ *
402
+ * http://www.apache.org/licenses/LICENSE-2.0
403
+ *
404
+ * Unless required by applicable law or agreed to in writing, software
405
+ * distributed under the License is distributed on an "AS IS" BASIS,
406
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
407
+ * See the License for the specific language governing permissions and
408
+ * limitations under the License.
409
+ * =============================================================================
410
+ */
411
+ const te = Nt("kernelRegistry", () => /* @__PURE__ */ new Map()), je = Nt("gradRegistry", () => /* @__PURE__ */ new Map());
412
+ function _t(n, t) {
413
+ const e = Ke(n, t);
414
+ return te.get(e);
415
+ }
416
+ function Ot(n) {
417
+ return je.get(n);
418
+ }
419
+ function Pt(n) {
420
+ const t = te.entries(), e = [];
421
+ for (; ; ) {
422
+ const { done: s, value: r } = t.next();
423
+ if (s)
424
+ break;
425
+ const [i, o] = r, [a] = i.split("_");
426
+ a === n && e.push(o);
427
+ }
428
+ return e;
429
+ }
430
+ function Ke(n, t) {
431
+ return `${t}_${n}`;
432
+ }
433
+ /**
434
+ * @license
435
+ * Copyright 2023 Google LLC.
436
+ * Licensed under the Apache License, Version 2.0 (the "License");
437
+ * you may not use this file except in compliance with the License.
438
+ * You may obtain a copy of the License at
439
+ *
440
+ * http://www.apache.org/licenses/LICENSE-2.0
441
+ *
442
+ * Unless required by applicable law or agreed to in writing, software
443
+ * distributed under the License is distributed on an "AS IS" BASIS,
444
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
445
+ * See the License for the specific language governing permissions and
446
+ * limitations under the License.
447
+ * =============================================================================
448
+ */
449
+ function ee(n) {
450
+ return n instanceof Float32Array || n instanceof Int32Array || n instanceof Uint8Array || n instanceof Uint8ClampedArray;
451
+ }
452
+ /**
453
+ * @license
454
+ * Copyright 2017 Google LLC. All Rights Reserved.
455
+ * Licensed under the Apache License, Version 2.0 (the "License");
456
+ * you may not use this file except in compliance with the License.
457
+ * You may obtain a copy of the License at
458
+ *
459
+ * http://www.apache.org/licenses/LICENSE-2.0
460
+ *
461
+ * Unless required by applicable law or agreed to in writing, software
462
+ * distributed under the License is distributed on an "AS IS" BASIS,
463
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
464
+ * See the License for the specific language governing permissions and
465
+ * limitations under the License.
466
+ * =============================================================================
467
+ */
468
+ function Ve(n, t) {
469
+ return n instanceof Float32Array && t === "float32" || n instanceof Int32Array && t === "int32" || n instanceof Uint8Array && t === "bool";
470
+ }
471
+ function ne(n, t) {
472
+ if (t === "string")
473
+ throw new Error("Cannot convert a string[] to a TypedArray");
474
+ if (Array.isArray(n) && (n = ot(n)), S().getBool("DEBUG") && Ie(n, t), Ve(n, t))
475
+ return n;
476
+ if (t == null || t === "float32" || t === "complex64")
477
+ return new Float32Array(n);
478
+ if (t === "int32")
479
+ return new Int32Array(n);
480
+ if (t === "bool") {
481
+ const e = new Uint8Array(n.length);
482
+ for (let s = 0; s < e.length; ++s)
483
+ Math.round(n[s]) !== 0 && (e[s] = 1);
484
+ return e;
485
+ } else
486
+ throw new Error(`Unknown data type ${t}`);
487
+ }
488
+ function ut() {
489
+ return S().platform.now();
490
+ }
491
+ function qe(n, t = "utf-8") {
492
+ return t = t || "utf-8", S().platform.encode(n, t);
493
+ }
494
+ function Lt(n, t = "utf-8") {
495
+ return t = t || "utf-8", S().platform.decode(n, t);
496
+ }
497
+ function R(n) {
498
+ return S().platform.isTypedArray != null ? S().platform.isTypedArray(n) : ee(n);
499
+ }
500
+ function ot(n, t = [], e = !1) {
501
+ if (t == null && (t = []), typeof n == "boolean" || typeof n == "number" || typeof n == "string" || xt(n) || n == null || R(n) && e)
502
+ t.push(n);
503
+ else if (Array.isArray(n) || R(n))
504
+ for (let s = 0; s < n.length; ++s)
505
+ ot(n[s], t, e);
506
+ else {
507
+ let s = -1;
508
+ for (const r of Object.keys(n))
509
+ /^([1-9]+[0-9]*|0)$/.test(r) && (s = Math.max(s, Number(r)));
510
+ for (let r = 0; r <= s; r++)
511
+ ot(n[r], t, e);
512
+ }
513
+ return t;
514
+ }
515
+ /**
516
+ * @license
517
+ * Copyright 2018 Google LLC. All Rights Reserved.
518
+ * Licensed under the Apache License, Version 2.0 (the "License");
519
+ * you may not use this file except in compliance with the License.
520
+ * You may obtain a copy of the License at
521
+ *
522
+ * http://www.apache.org/licenses/LICENSE-2.0
523
+ *
524
+ * Unless required by applicable law or agreed to in writing, software
525
+ * distributed under the License is distributed on an "AS IS" BASIS,
526
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
527
+ * See the License for the specific language governing permissions and
528
+ * limitations under the License.
529
+ * =============================================================================
530
+ */
531
+ class He {
532
+ constructor(t, e) {
533
+ this.backendTimer = t, this.logger = e, e == null && (this.logger = new Xe());
534
+ }
535
+ profileKernel(t, e, s) {
536
+ let r;
537
+ const i = () => {
538
+ r = s();
539
+ };
540
+ let o;
541
+ const a = ut();
542
+ if (this.backendTimer.timerAvailable())
543
+ o = this.backendTimer.time(i);
544
+ else {
545
+ i();
546
+ for (const l of r)
547
+ l.dataSync();
548
+ o = Promise.resolve({ kernelMs: ut() - a });
549
+ }
550
+ if (S().getBool("CHECK_COMPUTATION_FOR_ERRORS"))
551
+ for (let l = 0; l < r.length; l++) {
552
+ const u = r[l];
553
+ u.data().then((h) => {
554
+ Je(h, u.dtype, t);
555
+ });
556
+ }
557
+ return {
558
+ kernelName: t,
559
+ outputs: r,
560
+ inputs: e,
561
+ timeMs: o.then((l) => l.kernelMs),
562
+ extraInfo: o.then((l) => l.getExtraProfileInfo != null ? l.getExtraProfileInfo() : "")
563
+ };
564
+ }
565
+ logKernelProfile(t) {
566
+ const { kernelName: e, outputs: s, timeMs: r, inputs: i, extraInfo: o } = t;
567
+ s.forEach((a) => {
568
+ Promise.all([a.data(), r, o]).then((c) => {
569
+ this.logger.logKernelProfile(e, a, c[0], c[1], i, c[2]);
570
+ });
571
+ });
572
+ }
573
+ }
574
+ function Je(n, t, e) {
575
+ if (t !== "float32")
576
+ return !1;
577
+ for (let s = 0; s < n.length; s++) {
578
+ const r = n[s];
579
+ if (isNaN(r) || !isFinite(r))
580
+ return console.warn(`Found ${r} in the result of '${e}'`), !0;
581
+ }
582
+ return !1;
583
+ }
584
+ class Xe {
585
+ logKernelProfile(t, e, s, r, i, o) {
586
+ const a = typeof r == "number" ? lt(`${r}ms`, 9) : r.error, c = lt(t, 25), l = e.rank, u = e.size, h = lt(e.shape.toString(), 14);
587
+ let f = "";
588
+ for (const m in i) {
589
+ const y = i[m];
590
+ if (y != null) {
591
+ const d = y.shape || e.shape, I = d.length;
592
+ f += `${m}: ${I}D ${I > 0 ? d : ""} `;
593
+ }
594
+ }
595
+ console.log(`%c${c} %c${a} %c${l}D ${h} %c${u} %c${f} %c${o}`, "font-weight:bold", "color:red", "color:blue", "color: orange", "color: green", "color: steelblue");
596
+ }
597
+ }
598
+ /**
599
+ * @license
600
+ * Copyright 2017 Google LLC. All Rights Reserved.
601
+ * Licensed under the Apache License, Version 2.0 (the "License");
602
+ * you may not use this file except in compliance with the License.
603
+ * You may obtain a copy of the License at
604
+ *
605
+ * http://www.apache.org/licenses/LICENSE-2.0
606
+ *
607
+ * Unless required by applicable law or agreed to in writing, software
608
+ * distributed under the License is distributed on an "AS IS" BASIS,
609
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
610
+ * See the License for the specific language governing permissions and
611
+ * limitations under the License.
612
+ * =============================================================================
613
+ */
614
+ function Ye(n, t, e) {
615
+ const s = {}, r = {};
616
+ for (let c = 0; c < t.length; c++)
617
+ s[t[c].id] = !0;
618
+ for (let c = 0; c < n.length; c++) {
619
+ const l = n[c], u = l.inputs;
620
+ for (const h in u) {
621
+ const f = u[h];
622
+ let m = !1;
623
+ for (let y = 0; y < t.length; y++)
624
+ if (s[f.id]) {
625
+ l.outputs.forEach((d) => s[d.id] = !0), m = !0, r[l.id] = !0;
626
+ break;
627
+ }
628
+ if (m)
629
+ break;
630
+ }
631
+ }
632
+ const i = {};
633
+ i[e.id] = !0;
634
+ const o = {};
635
+ for (let c = n.length - 1; c >= 0; c--) {
636
+ const l = n[c], u = l.inputs;
637
+ for (let h = 0; h < l.outputs.length; h++)
638
+ if (i[l.outputs[h].id]) {
639
+ for (const f in u)
640
+ i[u[f].id] = !0, o[l.id] = !0;
641
+ break;
642
+ }
643
+ }
644
+ const a = [];
645
+ for (let c = 0; c < n.length; c++) {
646
+ const l = n[c];
647
+ if (r[l.id] && o[l.id]) {
648
+ const u = {};
649
+ for (const f in l.inputs) {
650
+ const m = l.inputs[f];
651
+ s[m.id] && (u[f] = m);
652
+ }
653
+ const h = Object.assign({}, l);
654
+ h.inputs = u, h.outputs = l.outputs, a.push(h);
655
+ }
656
+ }
657
+ return a;
658
+ }
659
+ function Qe(n, t, e, s) {
660
+ for (let r = t.length - 1; r >= 0; r--) {
661
+ const i = t[r], o = [];
662
+ if (i.outputs.forEach((c) => {
663
+ const l = n[c.id];
664
+ l != null ? o.push(l) : o.push(null);
665
+ }), i.gradient == null)
666
+ throw new Error(`Cannot compute gradient: gradient function not found for ${i.kernelName}.`);
667
+ const a = i.gradient(o);
668
+ for (const c in i.inputs) {
669
+ if (!(c in a))
670
+ throw new Error(`Cannot backprop through input ${c}. Available gradients found: ${Object.keys(a)}.`);
671
+ const l = e(() => a[c]());
672
+ if (l.dtype !== "float32")
673
+ throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input ${c} must have 'float32' dtype, but has '${l.dtype}'`);
674
+ const u = i.inputs[c];
675
+ if (!vt(l.shape, u.shape))
676
+ throw new Error(`Error in gradient for op ${i.kernelName}. The gradient of input '${c}' has shape '${l.shape}', which does not match the shape of the input '${u.shape}'`);
677
+ if (n[u.id] == null)
678
+ n[u.id] = l;
679
+ else {
680
+ const h = n[u.id];
681
+ n[u.id] = s(h, l), h.dispose();
682
+ }
683
+ }
684
+ }
685
+ }
686
+ /**
687
+ * @license
688
+ * Copyright 2018 Google LLC. All Rights Reserved.
689
+ * Licensed under the Apache License, Version 2.0 (the "License");
690
+ * you may not use this file except in compliance with the License.
691
+ * You may obtain a copy of the License at
692
+ *
693
+ * http://www.apache.org/licenses/LICENSE-2.0
694
+ *
695
+ * Unless required by applicable law or agreed to in writing, software
696
+ * distributed under the License is distributed on an "AS IS" BASIS,
697
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
698
+ * See the License for the specific language governing permissions and
699
+ * limitations under the License.
700
+ * =============================================================================
701
+ */
702
+ const Ut = 20, st = 3, gt = 7;
703
+ function Ze(n, t, e, s) {
704
+ const r = Ft(t), i = tn(n, t, e, r), o = t.length, a = ct(n, t, e, r, i), c = ["Tensor"];
705
+ return s && (c.push(` dtype: ${e}`), c.push(` rank: ${o}`), c.push(` shape: [${t}]`), c.push(" values:")), c.push(a.map((l) => " " + l).join(`
706
+ `)), c.join(`
707
+ `);
708
+ }
709
+ function tn(n, t, e, s) {
710
+ const r = U(t), i = s[s.length - 1], o = new Array(i).fill(0), a = t.length, c = e === "complex64" ? it(n) : n;
711
+ if (a > 1)
712
+ for (let l = 0; l < r / i; l++) {
713
+ const u = l * i;
714
+ for (let h = 0; h < i; h++)
715
+ o[h] = Math.max(o[h], rt(c[u + h], 0, e).length);
716
+ }
717
+ return o;
718
+ }
719
+ function rt(n, t, e) {
720
+ let s;
721
+ return Array.isArray(n) ? s = `${parseFloat(n[0].toFixed(gt))} + ${parseFloat(n[1].toFixed(gt))}j` : Mt(n) ? s = `'${n}'` : e === "bool" ? s = se(n) : s = parseFloat(n.toFixed(gt)).toString(), lt(s, t);
722
+ }
723
+ function se(n) {
724
+ return n === 0 ? "false" : "true";
725
+ }
726
+ function ct(n, t, e, s, r, i = !0) {
727
+ const o = e === "complex64" ? 2 : 1, a = t[0], c = t.length;
728
+ if (c === 0) {
729
+ if (e === "complex64") {
730
+ const d = it(n);
731
+ return [rt(d[0], 0, e)];
732
+ }
733
+ return e === "bool" ? [se(n[0])] : [n[0].toString()];
734
+ }
735
+ if (c === 1) {
736
+ if (a > Ut) {
737
+ const I = st * o;
738
+ let T = Array.from(n.slice(0, I)), tt = Array.from(n.slice((a - st) * o, a * o));
739
+ return e === "complex64" && (T = it(T), tt = it(tt)), [
740
+ "[" + T.map((q, H) => rt(q, r[H], e)).join(", ") + ", ..., " + tt.map((q, H) => rt(q, r[a - st + H], e)).join(", ") + "]"
741
+ ];
742
+ }
743
+ return [
744
+ "[" + (e === "complex64" ? it(n) : Array.from(n)).map((I, T) => rt(I, r[T], e)).join(", ") + "]"
745
+ ];
746
+ }
747
+ const l = t.slice(1), u = s.slice(1), h = s[0] * o, f = [];
748
+ if (a > Ut) {
749
+ for (let d = 0; d < st; d++) {
750
+ const I = d * h, T = I + h;
751
+ f.push(...ct(
752
+ n.slice(I, T),
753
+ l,
754
+ e,
755
+ u,
756
+ r,
757
+ !1
758
+ /* isLast */
759
+ ));
760
+ }
761
+ f.push("...");
762
+ for (let d = a - st; d < a; d++) {
763
+ const I = d * h, T = I + h;
764
+ f.push(...ct(
765
+ n.slice(I, T),
766
+ l,
767
+ e,
768
+ u,
769
+ r,
770
+ d === a - 1
771
+ /* isLast */
772
+ ));
773
+ }
774
+ } else
775
+ for (let d = 0; d < a; d++) {
776
+ const I = d * h, T = I + h;
777
+ f.push(...ct(
778
+ n.slice(I, T),
779
+ l,
780
+ e,
781
+ u,
782
+ r,
783
+ d === a - 1
784
+ /* isLast */
785
+ ));
786
+ }
787
+ const m = c === 2 ? "," : "";
788
+ f[0] = "[" + (a > 0 ? f[0] + m : "");
789
+ for (let d = 1; d < f.length - 1; d++)
790
+ f[d] = " " + f[d] + m;
791
+ let y = `,
792
+ `;
793
+ for (let d = 2; d < c; d++)
794
+ y += `
795
+ `;
796
+ return f[f.length - 1] = " " + f[f.length - 1] + "]" + (i ? "" : y), f;
797
+ }
798
+ function it(n) {
799
+ const t = [];
800
+ for (let e = 0; e < n.length; e += 2)
801
+ t.push([n[e], n[e + 1]]);
802
+ return t;
803
+ }
804
+ /**
805
+ * @license
806
+ * Copyright 2017 Google LLC. All Rights Reserved.
807
+ * Licensed under the Apache License, Version 2.0 (the "License");
808
+ * you may not use this file except in compliance with the License.
809
+ * You may obtain a copy of the License at
810
+ *
811
+ * http://www.apache.org/licenses/LICENSE-2.0
812
+ *
813
+ * Unless required by applicable law or agreed to in writing, software
814
+ * distributed under the License is distributed on an "AS IS" BASIS,
815
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
816
+ * See the License for the specific language governing permissions and
817
+ * limitations under the License.
818
+ * =============================================================================
819
+ */
820
+ class en {
821
+ constructor(t, e, s) {
822
+ if (this.dtype = e, this.shape = t.slice(), this.size = U(t), s != null) {
823
+ const r = s.length;
824
+ b(r === this.size, () => `Length of values '${r}' does not match the size inferred by the shape '${this.size}'.`);
825
+ }
826
+ if (e === "complex64")
827
+ throw new Error("complex64 dtype TensorBuffers are not supported. Please create a TensorBuffer for the real and imaginary parts separately and call tf.complex(real, imag).");
828
+ this.values = s || Se(e, this.size), this.strides = Ft(t);
829
+ }
830
+ /**
831
+ * Sets a value in the buffer at a given location.
832
+ *
833
+ * @param value The value to set.
834
+ * @param locs The location indices.
835
+ *
836
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
837
+ */
838
+ set(t, ...e) {
839
+ e.length === 0 && (e = [0]), b(e.length === this.rank, () => `The number of provided coordinates (${e.length}) must match the rank (${this.rank})`);
840
+ const s = this.locToIndex(e);
841
+ this.values[s] = t;
842
+ }
843
+ /**
844
+ * Returns the value in the buffer at the provided location.
845
+ *
846
+ * @param locs The location indices.
847
+ *
848
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
849
+ */
850
+ get(...t) {
851
+ t.length === 0 && (t = [0]);
852
+ let e = 0;
853
+ for (const r of t) {
854
+ if (r < 0 || r >= this.shape[e]) {
855
+ const i = `Requested out of range element at ${t}. Buffer shape=${this.shape}`;
856
+ throw new Error(i);
857
+ }
858
+ e++;
859
+ }
860
+ let s = t[t.length - 1];
861
+ for (let r = 0; r < t.length - 1; ++r)
862
+ s += this.strides[r] * t[r];
863
+ return this.values[s];
864
+ }
865
+ locToIndex(t) {
866
+ if (this.rank === 0)
867
+ return 0;
868
+ if (this.rank === 1)
869
+ return t[0];
870
+ let e = t[t.length - 1];
871
+ for (let s = 0; s < t.length - 1; ++s)
872
+ e += this.strides[s] * t[s];
873
+ return e;
874
+ }
875
+ indexToLoc(t) {
876
+ if (this.rank === 0)
877
+ return [];
878
+ if (this.rank === 1)
879
+ return [t];
880
+ const e = new Array(this.shape.length);
881
+ for (let s = 0; s < e.length - 1; ++s)
882
+ e[s] = Math.floor(t / this.strides[s]), t -= e[s] * this.strides[s];
883
+ return e[e.length - 1] = t, e;
884
+ }
885
+ get rank() {
886
+ return this.shape.length;
887
+ }
888
+ /**
889
+ * Creates an immutable `tf.Tensor` object from the buffer.
890
+ *
891
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
892
+ */
893
+ toTensor() {
894
+ return x().makeTensor(this.values, this.shape, this.dtype);
895
+ }
896
+ }
897
+ let x = null, J = null;
898
+ function nn(n) {
899
+ x = n;
900
+ }
901
+ function sn(n) {
902
+ J = n;
903
+ }
904
+ class N {
905
+ constructor(t, e, s, r) {
906
+ this.kept = !1, this.isDisposedInternal = !1, this.shape = t.slice(), this.dtype = e || "float32", this.size = U(t), this.strides = Ft(t), this.dataId = s, this.id = r, this.rankType = this.rank < 5 ? this.rank.toString() : "higher";
907
+ }
908
+ get rank() {
909
+ return this.shape.length;
910
+ }
911
+ /**
912
+ * Returns a promise of `tf.TensorBuffer` that holds the underlying data.
913
+ *
914
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
915
+ */
916
+ async buffer() {
917
+ const t = await this.data();
918
+ return J.buffer(this.shape, this.dtype, t);
919
+ }
920
+ /**
921
+ * Returns a `tf.TensorBuffer` that holds the underlying data.
922
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
923
+ */
924
+ bufferSync() {
925
+ return J.buffer(this.shape, this.dtype, this.dataSync());
926
+ }
927
+ /**
928
+ * Returns the tensor data as a nested array. The transfer of data is done
929
+ * asynchronously.
930
+ *
931
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
932
+ */
933
+ async array() {
934
+ const t = await this.data();
935
+ return Dt(this.shape, t, this.dtype === "complex64");
936
+ }
937
+ /**
938
+ * Returns the tensor data as a nested array. The transfer of data is done
939
+ * synchronously.
940
+ *
941
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
942
+ */
943
+ arraySync() {
944
+ return Dt(this.shape, this.dataSync(), this.dtype === "complex64");
945
+ }
946
+ /**
947
+ * Asynchronously downloads the values from the `tf.Tensor`. Returns a
948
+ * promise of `TypedArray` that resolves when the computation has finished.
949
+ *
950
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
951
+ */
952
+ async data() {
953
+ this.throwIfDisposed();
954
+ const t = x().read(this.dataId);
955
+ if (this.dtype === "string") {
956
+ const e = await t;
957
+ try {
958
+ return e.map((s) => Lt(s));
959
+ } catch {
960
+ throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
961
+ }
962
+ }
963
+ return t;
964
+ }
965
+ /**
966
+ * Copy the tensor's data to a new GPU resource. Comparing to the `dataSync()`
967
+ * and `data()`, this method prevents data from being downloaded to CPU.
968
+ *
969
+ * For WebGL backend, the data will be stored on a densely packed texture.
970
+ * This means that the texture will use the RGBA channels to store value.
971
+ *
972
+ * For WebGPU backend, the data will be stored on a buffer. There is no
973
+ * parameter, so can not use a user-defined size to create the buffer.
974
+ *
975
+ * @param options:
976
+ * For WebGL,
977
+ * - customTexShape: Optional. If set, will use the user defined
978
+ * texture shape to create the texture.
979
+ *
980
+ * @returns For WebGL backend, a GPUData contains the new texture and
981
+ * its information.
982
+ * {
983
+ * tensorRef: The tensor that is associated with this texture,
984
+ * texture: WebGLTexture,
985
+ * texShape: [number, number] // [height, width]
986
+ * }
987
+ *
988
+ * For WebGPU backend, a GPUData contains the new buffer.
989
+ * {
990
+ * tensorRef: The tensor that is associated with this buffer,
991
+ * buffer: GPUBuffer,
992
+ * }
993
+ *
994
+ * Remember to dispose the GPUData after it is used by
995
+ * `res.tensorRef.dispose()`.
996
+ *
997
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
998
+ */
999
+ dataToGPU(t) {
1000
+ return this.throwIfDisposed(), x().readToGPU(this.dataId, t);
1001
+ }
1002
+ /**
1003
+ * Synchronously downloads the values from the `tf.Tensor`. This blocks the
1004
+ * UI thread until the values are ready, which can cause performance issues.
1005
+ *
1006
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
1007
+ */
1008
+ dataSync() {
1009
+ this.throwIfDisposed();
1010
+ const t = x().readSync(this.dataId);
1011
+ if (this.dtype === "string")
1012
+ try {
1013
+ return t.map((e) => Lt(e));
1014
+ } catch {
1015
+ throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().");
1016
+ }
1017
+ return t;
1018
+ }
1019
+ /** Returns the underlying bytes of the tensor's data. */
1020
+ async bytes() {
1021
+ this.throwIfDisposed();
1022
+ const t = await x().read(this.dataId);
1023
+ return this.dtype === "string" ? t : new Uint8Array(t.buffer);
1024
+ }
1025
+ /**
1026
+ * Disposes `tf.Tensor` from memory.
1027
+ *
1028
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
1029
+ */
1030
+ dispose() {
1031
+ this.isDisposed || (this.kerasMask && this.kerasMask.dispose(), x().disposeTensor(this), this.isDisposedInternal = !0);
1032
+ }
1033
+ get isDisposed() {
1034
+ return this.isDisposedInternal;
1035
+ }
1036
+ throwIfDisposed() {
1037
+ if (this.isDisposed)
1038
+ throw new Error("Tensor is disposed.");
1039
+ }
1040
+ /**
1041
+ * Prints the `tf.Tensor`. See `tf.print` for details.
1042
+ *
1043
+ * @param verbose Whether to print verbose information about the tensor,
1044
+ * including dtype and size.
1045
+ *
1046
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
1047
+ */
1048
+ print(t = !1) {
1049
+ return J.print(this, t);
1050
+ }
1051
+ /**
1052
+ * Returns a copy of the tensor. See `tf.clone` for details.
1053
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
1054
+ */
1055
+ clone() {
1056
+ return this.throwIfDisposed(), J.clone(this);
1057
+ }
1058
+ /**
1059
+ * Returns a human-readable description of the tensor. Useful for logging.
1060
+ *
1061
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
1062
+ */
1063
+ toString(t = !1) {
1064
+ const e = this.dataSync();
1065
+ return Ze(e, this.shape, this.dtype, t);
1066
+ }
1067
+ cast(t) {
1068
+ return this.throwIfDisposed(), J.cast(this, t);
1069
+ }
1070
+ variable(t = !0, e, s) {
1071
+ return this.throwIfDisposed(), x().makeVariable(this, t, e, s);
1072
+ }
1073
+ }
1074
+ Object.defineProperty(N, Symbol.hasInstance, {
1075
+ value: (n) => !!n && n.data != null && n.dataSync != null && n.throwIfDisposed != null
1076
+ });
1077
+ function re() {
1078
+ return Nt("Tensor", () => N);
1079
+ }
1080
+ re();
1081
+ class ht extends N {
1082
+ constructor(t, e, s, r) {
1083
+ super(t.shape, t.dtype, t.dataId, r), this.trainable = e, this.name = s;
1084
+ }
1085
+ /**
1086
+ * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
1087
+ * the same shape and dtype as the old `tf.Tensor`.
1088
+ *
1089
+ * @param newValue New tensor to be assigned to this variable.
1090
+ *
1091
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
1092
+ */
1093
+ assign(t) {
1094
+ if (t.dtype !== this.dtype)
1095
+ throw new Error(`dtype of the new value (${t.dtype}) and previous value (${this.dtype}) must match`);
1096
+ if (!vt(t.shape, this.shape))
1097
+ throw new Error(`shape of the new value (${t.shape}) and previous value (${this.shape}) must match`);
1098
+ x().disposeTensor(this), this.dataId = t.dataId, x().incRef(
1099
+ this,
1100
+ null
1101
+ /* backend */
1102
+ );
1103
+ }
1104
+ dispose() {
1105
+ x().disposeVariable(this), this.isDisposedInternal = !0;
1106
+ }
1107
+ }
1108
+ Object.defineProperty(ht, Symbol.hasInstance, {
1109
+ value: (n) => n instanceof N && n.assign != null && n.assign instanceof Function
1110
+ });
1111
+ /**
1112
+ * @license
1113
+ * Copyright 2017 Google LLC. All Rights Reserved.
1114
+ * Licensed under the Apache License, Version 2.0 (the "License");
1115
+ * you may not use this file except in compliance with the License.
1116
+ * You may obtain a copy of the License at
1117
+ *
1118
+ * http://www.apache.org/licenses/LICENSE-2.0
1119
+ *
1120
+ * Unless required by applicable law or agreed to in writing, software
1121
+ * distributed under the License is distributed on an "AS IS" BASIS,
1122
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1123
+ * See the License for the specific language governing permissions and
1124
+ * limitations under the License.
1125
+ * =============================================================================
1126
+ */
1127
+ var Gt;
1128
+ (function(n) {
1129
+ n.R0 = "R0", n.R1 = "R1", n.R2 = "R2", n.R3 = "R3", n.R4 = "R4", n.R5 = "R5", n.R6 = "R6";
1130
+ })(Gt || (Gt = {}));
1131
+ var wt;
1132
+ (function(n) {
1133
+ n.float32 = "float32", n.int32 = "int32", n.bool = "int32", n.complex64 = "complex64";
1134
+ })(wt || (wt = {}));
1135
+ var St;
1136
+ (function(n) {
1137
+ n.float32 = "float32", n.int32 = "int32", n.bool = "bool", n.complex64 = "complex64";
1138
+ })(St || (St = {}));
1139
+ var It;
1140
+ (function(n) {
1141
+ n.float32 = "float32", n.int32 = "float32", n.bool = "float32", n.complex64 = "complex64";
1142
+ })(It || (It = {}));
1143
+ var kt;
1144
+ (function(n) {
1145
+ n.float32 = "complex64", n.int32 = "complex64", n.bool = "complex64", n.complex64 = "complex64";
1146
+ })(kt || (kt = {}));
1147
+ const rn = {
1148
+ float32: It,
1149
+ int32: wt,
1150
+ bool: St,
1151
+ complex64: kt
1152
+ };
1153
+ function on(n, t) {
1154
+ if (n === "string" || t === "string") {
1155
+ if (n === "string" && t === "string")
1156
+ return "string";
1157
+ throw new Error(`Can not upcast ${n} with ${t}`);
1158
+ }
1159
+ return rn[n][t];
1160
+ }
1161
+ function ie(n) {
1162
+ return n != null && typeof n == "object" && "texture" in n && n.texture instanceof WebGLTexture;
1163
+ }
1164
+ function oe(n) {
1165
+ return typeof GPUBuffer < "u" && n != null && typeof n == "object" && "buffer" in n && n.buffer instanceof GPUBuffer;
1166
+ }
1167
+ /**
1168
+ * @license
1169
+ * Copyright 2018 Google LLC. All Rights Reserved.
1170
+ * Licensed under the Apache License, Version 2.0 (the "License");
1171
+ * you may not use this file except in compliance with the License.
1172
+ * You may obtain a copy of the License at
1173
+ *
1174
+ * http://www.apache.org/licenses/LICENSE-2.0
1175
+ *
1176
+ * Unless required by applicable law or agreed to in writing, software
1177
+ * distributed under the License is distributed on an "AS IS" BASIS,
1178
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1179
+ * See the License for the specific language governing permissions and
1180
+ * limitations under the License.
1181
+ * =============================================================================
1182
+ */
1183
+ function K(n, t) {
1184
+ if (n.dtype === t.dtype)
1185
+ return [n, t];
1186
+ const e = on(n.dtype, t.dtype);
1187
+ return [n.cast(e), t.cast(e)];
1188
+ }
1189
+ function ae(n) {
1190
+ const t = [];
1191
+ return le(n, t, /* @__PURE__ */ new Set()), t;
1192
+ }
1193
+ function le(n, t, e) {
1194
+ if (n == null)
1195
+ return;
1196
+ if (n instanceof N) {
1197
+ t.push(n);
1198
+ return;
1199
+ }
1200
+ if (!an(n))
1201
+ return;
1202
+ const s = n;
1203
+ for (const r in s) {
1204
+ const i = s[r];
1205
+ e.has(i) || (e.add(i), le(i, t, e));
1206
+ }
1207
+ }
1208
+ function an(n) {
1209
+ return Array.isArray(n) || typeof n == "object";
1210
+ }
1211
+ /**
1212
+ * @license
1213
+ * Copyright 2018 Google LLC. All Rights Reserved.
1214
+ * Licensed under the Apache License, Version 2.0 (the "License");
1215
+ * you may not use this file except in compliance with the License.
1216
+ * You may obtain a copy of the License at
1217
+ *
1218
+ * http://www.apache.org/licenses/LICENSE-2.0
1219
+ *
1220
+ * Unless required by applicable law or agreed to in writing, software
1221
+ * distributed under the License is distributed on an "AS IS" BASIS,
1222
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1223
+ * See the License for the specific language governing permissions and
1224
+ * limitations under the License.
1225
+ * =============================================================================
1226
+ */
1227
+ function mt(n) {
1228
+ return n.kernelName != null;
1229
+ }
1230
+ class zt {
1231
+ constructor() {
1232
+ this.registeredVariables = {}, this.nextTapeNodeId = 0, this.numBytes = 0, this.numTensors = 0, this.numStringTensors = 0, this.numDataBuffers = 0, this.gradientDepth = 0, this.kernelDepth = 0, this.scopeStack = [], this.numDataMovesStack = [], this.nextScopeId = 0, this.tensorInfo = /* @__PURE__ */ new WeakMap(), this.profiling = !1, this.activeProfile = {
1233
+ newBytes: 0,
1234
+ newTensors: 0,
1235
+ peakBytes: 0,
1236
+ kernels: [],
1237
+ result: null,
1238
+ get kernelNames() {
1239
+ return Array.from(new Set(this.kernels.map((t) => t.name)));
1240
+ }
1241
+ };
1242
+ }
1243
+ dispose() {
1244
+ for (const t in this.registeredVariables)
1245
+ this.registeredVariables[t].dispose();
1246
+ }
1247
+ }
1248
+ class Q {
1249
+ constructor(t) {
1250
+ this.ENV = t, this.registry = {}, this.registryFactory = {}, this.pendingBackendInitId = 0, this.state = new zt();
1251
+ }
1252
+ async ready() {
1253
+ if (this.pendingBackendInit != null)
1254
+ return this.pendingBackendInit.then(() => {
1255
+ });
1256
+ if (this.backendInstance != null)
1257
+ return;
1258
+ const t = this.getSortedBackends();
1259
+ for (let e = 0; e < t.length; e++) {
1260
+ const s = t[e];
1261
+ if (await this.initializeBackend(s).success) {
1262
+ await this.setBackend(s);
1263
+ return;
1264
+ }
1265
+ }
1266
+ throw new Error("Could not initialize any backends, all backend initializations failed.");
1267
+ }
1268
+ get backend() {
1269
+ if (this.pendingBackendInit != null)
1270
+ throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`);
1271
+ if (this.backendInstance == null) {
1272
+ const { name: t, asyncInit: e } = this.initializeBackendsAndReturnBest();
1273
+ if (e)
1274
+ throw new Error(`The highest priority backend '${t}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`);
1275
+ this.setBackend(t);
1276
+ }
1277
+ return this.backendInstance;
1278
+ }
1279
+ backendNames() {
1280
+ return Object.keys(this.registryFactory);
1281
+ }
1282
+ findBackend(t) {
1283
+ if (!(t in this.registry))
1284
+ if (t in this.registryFactory) {
1285
+ const { asyncInit: e } = this.initializeBackend(t);
1286
+ if (e)
1287
+ return null;
1288
+ } else
1289
+ return null;
1290
+ return this.registry[t];
1291
+ }
1292
+ findBackendFactory(t) {
1293
+ return t in this.registryFactory ? this.registryFactory[t].factory : null;
1294
+ }
1295
+ registerBackend(t, e, s = 1) {
1296
+ return t in this.registryFactory ? (nt(`${t} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[t] = { factory: e, priority: s }, !0);
1297
+ }
1298
+ async setBackend(t) {
1299
+ if (this.registryFactory[t] == null)
1300
+ throw new Error(`Backend name '${t}' not found in registry`);
1301
+ if (this.backendName = t, this.registry[t] == null) {
1302
+ this.backendInstance = null;
1303
+ const { success: e, asyncInit: s } = this.initializeBackend(t);
1304
+ if (!(s ? await e : e))
1305
+ return !1;
1306
+ }
1307
+ return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new He(this.backendInstance), !0;
1308
+ }
1309
+ setupRegisteredKernels() {
1310
+ Pt(this.backendName).forEach((e) => {
1311
+ e.setupFunc != null && e.setupFunc(this.backendInstance);
1312
+ });
1313
+ }
1314
+ disposeRegisteredKernels(t) {
1315
+ Pt(t).forEach((s) => {
1316
+ s.disposeFunc != null && s.disposeFunc(this.registry[t]);
1317
+ });
1318
+ }
1319
+ /**
1320
+ * Initializes a backend by looking up the backend name in the factory
1321
+ * registry and calling the factory method. Returns a boolean representing
1322
+ * whether the initialization of the backend succeeded. Throws an error if
1323
+ * there is no backend in the factory registry.
1324
+ */
1325
+ initializeBackend(t) {
1326
+ const e = this.registryFactory[t];
1327
+ if (e == null)
1328
+ throw new Error(`Cannot initialize backend ${t}, no registration found.`);
1329
+ try {
1330
+ const s = e.factory();
1331
+ if (s && !(s instanceof we) && typeof s.then == "function") {
1332
+ const r = ++this.pendingBackendInitId, i = s.then((o) => r < this.pendingBackendInitId ? !1 : (this.registry[t] = o, this.pendingBackendInit = null, !0)).catch((o) => (r < this.pendingBackendInitId || (this.pendingBackendInit = null, nt(`Initialization of backend ${t} failed`), nt(o.stack || o.message)), !1));
1333
+ return this.pendingBackendInit = i, { success: i, asyncInit: !0 };
1334
+ } else
1335
+ return this.registry[t] = s, { success: !0, asyncInit: !1 };
1336
+ } catch (s) {
1337
+ return nt(`Initialization of backend ${t} failed`), nt(s.stack || s.message), { success: !1, asyncInit: !1 };
1338
+ }
1339
+ }
1340
+ removeBackend(t) {
1341
+ if (!(t in this.registryFactory))
1342
+ throw new Error(`${t} backend not found in registry`);
1343
+ this.backendName === t && this.pendingBackendInit != null && this.pendingBackendInitId++, t in this.registry && (this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t]), delete this.registryFactory[t], this.backendName === t && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null);
1344
+ }
1345
+ getSortedBackends() {
1346
+ if (Object.keys(this.registryFactory).length === 0)
1347
+ throw new Error("No backend found in registry.");
1348
+ return Object.keys(this.registryFactory).sort((t, e) => this.registryFactory[e].priority - this.registryFactory[t].priority);
1349
+ }
1350
+ initializeBackendsAndReturnBest() {
1351
+ const t = this.getSortedBackends();
1352
+ for (let e = 0; e < t.length; e++) {
1353
+ const s = t[e], { success: r, asyncInit: i } = this.initializeBackend(s);
1354
+ if (i || r)
1355
+ return { name: s, asyncInit: i };
1356
+ }
1357
+ throw new Error("Could not initialize any backends, all backend initializations failed.");
1358
+ }
1359
+ moveData(t, e) {
1360
+ const s = this.state.tensorInfo.get(e), r = s.backend, i = this.readSync(e), o = r.refCount(e);
1361
+ r.disposeData(e, !0), s.backend = t, t.move(e, i, s.shape, s.dtype, o), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
1362
+ }
1363
+ tidy(t, e) {
1364
+ let s = null;
1365
+ if (e == null) {
1366
+ if (typeof t != "function")
1367
+ throw new Error("Please provide a function to tidy()");
1368
+ e = t;
1369
+ } else {
1370
+ if (typeof t != "string" && !(t instanceof String))
1371
+ throw new Error("When calling with two arguments, the first argument to tidy() must be a string");
1372
+ if (typeof e != "function")
1373
+ throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function");
1374
+ s = t;
1375
+ }
1376
+ let r;
1377
+ return this.scopedRun(() => this.startScope(s), () => this.endScope(r), () => (r = e(), r instanceof Promise && console.error("Cannot return a Promise inside of tidy."), r));
1378
+ }
1379
+ scopedRun(t, e, s) {
1380
+ t();
1381
+ try {
1382
+ const r = s();
1383
+ return e(), r;
1384
+ } catch (r) {
1385
+ throw e(), r;
1386
+ }
1387
+ }
1388
+ nextTensorId() {
1389
+ return Q.nextTensorId++;
1390
+ }
1391
+ nextVariableId() {
1392
+ return Q.nextVariableId++;
1393
+ }
1394
+ /**
1395
+ * This method is called instead of the public-facing tensor.clone() when
1396
+ * saving a tensor for backwards pass. It makes sure to add the clone
1397
+ * operation to the tape regardless of being called inside a kernel
1398
+ * execution.
1399
+ */
1400
+ clone(t) {
1401
+ const e = g.runKernel(Zt, { x: t }), s = { x: t }, r = (o) => ({
1402
+ x: () => {
1403
+ const a = "float32", c = { x: o }, l = { dtype: a };
1404
+ return g.runKernel(
1405
+ Qt,
1406
+ c,
1407
+ // tslint:disable-next-line: no-unnecessary-type-assertion
1408
+ l
1409
+ );
1410
+ }
1411
+ }), i = [];
1412
+ return this.addTapeNode(this.state.activeScope.name, s, [e], r, i, {}), e;
1413
+ }
1414
+ /**
1415
+ * Execute a kernel with the given name and return the output tensor.
1416
+ *
1417
+ * @param kernelName The name of the kernel to execute.
1418
+ * @param inputs A map of input names to tensors.
1419
+ * @param attrs A map of attribute names to their values. An attribute is a
1420
+ * primitive (non-tensor) input to the kernel.
1421
+ * @param inputsToSave A list of tensors, inputs to save for the backprop
1422
+ * computation.
1423
+ * @param outputsToSave A list of booleans, specifying which output to save
1424
+ * for the backprop computation. These are booleans since the output
1425
+ * tensors are not visible to the user.
1426
+ */
1427
+ runKernel(t, e, s) {
1428
+ if (this.backendName == null && this.backend, !(_t(t, this.backendName) != null))
1429
+ throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);
1430
+ return this.runKernelFunc({ kernelName: t, inputs: e, attrs: s });
1431
+ }
1432
+ shouldCheckForMemLeaks() {
1433
+ return this.ENV.getBool("IS_TEST");
1434
+ }
1435
+ checkKernelForMemLeak(t, e, s) {
1436
+ const r = this.backend.numDataIds();
1437
+ let i = 0;
1438
+ s.forEach((c) => {
1439
+ i += c.dtype === "complex64" ? 3 : 1;
1440
+ });
1441
+ const o = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], a = r - e - i - o;
1442
+ if (a > 0)
1443
+ throw new Error(`Backend '${this.backendName}' has an internal memory leak (${a} data ids) after running '${t}'`);
1444
+ }
1445
+ /**
1446
+ * Internal helper method to execute a kernel Func
1447
+ *
1448
+ * Use `runKernel` to execute kernels from outside of engine.
1449
+ */
1450
+ runKernelFunc(t) {
1451
+ let e, s = [];
1452
+ const r = this.isTapeOn(), i = this.state.numBytes, o = this.state.numTensors;
1453
+ this.shouldCheckForMemLeaks() && this.state.numDataMovesStack.push(0);
1454
+ let a;
1455
+ this.backendName == null && this.backend;
1456
+ let c;
1457
+ const l = mt(t) ? t.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
1458
+ if (mt(t)) {
1459
+ const { kernelName: y, inputs: d, attrs: I } = t;
1460
+ this.backendName == null && this.backend;
1461
+ const T = _t(y, this.backendName);
1462
+ b(T != null, () => `Cannot find registered kernel '${y}' for backend '${this.backendName}'`), a = () => {
1463
+ const tt = this.backend.numDataIds();
1464
+ c = T.kernelFunc({ inputs: d, attrs: I, backend: this.backend });
1465
+ const q = Array.isArray(c) ? c : [c];
1466
+ this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(y, tt, q);
1467
+ const H = q.map((et) => et.rank != null ? et : this.makeTensorFromTensorInfo(et));
1468
+ if (r) {
1469
+ const et = this.getTensorsForGradient(y, d, H);
1470
+ s = this.saveTensorsForBackwardMode(et);
1471
+ }
1472
+ return H;
1473
+ };
1474
+ } else {
1475
+ const { forwardFunc: y } = t, d = (I) => {
1476
+ r && (s = I.map((T) => this.keep(this.clone(T))));
1477
+ };
1478
+ a = () => {
1479
+ const I = this.backend.numDataIds();
1480
+ c = this.tidy(() => y(this.backend, d));
1481
+ const T = Array.isArray(c) ? c : [c];
1482
+ return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(l, I, T), T;
1483
+ };
1484
+ }
1485
+ const { inputs: u, attrs: h } = t, f = mt(t) ? null : t.backwardsFunc;
1486
+ let m;
1487
+ return this.scopedRun(
1488
+ // Stop recording to a tape when running a kernel.
1489
+ () => this.state.kernelDepth++,
1490
+ () => this.state.kernelDepth--,
1491
+ () => {
1492
+ !this.ENV.getBool("DEBUG") && !this.state.profiling ? e = a() : (m = this.profiler.profileKernel(l, u, () => a()), this.ENV.getBool("DEBUG") && this.profiler.logKernelProfile(m), e = m.outputs);
1493
+ }
1494
+ ), r && this.addTapeNode(l, u, e, f, s, h), this.state.profiling && this.state.activeProfile.kernels.push({
1495
+ name: l,
1496
+ bytesAdded: this.state.numBytes - i,
1497
+ totalBytesSnapshot: this.state.numBytes,
1498
+ tensorsAdded: this.state.numTensors - o,
1499
+ totalTensorsSnapshot: this.state.numTensors,
1500
+ inputShapes: Object.keys(u).map((y) => u[y] != null ? u[y].shape : null),
1501
+ outputShapes: e.map((y) => y.shape),
1502
+ kernelTimeMs: m.timeMs,
1503
+ extraInfo: m.extraInfo
1504
+ }), Array.isArray(c) ? e : e[0];
1505
+ }
1506
+ /**
1507
+ * Saves tensors used in forward mode for use in backward mode.
1508
+ *
1509
+ * @param tensors the list of tensors to save.
1510
+ */
1511
+ saveTensorsForBackwardMode(t) {
1512
+ return t.map((s) => this.keep(this.clone(s)));
1513
+ }
1514
+ /**
1515
+ * Returns a list of tensors to save for a given gradient calculation.
1516
+ *
1517
+ * @param kernelName name of kernel to look up gradient for.
1518
+ * @param inputs a map of input tensors.
1519
+ * @param outputs an array of output tensors from forward mode of kernel.
1520
+ */
1521
+ getTensorsForGradient(t, e, s) {
1522
+ const r = Ot(t);
1523
+ if (r != null) {
1524
+ const i = r.inputsToSave || [], o = r.outputsToSave || [];
1525
+ let a;
1526
+ r.saveAllInputs ? (b(Array.isArray(e), () => "saveAllInputs is true, expected inputs to be an array."), a = Object.keys(e).map((l) => e[l])) : a = i.map((l) => e[l]);
1527
+ const c = s.filter((l, u) => o[u]);
1528
+ return a.concat(c);
1529
+ }
1530
+ return [];
1531
+ }
1532
+ /**
1533
+ * Internal method used by public APIs for tensor creation. Makes a new
1534
+ * tensor with the provided shape, dtype and values. It always
1535
+ * creates a new data id and writes the values to the underlying backend.
1536
+ */
1537
+ makeTensor(t, e, s, r) {
1538
+ if (t == null)
1539
+ throw new Error("Values passed to engine.makeTensor() are null");
1540
+ s = s || "float32", r = r || this.backend;
1541
+ let i = t;
1542
+ s === "string" && Mt(t[0]) && (i = t.map((c) => qe(c)));
1543
+ const o = r.write(i, e, s), a = new N(e, s, o, this.nextTensorId());
1544
+ if (this.trackTensor(a, r), s === "string") {
1545
+ const c = this.state.tensorInfo.get(o), l = Te(i);
1546
+ this.state.numBytes += l - c.bytes, c.bytes = l;
1547
+ }
1548
+ return a;
1549
+ }
1550
+ /**
1551
+ * Internal method used by backends. Makes a new tensor
1552
+ * that is a wrapper around an existing data id. It doesn't create
1553
+ * a new data id, only increments the ref count used in memory tracking.
1554
+ * @deprecated
1555
+ */
1556
+ makeTensorFromDataId(t, e, s, r) {
1557
+ s = s || "float32";
1558
+ const i = { dataId: t, shape: e, dtype: s };
1559
+ return this.makeTensorFromTensorInfo(i, r);
1560
+ }
1561
+ /**
1562
+ * Internal method used by backends. Makes a new tensor that is a wrapper
1563
+ * around an existing data id in TensorInfo. It doesn't create a new data id,
1564
+ * only increments the ref count used in memory tracking.
1565
+ */
1566
+ makeTensorFromTensorInfo(t, e) {
1567
+ const { dataId: s, shape: r, dtype: i } = t, o = new N(r, i, s, this.nextTensorId());
1568
+ return this.trackTensor(o, e), o;
1569
+ }
1570
+ makeVariable(t, e = !0, s, r) {
1571
+ s = s || this.nextVariableId().toString(), r != null && r !== t.dtype && (t = t.cast(r));
1572
+ const i = new ht(t, e, s, this.nextTensorId());
1573
+ if (this.state.registeredVariables[i.name] != null)
1574
+ throw new Error(`Variable with name ${i.name} was already registered`);
1575
+ return this.state.registeredVariables[i.name] = i, this.incRef(i, this.backend), i;
1576
+ }
1577
+ trackTensor(t, e) {
1578
+ this.state.numTensors++, t.dtype === "string" && this.state.numStringTensors++;
1579
+ let s = 0;
1580
+ t.dtype !== "complex64" && t.dtype !== "string" && (s = t.size * yt(t.dtype)), this.state.numBytes += s, this.state.tensorInfo.has(t.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(t.dataId, {
1581
+ backend: e || this.backend,
1582
+ dtype: t.dtype,
1583
+ shape: t.shape,
1584
+ bytes: s
1585
+ })), t instanceof ht || this.track(t);
1586
+ }
1587
+ // Track the tensor by dataId and increase the refCount for the dataId in the
1588
+ // backend.
1589
+ // TODO(pyu10055): This is currently used by makeVariable method, to increase
1590
+ // refCount on the backend for the dataId. It can potentially be replaced with
1591
+ // Identity op indead of calling backend directly.
1592
+ incRef(t, e) {
1593
+ this.trackTensor(t, e), this.backend.incRef(t.dataId);
1594
+ }
1595
+ removeDataId(t, e) {
1596
+ this.state.tensorInfo.has(t) && this.state.tensorInfo.get(t).backend === e && (this.state.tensorInfo.delete(t), this.state.numDataBuffers--);
1597
+ }
1598
+ disposeTensor(t) {
1599
+ if (!this.state.tensorInfo.has(t.dataId))
1600
+ return;
1601
+ const e = this.state.tensorInfo.get(t.dataId);
1602
+ if (this.state.numTensors--, t.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= e.bytes), t.dtype !== "complex64" && t.dtype !== "string") {
1603
+ const s = t.size * yt(t.dtype);
1604
+ this.state.numBytes -= s;
1605
+ }
1606
+ e.backend.disposeData(t.dataId) && this.removeDataId(t.dataId, e.backend);
1607
+ }
1608
+ disposeVariables() {
1609
+ for (const t in this.state.registeredVariables) {
1610
+ const e = this.state.registeredVariables[t];
1611
+ this.disposeVariable(e);
1612
+ }
1613
+ }
1614
+ disposeVariable(t) {
1615
+ this.disposeTensor(t), this.state.registeredVariables[t.name] != null && delete this.state.registeredVariables[t.name];
1616
+ }
1617
+ memory() {
1618
+ const t = this.backend.memory();
1619
+ return t.numTensors = this.state.numTensors, t.numDataBuffers = this.state.numDataBuffers, t.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (t.unreliable = !0, t.reasons == null && (t.reasons = []), t.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), t;
1620
+ }
1621
+ async profile(t) {
1622
+ this.state.profiling = !0;
1623
+ const e = this.state.numBytes, s = this.state.numTensors;
1624
+ this.state.activeProfile.kernels = [], this.state.activeProfile.result = await t(), this.state.profiling = !1, this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map((r) => r.totalBytesSnapshot)), this.state.activeProfile.newBytes = this.state.numBytes - e, this.state.activeProfile.newTensors = this.state.numTensors - s;
1625
+ for (const r of this.state.activeProfile.kernels)
1626
+ r.kernelTimeMs = await r.kernelTimeMs, r.extraInfo = await r.extraInfo;
1627
+ return this.state.activeProfile;
1628
+ }
1629
+ isTapeOn() {
1630
+ return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
1631
+ }
1632
+ addTapeNode(t, e, s, r, i, o) {
1633
+ const a = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: i }, c = Ot(t);
1634
+ c != null && (r = c.gradFunc), r != null && (a.gradient = (l) => (l = l.map((u, h) => {
1635
+ if (u == null) {
1636
+ const f = s[h], m = Ht(f.size, f.dtype);
1637
+ return this.makeTensor(m, f.shape, f.dtype);
1638
+ }
1639
+ return u;
1640
+ }), r(l.length > 1 ? l : l[0], i, o))), this.state.activeTape.push(a);
1641
+ }
1642
+ keep(t) {
1643
+ return t.kept = !0, t;
1644
+ }
1645
+ startTape() {
1646
+ this.state.gradientDepth === 0 && (this.state.activeTape = []), this.state.gradientDepth++;
1647
+ }
1648
+ endTape() {
1649
+ this.state.gradientDepth--;
1650
+ }
1651
+ /**
1652
+ * Start a scope. Use this with endScope() to achieve the same functionality
1653
+ * as scope() without the need for a function closure.
1654
+ */
1655
+ startScope(t) {
1656
+ const e = {
1657
+ track: [],
1658
+ name: "unnamed scope",
1659
+ id: this.state.nextScopeId++
1660
+ };
1661
+ t && (e.name = t), this.state.scopeStack.push(e), this.state.activeScope = e;
1662
+ }
1663
+ /**
1664
+ * End a scope. Use this with startScope() to achieve the same functionality
1665
+ * as scope() without the need for a function closure.
1666
+ */
1667
+ endScope(t) {
1668
+ const e = ae(t), s = new Set(e.map((i) => i.id));
1669
+ for (let i = 0; i < this.state.activeScope.track.length; i++) {
1670
+ const o = this.state.activeScope.track[i];
1671
+ !o.kept && !s.has(o.id) && o.dispose();
1672
+ }
1673
+ const r = this.state.scopeStack.pop();
1674
+ this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], e.forEach((i) => {
1675
+ !i.kept && i.scopeId === r.id && this.track(i);
1676
+ });
1677
+ }
1678
+ /**
1679
+ * Returns gradients of `f` with respect to each of the `xs`. The gradients
1680
+ * returned are of the same length as `xs`, but some might be null if `f`
1681
+ * was not a function of that `x`. It also takes optional dy to multiply the
1682
+ * gradient, which defaults to `1`.
1683
+ */
1684
+ gradients(t, e, s, r = !1) {
1685
+ if (b(e.length > 0, () => "gradients() received an empty list of xs."), s != null && s.dtype !== "float32")
1686
+ throw new Error(`dy must have 'float32' dtype, but has '${s.dtype}'`);
1687
+ const i = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy("forward", t));
1688
+ b(i instanceof N, () => "The result y returned by f() must be a tensor.");
1689
+ const o = Ye(this.state.activeTape, e, i);
1690
+ if (!r && o.length === 0 && e.length > 0)
1691
+ throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.");
1692
+ return this.tidy("backward", () => {
1693
+ const a = {};
1694
+ a[i.id] = s ?? ln(i.shape), Qe(
1695
+ a,
1696
+ o,
1697
+ // Pass the tidy function to avoid circular dep with `tape.ts`.
1698
+ (l) => this.tidy(l),
1699
+ // Pass an add function to avoide a circular dep with `tape.ts`.
1700
+ cn
1701
+ );
1702
+ const c = e.map((l) => a[l.id]);
1703
+ return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((l) => {
1704
+ for (const u of l.saved)
1705
+ u.dispose();
1706
+ }), this.state.activeTape = null), { value: i, grads: c };
1707
+ });
1708
+ }
1709
+ customGrad(t) {
1710
+ return b(bt(t), () => "The f passed in customGrad(f) must be a function."), (...e) => {
1711
+ b(e.every((a) => a instanceof N), () => "The args passed in customGrad(f)(x1, x2,...) must all be tensors");
1712
+ let s;
1713
+ const r = {};
1714
+ e.forEach((a, c) => {
1715
+ r[c] = a;
1716
+ });
1717
+ const i = (a, c) => (s = t(...e, c), b(s.value instanceof N, () => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"), b(bt(s.gradFunc), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."), s.value), o = (a, c) => {
1718
+ const l = s.gradFunc(a, c), u = Array.isArray(l) ? l : [l];
1719
+ b(u.length === e.length, () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."), b(u.every((f) => f instanceof N), () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors.");
1720
+ const h = {};
1721
+ return u.forEach((f, m) => {
1722
+ h[m] = () => f;
1723
+ }), h;
1724
+ };
1725
+ return this.runKernelFunc({
1726
+ forwardFunc: i,
1727
+ backwardsFunc: o,
1728
+ inputs: r
1729
+ });
1730
+ };
1731
+ }
1732
+ readSync(t) {
1733
+ return this.state.tensorInfo.get(t).backend.readSync(t);
1734
+ }
1735
+ read(t) {
1736
+ return this.state.tensorInfo.get(t).backend.read(t);
1737
+ }
1738
+ readToGPU(t, e) {
1739
+ return this.state.tensorInfo.get(t).backend.readToGPU(t, e);
1740
+ }
1741
+ async time(t) {
1742
+ const e = ut(), s = await this.backend.time(t);
1743
+ return s.wallMs = ut() - e, s;
1744
+ }
1745
+ /**
1746
+ * Tracks a Tensor in the current scope to be automatically cleaned up
1747
+ * when the current scope ends, and returns the value.
1748
+ *
1749
+ * @param result The Tensor to track in the current scope.
1750
+ */
1751
+ track(t) {
1752
+ return this.state.activeScope != null && (t.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(t)), t;
1753
+ }
1754
+ get registeredVariables() {
1755
+ return this.state.registeredVariables;
1756
+ }
1757
+ /**
1758
+ * Resets the engine state. Removes all backends but does not remove
1759
+ * registered backend factories.
1760
+ */
1761
+ reset() {
1762
+ this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new zt();
1763
+ for (const t in this.registry)
1764
+ this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t];
1765
+ this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
1766
+ }
1767
+ }
1768
+ Q.nextTensorId = 0;
1769
+ Q.nextVariableId = 0;
1770
+ function ln(n) {
1771
+ const t = Ae(U(n), "float32");
1772
+ return g.makeTensor(t, n, "float32");
1773
+ }
1774
+ function ce() {
1775
+ const n = Xt();
1776
+ if (n._tfengine == null) {
1777
+ const t = new ve(n);
1778
+ n._tfengine = new Q(t);
1779
+ }
1780
+ return xe(n._tfengine.ENV), nn(() => n._tfengine), n._tfengine;
1781
+ }
1782
+ const g = ce();
1783
+ function cn(n, t) {
1784
+ const e = { a: n, b: t };
1785
+ return g.runKernel(Yt, e);
1786
+ }
1787
+ /**
1788
+ * @license
1789
+ * Copyright 2017 Google LLC. All Rights Reserved.
1790
+ * Licensed under the Apache License, Version 2.0 (the "License");
1791
+ * you may not use this file except in compliance with the License.
1792
+ * You may obtain a copy of the License at
1793
+ *
1794
+ * http://www.apache.org/licenses/LICENSE-2.0
1795
+ *
1796
+ * Unless required by applicable law or agreed to in writing, software
1797
+ * distributed under the License is distributed on an "AS IS" BASIS,
1798
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1799
+ * See the License for the specific language governing permissions and
1800
+ * limitations under the License.
1801
+ * =============================================================================
1802
+ */
1803
+ function un() {
1804
+ return typeof window < "u" && window.document != null || //@ts-ignore
1805
+ typeof WorkerGlobalScope < "u";
1806
+ }
1807
+ /**
1808
+ * @license
1809
+ * Copyright 2019 Google LLC. All Rights Reserved.
1810
+ * Licensed under the Apache License, Version 2.0 (the "License");
1811
+ * you may not use this file except in compliance with the License.
1812
+ * You may obtain a copy of the License at
1813
+ *
1814
+ * http://www.apache.org/licenses/LICENSE-2.0
1815
+ *
1816
+ * Unless required by applicable law or agreed to in writing, software
1817
+ * distributed under the License is distributed on an "AS IS" BASIS,
1818
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1819
+ * See the License for the specific language governing permissions and
1820
+ * limitations under the License.
1821
+ * =============================================================================
1822
+ */
1823
+ const A = S();
1824
+ A.registerFlag("DEBUG", () => !1, (n) => {
1825
+ n && console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.");
1826
+ });
1827
+ A.registerFlag("IS_BROWSER", () => un());
1828
+ A.registerFlag("IS_NODE", () => typeof process < "u" && typeof process.versions < "u" && typeof process.versions.node < "u");
1829
+ A.registerFlag("IS_CHROME", () => typeof navigator < "u" && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor));
1830
+ A.registerFlag("IS_SAFARI", () => typeof navigator < "u" && navigator != null && navigator.userAgent != null && /Safari/.test(navigator.userAgent) && /Apple/.test(navigator.vendor));
1831
+ A.registerFlag("PROD", () => !1);
1832
+ A.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY", () => A.getBool("DEBUG"));
1833
+ A.registerFlag("DEPRECATION_WARNINGS_ENABLED", () => !0);
1834
+ A.registerFlag("IS_TEST", () => !1);
1835
+ A.registerFlag("CHECK_COMPUTATION_FOR_ERRORS", () => A.getBool("DEBUG"));
1836
+ A.registerFlag("WRAP_TO_IMAGEBITMAP", () => !1);
1837
+ A.registerFlag("CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU", () => !1);
1838
+ A.registerFlag("USE_SETTIMEOUTCUSTOM", () => !1);
1839
+ /**
1840
+ * @license
1841
+ * Copyright 2018 Google LLC. All Rights Reserved.
1842
+ * Licensed under the Apache License, Version 2.0 (the "License");
1843
+ * you may not use this file except in compliance with the License.
1844
+ * You may obtain a copy of the License at
1845
+ *
1846
+ * http://www.apache.org/licenses/LICENSE-2.0
1847
+ *
1848
+ * Unless required by applicable law or agreed to in writing, software
1849
+ * distributed under the License is distributed on an "AS IS" BASIS,
1850
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1851
+ * See the License for the specific language governing permissions and
1852
+ * limitations under the License.
1853
+ * =============================================================================
1854
+ */
1855
+ function hn(n, t) {
1856
+ let e = n;
1857
+ if (R(n))
1858
+ return t === "string" ? [] : [n.length];
1859
+ if (ie(n)) {
1860
+ const r = n.channels || "RGBA";
1861
+ return [n.height, n.width * r.length];
1862
+ } else if (oe(n))
1863
+ return [n.buffer.size / (t == null ? 4 : yt(t))];
1864
+ if (!Array.isArray(n))
1865
+ return [];
1866
+ const s = [];
1867
+ for (; Array.isArray(e) || R(e) && t !== "string"; )
1868
+ s.push(e.length), e = e[0];
1869
+ return Array.isArray(n) && S().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY") && ue(n, s, []), s;
1870
+ }
1871
+ function ue(n, t, e) {
1872
+ if (e = e || [], !Array.isArray(n) && !R(n)) {
1873
+ b(t.length === 0, () => `Element arr[${e.join("][")}] is a primitive, but should be an array/TypedArray of ${t[0]} elements`);
1874
+ return;
1875
+ }
1876
+ b(t.length > 0, () => `Element arr[${e.join("][")}] should be a primitive, but is an array of ${n.length} elements`), b(n.length === t[0], () => `Element arr[${e.join("][")}] should have ${t[0]} elements, but has ${n.length} elements`);
1877
+ const s = t.slice(1);
1878
+ for (let r = 0; r < n.length; ++r)
1879
+ ue(n[r], s, e.concat(r));
1880
+ }
1881
+ function Wt(n, t, e, s) {
1882
+ if (n !== "string_or_numeric") {
1883
+ if (n == null)
1884
+ throw new Error("Expected dtype cannot be null.");
1885
+ if (n !== "numeric" && n !== t || n === "numeric" && t === "string")
1886
+ throw new Error(`Argument '${e}' passed to '${s}' must be ${n} tensor, but got ${t} tensor`);
1887
+ }
1888
+ }
1889
+ function k(n, t, e, s = "numeric") {
1890
+ if (n instanceof re())
1891
+ return Wt(s, n.dtype, t, e), n;
1892
+ let r = ft(n);
1893
+ if (r !== "string" && ["bool", "int32", "float32"].indexOf(s) >= 0 && (r = s), Wt(s, r, t, e), n == null || !R(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string") {
1894
+ const c = n == null ? "null" : n.constructor.name;
1895
+ throw new Error(`Argument '${t}' passed to '${e}' must be a Tensor or TensorLike, but got '${c}'`);
1896
+ }
1897
+ const i = hn(n, r);
1898
+ !R(n) && !Array.isArray(n) && (n = [n]);
1899
+ const a = r !== "string" ? ne(n, r) : ot(n, [], !0);
1900
+ return g.makeTensor(a, i, r);
1901
+ }
1902
+ /**
1903
+ * @license
1904
+ * Copyright 2018 Google LLC. All Rights Reserved.
1905
+ * Licensed under the Apache License, Version 2.0 (the "License");
1906
+ * you may not use this file except in compliance with the License.
1907
+ * You may obtain a copy of the License at
1908
+ *
1909
+ * http://www.apache.org/licenses/LICENSE-2.0
1910
+ *
1911
+ * Unless required by applicable law or agreed to in writing, software
1912
+ * distributed under the License is distributed on an "AS IS" BASIS,
1913
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1914
+ * See the License for the specific language governing permissions and
1915
+ * limitations under the License.
1916
+ * =============================================================================
1917
+ */
1918
+ const fn = "__op";
1919
+ function F(n) {
1920
+ const t = Object.keys(n);
1921
+ if (t.length !== 1)
1922
+ throw new Error(`Please provide an object with a single key (operation name) mapping to a function. Got an object with ${t.length} keys.`);
1923
+ let e = t[0];
1924
+ const s = n[e];
1925
+ e.endsWith("_") && (e = e.substring(0, e.length - 1)), e = e + fn;
1926
+ const r = (...i) => {
1927
+ g.startScope(e);
1928
+ try {
1929
+ const o = s(...i);
1930
+ return xt(o) && console.error("Cannot return a Promise inside of tidy."), g.endScope(o), o;
1931
+ } catch (o) {
1932
+ throw g.endScope(null), o;
1933
+ }
1934
+ };
1935
+ return Object.defineProperty(r, "name", { value: e, configurable: !0 }), r;
1936
+ }
1937
+ /**
1938
+ * @license
1939
+ * Copyright 2018 Google LLC. All Rights Reserved.
1940
+ * Licensed under the Apache License, Version 2.0 (the "License");
1941
+ * you may not use this file except in compliance with the License.
1942
+ * You may obtain a copy of the License at
1943
+ *
1944
+ * http://www.apache.org/licenses/LICENSE-2.0
1945
+ *
1946
+ * Unless required by applicable law or agreed to in writing, software
1947
+ * distributed under the License is distributed on an "AS IS" BASIS,
1948
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1949
+ * See the License for the specific language governing permissions and
1950
+ * limitations under the License.
1951
+ * =============================================================================
1952
+ */
1953
+ function dn(n, t, e, s) {
1954
+ if (s == null)
1955
+ s = ft(n);
1956
+ else if (s === "complex64")
1957
+ throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");
1958
+ if (oe(n) || ie(n)) {
1959
+ if (s !== "float32" && s !== "int32")
1960
+ throw new Error(`Creating tensor from GPU data only supports 'float32'|'int32' dtype, while the dtype is ${s}.`);
1961
+ return g.backend.createTensorFromGPUData(n, t || e, s);
1962
+ }
1963
+ if (!R(n) && !Array.isArray(n) && typeof n != "number" && typeof n != "boolean" && typeof n != "string")
1964
+ throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");
1965
+ if (t != null) {
1966
+ Rt(t);
1967
+ const r = U(t), i = U(e);
1968
+ b(r === i, () => `Based on the provided shape, [${t}], the tensor should have ${r} values but has ${i}`);
1969
+ for (let o = 0; o < e.length; ++o) {
1970
+ const a = e[o], c = o === e.length - 1 ? a !== U(t.slice(o)) : !0;
1971
+ b(e[o] === t[o] || !c, () => `Error creating a new Tensor. Inferred shape (${e}) does not match the provided shape (${t}). `);
1972
+ }
1973
+ }
1974
+ return !R(n) && !Array.isArray(n) && (n = [n]), t = t || e, n = s !== "string" ? ne(n, s) : ot(n, [], !0), g.makeTensor(n, t, s);
1975
+ }
1976
+ class at {
1977
+ /**
1978
+ * Concatenate a number of ArrayBuffers into one.
1979
+ *
1980
+ * @param buffers An array of ArrayBuffers to concatenate, or a single
1981
+ * ArrayBuffer.
1982
+ * @returns Result of concatenating `buffers` in order.
1983
+ */
1984
+ static join(t) {
1985
+ return new at(t).slice();
1986
+ }
1987
+ constructor(t) {
1988
+ if (this.shards = [], this.previousShardIndex = 0, t == null || (t instanceof Array || (t = [t]), t = t.map((s) => R(s) ? s.buffer : s), t.length === 0))
1989
+ return;
1990
+ this.bufferUniformSize = t[0].byteLength;
1991
+ let e = 0;
1992
+ for (let s = 0; s < t.length; s++) {
1993
+ const r = t[s];
1994
+ s !== t.length - 1 && r.byteLength !== this.bufferUniformSize && (this.bufferUniformSize = void 0);
1995
+ const i = e + r.byteLength;
1996
+ this.shards.push({ buffer: r, start: e, end: i }), e = i;
1997
+ }
1998
+ this.shards.length === 0 && (this.byteLength = 0), this.byteLength = this.shards[this.shards.length - 1].end;
1999
+ }
2000
+ slice(t = 0, e = this.byteLength) {
2001
+ if (this.shards.length === 0)
2002
+ return new ArrayBuffer(0);
2003
+ if (t = isNaN(Number(t)) ? 0 : t, e = isNaN(Number(e)) ? 0 : e, t = Math.max(0, t), e = Math.min(this.byteLength, e), e <= t)
2004
+ return new ArrayBuffer(0);
2005
+ const s = this.findShardForByte(t);
2006
+ if (s === -1)
2007
+ throw new Error(`Could not find start shard for byte ${t}`);
2008
+ const r = e - t, i = new ArrayBuffer(r), o = new Uint8Array(i);
2009
+ let a = 0;
2010
+ for (let c = s; c < this.shards.length; c++) {
2011
+ const l = this.shards[c], h = t + a - l.start, f = a, y = Math.min(e, l.end) - l.start, d = new Uint8Array(l.buffer, h, y - h);
2012
+ if (o.set(d, f), a += d.length, e < l.end)
2013
+ break;
2014
+ }
2015
+ return i;
2016
+ }
2017
+ /**
2018
+ * Get the index of the shard that contains the byte at `byteIndex`.
2019
+ */
2020
+ findShardForByte(t) {
2021
+ if (this.shards.length === 0 || t < 0 || t >= this.byteLength)
2022
+ return -1;
2023
+ if (this.bufferUniformSize != null)
2024
+ return this.previousShardIndex = Math.floor(t / this.bufferUniformSize), this.previousShardIndex;
2025
+ function e(r) {
2026
+ return t < r.start ? -1 : t >= r.end ? 1 : 0;
2027
+ }
2028
+ if (e(this.shards[this.previousShardIndex]) === 0)
2029
+ return this.previousShardIndex;
2030
+ const s = gn(this.shards, e);
2031
+ return s === -1 ? -1 : (this.previousShardIndex = s, this.previousShardIndex);
2032
+ }
2033
+ }
2034
+ function gn(n, t) {
2035
+ let e = 0, s = n.length;
2036
+ for (; e <= s; ) {
2037
+ const r = Math.floor((s - e) / 2) + e, i = t(n[r]);
2038
+ if (i === 0)
2039
+ return r;
2040
+ i < 0 ? s = r : e = r + 1;
2041
+ }
2042
+ return -1;
2043
+ }
2044
+ /**
2045
+ * @license
2046
+ * Copyright 2018 Google LLC. All Rights Reserved.
2047
+ * Licensed under the Apache License, Version 2.0 (the "License");
2048
+ * you may not use this file except in compliance with the License.
2049
+ * You may obtain a copy of the License at
2050
+ *
2051
+ * http://www.apache.org/licenses/LICENSE-2.0
2052
+ *
2053
+ * Unless required by applicable law or agreed to in writing, software
2054
+ * distributed under the License is distributed on an "AS IS" BASIS,
2055
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2056
+ * See the License for the specific language governing permissions and
2057
+ * limitations under the License.
2058
+ * =============================================================================
2059
+ */
2060
+ function Rs() {
2061
+ return g;
2062
+ }
2063
+ function E(n, t) {
2064
+ return g.tidy(n, t);
2065
+ }
2066
+ function M(n) {
2067
+ ae(n).forEach((e) => e.dispose());
2068
+ }
2069
+ function mn(n) {
2070
+ return g.keep(n);
2071
+ }
2072
+ /**
2073
+ * @license
2074
+ * Copyright 2018 Google LLC. All Rights Reserved.
2075
+ * Licensed under the Apache License, Version 2.0 (the "License");
2076
+ * you may not use this file except in compliance with the License.
2077
+ * You may obtain a copy of the License at
2078
+ *
2079
+ * http://www.apache.org/licenses/LICENSE-2.0
2080
+ *
2081
+ * Unless required by applicable law or agreed to in writing, software
2082
+ * distributed under the License is distributed on an "AS IS" BASIS,
2083
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2084
+ * See the License for the specific language governing permissions and
2085
+ * limitations under the License.
2086
+ * =============================================================================
2087
+ */
2088
+ const $t = typeof Buffer < "u" && (typeof Blob > "u" || typeof atob > "u" || typeof btoa > "u");
2089
+ function jt(n) {
2090
+ return $t ? Buffer.byteLength(n, "utf8") : new Blob([n]).size;
2091
+ }
2092
+ function pn(n) {
2093
+ if ($t)
2094
+ return Buffer.from(n).toString("base64");
2095
+ const t = new Uint8Array(n);
2096
+ let e = "";
2097
+ for (let s = 0, r = t.length; s < r; s++)
2098
+ e += String.fromCharCode(t[s]);
2099
+ return btoa(e);
2100
+ }
2101
+ function yn(n) {
2102
+ if ($t) {
2103
+ const s = Buffer.from(n, "base64");
2104
+ return s.buffer.slice(s.byteOffset, s.byteOffset + s.byteLength);
2105
+ }
2106
+ const t = atob(n), e = new Uint8Array(t.length);
2107
+ for (let s = 0; s < t.length; ++s)
2108
+ e.set([t.charCodeAt(s)], s);
2109
+ return e.buffer;
2110
+ }
2111
+ function he(n) {
2112
+ if (n.modelTopology instanceof ArrayBuffer)
2113
+ throw new Error("Expected JSON model topology, received ArrayBuffer.");
2114
+ return {
2115
+ dateSaved: /* @__PURE__ */ new Date(),
2116
+ modelTopologyType: "JSON",
2117
+ modelTopologyBytes: n.modelTopology == null ? 0 : jt(JSON.stringify(n.modelTopology)),
2118
+ weightSpecsBytes: n.weightSpecs == null ? 0 : jt(JSON.stringify(n.weightSpecs)),
2119
+ weightDataBytes: n.weightData == null ? 0 : new at(n.weightData).byteLength
2120
+ };
2121
+ }
2122
+ /**
2123
+ * @license
2124
+ * Copyright 2018 Google LLC. All Rights Reserved.
2125
+ * Licensed under the Apache License, Version 2.0 (the "License");
2126
+ * you may not use this file except in compliance with the License.
2127
+ * You may obtain a copy of the License at
2128
+ *
2129
+ * http://www.apache.org/licenses/LICENSE-2.0
2130
+ *
2131
+ * Unless required by applicable law or agreed to in writing, software
2132
+ * distributed under the License is distributed on an "AS IS" BASIS,
2133
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2134
+ * See the License for the specific language governing permissions and
2135
+ * limitations under the License.
2136
+ * =============================================================================
2137
+ */
2138
+ class B {
2139
+ constructor() {
2140
+ this.saveRouters = [], this.loadRouters = [];
2141
+ }
2142
+ static getInstance() {
2143
+ return B.instance == null && (B.instance = new B()), B.instance;
2144
+ }
2145
+ /**
2146
+ * Register a save-handler router.
2147
+ *
2148
+ * @param saveRouter A function that maps a URL-like string onto an instance
2149
+ * of `IOHandler` with the `save` method defined or `null`.
2150
+ */
2151
+ static registerSaveRouter(t) {
2152
+ B.getInstance().saveRouters.push(t);
2153
+ }
2154
+ /**
2155
+ * Register a load-handler router.
2156
+ *
2157
+ * @param loadRouter A function that maps a URL-like string onto an instance
2158
+ * of `IOHandler` with the `load` method defined or `null`.
2159
+ */
2160
+ static registerLoadRouter(t) {
2161
+ B.getInstance().loadRouters.push(t);
2162
+ }
2163
+ /**
2164
+ * Look up IOHandler for saving, given a URL-like string.
2165
+ *
2166
+ * @param url
2167
+ * @returns If only one match is found, an instance of IOHandler with the
2168
+ * `save` method defined. If no match is found, `null`.
2169
+ * @throws Error, if more than one match is found.
2170
+ */
2171
+ static getSaveHandlers(t) {
2172
+ return B.getHandlers(t, "save");
2173
+ }
2174
+ /**
2175
+ * Look up IOHandler for loading, given a URL-like string.
2176
+ *
2177
+ * @param url
2178
+ * @param loadOptions Optional, custom load options.
2179
+ * @returns All valid handlers for `url`, given the currently registered
2180
+ * handler routers.
2181
+ */
2182
+ static getLoadHandlers(t, e) {
2183
+ return B.getHandlers(t, "load", e);
2184
+ }
2185
+ static getHandlers(t, e, s) {
2186
+ const r = [];
2187
+ return (e === "load" ? B.getInstance().loadRouters : B.getInstance().saveRouters).forEach((o) => {
2188
+ const a = o(t, s);
2189
+ a !== null && r.push(a);
2190
+ }), r;
2191
+ }
2192
+ }
2193
+ /**
2194
+ * @license
2195
+ * Copyright 2018 Google LLC. All Rights Reserved.
2196
+ * Licensed under the Apache License, Version 2.0 (the "License");
2197
+ * you may not use this file except in compliance with the License.
2198
+ * You may obtain a copy of the License at
2199
+ *
2200
+ * http://www.apache.org/licenses/LICENSE-2.0
2201
+ *
2202
+ * Unless required by applicable law or agreed to in writing, software
2203
+ * distributed under the License is distributed on an "AS IS" BASIS,
2204
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2205
+ * See the License for the specific language governing permissions and
2206
+ * limitations under the License.
2207
+ * =============================================================================
2208
+ */
2209
+ const Tt = "tensorflowjs", Et = 1, L = "models_store", O = "model_info_store";
2210
+ function fe() {
2211
+ if (!S().getBool("IS_BROWSER"))
2212
+ throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");
2213
+ const n = typeof window > "u" ? self : window, t = n.indexedDB || n.mozIndexedDB || n.webkitIndexedDB || n.msIndexedDB || n.shimIndexedDB;
2214
+ if (t == null)
2215
+ throw new Error("The current browser does not appear to support IndexedDB.");
2216
+ return t;
2217
+ }
2218
+ function Bt(n) {
2219
+ const t = n.result;
2220
+ t.createObjectStore(L, { keyPath: "modelPath" }), t.createObjectStore(O, { keyPath: "modelPath" });
2221
+ }
2222
+ class z {
2223
+ constructor(t) {
2224
+ if (this.indexedDB = fe(), t == null || !t)
2225
+ throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");
2226
+ this.modelPath = t;
2227
+ }
2228
+ async save(t) {
2229
+ if (t.modelTopology instanceof ArrayBuffer)
2230
+ throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
2231
+ return this.databaseAction(this.modelPath, t);
2232
+ }
2233
+ async load() {
2234
+ return this.databaseAction(this.modelPath);
2235
+ }
2236
+ /**
2237
+ * Perform database action to put model artifacts into or read model artifacts
2238
+ * from IndexedDB object store.
2239
+ *
2240
+ * Whether the action is put or get depends on whether `modelArtifacts` is
2241
+ * specified. If it is specified, the action will be put; otherwise the action
2242
+ * will be get.
2243
+ *
2244
+ * @param modelPath A unique string path for the model.
2245
+ * @param modelArtifacts If specified, it will be the model artifacts to be
2246
+ * stored in IndexedDB.
2247
+ * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`
2248
+ * of `ModelArtifacts`, if the action is get.
2249
+ */
2250
+ databaseAction(t, e) {
2251
+ return new Promise((s, r) => {
2252
+ const i = this.indexedDB.open(Tt, Et);
2253
+ i.onupgradeneeded = () => Bt(i), i.onsuccess = () => {
2254
+ const o = i.result;
2255
+ if (e == null) {
2256
+ const a = o.transaction(L, "readonly"), l = a.objectStore(L).get(this.modelPath);
2257
+ l.onsuccess = () => {
2258
+ if (l.result == null)
2259
+ return o.close(), r(new Error(`Cannot find model with path '${this.modelPath}' in IndexedDB.`));
2260
+ s(l.result.modelArtifacts);
2261
+ }, l.onerror = (u) => (o.close(), r(l.error)), a.oncomplete = () => o.close();
2262
+ } else {
2263
+ e.weightData = at.join(e.weightData);
2264
+ const a = he(e), c = o.transaction(O, "readwrite");
2265
+ let l = c.objectStore(O), u;
2266
+ try {
2267
+ u = l.put({ modelPath: this.modelPath, modelArtifactsInfo: a });
2268
+ } catch (f) {
2269
+ return r(f);
2270
+ }
2271
+ let h;
2272
+ u.onsuccess = () => {
2273
+ h = o.transaction(L, "readwrite");
2274
+ const f = h.objectStore(L);
2275
+ let m;
2276
+ try {
2277
+ m = f.put({
2278
+ modelPath: this.modelPath,
2279
+ modelArtifacts: e,
2280
+ modelArtifactsInfo: a
2281
+ });
2282
+ } catch (y) {
2283
+ return r(y);
2284
+ }
2285
+ m.onsuccess = () => s({ modelArtifactsInfo: a }), m.onerror = (y) => {
2286
+ l = c.objectStore(O);
2287
+ const d = l.delete(this.modelPath);
2288
+ d.onsuccess = () => (o.close(), r(m.error)), d.onerror = (I) => (o.close(), r(m.error));
2289
+ };
2290
+ }, u.onerror = (f) => (o.close(), r(u.error)), c.oncomplete = () => {
2291
+ h == null ? o.close() : h.oncomplete = () => o.close();
2292
+ };
2293
+ }
2294
+ }, i.onerror = (o) => r(i.error);
2295
+ });
2296
+ }
2297
+ }
2298
+ z.URL_SCHEME = "indexeddb://";
2299
+ const de = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(z.URL_SCHEME) ? bn(n.slice(z.URL_SCHEME.length)) : null;
2300
+ B.registerSaveRouter(de);
2301
+ B.registerLoadRouter(de);
2302
+ function bn(n) {
2303
+ return new z(n);
2304
+ }
2305
+ function wn(n) {
2306
+ return n.startsWith(z.URL_SCHEME) ? n.slice(z.URL_SCHEME.length) : n;
2307
+ }
2308
+ class Sn {
2309
+ constructor() {
2310
+ this.indexedDB = fe();
2311
+ }
2312
+ async listModels() {
2313
+ return new Promise((t, e) => {
2314
+ const s = this.indexedDB.open(Tt, Et);
2315
+ s.onupgradeneeded = () => Bt(s), s.onsuccess = () => {
2316
+ const r = s.result, i = r.transaction(O, "readonly"), a = i.objectStore(O).getAll();
2317
+ a.onsuccess = () => {
2318
+ const c = {};
2319
+ for (const l of a.result)
2320
+ c[l.modelPath] = l.modelArtifactsInfo;
2321
+ t(c);
2322
+ }, a.onerror = (c) => (r.close(), e(a.error)), i.oncomplete = () => r.close();
2323
+ }, s.onerror = (r) => e(s.error);
2324
+ });
2325
+ }
2326
+ async removeModel(t) {
2327
+ return t = wn(t), new Promise((e, s) => {
2328
+ const r = this.indexedDB.open(Tt, Et);
2329
+ r.onupgradeneeded = () => Bt(r), r.onsuccess = () => {
2330
+ const i = r.result, o = i.transaction(O, "readwrite"), a = o.objectStore(O), c = a.get(t);
2331
+ let l;
2332
+ c.onsuccess = () => {
2333
+ if (c.result == null)
2334
+ return i.close(), s(new Error(`Cannot find model with path '${t}' in IndexedDB.`));
2335
+ {
2336
+ const u = a.delete(t), h = () => {
2337
+ l = i.transaction(L, "readwrite");
2338
+ const m = l.objectStore(L).delete(t);
2339
+ m.onsuccess = () => e(c.result.modelArtifactsInfo), m.onerror = (y) => s(c.error);
2340
+ };
2341
+ u.onsuccess = h, u.onerror = (f) => (h(), i.close(), s(c.error));
2342
+ }
2343
+ }, c.onerror = (u) => (i.close(), s(c.error)), o.oncomplete = () => {
2344
+ l == null ? i.close() : l.oncomplete = () => i.close();
2345
+ };
2346
+ }, r.onerror = (i) => s(r.error);
2347
+ });
2348
+ }
2349
+ }
2350
+ /**
2351
+ * @license
2352
+ * Copyright 2018 Google LLC. All Rights Reserved.
2353
+ * Licensed under the Apache License, Version 2.0 (the "License");
2354
+ * you may not use this file except in compliance with the License.
2355
+ * You may obtain a copy of the License at
2356
+ *
2357
+ * http://www.apache.org/licenses/LICENSE-2.0
2358
+ *
2359
+ * Unless required by applicable law or agreed to in writing, software
2360
+ * distributed under the License is distributed on an "AS IS" BASIS,
2361
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2362
+ * See the License for the specific language governing permissions and
2363
+ * limitations under the License.
2364
+ * =============================================================================
2365
+ */
2366
+ const _ = "/", X = "tensorflowjs_models", ge = "info", In = "model_topology", kn = "weight_specs", Tn = "weight_data", En = "model_metadata";
2367
+ function me(n) {
2368
+ return {
2369
+ info: [X, n, ge].join(_),
2370
+ topology: [X, n, In].join(_),
2371
+ weightSpecs: [X, n, kn].join(_),
2372
+ weightData: [X, n, Tn].join(_),
2373
+ modelMetadata: [X, n, En].join(_)
2374
+ };
2375
+ }
2376
+ function pe(n) {
2377
+ for (const t of Object.values(n))
2378
+ window.localStorage.removeItem(t);
2379
+ }
2380
+ function Bn(n) {
2381
+ const t = n.split(_);
2382
+ if (t.length < 3)
2383
+ throw new Error(`Invalid key format: ${n}`);
2384
+ return t.slice(1, t.length - 1).join(_);
2385
+ }
2386
+ function An(n) {
2387
+ return n.startsWith(W.URL_SCHEME) ? n.slice(W.URL_SCHEME.length) : n;
2388
+ }
2389
+ class W {
2390
+ constructor(t) {
2391
+ if (!S().getBool("IS_BROWSER") || typeof window > "u" || typeof window.localStorage > "u")
2392
+ throw new Error("The current environment does not support local storage.");
2393
+ if (this.LS = window.localStorage, t == null || !t)
2394
+ throw new Error("For local storage, modelPath must not be null, undefined or empty.");
2395
+ this.modelPath = t, this.keys = me(this.modelPath);
2396
+ }
2397
+ /**
2398
+ * Save model artifacts to browser local storage.
2399
+ *
2400
+ * See the documentation to `browserLocalStorage` for details on the saved
2401
+ * artifacts.
2402
+ *
2403
+ * @param modelArtifacts The model artifacts to be stored.
2404
+ * @returns An instance of SaveResult.
2405
+ */
2406
+ async save(t) {
2407
+ if (t.modelTopology instanceof ArrayBuffer)
2408
+ throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");
2409
+ {
2410
+ const e = JSON.stringify(t.modelTopology), s = JSON.stringify(t.weightSpecs), r = he(t), i = at.join(t.weightData);
2411
+ try {
2412
+ this.LS.setItem(this.keys.info, JSON.stringify(r)), this.LS.setItem(this.keys.topology, e), this.LS.setItem(this.keys.weightSpecs, s), this.LS.setItem(this.keys.weightData, pn(i));
2413
+ const o = {
2414
+ format: t.format,
2415
+ generatedBy: t.generatedBy,
2416
+ convertedBy: t.convertedBy,
2417
+ signature: t.signature != null ? t.signature : void 0,
2418
+ userDefinedMetadata: t.userDefinedMetadata != null ? t.userDefinedMetadata : void 0,
2419
+ modelInitializer: t.modelInitializer != null ? t.modelInitializer : void 0,
2420
+ initializerSignature: t.initializerSignature != null ? t.initializerSignature : void 0,
2421
+ trainingConfig: t.trainingConfig != null ? t.trainingConfig : void 0
2422
+ };
2423
+ return this.LS.setItem(this.keys.modelMetadata, JSON.stringify(o)), { modelArtifactsInfo: r };
2424
+ } catch {
2425
+ throw pe(this.keys), new Error(`Failed to save model '${this.modelPath}' to local storage: size quota being exceeded is a possible cause of this failure: modelTopologyBytes=${r.modelTopologyBytes}, weightSpecsBytes=${r.weightSpecsBytes}, weightDataBytes=${r.weightDataBytes}.`);
2426
+ }
2427
+ }
2428
+ }
2429
+ /**
2430
+ * Load a model from local storage.
2431
+ *
2432
+ * See the documentation to `browserLocalStorage` for details on the saved
2433
+ * artifacts.
2434
+ *
2435
+ * @returns The loaded model (if loading succeeds).
2436
+ */
2437
+ async load() {
2438
+ const t = JSON.parse(this.LS.getItem(this.keys.info));
2439
+ if (t == null)
2440
+ throw new Error(`In local storage, there is no model with name '${this.modelPath}'`);
2441
+ if (t.modelTopologyType !== "JSON")
2442
+ throw new Error("BrowserLocalStorage does not support loading non-JSON model topology yet.");
2443
+ const e = {}, s = JSON.parse(this.LS.getItem(this.keys.topology));
2444
+ if (s == null)
2445
+ throw new Error(`In local storage, the topology of model '${this.modelPath}' is missing.`);
2446
+ e.modelTopology = s;
2447
+ const r = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
2448
+ if (r == null)
2449
+ throw new Error(`In local storage, the weight specs of model '${this.modelPath}' are missing.`);
2450
+ e.weightSpecs = r;
2451
+ const i = this.LS.getItem(this.keys.modelMetadata);
2452
+ if (i != null) {
2453
+ const a = JSON.parse(i);
2454
+ e.format = a.format, e.generatedBy = a.generatedBy, e.convertedBy = a.convertedBy, a.signature != null && (e.signature = a.signature), a.userDefinedMetadata != null && (e.userDefinedMetadata = a.userDefinedMetadata), a.modelInitializer != null && (e.modelInitializer = a.modelInitializer), a.initializerSignature != null && (e.initializerSignature = a.initializerSignature), a.trainingConfig != null && (e.trainingConfig = a.trainingConfig);
2455
+ }
2456
+ const o = this.LS.getItem(this.keys.weightData);
2457
+ if (o == null)
2458
+ throw new Error(`In local storage, the binary weight values of model '${this.modelPath}' are missing.`);
2459
+ return e.weightData = yn(o), e;
2460
+ }
2461
+ }
2462
+ W.URL_SCHEME = "localstorage://";
2463
+ const ye = (n) => S().getBool("IS_BROWSER") && !Array.isArray(n) && n.startsWith(W.URL_SCHEME) ? vn(n.slice(W.URL_SCHEME.length)) : null;
2464
+ B.registerSaveRouter(ye);
2465
+ B.registerLoadRouter(ye);
2466
+ function vn(n) {
2467
+ return new W(n);
2468
+ }
2469
+ class Mn {
2470
+ constructor() {
2471
+ b(S().getBool("IS_BROWSER"), () => "Current environment is not a web browser"), b(typeof window > "u" || typeof window.localStorage < "u", () => "Current browser does not appear to support localStorage"), this.LS = window.localStorage;
2472
+ }
2473
+ async listModels() {
2474
+ const t = {}, e = X + _, s = _ + ge;
2475
+ for (let r = 0; r < this.LS.length; ++r) {
2476
+ const i = this.LS.key(r);
2477
+ if (i.startsWith(e) && i.endsWith(s)) {
2478
+ const o = Bn(i);
2479
+ t[o] = JSON.parse(this.LS.getItem(i));
2480
+ }
2481
+ }
2482
+ return t;
2483
+ }
2484
+ async removeModel(t) {
2485
+ t = An(t);
2486
+ const e = me(t);
2487
+ if (this.LS.getItem(e.info) == null)
2488
+ throw new Error(`Cannot find model at path '${t}'`);
2489
+ const s = JSON.parse(this.LS.getItem(e.info));
2490
+ return pe(e), s;
2491
+ }
2492
+ }
2493
+ /**
2494
+ * @license
2495
+ * Copyright 2018 Google LLC. All Rights Reserved.
2496
+ * Licensed under the Apache License, Version 2.0 (the "License");
2497
+ * you may not use this file except in compliance with the License.
2498
+ * You may obtain a copy of the License at
2499
+ *
2500
+ * http://www.apache.org/licenses/LICENSE-2.0
2501
+ *
2502
+ * Unless required by applicable law or agreed to in writing, software
2503
+ * distributed under the License is distributed on an "AS IS" BASIS,
2504
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2505
+ * See the License for the specific language governing permissions and
2506
+ * limitations under the License.
2507
+ * =============================================================================
2508
+ */
2509
+ const Kt = "://";
2510
+ class $ {
2511
+ constructor() {
2512
+ this.managers = {};
2513
+ }
2514
+ static getInstance() {
2515
+ return $.instance == null && ($.instance = new $()), $.instance;
2516
+ }
2517
+ /**
2518
+ * Register a save-handler router.
2519
+ *
2520
+ * @param saveRouter A function that maps a URL-like string onto an instance
2521
+ * of `IOHandler` with the `save` method defined or `null`.
2522
+ */
2523
+ static registerManager(t, e) {
2524
+ b(t != null, () => "scheme must not be undefined or null."), t.endsWith(Kt) && (t = t.slice(0, t.indexOf(Kt))), b(t.length > 0, () => "scheme must not be an empty string.");
2525
+ const s = $.getInstance();
2526
+ b(s.managers[t] == null, () => `A model store manager is already registered for scheme '${t}'.`), s.managers[t] = e;
2527
+ }
2528
+ static getManager(t) {
2529
+ const e = $.getInstance().managers[t];
2530
+ if (e == null)
2531
+ throw new Error(`Cannot find model manager for scheme '${t}'`);
2532
+ return e;
2533
+ }
2534
+ static getSchemes() {
2535
+ return Object.keys($.getInstance().managers);
2536
+ }
2537
+ }
2538
+ /**
2539
+ * @license
2540
+ * Copyright 2019 Google LLC. All Rights Reserved.
2541
+ * Licensed under the Apache License, Version 2.0 (the "License");
2542
+ * you may not use this file except in compliance with the License.
2543
+ * You may obtain a copy of the License at
2544
+ *
2545
+ * http://www.apache.org/licenses/LICENSE-2.0
2546
+ *
2547
+ * Unless required by applicable law or agreed to in writing, software
2548
+ * distributed under the License is distributed on an "AS IS" BASIS,
2549
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2550
+ * See the License for the specific language governing permissions and
2551
+ * limitations under the License.
2552
+ * =============================================================================
2553
+ */
2554
+ class Fn {
2555
+ constructor() {
2556
+ this.messageName = "setTimeoutCustom", this.functionRefs = [], this.handledMessageCount = 0, this.hasEventListener = !1;
2557
+ }
2558
+ fetch(t, e) {
2559
+ return fetch(t, e);
2560
+ }
2561
+ now() {
2562
+ return performance.now();
2563
+ }
2564
+ encode(t, e) {
2565
+ if (e !== "utf-8" && e !== "utf8")
2566
+ throw new Error(`Browser's encoder only supports utf-8, but got ${e}`);
2567
+ return this.textEncoder == null && (this.textEncoder = new TextEncoder()), this.textEncoder.encode(t);
2568
+ }
2569
+ decode(t, e) {
2570
+ return new TextDecoder(e).decode(t);
2571
+ }
2572
+ // If the setTimeout nesting level is greater than 5 and timeout is less
2573
+ // than 4ms, timeout will be clamped to 4ms, which hurts the perf.
2574
+ // Interleaving window.postMessage and setTimeout will trick the browser and
2575
+ // avoid the clamp.
2576
+ setTimeoutCustom(t, e) {
2577
+ if (typeof window > "u" || !S().getBool("USE_SETTIMEOUTCUSTOM")) {
2578
+ setTimeout(t, e);
2579
+ return;
2580
+ }
2581
+ this.functionRefs.push(t), setTimeout(() => {
2582
+ window.postMessage({ name: this.messageName, index: this.functionRefs.length - 1 }, "*");
2583
+ }, e), this.hasEventListener || (this.hasEventListener = !0, window.addEventListener("message", (s) => {
2584
+ if (s.source === window && s.data.name === this.messageName) {
2585
+ s.stopPropagation();
2586
+ const r = this.functionRefs[s.data.index];
2587
+ r(), this.handledMessageCount++, this.handledMessageCount === this.functionRefs.length && (this.functionRefs = [], this.handledMessageCount = 0);
2588
+ }
2589
+ }, !0));
2590
+ }
2591
+ isTypedArray(t) {
2592
+ return ee(t);
2593
+ }
2594
+ }
2595
+ if (S().get("IS_BROWSER")) {
2596
+ S().setPlatform("browser", new Fn());
2597
+ try {
2598
+ $.registerManager(W.URL_SCHEME, new Mn());
2599
+ } catch {
2600
+ }
2601
+ try {
2602
+ $.registerManager(z.URL_SCHEME, new Sn());
2603
+ } catch {
2604
+ }
2605
+ }
2606
+ /**
2607
+ * @license
2608
+ * Copyright 2019 Google LLC. All Rights Reserved.
2609
+ * Licensed under the Apache License, Version 2.0 (the "License");
2610
+ * you may not use this file except in compliance with the License.
2611
+ * You may obtain a copy of the License at
2612
+ *
2613
+ * http://www.apache.org/licenses/LICENSE-2.0
2614
+ *
2615
+ * Unless required by applicable law or agreed to in writing, software
2616
+ * distributed under the License is distributed on an "AS IS" BASIS,
2617
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2618
+ * See the License for the specific language governing permissions and
2619
+ * limitations under the License.
2620
+ * =============================================================================
2621
+ */
2622
+ const Rn = {
2623
+ // tslint:disable-next-line:no-require-imports
2624
+ importFetch: () => require("node-fetch")
2625
+ };
2626
+ let pt;
2627
+ class xn {
2628
+ constructor() {
2629
+ this.util = require("util"), this.textEncoder = new this.util.TextEncoder();
2630
+ }
2631
+ fetch(t, e) {
2632
+ return S().global.fetch != null ? S().global.fetch(t, e) : (pt == null && (pt = Rn.importFetch()), pt(t, e));
2633
+ }
2634
+ now() {
2635
+ const t = process.hrtime();
2636
+ return t[0] * 1e3 + t[1] / 1e6;
2637
+ }
2638
+ encode(t, e) {
2639
+ if (e !== "utf-8" && e !== "utf8")
2640
+ throw new Error(`Node built-in encoder only supports utf-8, but got ${e}`);
2641
+ return this.textEncoder.encode(t);
2642
+ }
2643
+ decode(t, e) {
2644
+ return t.length === 0 ? "" : new this.util.TextDecoder(e).decode(t);
2645
+ }
2646
+ isTypedArray(t) {
2647
+ return this.util.types.isFloat32Array(t) || this.util.types.isInt32Array(t) || this.util.types.isUint8Array(t) || this.util.types.isUint8ClampedArray(t);
2648
+ }
2649
+ }
2650
+ S().get("IS_NODE") && !S().get("IS_BROWSER") && S().setPlatform("node", new xn());
2651
+ /**
2652
+ * @license
2653
+ * Copyright 2020 Google Inc. All Rights Reserved.
2654
+ * Licensed under the Apache License, Version 2.0 (the "License");
2655
+ * you may not use this file except in compliance with the License.
2656
+ * You may obtain a copy of the License at
2657
+ *
2658
+ * http://www.apache.org/licenses/LICENSE-2.0
2659
+ *
2660
+ * Unless required by applicable law or agreed to in writing, software
2661
+ * distributed under the License is distributed on an "AS IS" BASIS,
2662
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2663
+ * See the License for the specific language governing permissions and
2664
+ * limitations under the License.
2665
+ * =============================================================================
2666
+ */
2667
+ function Nn(n, t = "float32", e) {
2668
+ return t = t || "float32", Rt(n), new en(n, t, e);
2669
+ }
2670
+ /**
2671
+ * @license
2672
+ * Copyright 2020 Google Inc. All Rights Reserved.
2673
+ * Licensed under the Apache License, Version 2.0 (the "License");
2674
+ * you may not use this file except in compliance with the License.
2675
+ * You may obtain a copy of the License at
2676
+ *
2677
+ * http://www.apache.org/licenses/LICENSE-2.0
2678
+ *
2679
+ * Unless required by applicable law or agreed to in writing, software
2680
+ * distributed under the License is distributed on an "AS IS" BASIS,
2681
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2682
+ * See the License for the specific language governing permissions and
2683
+ * limitations under the License.
2684
+ * =============================================================================
2685
+ */
2686
+ function $n(n, t) {
2687
+ const e = k(n, "x", "cast");
2688
+ if (!ke(t))
2689
+ throw new Error(`Failed to cast to unknown dtype ${t}`);
2690
+ if (t === "string" && e.dtype !== "string" || t !== "string" && e.dtype === "string")
2691
+ throw new Error("Only strings can be casted to strings");
2692
+ const s = { x: e }, r = { dtype: t };
2693
+ return g.runKernel(Qt, s, r);
2694
+ }
2695
+ const At = /* @__PURE__ */ F({ cast_: $n });
2696
+ /**
2697
+ * @license
2698
+ * Copyright 2020 Google LLC. All Rights Reserved.
2699
+ * Licensed under the Apache License, Version 2.0 (the "License");
2700
+ * you may not use this file except in compliance with the License.
2701
+ * You may obtain a copy of the License at
2702
+ *
2703
+ * http://www.apache.org/licenses/LICENSE-2.0
2704
+ *
2705
+ * Unless required by applicable law or agreed to in writing, software
2706
+ * distributed under the License is distributed on an "AS IS" BASIS,
2707
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2708
+ * See the License for the specific language governing permissions and
2709
+ * limitations under the License.
2710
+ * =============================================================================
2711
+ */
2712
+ function Dn(n) {
2713
+ const e = { x: k(n, "x", "clone", "string_or_numeric") };
2714
+ return g.runKernel(Zt, e);
2715
+ }
2716
+ const Cn = /* @__PURE__ */ F({ clone_: Dn });
2717
+ /**
2718
+ * @license
2719
+ * Copyright 2020 Google Inc. All Rights Reserved.
2720
+ * Licensed under the Apache License, Version 2.0 (the "License");
2721
+ * you may not use this file except in compliance with the License.
2722
+ * You may obtain a copy of the License at
2723
+ *
2724
+ * http://www.apache.org/licenses/LICENSE-2.0
2725
+ *
2726
+ * Unless required by applicable law or agreed to in writing, software
2727
+ * distributed under the License is distributed on an "AS IS" BASIS,
2728
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2729
+ * See the License for the specific language governing permissions and
2730
+ * limitations under the License.
2731
+ * =============================================================================
2732
+ */
2733
+ function _n(n, t = !1) {
2734
+ console.log(n.toString(t));
2735
+ }
2736
+ /**
2737
+ * @license
2738
+ * Copyright 2020 Google Inc. All Rights Reserved.
2739
+ * Licensed under the Apache License, Version 2.0 (the "License");
2740
+ * you may not use this file except in compliance with the License.
2741
+ * You may obtain a copy of the License at
2742
+ *
2743
+ * http://www.apache.org/licenses/LICENSE-2.0
2744
+ *
2745
+ * Unless required by applicable law or agreed to in writing, software
2746
+ * distributed under the License is distributed on an "AS IS" BASIS,
2747
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2748
+ * See the License for the specific language governing permissions and
2749
+ * limitations under the License.
2750
+ * =============================================================================
2751
+ */
2752
+ ce();
2753
+ const On = {
2754
+ buffer: Nn,
2755
+ cast: At,
2756
+ clone: Cn,
2757
+ print: _n
2758
+ };
2759
+ sn(On);
2760
+ /**
2761
+ * @license
2762
+ * Copyright 2020 Google LLC. All Rights Reserved.
2763
+ * Licensed under the Apache License, Version 2.0 (the "License");
2764
+ * you may not use this file except in compliance with the License.
2765
+ * You may obtain a copy of the License at
2766
+ *
2767
+ * http://www.apache.org/licenses/LICENSE-2.0
2768
+ *
2769
+ * Unless required by applicable law or agreed to in writing, software
2770
+ * distributed under the License is distributed on an "AS IS" BASIS,
2771
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2772
+ * See the License for the specific language governing permissions and
2773
+ * limitations under the License.
2774
+ * =============================================================================
2775
+ */
2776
+ function Pn(n, t) {
2777
+ let e = k(n, "a", "add"), s = k(t, "b", "add");
2778
+ [e, s] = K(e, s);
2779
+ const r = { a: e, b: s };
2780
+ return g.runKernel(Yt, r);
2781
+ }
2782
+ const w = /* @__PURE__ */ F({ add_: Pn });
2783
+ /**
2784
+ * @license
2785
+ * Copyright 2020 Google LLC. All Rights Reserved.
2786
+ * Licensed under the Apache License, Version 2.0 (the "License");
2787
+ * you may not use this file except in compliance with the License.
2788
+ * You may obtain a copy of the License at
2789
+ *
2790
+ * http://www.apache.org/licenses/LICENSE-2.0
2791
+ *
2792
+ * Unless required by applicable law or agreed to in writing, software
2793
+ * distributed under the License is distributed on an "AS IS" BASIS,
2794
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2795
+ * See the License for the specific language governing permissions and
2796
+ * limitations under the License.
2797
+ * =============================================================================
2798
+ */
2799
+ function Ln(n, t) {
2800
+ let e = k(n, "a", "floorDiv"), s = k(t, "b", "floorDiv");
2801
+ [e, s] = K(e, s);
2802
+ const r = { a: e, b: s };
2803
+ return g.runKernel(Oe, r);
2804
+ }
2805
+ const Un = /* @__PURE__ */ F({ floorDiv_: Ln });
2806
+ /**
2807
+ * @license
2808
+ * Copyright 2020 Google LLC. All Rights Reserved.
2809
+ * Licensed under the Apache License, Version 2.0 (the "License");
2810
+ * you may not use this file except in compliance with the License.
2811
+ * You may obtain a copy of the License at
2812
+ *
2813
+ * http://www.apache.org/licenses/LICENSE-2.0
2814
+ *
2815
+ * Unless required by applicable law or agreed to in writing, software
2816
+ * distributed under the License is distributed on an "AS IS" BASIS,
2817
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2818
+ * See the License for the specific language governing permissions and
2819
+ * limitations under the License.
2820
+ * =============================================================================
2821
+ */
2822
+ function Gn(n, t) {
2823
+ let e = k(n, "a", "div"), s = k(t, "b", "div");
2824
+ if ([e, s] = K(e, s), e.dtype === "int32" && s.dtype === "int32")
2825
+ return Un(e, s);
2826
+ const r = { a: e, b: s }, i = {};
2827
+ return g.runKernel(Ce, r, i);
2828
+ }
2829
+ const D = /* @__PURE__ */ F({ div_: Gn });
2830
+ /**
2831
+ * @license
2832
+ * Copyright 2020 Google LLC. All Rights Reserved.
2833
+ * Licensed under the Apache License, Version 2.0 (the "License");
2834
+ * you may not use this file except in compliance with the License.
2835
+ * You may obtain a copy of the License at
2836
+ *
2837
+ * http://www.apache.org/licenses/LICENSE-2.0
2838
+ *
2839
+ * Unless required by applicable law or agreed to in writing, software
2840
+ * distributed under the License is distributed on an "AS IS" BASIS,
2841
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2842
+ * See the License for the specific language governing permissions and
2843
+ * limitations under the License.
2844
+ * =============================================================================
2845
+ */
2846
+ function zn(n, t) {
2847
+ let e = k(n, "a", "mul"), s = k(t, "b", "mul");
2848
+ [e, s] = K(e, s);
2849
+ const r = { a: e, b: s };
2850
+ return g.runKernel(Le, r);
2851
+ }
2852
+ const p = /* @__PURE__ */ F({ mul_: zn });
2853
+ /**
2854
+ * @license
2855
+ * Copyright 2018 Google LLC. All Rights Reserved.
2856
+ * Licensed under the Apache License, Version 2.0 (the "License");
2857
+ * you may not use this file except in compliance with the License.
2858
+ * You may obtain a copy of the License at
2859
+ *
2860
+ * http://www.apache.org/licenses/LICENSE-2.0
2861
+ *
2862
+ * Unless required by applicable law or agreed to in writing, software
2863
+ * distributed under the License is distributed on an "AS IS" BASIS,
2864
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2865
+ * See the License for the specific language governing permissions and
2866
+ * limitations under the License.
2867
+ * =============================================================================
2868
+ */
2869
+ function Wn(n) {
2870
+ const t = k(n, "x", "abs");
2871
+ if (t.dtype === "complex64") {
2872
+ const e = { x: t };
2873
+ return g.runKernel(De, e);
2874
+ } else {
2875
+ const e = { x: t };
2876
+ return g.runKernel($e, e);
2877
+ }
2878
+ }
2879
+ const jn = /* @__PURE__ */ F({ abs_: Wn });
2880
+ /**
2881
+ * @license
2882
+ * Copyright 2020 Google LLC. All Rights Reserved.
2883
+ * Licensed under the Apache License, Version 2.0 (the "License");
2884
+ * you may not use this file except in compliance with the License.
2885
+ * You may obtain a copy of the License at
2886
+ *
2887
+ * http://www.apache.org/licenses/LICENSE-2.0
2888
+ *
2889
+ * Unless required by applicable law or agreed to in writing, software
2890
+ * distributed under the License is distributed on an "AS IS" BASIS,
2891
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2892
+ * See the License for the specific language governing permissions and
2893
+ * limitations under the License.
2894
+ * =============================================================================
2895
+ */
2896
+ function Kn(n, t, e) {
2897
+ Rt(n), e = e || ft(t);
2898
+ const s = { shape: n, value: t, dtype: e };
2899
+ return g.runKernel(_e, {}, s);
2900
+ }
2901
+ /**
2902
+ * @license
2903
+ * Copyright 2017 Google LLC. All Rights Reserved.
2904
+ * Licensed under the Apache License, Version 2.0 (the "License");
2905
+ * you may not use this file except in compliance with the License.
2906
+ * You may obtain a copy of the License at
2907
+ *
2908
+ * http://www.apache.org/licenses/LICENSE-2.0
2909
+ *
2910
+ * Unless required by applicable law or agreed to in writing, software
2911
+ * distributed under the License is distributed on an "AS IS" BASIS,
2912
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2913
+ * See the License for the specific language governing permissions and
2914
+ * limitations under the License.
2915
+ * =============================================================================
2916
+ */
2917
+ function xs(n, t) {
2918
+ const e = [];
2919
+ for (let s = 0; s < t.length; s++) {
2920
+ const r = n[n.length - s - 1], i = t.length - s - 1, o = t[i];
2921
+ (r == null || r === 1 && o > 1) && e.unshift(i);
2922
+ }
2923
+ return e;
2924
+ }
2925
+ function Vn(n, t) {
2926
+ const e = Math.max(n.length, t.length), s = new Array(e);
2927
+ for (let r = 0; r < e; r++) {
2928
+ let i = n[n.length - r - 1];
2929
+ i == null && (i = 1);
2930
+ let o = t[t.length - r - 1];
2931
+ if (o == null && (o = 1), i === 1)
2932
+ s[e - r - 1] = o;
2933
+ else if (o === 1)
2934
+ s[e - r - 1] = i;
2935
+ else if (i !== o) {
2936
+ const a = `Operands could not be broadcast together with shapes ${n} and ${t}.`;
2937
+ throw Error(a);
2938
+ } else
2939
+ s[e - r - 1] = i;
2940
+ }
2941
+ return s;
2942
+ }
2943
+ /**
2944
+ * @license
2945
+ * Copyright 2018 Google LLC. All Rights Reserved.
2946
+ * Licensed under the Apache License, Version 2.0 (the "License");
2947
+ * you may not use this file except in compliance with the License.
2948
+ * You may obtain a copy of the License at
2949
+ *
2950
+ * http://www.apache.org/licenses/LICENSE-2.0
2951
+ *
2952
+ * Unless required by applicable law or agreed to in writing, software
2953
+ * distributed under the License is distributed on an "AS IS" BASIS,
2954
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2955
+ * See the License for the specific language governing permissions and
2956
+ * limitations under the License.
2957
+ * =============================================================================
2958
+ */
2959
+ function qn(n) {
2960
+ const e = { x: k(n, "x", "zerosLike") };
2961
+ return g.runKernel(We, e);
2962
+ }
2963
+ const C = /* @__PURE__ */ F({ zerosLike_: qn });
2964
+ /**
2965
+ * @license
2966
+ * Copyright 2020 Google LLC. All Rights Reserved.
2967
+ * Licensed under the Apache License, Version 2.0 (the "License");
2968
+ * you may not use this file except in compliance with the License.
2969
+ * You may obtain a copy of the License at
2970
+ *
2971
+ * http://www.apache.org/licenses/LICENSE-2.0
2972
+ *
2973
+ * Unless required by applicable law or agreed to in writing, software
2974
+ * distributed under the License is distributed on an "AS IS" BASIS,
2975
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2976
+ * See the License for the specific language governing permissions and
2977
+ * limitations under the License.
2978
+ * =============================================================================
2979
+ */
2980
+ function Hn(n, t) {
2981
+ let e = k(n, "base", "pow"), s = k(t, "exp", "pow");
2982
+ [e, s] = K(e, s);
2983
+ const r = { a: e, b: s };
2984
+ return g.runKernel(Ue, r);
2985
+ }
2986
+ const Vt = /* @__PURE__ */ F({ pow_: Hn });
2987
+ /**
2988
+ * @license
2989
+ * Copyright 2018 Google LLC. All Rights Reserved.
2990
+ * Licensed under the Apache License, Version 2.0 (the "License");
2991
+ * you may not use this file except in compliance with the License.
2992
+ * You may obtain a copy of the License at
2993
+ *
2994
+ * http://www.apache.org/licenses/LICENSE-2.0
2995
+ *
2996
+ * Unless required by applicable law or agreed to in writing, software
2997
+ * distributed under the License is distributed on an "AS IS" BASIS,
2998
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2999
+ * See the License for the specific language governing permissions and
3000
+ * limitations under the License.
3001
+ * =============================================================================
3002
+ */
3003
+ function j(n, t) {
3004
+ if ((R(n) && t !== "string" || Array.isArray(n)) && t !== "complex64")
3005
+ throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");
3006
+ if (t === "string" && R(n) && !(n instanceof Uint8Array))
3007
+ throw new Error("When making a scalar from encoded string, the value must be `Uint8Array`.");
3008
+ return dn(n, [], [], t);
3009
+ }
3010
+ /**
3011
+ * @license
3012
+ * Copyright 2018 Google LLC. All Rights Reserved.
3013
+ * Licensed under the Apache License, Version 2.0 (the "License");
3014
+ * you may not use this file except in compliance with the License.
3015
+ * You may obtain a copy of the License at
3016
+ *
3017
+ * http://www.apache.org/licenses/LICENSE-2.0
3018
+ *
3019
+ * Unless required by applicable law or agreed to in writing, software
3020
+ * distributed under the License is distributed on an "AS IS" BASIS,
3021
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3022
+ * See the License for the specific language governing permissions and
3023
+ * limitations under the License.
3024
+ * =============================================================================
3025
+ */
3026
+ function Jn(n) {
3027
+ const e = { x: k(n, "x", "sqrt", "float32") };
3028
+ return g.runKernel(Ge, e);
3029
+ }
3030
+ const Z = /* @__PURE__ */ F({ sqrt_: Jn });
3031
+ /**
3032
+ * @license
3033
+ * Copyright 2019 Google LLC. All Rights Reserved.
3034
+ * Licensed under the Apache License, Version 2.0 (the "License");
3035
+ * you may not use this file except in compliance with the License.
3036
+ * You may obtain a copy of the License at
3037
+ *
3038
+ * http://www.apache.org/licenses/LICENSE-2.0
3039
+ *
3040
+ * Unless required by applicable law or agreed to in writing, software
3041
+ * distributed under the License is distributed on an "AS IS" BASIS,
3042
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3043
+ * See the License for the specific language governing permissions and
3044
+ * limitations under the License.
3045
+ * =============================================================================
3046
+ */
3047
+ function Xn(n) {
3048
+ const t = k(n, "x", "square"), e = {};
3049
+ return g.runKernel("Square", { x: t }, e);
3050
+ }
3051
+ const G = /* @__PURE__ */ F({ square_: Xn });
3052
+ /**
3053
+ * @license
3054
+ * Copyright 2018 Google LLC. All Rights Reserved.
3055
+ * Licensed under the Apache License, Version 2.0 (the "License");
3056
+ * you may not use this file except in compliance with the License.
3057
+ * You may obtain a copy of the License at
3058
+ *
3059
+ * http://www.apache.org/licenses/LICENSE-2.0
3060
+ *
3061
+ * Unless required by applicable law or agreed to in writing, software
3062
+ * distributed under the License is distributed on an "AS IS" BASIS,
3063
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3064
+ * See the License for the specific language governing permissions and
3065
+ * limitations under the License.
3066
+ * =============================================================================
3067
+ */
3068
+ function Yn(n, t) {
3069
+ b(bt(n), () => "The f passed in variableGrads(f) must be a function"), b(t == null || Array.isArray(t) && t.every((l) => l instanceof ht), () => "The varList passed in variableGrads(f, varList) must be an array of variables");
3070
+ const e = t != null;
3071
+ if (!e) {
3072
+ t = [];
3073
+ for (const l in g.registeredVariables)
3074
+ t.push(g.registeredVariables[l]);
3075
+ }
3076
+ const s = e ? t.filter((l) => !l.trainable) : null, r = t.length;
3077
+ t = t.filter((l) => l.trainable), b(t.length > 0, () => `variableGrads() expects at least one of the input variables to be trainable, but none of the ${r} variables is trainable.`);
3078
+ const i = !0, { value: o, grads: a } = g.gradients(n, t, null, i);
3079
+ b(a.some((l) => l != null), () => "Cannot find a connection between any variable and the result of the loss function y=f(x). Please make sure the operations that use variables are inside the function f passed to minimize()."), b(o.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it returned a rank-${o.rank} tensor`);
3080
+ const c = {};
3081
+ return t.forEach((l, u) => {
3082
+ a[u] != null && (c[l.name] = a[u]);
3083
+ }), s?.forEach((l) => c[l.name] = null), { value: o, grads: c };
3084
+ }
3085
+ function Ns(n) {
3086
+ return g.customGrad(n);
3087
+ }
3088
+ /**
3089
+ * @license
3090
+ * Copyright 2020 Google LLC. All Rights Reserved.
3091
+ * Licensed under the Apache License, Version 2.0 (the "License");
3092
+ * you may not use this file except in compliance with the License.
3093
+ * You may obtain a copy of the License at
3094
+ *
3095
+ * http://www.apache.org/licenses/LICENSE-2.0
3096
+ *
3097
+ * Unless required by applicable law or agreed to in writing, software
3098
+ * distributed under the License is distributed on an "AS IS" BASIS,
3099
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3100
+ * See the License for the specific language governing permissions and
3101
+ * limitations under the License.
3102
+ * =============================================================================
3103
+ */
3104
+ function Qn(n, t) {
3105
+ let e = k(n, "a", "sub"), s = k(t, "b", "sub");
3106
+ [e, s] = K(e, s);
3107
+ const r = { a: e, b: s };
3108
+ return g.runKernel(ze, r);
3109
+ }
3110
+ const Y = /* @__PURE__ */ F({ sub_: Qn });
3111
+ /**
3112
+ * @license
3113
+ * Copyright 2020 Google LLC. All Rights Reserved.
3114
+ * Licensed under the Apache License, Version 2.0 (the "License");
3115
+ * you may not use this file except in compliance with the License.
3116
+ * You may obtain a copy of the License at
3117
+ *
3118
+ * http://www.apache.org/licenses/LICENSE-2.0
3119
+ *
3120
+ * Unless required by applicable law or agreed to in writing, software
3121
+ * distributed under the License is distributed on an "AS IS" BASIS,
3122
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3123
+ * See the License for the specific language governing permissions and
3124
+ * limitations under the License.
3125
+ * =============================================================================
3126
+ */
3127
+ function Zn(n, t) {
3128
+ let e = k(n, "a", "maximum"), s = k(t, "b", "maximum");
3129
+ [e, s] = K(e, s), e.dtype === "bool" && (e = At(e, "int32"), s = At(s, "int32")), Vn(e.shape, s.shape);
3130
+ const r = { a: e, b: s };
3131
+ return g.runKernel(Pe, r);
3132
+ }
3133
+ const ts = /* @__PURE__ */ F({ maximum_: Zn });
3134
+ /**
3135
+ * @license
3136
+ * Copyright 2018 Google LLC. All Rights Reserved.
3137
+ * Licensed under the Apache License, Version 2.0 (the "License");
3138
+ * you may not use this file except in compliance with the License.
3139
+ * You may obtain a copy of the License at
3140
+ *
3141
+ * http://www.apache.org/licenses/LICENSE-2.0
3142
+ *
3143
+ * Unless required by applicable law or agreed to in writing, software
3144
+ * distributed under the License is distributed on an "AS IS" BASIS,
3145
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3146
+ * See the License for the specific language governing permissions and
3147
+ * limitations under the License.
3148
+ * =============================================================================
3149
+ */
3150
+ const es = /* @__PURE__ */ new Map(), ns = /* @__PURE__ */ new Map();
3151
+ class ss {
3152
+ /**
3153
+ * Return the class name for this class to use in serialization contexts.
3154
+ *
3155
+ * Generally speaking this will be the same thing that constructor.name
3156
+ * would have returned. However, the class name needs to be robust
3157
+ * against minification for serialization/deserialization to work properly.
3158
+ *
3159
+ * There's also places such as initializers.VarianceScaling, where
3160
+ * implementation details between different languages led to different
3161
+ * class hierarchies and a non-leaf node is used for serialization purposes.
3162
+ */
3163
+ getClassName() {
3164
+ return this.constructor.className;
3165
+ }
3166
+ /**
3167
+ * Creates an instance of T from a ConfigDict.
3168
+ *
3169
+ * This works for most descendants of serializable. A few need to
3170
+ * provide special handling.
3171
+ * @param cls A Constructor for the class to instantiate.
3172
+ * @param config The Configuration for the object.
3173
+ */
3174
+ /** @nocollapse */
3175
+ static fromConfig(t, e) {
3176
+ return new t(e);
3177
+ }
3178
+ }
3179
+ class P {
3180
+ constructor() {
3181
+ this.classNameMap = {};
3182
+ }
3183
+ /**
3184
+ * Returns the singleton instance of the map.
3185
+ */
3186
+ static getMap() {
3187
+ return P.instance == null && (P.instance = new P()), P.instance;
3188
+ }
3189
+ /**
3190
+ * Registers the class as serializable.
3191
+ */
3192
+ static register(t) {
3193
+ P.getMap().classNameMap[t.className] = [t, t.fromConfig];
3194
+ }
3195
+ }
3196
+ function rs(n, t, e) {
3197
+ b(n.className != null, () => "Class being registered does not have the static className property defined."), b(typeof n.className == "string", () => "className is required to be a string, but got type " + typeof n.className), b(n.className.length > 0, () => "Class being registered has an empty-string as its className, which is disallowed."), typeof t > "u" && (t = "Custom"), typeof e > "u" && (e = n.className);
3198
+ const s = e, r = t + ">" + s;
3199
+ return P.register(n), es.set(r, n), ns.set(n, r), n;
3200
+ }
3201
+ /**
3202
+ * @license
3203
+ * Copyright 2018 Google LLC. All Rights Reserved.
3204
+ * Licensed under the Apache License, Version 2.0 (the "License");
3205
+ * you may not use this file except in compliance with the License.
3206
+ * You may obtain a copy of the License at
3207
+ *
3208
+ * http://www.apache.org/licenses/LICENSE-2.0
3209
+ *
3210
+ * Unless required by applicable law or agreed to in writing, software
3211
+ * distributed under the License is distributed on an "AS IS" BASIS,
3212
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3213
+ * See the License for the specific language governing permissions and
3214
+ * limitations under the License.
3215
+ * =============================================================================
3216
+ */
3217
+ class V extends ss {
3218
+ /**
3219
+ * Executes `f()` and minimizes the scalar output of `f()` by computing
3220
+ * gradients of y with respect to the list of trainable variables provided by
3221
+ * `varList`. If no list is provided, it defaults to all trainable variables.
3222
+ *
3223
+ * @param f The function to execute and whose output to minimize.
3224
+ * @param returnCost Whether to return the scalar cost value produced by
3225
+ * executing `f()`.
3226
+ * @param varList An optional list of variables to update. If specified, only
3227
+ * the trainable variables in varList will be updated by minimize. Defaults to
3228
+ * all trainable variables.
3229
+ *
3230
+ * @doc {heading: 'Training', subheading: 'Optimizers'}
3231
+ */
3232
+ minimize(t, e = !1, s) {
3233
+ const { value: r, grads: i } = this.computeGradients(t, s);
3234
+ if (s != null) {
3235
+ const o = s.map((a) => ({ name: a.name, tensor: i[a.name] }));
3236
+ this.applyGradients(o);
3237
+ } else
3238
+ this.applyGradients(i);
3239
+ return M(i), e ? r : (r.dispose(), null);
3240
+ }
3241
+ /**
3242
+ * The number of iterations that this optimizer instance has been invoked for.
3243
+ */
3244
+ get iterations() {
3245
+ return this.iterations_ == null && (this.iterations_ = 0), this.iterations_;
3246
+ }
3247
+ incrementIterations() {
3248
+ this.iterations_ = this.iterations + 1;
3249
+ }
3250
+ /**
3251
+ * Executes f() and computes the gradient of the scalar output of f() with
3252
+ * respect to the list of trainable variables provided by `varList`. If no
3253
+ * list is provided, it defaults to all trainable variables.
3254
+ *
3255
+ * @param f The function to execute and whose output to use for computing
3256
+ * gradients with respect to variables.
3257
+ * @param varList An optional list of variables to compute gradients with
3258
+ * respect to. If specified, only the trainable variables in varList will have
3259
+ * gradients computed with respect to. Defaults to all trainable variables.
3260
+ *
3261
+ * @doc {heading: 'Training', subheading: 'Optimizers'}
3262
+ */
3263
+ computeGradients(t, e) {
3264
+ return Yn(t, e);
3265
+ }
3266
+ /**
3267
+ * Dispose the variables (if any) owned by this optimizer instance.
3268
+ */
3269
+ dispose() {
3270
+ this.iterations_ != null && M(this.iterations_);
3271
+ }
3272
+ async saveIterations() {
3273
+ return this.iterations_ == null && (this.iterations_ = 0), {
3274
+ name: "iter",
3275
+ // TODO(cais): Use 'int64' type when available.
3276
+ tensor: j(this.iterations_, "int32")
3277
+ };
3278
+ }
3279
+ async getWeights() {
3280
+ throw new Error("getWeights() is not implemented for this optimizer yet.");
3281
+ }
3282
+ async setWeights(t) {
3283
+ throw new Error(`setWeights() is not implemented for this optimizer class ${this.getClassName()}`);
3284
+ }
3285
+ /**
3286
+ * Extract the first element of the weight values and set it
3287
+ * as the iterations counter variable of this instance of optimizer.
3288
+ *
3289
+ * @param weightValues
3290
+ * @returns Weight values with the first element consumed and excluded.
3291
+ */
3292
+ async extractIterations(t) {
3293
+ return this.iterations_ = (await t[0].tensor.data())[0], t.slice(1);
3294
+ }
3295
+ }
3296
+ Object.defineProperty(V, Symbol.hasInstance, {
3297
+ value: (n) => n.minimize != null && n.computeGradients != null && n.applyGradients != null
3298
+ });
3299
+ /**
3300
+ * @license
3301
+ * Copyright 2018 Google LLC. All Rights Reserved.
3302
+ * Licensed under the Apache License, Version 2.0 (the "License");
3303
+ * you may not use this file except in compliance with the License.
3304
+ * You may obtain a copy of the License at
3305
+ *
3306
+ * http://www.apache.org/licenses/LICENSE-2.0
3307
+ *
3308
+ * Unless required by applicable law or agreed to in writing, software
3309
+ * distributed under the License is distributed on an "AS IS" BASIS,
3310
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3311
+ * See the License for the specific language governing permissions and
3312
+ * limitations under the License.
3313
+ * =============================================================================
3314
+ */
3315
+ class is extends V {
3316
+ /** @nocollapse */
3317
+ static get className() {
3318
+ return "Adadelta";
3319
+ }
3320
+ constructor(t, e, s = null) {
3321
+ super(), this.learningRate = t, this.rho = e, this.epsilon = s, this.accumulatedGrads = [], this.accumulatedUpdates = [], s == null && (this.epsilon = g.backend.epsilon());
3322
+ }
3323
+ applyGradients(t) {
3324
+ (Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t)).forEach((s, r) => {
3325
+ const i = g.registeredVariables[s], o = !1;
3326
+ this.accumulatedGrads[r] == null && (this.accumulatedGrads[r] = {
3327
+ originalName: `${s}/accum_grad`,
3328
+ variable: E(() => C(i).variable(o))
3329
+ }), this.accumulatedUpdates[r] == null && (this.accumulatedUpdates[r] = {
3330
+ originalName: `${s}/accum_var`,
3331
+ variable: E(() => C(i).variable(o))
3332
+ });
3333
+ const a = Array.isArray(t) ? t[r].tensor : t[s];
3334
+ if (a == null)
3335
+ return;
3336
+ const c = this.accumulatedGrads[r].variable, l = this.accumulatedUpdates[r].variable;
3337
+ E(() => {
3338
+ const u = w(p(c, this.rho), p(G(a), 1 - this.rho)), h = p(D(Z(w(l, this.epsilon)), Z(w(c, this.epsilon))), a), f = w(p(l, this.rho), p(G(h), 1 - this.rho));
3339
+ c.assign(u), l.assign(f);
3340
+ const m = w(p(h, -this.learningRate), i);
3341
+ i.assign(m);
3342
+ });
3343
+ }), this.incrementIterations();
3344
+ }
3345
+ dispose() {
3346
+ this.accumulatedUpdates != null && (M(this.accumulatedGrads.map((t) => t.variable)), M(this.accumulatedUpdates.map((t) => t.variable)));
3347
+ }
3348
+ async getWeights() {
3349
+ const t = [...this.accumulatedGrads, ...this.accumulatedUpdates];
3350
+ return [await this.saveIterations()].concat(t.map((e) => ({ name: e.originalName, tensor: e.variable })));
3351
+ }
3352
+ async setWeights(t) {
3353
+ t = await this.extractIterations(t);
3354
+ const e = t.length / 2, s = !1;
3355
+ this.accumulatedGrads = t.slice(0, e).map((r) => ({
3356
+ originalName: r.name,
3357
+ variable: r.tensor.variable(s)
3358
+ })), this.accumulatedUpdates = t.slice(e, e * 2).map((r) => ({
3359
+ originalName: r.name,
3360
+ variable: r.tensor.variable(s)
3361
+ }));
3362
+ }
3363
+ getConfig() {
3364
+ return {
3365
+ learningRate: this.learningRate,
3366
+ rho: this.rho,
3367
+ epsilon: this.epsilon
3368
+ };
3369
+ }
3370
+ /** @nocollapse */
3371
+ static fromConfig(t, e) {
3372
+ return new t(e.learningRate, e.rho, e.epsilon);
3373
+ }
3374
+ }
3375
+ /**
3376
+ * @license
3377
+ * Copyright 2018 Google LLC. All Rights Reserved.
3378
+ * Licensed under the Apache License, Version 2.0 (the "License");
3379
+ * you may not use this file except in compliance with the License.
3380
+ * You may obtain a copy of the License at
3381
+ *
3382
+ * http://www.apache.org/licenses/LICENSE-2.0
3383
+ *
3384
+ * Unless required by applicable law or agreed to in writing, software
3385
+ * distributed under the License is distributed on an "AS IS" BASIS,
3386
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3387
+ * See the License for the specific language governing permissions and
3388
+ * limitations under the License.
3389
+ * =============================================================================
3390
+ */
3391
+ class os extends V {
3392
+ /** @nocollapse */
3393
+ static get className() {
3394
+ return "Adagrad";
3395
+ }
3396
+ constructor(t, e = 0.1) {
3397
+ super(), this.learningRate = t, this.initialAccumulatorValue = e, this.accumulatedGrads = [];
3398
+ }
3399
+ applyGradients(t) {
3400
+ (Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t)).forEach((s, r) => {
3401
+ const i = g.registeredVariables[s];
3402
+ this.accumulatedGrads[r] == null && (this.accumulatedGrads[r] = {
3403
+ originalName: `${s}/accumulator`,
3404
+ variable: E(() => Kn(i.shape, this.initialAccumulatorValue).variable(!1))
3405
+ });
3406
+ const o = Array.isArray(t) ? t[r].tensor : t[s];
3407
+ if (o == null)
3408
+ return;
3409
+ const a = this.accumulatedGrads[r].variable;
3410
+ E(() => {
3411
+ const c = w(a, G(o));
3412
+ a.assign(c);
3413
+ const l = w(p(D(o, Z(w(c, g.backend.epsilon()))), -this.learningRate), i);
3414
+ i.assign(l);
3415
+ });
3416
+ }), this.incrementIterations();
3417
+ }
3418
+ dispose() {
3419
+ this.accumulatedGrads != null && M(this.accumulatedGrads.map((t) => t.variable));
3420
+ }
3421
+ async getWeights() {
3422
+ return [await this.saveIterations()].concat(this.accumulatedGrads.map((t) => ({ name: t.originalName, tensor: t.variable })));
3423
+ }
3424
+ async setWeights(t) {
3425
+ t = await this.extractIterations(t);
3426
+ const e = !1;
3427
+ this.accumulatedGrads = t.map((s) => ({ originalName: s.name, variable: s.tensor.variable(e) }));
3428
+ }
3429
+ getConfig() {
3430
+ return {
3431
+ learningRate: this.learningRate,
3432
+ initialAccumulatorValue: this.initialAccumulatorValue
3433
+ };
3434
+ }
3435
+ /** @nocollapse */
3436
+ static fromConfig(t, e) {
3437
+ return new t(e.learningRate, e.initialAccumulatorValue);
3438
+ }
3439
+ }
3440
+ /**
3441
+ * @license
3442
+ * Copyright 2018 Google LLC. All Rights Reserved.
3443
+ * Licensed under the Apache License, Version 2.0 (the "License");
3444
+ * you may not use this file except in compliance with the License.
3445
+ * You may obtain a copy of the License at
3446
+ *
3447
+ * http://www.apache.org/licenses/LICENSE-2.0
3448
+ *
3449
+ * Unless required by applicable law or agreed to in writing, software
3450
+ * distributed under the License is distributed on an "AS IS" BASIS,
3451
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3452
+ * See the License for the specific language governing permissions and
3453
+ * limitations under the License.
3454
+ * =============================================================================
3455
+ */
3456
+ class as extends V {
3457
+ /** @nocollapse */
3458
+ static get className() {
3459
+ return "Adam";
3460
+ }
3461
+ constructor(t, e, s, r = null) {
3462
+ super(), this.learningRate = t, this.beta1 = e, this.beta2 = s, this.epsilon = r, this.accumulatedFirstMoment = [], this.accumulatedSecondMoment = [], E(() => {
3463
+ this.accBeta1 = j(e).variable(), this.accBeta2 = j(s).variable();
3464
+ }), r == null && (this.epsilon = g.backend.epsilon());
3465
+ }
3466
+ applyGradients(t) {
3467
+ const e = Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t);
3468
+ E(() => {
3469
+ const s = Y(1, this.accBeta1), r = Y(1, this.accBeta2);
3470
+ e.forEach((i, o) => {
3471
+ const a = g.registeredVariables[i], c = !1;
3472
+ this.accumulatedFirstMoment[o] == null && (this.accumulatedFirstMoment[o] = {
3473
+ originalName: `${i}/m`,
3474
+ variable: E(() => C(a).variable(c))
3475
+ }), this.accumulatedSecondMoment[o] == null && (this.accumulatedSecondMoment[o] = {
3476
+ originalName: `${i}/v`,
3477
+ variable: E(() => C(a).variable(c))
3478
+ });
3479
+ const l = Array.isArray(t) ? t[o].tensor : t[i];
3480
+ if (l == null)
3481
+ return;
3482
+ const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedSecondMoment[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = w(p(h, this.beta2), p(G(l), 1 - this.beta2)), y = D(f, s), d = D(m, r);
3483
+ u.assign(f), h.assign(m);
3484
+ const I = w(p(D(y, w(Z(d), this.epsilon)), -this.learningRate), a);
3485
+ a.assign(I);
3486
+ }), this.accBeta1.assign(p(this.accBeta1, this.beta1)), this.accBeta2.assign(p(this.accBeta2, this.beta2));
3487
+ }), this.incrementIterations();
3488
+ }
3489
+ dispose() {
3490
+ this.accBeta1.dispose(), this.accBeta2.dispose(), this.accumulatedFirstMoment != null && M(this.accumulatedFirstMoment.map((t) => t.variable)), this.accumulatedSecondMoment != null && M(this.accumulatedSecondMoment.map((t) => t.variable));
3491
+ }
3492
+ async getWeights() {
3493
+ const t = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment];
3494
+ return [await this.saveIterations()].concat(t.map((e) => ({ name: e.originalName, tensor: e.variable })));
3495
+ }
3496
+ async setWeights(t) {
3497
+ t = await this.extractIterations(t), E(() => {
3498
+ this.accBeta1.assign(Vt(this.beta1, this.iterations_ + 1)), this.accBeta2.assign(Vt(this.beta2, this.iterations_ + 1));
3499
+ });
3500
+ const e = t.length / 2, s = !1;
3501
+ this.accumulatedFirstMoment = t.slice(0, e).map((r) => ({
3502
+ originalName: r.name,
3503
+ variable: r.tensor.variable(s)
3504
+ })), this.accumulatedSecondMoment = t.slice(e, e * 2).map((r) => ({
3505
+ originalName: r.name,
3506
+ variable: r.tensor.variable(s)
3507
+ }));
3508
+ }
3509
+ getConfig() {
3510
+ return {
3511
+ learningRate: this.learningRate,
3512
+ beta1: this.beta1,
3513
+ beta2: this.beta2,
3514
+ epsilon: this.epsilon
3515
+ };
3516
+ }
3517
+ /** @nocollapse */
3518
+ static fromConfig(t, e) {
3519
+ return new t(e.learningRate, e.beta1, e.beta2, e.epsilon);
3520
+ }
3521
+ }
3522
+ /**
3523
+ * @license
3524
+ * Copyright 2018 Google LLC. All Rights Reserved.
3525
+ * Licensed under the Apache License, Version 2.0 (the "License");
3526
+ * you may not use this file except in compliance with the License.
3527
+ * You may obtain a copy of the License at
3528
+ *
3529
+ * http://www.apache.org/licenses/LICENSE-2.0
3530
+ *
3531
+ * Unless required by applicable law or agreed to in writing, software
3532
+ * distributed under the License is distributed on an "AS IS" BASIS,
3533
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3534
+ * See the License for the specific language governing permissions and
3535
+ * limitations under the License.
3536
+ * =============================================================================
3537
+ */
3538
+ class ls extends V {
3539
+ /** @nocollapse */
3540
+ static get className() {
3541
+ return "Adamax";
3542
+ }
3543
+ constructor(t, e, s, r = null, i = 0) {
3544
+ super(), this.learningRate = t, this.beta1 = e, this.beta2 = s, this.epsilon = r, this.decay = i, this.accumulatedFirstMoment = [], this.accumulatedWeightedInfNorm = [], E(() => {
3545
+ this.iteration = j(0).variable(), this.accBeta1 = j(e).variable();
3546
+ }), r == null && (this.epsilon = g.backend.epsilon());
3547
+ }
3548
+ applyGradients(t) {
3549
+ const e = Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t);
3550
+ E(() => {
3551
+ const s = Y(1, this.accBeta1), r = D(-this.learningRate, w(p(this.iteration, this.decay), 1));
3552
+ e.forEach((i, o) => {
3553
+ const a = g.registeredVariables[i], c = !1;
3554
+ this.accumulatedFirstMoment[o] == null && (this.accumulatedFirstMoment[o] = {
3555
+ originalName: `${i}/m`,
3556
+ variable: C(a).variable(c)
3557
+ }), this.accumulatedWeightedInfNorm[o] == null && (this.accumulatedWeightedInfNorm[o] = {
3558
+ originalName: `${i}/v`,
3559
+ variable: C(a).variable(c)
3560
+ });
3561
+ const l = Array.isArray(t) ? t[o].tensor : t[i];
3562
+ if (l == null)
3563
+ return;
3564
+ const u = this.accumulatedFirstMoment[o].variable, h = this.accumulatedWeightedInfNorm[o].variable, f = w(p(u, this.beta1), p(l, 1 - this.beta1)), m = p(h, this.beta2), y = jn(l), d = ts(m, y);
3565
+ u.assign(f), h.assign(d);
3566
+ const I = w(p(D(r, s), D(f, w(d, this.epsilon))), a);
3567
+ a.assign(I);
3568
+ }), this.iteration.assign(w(this.iteration, 1)), this.accBeta1.assign(p(this.accBeta1, this.beta1));
3569
+ }), this.incrementIterations();
3570
+ }
3571
+ dispose() {
3572
+ this.accBeta1.dispose(), this.iteration.dispose(), this.accumulatedFirstMoment != null && M(this.accumulatedFirstMoment.map((t) => t.variable)), this.accumulatedWeightedInfNorm != null && M(this.accumulatedWeightedInfNorm.map((t) => t.variable));
3573
+ }
3574
+ async getWeights() {
3575
+ throw new Error("getWeights() is not implemented for Adamax yet.");
3576
+ }
3577
+ async setWeights(t) {
3578
+ throw new Error("setWeights() is not implemented for Adamax yet.");
3579
+ }
3580
+ getConfig() {
3581
+ return {
3582
+ learningRate: this.learningRate,
3583
+ beta1: this.beta1,
3584
+ beta2: this.beta2,
3585
+ epsilon: this.epsilon,
3586
+ decay: this.decay
3587
+ };
3588
+ }
3589
+ /** @nocollapse */
3590
+ static fromConfig(t, e) {
3591
+ return new t(e.learningRate, e.beta1, e.beta2, e.epsilon, e.decay);
3592
+ }
3593
+ }
3594
+ /**
3595
+ * @license
3596
+ * Copyright 2018 Google LLC. All Rights Reserved.
3597
+ * Licensed under the Apache License, Version 2.0 (the "License");
3598
+ * you may not use this file except in compliance with the License.
3599
+ * You may obtain a copy of the License at
3600
+ *
3601
+ * http://www.apache.org/licenses/LICENSE-2.0
3602
+ *
3603
+ * Unless required by applicable law or agreed to in writing, software
3604
+ * distributed under the License is distributed on an "AS IS" BASIS,
3605
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3606
+ * See the License for the specific language governing permissions and
3607
+ * limitations under the License.
3608
+ * =============================================================================
3609
+ */
3610
+ class be extends V {
3611
+ /** @nocollapse */
3612
+ static get className() {
3613
+ return "SGD";
3614
+ }
3615
+ constructor(t) {
3616
+ super(), this.learningRate = t, this.setLearningRate(t);
3617
+ }
3618
+ applyGradients(t) {
3619
+ (Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t)).forEach((s, r) => {
3620
+ const i = Array.isArray(t) ? t[r].tensor : t[s];
3621
+ if (i == null)
3622
+ return;
3623
+ const o = g.registeredVariables[s];
3624
+ E(() => {
3625
+ const a = w(p(this.c, i), o);
3626
+ o.assign(a);
3627
+ });
3628
+ }), this.incrementIterations();
3629
+ }
3630
+ /**
3631
+ * Sets the learning rate of the optimizer.
3632
+ */
3633
+ setLearningRate(t) {
3634
+ this.learningRate = t, this.c != null && this.c.dispose(), this.c = mn(j(-t));
3635
+ }
3636
+ dispose() {
3637
+ this.c.dispose();
3638
+ }
3639
+ async getWeights() {
3640
+ return [await this.saveIterations()];
3641
+ }
3642
+ async setWeights(t) {
3643
+ if (t = await this.extractIterations(t), t.length !== 0)
3644
+ throw new Error("SGD optimizer does not have settable weights.");
3645
+ }
3646
+ getConfig() {
3647
+ return { learningRate: this.learningRate };
3648
+ }
3649
+ /** @nocollapse */
3650
+ static fromConfig(t, e) {
3651
+ return new t(e.learningRate);
3652
+ }
3653
+ }
3654
+ /**
3655
+ * @license
3656
+ * Copyright 2018 Google LLC. All Rights Reserved.
3657
+ * Licensed under the Apache License, Version 2.0 (the "License");
3658
+ * you may not use this file except in compliance with the License.
3659
+ * You may obtain a copy of the License at
3660
+ *
3661
+ * http://www.apache.org/licenses/LICENSE-2.0
3662
+ *
3663
+ * Unless required by applicable law or agreed to in writing, software
3664
+ * distributed under the License is distributed on an "AS IS" BASIS,
3665
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3666
+ * See the License for the specific language governing permissions and
3667
+ * limitations under the License.
3668
+ * =============================================================================
3669
+ */
3670
+ class cs extends be {
3671
+ /** @nocollapse */
3672
+ // Name matters for Python compatibility.
3673
+ static get className() {
3674
+ return "Momentum";
3675
+ }
3676
+ constructor(t, e, s = !1) {
3677
+ super(t), this.learningRate = t, this.momentum = e, this.useNesterov = s, this.accumulations = [], this.m = j(this.momentum);
3678
+ }
3679
+ applyGradients(t) {
3680
+ (Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t)).forEach((s, r) => {
3681
+ const i = g.registeredVariables[s];
3682
+ this.accumulations[r] == null && (this.accumulations[r] = {
3683
+ originalName: `${s}/momentum`,
3684
+ variable: E(() => C(i).variable(!1))
3685
+ });
3686
+ const o = this.accumulations[r].variable, a = Array.isArray(t) ? t[r].tensor : t[s];
3687
+ a != null && E(() => {
3688
+ let c;
3689
+ const l = w(p(this.m, o), a);
3690
+ this.useNesterov ? c = w(p(this.c, w(a, p(l, this.m))), i) : c = w(p(this.c, l), i), o.assign(l), i.assign(c);
3691
+ });
3692
+ }), this.incrementIterations();
3693
+ }
3694
+ dispose() {
3695
+ this.m.dispose(), this.accumulations != null && M(this.accumulations.map((t) => t.variable));
3696
+ }
3697
+ /**
3698
+ * Sets the momentum of the optimizer.
3699
+ *
3700
+ * @param momentum
3701
+ */
3702
+ setMomentum(t) {
3703
+ this.momentum = t;
3704
+ }
3705
+ async getWeights() {
3706
+ return [await this.saveIterations()].concat(this.accumulations.map((t) => ({ name: t.originalName, tensor: t.variable })));
3707
+ }
3708
+ async setWeights(t) {
3709
+ t = await this.extractIterations(t);
3710
+ const e = !1;
3711
+ this.accumulations = t.map((s) => ({ originalName: s.name, variable: s.tensor.variable(e) }));
3712
+ }
3713
+ getConfig() {
3714
+ return {
3715
+ learningRate: this.learningRate,
3716
+ momentum: this.momentum,
3717
+ useNesterov: this.useNesterov
3718
+ };
3719
+ }
3720
+ /** @nocollapse */
3721
+ static fromConfig(t, e) {
3722
+ return new t(e.learningRate, e.momentum, e.useNesterov);
3723
+ }
3724
+ }
3725
+ /**
3726
+ * @license
3727
+ * Copyright 2018 Google LLC. All Rights Reserved.
3728
+ * Licensed under the Apache License, Version 2.0 (the "License");
3729
+ * you may not use this file except in compliance with the License.
3730
+ * You may obtain a copy of the License at
3731
+ *
3732
+ * http://www.apache.org/licenses/LICENSE-2.0
3733
+ *
3734
+ * Unless required by applicable law or agreed to in writing, software
3735
+ * distributed under the License is distributed on an "AS IS" BASIS,
3736
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3737
+ * See the License for the specific language governing permissions and
3738
+ * limitations under the License.
3739
+ * =============================================================================
3740
+ */
3741
+ class us extends V {
3742
+ /** @nocollapse */
3743
+ static get className() {
3744
+ return "RMSProp";
3745
+ }
3746
+ constructor(t, e = 0.9, s = 0, r = null, i = !1) {
3747
+ if (super(), this.learningRate = t, this.decay = e, this.momentum = s, this.epsilon = r, this.accumulatedMeanSquares = [], this.accumulatedMoments = [], this.accumulatedMeanGrads = [], this.centered = i, r == null && (this.epsilon = g.backend.epsilon()), t == null)
3748
+ throw new Error("learningRate for RMSPropOptimizer must be defined.");
3749
+ }
3750
+ applyGradients(t) {
3751
+ (Array.isArray(t) ? t.map((s) => s.name) : Object.keys(t)).forEach((s, r) => {
3752
+ const i = g.registeredVariables[s], o = !1;
3753
+ this.accumulatedMeanSquares[r] == null && (this.accumulatedMeanSquares[r] = {
3754
+ originalName: `${s}/rms`,
3755
+ variable: E(() => C(i).variable(o))
3756
+ }), this.accumulatedMoments[r] == null && (this.accumulatedMoments[r] = {
3757
+ originalName: `${s}/momentum`,
3758
+ variable: E(() => C(i).variable(o))
3759
+ }), this.accumulatedMeanGrads[r] == null && this.centered && (this.accumulatedMeanGrads[r] = {
3760
+ originalName: `${s}/mg`,
3761
+ variable: E(() => C(i).variable(o))
3762
+ });
3763
+ const a = Array.isArray(t) ? t[r].tensor : t[s];
3764
+ if (a == null)
3765
+ return;
3766
+ const c = this.accumulatedMeanSquares[r].variable, l = this.accumulatedMoments[r].variable;
3767
+ E(() => {
3768
+ const u = w(p(c, this.decay), p(G(a), 1 - this.decay));
3769
+ if (this.centered) {
3770
+ const h = this.accumulatedMeanGrads[r].variable, f = w(p(h, this.decay), p(a, 1 - this.decay)), m = D(p(a, this.learningRate), Z(Y(u, w(G(f), this.epsilon)))), y = w(p(l, this.momentum), m);
3771
+ c.assign(u), h.assign(f), l.assign(y);
3772
+ const d = Y(i, y);
3773
+ i.assign(d);
3774
+ } else {
3775
+ const h = w(p(c, this.decay), p(G(a), 1 - this.decay)), f = w(p(l, this.momentum), D(p(a, this.learningRate), Z(w(h, this.epsilon))));
3776
+ c.assign(h), l.assign(f);
3777
+ const m = Y(i, f);
3778
+ i.assign(m);
3779
+ }
3780
+ });
3781
+ }), this.incrementIterations();
3782
+ }
3783
+ dispose() {
3784
+ this.accumulatedMeanSquares != null && M(this.accumulatedMeanSquares.map((t) => t.variable)), this.accumulatedMeanGrads != null && this.centered && M(this.accumulatedMeanGrads.map((t) => t.variable)), this.accumulatedMoments != null && M(this.accumulatedMoments.map((t) => t.variable));
3785
+ }
3786
+ async getWeights() {
3787
+ const t = [...this.accumulatedMeanSquares, ...this.accumulatedMoments];
3788
+ return this.centered && t.push(...this.accumulatedMeanGrads), [await this.saveIterations()].concat(t.map((e) => ({ name: e.originalName, tensor: e.variable })));
3789
+ }
3790
+ async setWeights(t) {
3791
+ t = await this.extractIterations(t);
3792
+ const e = this.centered ? t.length / 3 : t.length / 2, s = !1;
3793
+ this.accumulatedMeanSquares = t.slice(0, e).map((r) => ({
3794
+ originalName: r.name,
3795
+ variable: r.tensor.variable(s)
3796
+ })), this.accumulatedMoments = t.slice(e, e * 2).map((r) => ({
3797
+ originalName: r.name,
3798
+ variable: r.tensor.variable(s)
3799
+ })), this.centered && (this.accumulatedMeanGrads = t.slice(e * 2, e * 3).map((r) => ({
3800
+ originalName: r.name,
3801
+ variable: r.tensor.variable(s)
3802
+ })));
3803
+ }
3804
+ getConfig() {
3805
+ return {
3806
+ learningRate: this.learningRate,
3807
+ decay: this.decay,
3808
+ momentum: this.momentum,
3809
+ epsilon: this.epsilon,
3810
+ centered: this.centered
3811
+ };
3812
+ }
3813
+ /** @nocollapse */
3814
+ static fromConfig(t, e) {
3815
+ return new t(e.learningRate, e.decay, e.momentum, e.epsilon, e.centered);
3816
+ }
3817
+ }
3818
+ /**
3819
+ * @license
3820
+ * Copyright 2022 Google LLC.
3821
+ * Licensed under the Apache License, Version 2.0 (the "License");
3822
+ * you may not use this file except in compliance with the License.
3823
+ * You may obtain a copy of the License at
3824
+ *
3825
+ * http://www.apache.org/licenses/LICENSE-2.0
3826
+ *
3827
+ * Unless required by applicable law or agreed to in writing, software
3828
+ * distributed under the License is distributed on an "AS IS" BASIS,
3829
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3830
+ * See the License for the specific language governing permissions and
3831
+ * limitations under the License.
3832
+ * =============================================================================
3833
+ */
3834
+ const hs = [
3835
+ is,
3836
+ os,
3837
+ as,
3838
+ ls,
3839
+ cs,
3840
+ us,
3841
+ be
3842
+ ];
3843
+ function fs() {
3844
+ for (const n of hs)
3845
+ rs(n);
3846
+ }
3847
+ /**
3848
+ * @license
3849
+ * Copyright 2017 Google LLC. All Rights Reserved.
3850
+ * Licensed under the Apache License, Version 2.0 (the "License");
3851
+ * you may not use this file except in compliance with the License.
3852
+ * You may obtain a copy of the License at
3853
+ *
3854
+ * http://www.apache.org/licenses/LICENSE-2.0
3855
+ *
3856
+ * Unless required by applicable law or agreed to in writing, software
3857
+ * distributed under the License is distributed on an "AS IS" BASIS,
3858
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3859
+ * See the License for the specific language governing permissions and
3860
+ * limitations under the License.
3861
+ * =============================================================================
3862
+ */
3863
+ fs();
3864
+ export {
3865
+ as as A,
3866
+ gs as B,
3867
+ ms as C,
3868
+ g as E,
3869
+ ys as I,
3870
+ bs as L,
3871
+ ws as N,
3872
+ Ss as P,
3873
+ Ts as R,
3874
+ Bs as S,
3875
+ vs as T,
3876
+ Fs as _,
3877
+ Y as a,
3878
+ ds as b,
3879
+ k as c,
3880
+ K as d,
3881
+ Rs as e,
3882
+ ps as f,
3883
+ At as g,
3884
+ As as h,
3885
+ Is as i,
3886
+ ks as j,
3887
+ Es as k,
3888
+ Ms as l,
3889
+ p as m,
3890
+ b as n,
3891
+ F as o,
3892
+ xs as p,
3893
+ w as q,
3894
+ U as r,
3895
+ j as s,
3896
+ E as t,
3897
+ Vn as u,
3898
+ Ns as v
3899
+ };