@genai-fi/nanogpt 0.10.1 → 0.10.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. package/dist/Generator.js +14 -14
  2. package/dist/{RealDiv-DgA3z9oO.js → RealDiv-zz7FpkKX.js} +17 -17
  3. package/dist/{Reshape-CF6odzV4.js → Reshape-CDVLyVfz.js} +3 -3
  4. package/dist/{Reshape-_kILl6tK.js → Reshape-CHdUjC72.js} +4 -4
  5. package/dist/TeachableLLM.js +8 -8
  6. package/dist/{axis_util-BvHEw88j.js → axis_util-BsIr9ZNu.js} +1 -1
  7. package/dist/backend.js +2 -2
  8. package/dist/{backend_util-D-rUb2ty.js → backend_util-B1XRLuq9.js} +31 -31
  9. package/dist/{backend_webgpu-B0u2ndUn.js → backend_webgpu-CqpfEImu.js} +5 -5
  10. package/dist/{broadcast_to-CwF7XIeu.js → broadcast_to-B0ChcDaz.js} +4 -4
  11. package/dist/checks/appendCache.js +2 -2
  12. package/dist/checks/attentionMask.js +3 -3
  13. package/dist/checks/gelu.js +2 -2
  14. package/dist/checks/matMulGelu.js +5 -5
  15. package/dist/checks/normRMS.js +4 -4
  16. package/dist/checks/normRMSGrad.js +3 -3
  17. package/dist/checks/packUnpack.js +2 -2
  18. package/dist/checks/qkv.js +3 -3
  19. package/dist/checks/rope.js +2 -2
  20. package/dist/{complex-CSlYz-2T.js → complex-BBiRlsVq.js} +3 -3
  21. package/dist/{concat-BHlIJeyT.js → concat-DmBLPVGC.js} +3 -3
  22. package/dist/{concat_util-DcJk7YHS.js → concat_util-iBYIyuQe.js} +1 -1
  23. package/dist/{dataset-0xP8GjwI.js → dataset-D2P7rHAw.js} +5 -5
  24. package/dist/{dropout-C1pM3f11.js → dropout-B1x1kYMa.js} +3 -3
  25. package/dist/{expand_dims-BPG4fwBP.js → expand_dims-ouvfxQ1n.js} +3 -3
  26. package/dist/{exports_initializers-xuidcwI4.js → exports_initializers-CZSUJoVE.js} +1 -1
  27. package/dist/{gather-DykLGqmW.js → gather-CH9sdacz.js} +2 -2
  28. package/dist/{gelu-CNLFZWea.js → gelu-Bmhopi0J.js} +2 -2
  29. package/dist/{gpgpu_math-DDVJCn6-.js → gpgpu_math-DsCcikas.js} +3 -3
  30. package/dist/{index-ZyQhjEPo.js → index-D6Q1lPZO.js} +55 -55
  31. package/dist/{index-CjOj7j-u.js → index-DRyE072i.js} +15 -15
  32. package/dist/{kernel_funcs_utils-Dg_-E44D.js → kernel_funcs_utils-CWfOAPGO.js} +9 -9
  33. package/dist/layers/BaseLayer.js +10 -10
  34. package/dist/layers/CausalSelfAttention.js +6 -6
  35. package/dist/layers/MLP.js +4 -4
  36. package/dist/layers/PositionEmbedding.js +5 -5
  37. package/dist/layers/RMSNorm.js +3 -3
  38. package/dist/layers/RoPECache.js +4 -4
  39. package/dist/layers/TiedEmbedding.js +6 -6
  40. package/dist/layers/TransformerBlock.js +1 -1
  41. package/dist/loader/loadTransformers.js +1 -1
  42. package/dist/loader/oldZipLoad.js +8 -8
  43. package/dist/{log_sum_exp-DWI-76TI.js → log_sum_exp-D3ftBNY5.js} +6 -6
  44. package/dist/main.js +8 -8
  45. package/dist/{matMul16--R5hOwDG.js → matMul16-fEAJ4smh.js} +4 -4
  46. package/dist/{mat_mul-DeAh4uTH.js → mat_mul-C59XWcJd.js} +2 -2
  47. package/dist/{mod-Gt1rMB4n.js → mod-DESSvHIU.js} +2 -2
  48. package/dist/models/NanoGPTV1.js +2 -2
  49. package/dist/models/model.js +8 -8
  50. package/dist/{mulmat_packed_gpu-BMFhLwta.js → mulmat_packed_gpu-Coh6qbJk.js} +1 -1
  51. package/dist/{ones-CAMiP4I2.js → ones-jU9jlQvM.js} +4 -4
  52. package/dist/ops/adamAdjust.js +1 -1
  53. package/dist/ops/adamMoments.js +1 -1
  54. package/dist/ops/add16.js +1 -1
  55. package/dist/ops/appendCache.js +3 -3
  56. package/dist/ops/attentionMask.js +1 -1
  57. package/dist/ops/concat16.js +2 -2
  58. package/dist/ops/cpu/adamAdjust.js +2 -2
  59. package/dist/ops/cpu/adamMoments.js +3 -3
  60. package/dist/ops/cpu/appendCache.js +3 -3
  61. package/dist/ops/cpu/attentionMask.js +6 -6
  62. package/dist/ops/cpu/fusedSoftmax.js +3 -3
  63. package/dist/ops/cpu/gatherSub.js +4 -4
  64. package/dist/ops/cpu/gelu.js +2 -2
  65. package/dist/ops/cpu/matMul16.js +3 -3
  66. package/dist/ops/cpu/matMulGelu.js +4 -4
  67. package/dist/ops/cpu/matMulMul.js +2 -2
  68. package/dist/ops/cpu/mulDropout.js +2 -2
  69. package/dist/ops/cpu/normRMS.js +2 -2
  70. package/dist/ops/cpu/qkv.js +4 -4
  71. package/dist/ops/cpu/rope.js +6 -6
  72. package/dist/ops/cpu/scatterSub.js +7 -7
  73. package/dist/ops/dot16.js +2 -2
  74. package/dist/ops/gatherSub.js +1 -1
  75. package/dist/ops/gelu.js +2 -2
  76. package/dist/ops/grads/add16.js +2 -2
  77. package/dist/ops/grads/attentionMask.js +3 -3
  78. package/dist/ops/grads/gelu.js +3 -3
  79. package/dist/ops/grads/matMul16.js +4 -4
  80. package/dist/ops/grads/matMulGelu.js +2 -2
  81. package/dist/ops/grads/normRMS.js +2 -2
  82. package/dist/ops/grads/pack16.js +4 -4
  83. package/dist/ops/grads/qkv.js +4 -4
  84. package/dist/ops/grads/rope.js +3 -3
  85. package/dist/ops/grads/softmax16.js +2 -2
  86. package/dist/ops/grads/unpack16.js +3 -3
  87. package/dist/ops/matMul16.js +3 -3
  88. package/dist/ops/matMulGelu.js +1 -1
  89. package/dist/ops/matMulMul.js +1 -1
  90. package/dist/ops/mul16.js +1 -1
  91. package/dist/ops/mulDrop.js +1 -1
  92. package/dist/ops/normRMS.js +1 -1
  93. package/dist/ops/pack16.js +2 -2
  94. package/dist/ops/qkv.js +1 -1
  95. package/dist/ops/reshape16.js +3 -3
  96. package/dist/ops/rope.js +5 -5
  97. package/dist/ops/scatterSub.js +1 -1
  98. package/dist/ops/slice16.js +2 -2
  99. package/dist/ops/softmax16.js +1 -1
  100. package/dist/ops/sub16.js +1 -1
  101. package/dist/ops/sum16.js +2 -2
  102. package/dist/ops/transpose16.js +4 -4
  103. package/dist/ops/unpack16.js +2 -2
  104. package/dist/ops/webgl/adamAdjust.js +3 -3
  105. package/dist/ops/webgl/adamMoments.js +2 -2
  106. package/dist/ops/webgl/appendCache.js +2 -2
  107. package/dist/ops/webgl/attentionMask.js +2 -2
  108. package/dist/ops/webgl/fusedSoftmax.js +6 -6
  109. package/dist/ops/webgl/gatherSub.js +2 -2
  110. package/dist/ops/webgl/gelu.js +3 -3
  111. package/dist/ops/webgl/log.js +4 -4
  112. package/dist/ops/webgl/matMul16.js +5 -5
  113. package/dist/ops/webgl/matMulGelu.js +6 -6
  114. package/dist/ops/webgl/matMulMul.js +2 -2
  115. package/dist/ops/webgl/mulDropout.js +2 -2
  116. package/dist/ops/webgl/normRMS.js +3 -3
  117. package/dist/ops/webgl/qkv.js +2 -2
  118. package/dist/ops/webgl/rope.js +2 -2
  119. package/dist/ops/webgl/scatterSub.js +2 -2
  120. package/dist/ops/webgpu/adamAdjust.js +5 -5
  121. package/dist/ops/webgpu/adamMoments.js +5 -5
  122. package/dist/ops/webgpu/add16.js +2 -2
  123. package/dist/ops/webgpu/appendCache.js +5 -5
  124. package/dist/ops/webgpu/attentionMask.js +4 -4
  125. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  126. package/dist/ops/webgpu/concat16.js +7 -7
  127. package/dist/ops/webgpu/gatherSub.js +5 -5
  128. package/dist/ops/webgpu/gelu.js +4 -4
  129. package/dist/ops/webgpu/matMul16.js +6 -6
  130. package/dist/ops/webgpu/matMul16_program.js +3 -3
  131. package/dist/ops/webgpu/mul16.js +2 -2
  132. package/dist/ops/webgpu/normRMS.js +4 -4
  133. package/dist/ops/webgpu/normRMSGrad.js +6 -6
  134. package/dist/ops/webgpu/pack16.js +2 -2
  135. package/dist/ops/webgpu/pack16_program.js +2 -2
  136. package/dist/ops/webgpu/qkv.js +4 -4
  137. package/dist/ops/webgpu/rope.js +5 -5
  138. package/dist/ops/webgpu/scatterSub.js +5 -5
  139. package/dist/ops/webgpu/slice16.js +6 -6
  140. package/dist/ops/webgpu/softmax16.js +4 -4
  141. package/dist/ops/webgpu/softmax16_program.js +2 -2
  142. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  143. package/dist/ops/webgpu/softmax16grad.js +2 -2
  144. package/dist/ops/webgpu/sub16.js +2 -2
  145. package/dist/ops/webgpu/sum16.js +5 -5
  146. package/dist/ops/webgpu/transpose16.js +3 -3
  147. package/dist/ops/webgpu/transpose16_program.js +2 -2
  148. package/dist/ops/webgpu/transpose16_shared_program.js +4 -4
  149. package/dist/ops/webgpu/unpack16.js +4 -4
  150. package/dist/ops/webgpu/utils/binary_op.js +4 -4
  151. package/dist/ops/webgpu/utils/reductions.js +5 -5
  152. package/dist/{ops-CNI3TwqM.js → ops-BFDtP6th.js} +24 -24
  153. package/dist/{pack16-CFUqumar.js → pack16-CmVZs6af.js} +3 -3
  154. package/dist/patches/PackedTensor.js +1 -1
  155. package/dist/patches/engine.js +7 -5
  156. package/dist/patches/tape.js +1 -1
  157. package/dist/patches/webgpu_backend.js +5 -5
  158. package/dist/patches/webgpu_base.js +1 -1
  159. package/dist/patches/webgpu_program.js +3 -3
  160. package/dist/{random_width-DY6Kk2Dl.js → random_width-BVV9HveY.js} +31 -31
  161. package/dist/{range-BMS52eQi.js → range-ZZZD60Fx.js} +2 -2
  162. package/dist/{reciprocal-CTmshQ9J.js → reciprocal-CrYlsAGD.js} +2 -2
  163. package/dist/{register_all_kernels-Bwu1PTuU.js → register_all_kernels-nvj2k7OC.js} +41 -41
  164. package/dist/{relu-yZ2-7WxU.js → relu-BYDneVPn.js} +2 -2
  165. package/dist/{reshape-DevtBWtf.js → reshape-CaPQzFvz.js} +2 -2
  166. package/dist/{rope-B5UUMsPi.js → rope-s4W2XO9B.js} +5 -5
  167. package/dist/{scatter_nd_util-5EL-8VAQ.js → scatter_nd_util-C7zXRT_h.js} +1 -1
  168. package/dist/{selu_util-D1w6yyTO.js → selu_util-BGPXmd4B.js} +16 -16
  169. package/dist/{shared-BRksrJb3.js → shared-CHhxz-O5.js} +1 -1
  170. package/dist/{shared-BuAXb4CI.js → shared-D2NP_CpY.js} +8 -8
  171. package/dist/{sin-BGfy2HZo.js → sin-Djs4aQiu.js} +2 -2
  172. package/dist/{slice-D_gkkqZK.js → slice-DvovR5wq.js} +2 -2
  173. package/dist/{slice_util-DtEldBfK.js → slice_util-DyjSAD0u.js} +1 -1
  174. package/dist/{softmax-ZHVebtR1.js → softmax-C9JQEtnO.js} +2 -2
  175. package/dist/{split-DrfihRpZ.js → split-DBck65sX.js} +2 -2
  176. package/dist/{squeeze-DZEpeblb.js → squeeze-C00Ipm_7.js} +3 -3
  177. package/dist/{stack-yOIAalTq.js → stack-ChnHwRpX.js} +3 -3
  178. package/dist/{sum-_fzj5ZTB.js → sum-ywRJj3Zr.js} +2 -2
  179. package/dist/{tensor-f35l8Odg.js → tensor-0r5yOo2R.js} +1 -1
  180. package/dist/{tensor-DdQUJZlz.js → tensor-CzmOBsdf.js} +21 -21
  181. package/dist/{tensor1d-CeZuc-Rv.js → tensor1d-BlUT89BP.js} +2 -2
  182. package/dist/{tensor2d-G4Ys2GxX.js → tensor2d-CSB4KOb0.js} +2 -2
  183. package/dist/{tensor4d-B8roDgtc.js → tensor4d-D7bLqGqz.js} +2 -2
  184. package/dist/{tensor_util-DV-FP5Q3.js → tensor_util-DfwaWayG.js} +12 -12
  185. package/dist/{tfjs_backend-kNyO5L2d.js → tfjs_backend-CNkSTL0c.js} +38 -38
  186. package/dist/{tile-BzyEiF-F.js → tile-CR074jmp.js} +3 -3
  187. package/dist/training/Adam.js +2 -2
  188. package/dist/training/AdamExt.js +1 -1
  189. package/dist/training/DatasetBuilder.js +2 -2
  190. package/dist/training/FullTrainer.js +1 -1
  191. package/dist/training/Trainer.js +2 -2
  192. package/dist/training/sparseCrossEntropy.js +3 -3
  193. package/dist/{transpose-DKELTqhe.js → transpose-DH4gmHvu.js} +4 -4
  194. package/dist/utilities/dummy.js +3 -3
  195. package/dist/utilities/multinomialCPU.js +2 -2
  196. package/dist/utilities/packed.js +338 -304
  197. package/dist/utilities/performance.js +1 -1
  198. package/dist/utilities/profile.js +1 -1
  199. package/dist/utilities/safetensors.js +2 -2
  200. package/dist/utilities/sentences.js +5 -5
  201. package/dist/utilities/weights.js +2 -2
  202. package/dist/{variable-Bhn5bHYv.js → variable-DzfrwYuP.js} +1 -1
  203. package/dist/{webgpu_program-Cigz-7RF.js → webgpu_program-DzaQiqel.js} +2 -2
  204. package/dist/{webgpu_util-BBCnKm2X.js → webgpu_util-0_ubCEHJ.js} +2 -2
  205. package/dist/{zeros-2gldETuK.js → zeros-DBFVbpv5.js} +3 -3
  206. package/package.json +1 -1
@@ -1,68 +1,102 @@
1
- import { w as T, o as R, p as x, K as j, I as M, q as V, s as $, t as O, v as L, x as _, A as U } from "../tensor_util-DV-FP5Q3.js";
2
- import { b as q, o as W, E as H, q as X, a as p, r as Y, t as Z, u as J, v as K, V as Q, w as N, T as B, x as P, f as tt, m as et, s as st, y as nt } from "../tensor-DdQUJZlz.js";
3
- import { PackableTensor as A, PackableVariable as rt } from "../patches/PackedTensor.js";
4
- function lt() {
1
+ import { g as $ } from "../index-D5v913EJ.js";
2
+ import { p as K } from "../index-xuotMAFm.js";
3
+ import { w, o as W, p as N, K as H, I as P, q as A, s as z, t as X, v as Y, x as Z, A as J } from "../tensor_util-DfwaWayG.js";
4
+ import { b as Q, E as L, a as p, o as ee, q as te, r as se, t as C, V as ne, u as G, T as B, v as R, m as re, s as ie, w as ae, f as oe, x as ce } from "../tensor-CzmOBsdf.js";
5
+ import { PackableTensor as j, PackableVariable as de } from "../patches/PackedTensor.js";
6
+ function be() {
5
7
  return y.backendName === "webgpu";
6
8
  }
7
- function E(o) {
8
- return o.packed !== void 0;
9
+ function x(i) {
10
+ return i.packed !== void 0;
9
11
  }
10
- function w(o) {
11
- return E(o) && o.packed;
12
+ function T(i) {
13
+ return x(i) && i.packed;
12
14
  }
13
- function z(o) {
14
- if (E(o)) {
15
- if (o.dtype !== "int32")
15
+ function O(i) {
16
+ if (x(i)) {
17
+ if (i.dtype !== "int32")
16
18
  throw new Error("packTensor: only int32 tensors can be packed.");
17
- return o.packed = !0, o;
19
+ return i.packed = !0, i;
18
20
  } else
19
- throw console.error("Tensor:", o), new Error("Tensor is not packable");
21
+ throw console.error("Tensor:", i), new Error("Tensor is not packable");
20
22
  }
21
- function ft(o) {
22
- if (E(o)) {
23
- if (o.dtype !== "float32")
23
+ function we(i) {
24
+ if (x(i)) {
25
+ if (i.dtype !== "float32")
24
26
  throw new Error("unpackTensor: only float32 tensors can be unpacked.");
25
- o.packed = !1;
27
+ i.packed = !1;
26
28
  }
27
- return o;
29
+ return i;
28
30
  }
29
- function it(o, t, e, s) {
30
- for (let n = t.length - 1; n >= 0; n--) {
31
- const r = t[n], c = [];
32
- if (r.outputs.forEach((a) => {
33
- const i = o[a.id];
34
- i != null ? c.push(i) : c.push(null);
31
+ function he(i, e, t, s) {
32
+ for (let n = e.length - 1; n >= 0; n--) {
33
+ const r = e[n], c = [];
34
+ if (r.outputs.forEach((o) => {
35
+ const a = i[o.id];
36
+ a != null ? c.push(a) : c.push(null);
35
37
  }), r.gradient == null)
36
38
  throw new Error(`Cannot compute gradient: gradient function not found for ${r.kernelName}.`);
37
- const h = r.gradient(c);
38
- for (const a in r.inputs) {
39
- if (!(a in h))
39
+ const d = r.gradient(c);
40
+ for (const o in r.inputs) {
41
+ if (!(o in d))
40
42
  throw new Error(
41
- `Cannot backprop through input ${a}. Available gradients found: ${Object.keys(h)}.`
43
+ `Cannot backprop through input ${o}. Available gradients found: ${Object.keys(d)}.`
42
44
  );
43
- const i = e(() => h[a]()), d = w(i);
44
- if (i.dtype !== "float32" && (!d || i.dtype !== "int32"))
45
+ const a = t(() => d[o]()), h = T(a);
46
+ if (a.dtype !== "float32" && (!h || a.dtype !== "int32"))
45
47
  throw new Error(
46
- `Error in gradient for op ${r.kernelName}. The gradient of input ${a} must have 'float32' dtype, but has '${i.dtype}'`
48
+ `Error in gradient for op ${r.kernelName}. The gradient of input ${o} must have 'float32' dtype, but has '${a.dtype}'`
47
49
  );
48
- const l = r.inputs[a];
49
- if (!q(i.shape, l.shape))
50
+ const u = r.inputs[o];
51
+ if (!Q(a.shape, u.shape))
50
52
  throw new Error(
51
- `Error in gradient for op ${r.kernelName}. The gradient of input '${a}' has shape '${i.shape}', which does not match the shape of the input '${l.shape}'`
53
+ `Error in gradient for op ${r.kernelName}. The gradient of input '${o}' has shape '${a.shape}', which does not match the shape of the input '${u.shape}'`
52
54
  );
53
- if (o[l.id] == null)
54
- o[l.id] = i;
55
+ if (i[u.id] == null)
56
+ i[u.id] = a;
55
57
  else {
56
- const u = o[l.id];
57
- o[l.id] = s(u, i), u.dispose();
58
+ const l = i[u.id];
59
+ i[u.id] = s(l, a), l.dispose();
58
60
  }
59
61
  }
60
62
  }
61
63
  }
62
- function S(o) {
63
- return o.kernelName != null;
64
+ let S;
65
+ function U() {
66
+ if (S == null) {
67
+ let i;
68
+ if (typeof window < "u")
69
+ i = window;
70
+ else if (typeof $ < "u")
71
+ i = $;
72
+ else if (typeof K < "u")
73
+ i = K;
74
+ else if (typeof self < "u")
75
+ i = self;
76
+ else
77
+ throw new Error("Could not find a global object");
78
+ S = i;
79
+ }
80
+ return S;
64
81
  }
65
- class C {
82
+ const E = {
83
+ engine: null
84
+ }, D = U();
85
+ if (D._tfengine)
86
+ throw new Error("TensorFlow engine already initialized before patching.");
87
+ Object.defineProperty(D, "_tfengine", {
88
+ get: () => {
89
+ if (E.engine == null) {
90
+ const i = new L(D);
91
+ E.engine = new I(i);
92
+ }
93
+ return E.engine;
94
+ }
95
+ });
96
+ function F(i) {
97
+ return i.kernelName != null;
98
+ }
99
+ class _ {
66
100
  // Public since optimizers will use it.
67
101
  registeredVariables = {};
68
102
  nextTapeNodeId = 0;
@@ -96,17 +130,17 @@ class C {
96
130
  kernels: [],
97
131
  result: null,
98
132
  get kernelNames() {
99
- return Array.from(new Set(this.kernels.map((t) => t.name)));
133
+ return Array.from(new Set(this.kernels.map((e) => e.name)));
100
134
  }
101
135
  };
102
136
  dispose() {
103
- for (const t in this.registeredVariables)
104
- this.registeredVariables[t].dispose();
137
+ for (const e in this.registeredVariables)
138
+ this.registeredVariables[e].dispose();
105
139
  }
106
140
  }
107
- class v {
108
- constructor(t) {
109
- this.ENV = t, this.state = new C(), console.log("GenAI Patched Engine Initialized");
141
+ class I {
142
+ constructor(e) {
143
+ this.ENV = e, this.state = new _(), console.log("GenAI Patched Engine Initialized");
110
144
  }
111
145
  version = "GENAI_PATCHED_ENGINE";
112
146
  state;
@@ -123,9 +157,9 @@ class v {
123
157
  });
124
158
  if (this.backendInstance != null)
125
159
  return;
126
- const t = this.getSortedBackends();
127
- for (let e = 0; e < t.length; e++) {
128
- const s = t[e];
160
+ const e = this.getSortedBackends();
161
+ for (let t = 0; t < e.length; t++) {
162
+ const s = e[t];
129
163
  if (await this.initializeBackend(s).success) {
130
164
  await this.setBackend(s);
131
165
  return;
@@ -139,53 +173,53 @@ class v {
139
173
  `Backend '${this.backendName}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
140
174
  );
141
175
  if (this.backendInstance == null) {
142
- const { name: t, asyncInit: e } = this.initializeBackendsAndReturnBest();
143
- if (e)
176
+ const { name: e, asyncInit: t } = this.initializeBackendsAndReturnBest();
177
+ if (t)
144
178
  throw new Error(
145
- `The highest priority backend '${t}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
179
+ `The highest priority backend '${e}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
146
180
  );
147
- this.setBackend(t);
181
+ this.setBackend(e);
148
182
  }
149
183
  return this.backendInstance;
150
184
  }
151
185
  backendNames() {
152
186
  return Object.keys(this.registryFactory);
153
187
  }
154
- findBackend(t) {
155
- if (!(t in this.registry))
156
- if (t in this.registryFactory) {
157
- const { asyncInit: e } = this.initializeBackend(t);
158
- if (e)
188
+ findBackend(e) {
189
+ if (!(e in this.registry))
190
+ if (e in this.registryFactory) {
191
+ const { asyncInit: t } = this.initializeBackend(e);
192
+ if (t)
159
193
  return null;
160
194
  } else
161
195
  return null;
162
- return this.registry[t];
196
+ return this.registry[e];
163
197
  }
164
- findBackendFactory(t) {
165
- return t in this.registryFactory ? this.registryFactory[t].factory : null;
198
+ findBackendFactory(e) {
199
+ return e in this.registryFactory ? this.registryFactory[e].factory : null;
166
200
  }
167
- registerBackend(t, e, s = 1) {
168
- return t in this.registryFactory ? (T(`${t} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[t] = { factory: e, priority: s }, console.log("Registered backend", t), !0);
201
+ registerBackend(e, t, s = 1) {
202
+ return e in this.registryFactory ? (w(`${e} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[e] = { factory: t, priority: s }, console.log("Registered backend", e), !0);
169
203
  }
170
- async setBackend(t) {
171
- if (this.registryFactory[t] == null)
172
- throw new Error(`Backend name '${t}' not found in registry`);
173
- if (this.backendName = t, this.registry[t] == null) {
204
+ async setBackend(e) {
205
+ if (this.registryFactory[e] == null)
206
+ throw new Error(`Backend name '${e}' not found in registry`);
207
+ if (this.backendName = e, this.registry[e] == null) {
174
208
  this.backendInstance = null;
175
- const { success: e, asyncInit: s } = this.initializeBackend(t);
176
- if (!(s ? await e : e))
209
+ const { success: t, asyncInit: s } = this.initializeBackend(e);
210
+ if (!(s ? await t : t))
177
211
  return !1;
178
212
  }
179
- return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new R(this.backendInstance), !0;
213
+ return this.backendInstance = this.registry[e], this.setupRegisteredKernels(), this.profiler = new W(this.backendInstance), !0;
180
214
  }
181
215
  setupRegisteredKernels() {
182
- x(this.backendName).forEach((e) => {
183
- e.setupFunc != null && e.setupFunc(this.backendInstance);
216
+ N(this.backendName).forEach((t) => {
217
+ t.setupFunc != null && t.setupFunc(this.backendInstance);
184
218
  });
185
219
  }
186
- disposeRegisteredKernels(t) {
187
- x(t).forEach((s) => {
188
- s.disposeFunc != null && s.disposeFunc(this.registry[t]);
220
+ disposeRegisteredKernels(e) {
221
+ N(e).forEach((s) => {
222
+ s.disposeFunc != null && s.disposeFunc(this.registry[e]);
189
223
  });
190
224
  }
191
225
  /**
@@ -194,82 +228,82 @@ class v {
194
228
  * whether the initialization of the backend succeeded. Throws an error if
195
229
  * there is no backend in the factory registry.
196
230
  */
197
- initializeBackend(t) {
198
- const e = this.registryFactory[t];
199
- if (e == null)
200
- throw new Error(`Cannot initialize backend ${t}, no registration found.`);
231
+ initializeBackend(e) {
232
+ const t = this.registryFactory[e];
233
+ if (t == null)
234
+ throw new Error(`Cannot initialize backend ${e}, no registration found.`);
201
235
  try {
202
- const s = e.factory();
203
- if (s && !(s instanceof j) && typeof s.then == "function") {
204
- const n = ++this.pendingBackendInitId, r = s.then((c) => n < this.pendingBackendInitId ? !1 : (this.registry[t] = c, this.pendingBackendInit = null, !0)).catch((c) => (n < this.pendingBackendInitId || (this.pendingBackendInit = null, T(`Initialization of backend ${t} failed`), T(c.stack || c.message)), !1));
236
+ const s = t.factory();
237
+ if (s && !(s instanceof H) && typeof s.then == "function") {
238
+ const n = ++this.pendingBackendInitId, r = s.then((c) => n < this.pendingBackendInitId ? !1 : (this.registry[e] = c, this.pendingBackendInit = null, !0)).catch((c) => (n < this.pendingBackendInitId || (this.pendingBackendInit = null, w(`Initialization of backend ${e} failed`), w(c.stack || c.message)), !1));
205
239
  return this.pendingBackendInit = r, { success: r, asyncInit: !0 };
206
240
  } else
207
- return this.registry[t] = s, { success: !0, asyncInit: !1 };
241
+ return this.registry[e] = s, { success: !0, asyncInit: !1 };
208
242
  } catch (s) {
209
- return T(`Initialization of backend ${t} failed`), T(s.stack || s.message), { success: !1, asyncInit: !1 };
243
+ return w(`Initialization of backend ${e} failed`), w(s.stack || s.message), { success: !1, asyncInit: !1 };
210
244
  }
211
245
  }
212
- removeBackend(t) {
213
- if (!(t in this.registryFactory))
214
- throw new Error(`${t} backend not found in registry`);
215
- this.backendName === t && this.pendingBackendInit != null && this.pendingBackendInitId++, t in this.registry && (this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t]), delete this.registryFactory[t], this.backendName === t && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null);
246
+ removeBackend(e) {
247
+ if (!(e in this.registryFactory))
248
+ throw new Error(`${e} backend not found in registry`);
249
+ this.backendName === e && this.pendingBackendInit != null && this.pendingBackendInitId++, e in this.registry && (this.disposeRegisteredKernels(e), this.registry[e].dispose(), delete this.registry[e]), delete this.registryFactory[e], this.backendName === e && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null);
216
250
  }
217
251
  getSortedBackends() {
218
252
  if (Object.keys(this.registryFactory).length === 0)
219
253
  throw new Error("No backend found in registry.");
220
- return Object.keys(this.registryFactory).sort((t, e) => this.registryFactory[e].priority - this.registryFactory[t].priority);
254
+ return Object.keys(this.registryFactory).sort((e, t) => this.registryFactory[t].priority - this.registryFactory[e].priority);
221
255
  }
222
256
  initializeBackendsAndReturnBest() {
223
- const t = this.getSortedBackends();
224
- for (let e = 0; e < t.length; e++) {
225
- const s = t[e], { success: n, asyncInit: r } = this.initializeBackend(s);
257
+ const e = this.getSortedBackends();
258
+ for (let t = 0; t < e.length; t++) {
259
+ const s = e[t], { success: n, asyncInit: r } = this.initializeBackend(s);
226
260
  if (r || n)
227
261
  return { name: s, asyncInit: r };
228
262
  }
229
263
  throw new Error("Could not initialize any backends, all backend initializations failed.");
230
264
  }
231
- moveData(t, e) {
232
- const s = this.state.tensorInfo.get(e);
233
- s || console.warn("Tried to move data that does not exist", this.state, e);
234
- const n = s.backend, r = this.readSync(e), c = n.refCount(e);
235
- n.disposeData(e, !0), s.backend = t, t.move(e, r, s.shape, s.dtype, c), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
265
+ moveData(e, t) {
266
+ const s = this.state.tensorInfo.get(t);
267
+ s || console.warn("Tried to move data that does not exist", this.state, t);
268
+ const n = s.backend, r = this.readSync(t), c = n.refCount(t);
269
+ n.disposeData(t, !0), s.backend = e, e.move(t, r, s.shape, s.dtype, c), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
236
270
  }
237
- tidy(t, e) {
271
+ tidy(e, t) {
238
272
  let s = null;
239
- if (e == null) {
240
- if (typeof t != "function")
273
+ if (t == null) {
274
+ if (typeof e != "function")
241
275
  throw new Error("Please provide a function to tidy()");
242
- e = t;
276
+ t = e;
243
277
  } else {
244
- if (typeof t != "string" && !(t instanceof String))
278
+ if (typeof e != "string" && !(e instanceof String))
245
279
  throw new Error("When calling with two arguments, the first argument to tidy() must be a string");
246
- if (typeof e != "function")
280
+ if (typeof t != "function")
247
281
  throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function");
248
- s = t;
282
+ s = e;
249
283
  }
250
284
  let n;
251
285
  return this.scopedRun(
252
286
  () => this.startScope(s),
253
287
  () => this.endScope(n),
254
- () => (n = e(), n instanceof Promise && console.error("Cannot return a Promise inside of tidy."), n)
288
+ () => (n = t(), n instanceof Promise && console.error("Cannot return a Promise inside of tidy."), n)
255
289
  );
256
290
  }
257
- scopedRun(t, e, s) {
258
- t();
291
+ scopedRun(e, t, s) {
292
+ e();
259
293
  try {
260
294
  const n = s();
261
- return e(), n;
295
+ return t(), n;
262
296
  } catch (n) {
263
- throw e(), n;
297
+ throw t(), n;
264
298
  }
265
299
  }
266
300
  static nextTensorId = 0;
267
301
  nextTensorId() {
268
- return v.nextTensorId++;
302
+ return I.nextTensorId++;
269
303
  }
270
304
  static nextVariableId = 0;
271
305
  nextVariableId() {
272
- return v.nextVariableId++;
306
+ return I.nextVariableId++;
273
307
  }
274
308
  /**
275
309
  * This method is called instead of the public-facing tensor.clone() when
@@ -277,16 +311,16 @@ class v {
277
311
  * operation to the tape regardless of being called inside a kernel
278
312
  * execution.
279
313
  */
280
- clone(t) {
281
- const s = w(t) ? z(y.runKernel(M, { x: t })) : y.runKernel(M, { x: t }), n = { x: t }, r = (h) => ({
314
+ clone(e) {
315
+ const s = T(e) ? O(y.runKernel(P, { x: e })) : y.runKernel(P, { x: e }), n = { x: e }, r = (d) => ({
282
316
  x: () => {
283
- const a = "float32", i = { x: h }, d = { dtype: a }, l = w(t), u = y.runKernel(
284
- _,
285
- i,
317
+ const o = "float32", a = { x: d }, h = { dtype: o }, u = T(e), l = y.runKernel(
318
+ Z,
319
+ a,
286
320
  // tslint:disable-next-line: no-unnecessary-type-assertion
287
- d
321
+ h
288
322
  );
289
- return l && z(u), u;
323
+ return u && O(l), l;
290
324
  }
291
325
  }), c = [];
292
326
  return this.addTapeNode(this.state.activeScope.name, n, [s], r, c, {}), s;
@@ -304,24 +338,24 @@ class v {
304
338
  * for the backprop computation. These are booleans since the output
305
339
  * tensors are not visible to the user.
306
340
  */
307
- runKernel(t, e, s) {
308
- if (this.backendName == null && this.backend, !(V(t, this.backendName) != null))
309
- throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);
310
- return this.runKernelFunc({ kernelName: t, inputs: e, attrs: s });
341
+ runKernel(e, t, s) {
342
+ if (this.backendName == null && this.backend, !(A(e, this.backendName) != null))
343
+ throw new Error(`Kernel '${e}' not registered for backend '${this.backendName}'`);
344
+ return this.runKernelFunc({ kernelName: e, inputs: t, attrs: s });
311
345
  }
312
346
  shouldCheckForMemLeaks() {
313
347
  return this.ENV.getBool("IS_TEST");
314
348
  }
315
- checkKernelForMemLeak(t, e, s) {
349
+ checkKernelForMemLeak(e, t, s) {
316
350
  const n = this.backend.numDataIds();
317
351
  let r = 0;
318
- s.forEach((a) => {
319
- r += a.dtype === "complex64" ? 3 : 1;
352
+ s.forEach((o) => {
353
+ r += o.dtype === "complex64" ? 3 : 1;
320
354
  });
321
- const c = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], h = n - e - r - c;
322
- if (h > 0)
355
+ const c = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], d = n - t - r - c;
356
+ if (d > 0)
323
357
  throw new Error(
324
- `Backend '${this.backendName}' has an internal memory leak (${h} data ids) after running '${t}'`
358
+ `Backend '${this.backendName}' has an internal memory leak (${d} data ids) after running '${e}'`
325
359
  );
326
360
  }
327
361
  /**
@@ -329,81 +363,81 @@ class v {
329
363
  *
330
364
  * Use `runKernel` to execute kernels from outside of engine.
331
365
  */
332
- runKernelFunc(t) {
333
- let e, s = [];
366
+ runKernelFunc(e) {
367
+ let t, s = [];
334
368
  const n = this.isTapeOn(), r = this.state.numBytes, c = this.state.numTensors;
335
369
  this.shouldCheckForMemLeaks() && this.state.numDataMovesStack.push(0);
336
- let h;
370
+ let d;
337
371
  this.backendName == null && this.backend;
338
- let a;
339
- const i = S(t) ? t.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
340
- if (S(t)) {
341
- const { kernelName: f, inputs: I, attrs: m } = t;
372
+ let o;
373
+ const a = F(e) ? e.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
374
+ if (F(e)) {
375
+ const { kernelName: f, inputs: v, attrs: m } = e;
342
376
  this.backendName == null && this.backend;
343
- const g = V(f, this.backendName);
377
+ const k = A(f, this.backendName);
344
378
  p(
345
- g != null,
379
+ k != null,
346
380
  () => `Cannot find registered kernel '${f}' for backend '${this.backendName}'`
347
- ), h = () => {
348
- const G = this.backend.numDataIds();
349
- a = g.kernelFunc({ inputs: I, attrs: m, backend: this.backend });
350
- const F = Array.isArray(a) ? a : [a];
351
- this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(f, G, F);
352
- const D = F.map((b) => b.rank != null ? b : this.makeTensorFromTensorInfo(b));
381
+ ), d = () => {
382
+ const q = this.backend.numDataIds();
383
+ o = k.kernelFunc({ inputs: v, attrs: m, backend: this.backend });
384
+ const M = Array.isArray(o) ? o : [o];
385
+ this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(f, q, M);
386
+ const V = M.map((b) => b.rank != null ? b : this.makeTensorFromTensorInfo(b));
353
387
  if (n) {
354
- const b = this.getTensorsForGradient(f, I, D);
388
+ const b = this.getTensorsForGradient(f, v, V);
355
389
  s = this.saveTensorsForBackwardMode(b ?? []);
356
390
  }
357
- return D;
391
+ return V;
358
392
  };
359
393
  } else {
360
- const { forwardFunc: f } = t, I = (m) => {
361
- n && (s = m.map((g) => this.keep(this.clone(g))));
394
+ const { forwardFunc: f } = e, v = (m) => {
395
+ n && (s = m.map((k) => this.keep(this.clone(k))));
362
396
  };
363
- h = () => {
397
+ d = () => {
364
398
  const m = this.backend.numDataIds();
365
- a = this.tidy(() => f(this.backend, I));
366
- const g = Array.isArray(a) ? a : [a];
367
- return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(i, m, g), g;
399
+ o = this.tidy(() => f(this.backend, v));
400
+ const k = Array.isArray(o) ? o : [o];
401
+ return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(a, m, k), k;
368
402
  };
369
403
  }
370
- const { inputs: d, attrs: l } = t, u = S(t) ? null : t.backwardsFunc;
371
- let k;
404
+ const { inputs: h, attrs: u } = e, l = F(e) ? null : e.backwardsFunc;
405
+ let g;
372
406
  return this.scopedRun(
373
407
  // Stop recording to a tape when running a kernel.
374
408
  () => this.state.kernelDepth++,
375
409
  () => this.state.kernelDepth--,
376
410
  () => {
377
- !this.ENV.getBool("DEBUG") && !this.state.profiling ? e = h() : (k = this.profiler.profileKernel(i, d, () => h()), this.ENV.getBool("DEBUG") && this.profiler.logKernelProfile(k), e = k.outputs);
411
+ !this.ENV.getBool("DEBUG") && !this.state.profiling ? t = d() : (g = this.profiler.profileKernel(a, h, () => d()), this.ENV.getBool("DEBUG") && this.profiler.logKernelProfile(g), t = g.outputs);
378
412
  }
379
413
  ), n && this.addTapeNode(
380
- i,
381
- d,
382
- e,
383
- u,
414
+ a,
415
+ h,
416
+ t,
417
+ l,
384
418
  s,
385
- l ?? {}
419
+ u ?? {}
386
420
  ), this.state.profiling && this.state.activeProfile.kernels.push({
387
- name: i,
421
+ name: a,
388
422
  bytesAdded: this.state.numBytes - r,
389
423
  totalBytesSnapshot: this.state.numBytes,
390
424
  tensorsAdded: this.state.numTensors - c,
391
425
  totalTensorsSnapshot: this.state.numTensors,
392
- inputShapes: Object.keys(d).map(
393
- (f) => d[f] != null ? d[f].shape : null
426
+ inputShapes: Object.keys(h).map(
427
+ (f) => h[f] != null ? h[f].shape : null
394
428
  ),
395
- outputShapes: e.map((f) => f.shape),
396
- kernelTimeMs: k.timeMs,
397
- extraInfo: k.extraInfo
398
- }), Array.isArray(a) ? e : e[0];
429
+ outputShapes: t.map((f) => f.shape),
430
+ kernelTimeMs: g.timeMs,
431
+ extraInfo: g.extraInfo
432
+ }), Array.isArray(o) ? t : t[0];
399
433
  }
400
434
  /**
401
435
  * Saves tensors used in forward mode for use in backward mode.
402
436
  *
403
437
  * @param tensors the list of tensors to save.
404
438
  */
405
- saveTensorsForBackwardMode(t) {
406
- return t.map((s) => this.keep(this.clone(s)));
439
+ saveTensorsForBackwardMode(e) {
440
+ return e.map((s) => this.keep(this.clone(s)));
407
441
  }
408
442
  /**
409
443
  * Returns a list of tensors to save for a given gradient calculation.
@@ -412,14 +446,14 @@ class v {
412
446
  * @param inputs a map of input tensors.
413
447
  * @param outputs an array of output tensors from forward mode of kernel.
414
448
  */
415
- getTensorsForGradient(t, e, s) {
416
- const n = $(t);
449
+ getTensorsForGradient(e, t, s) {
450
+ const n = z(e);
417
451
  if (n != null) {
418
452
  const r = n.inputsToSave || [], c = n.outputsToSave || [];
419
- let h;
420
- n.saveAllInputs ? (p(Array.isArray(e), () => "saveAllInputs is true, expected inputs to be an array."), h = Object.keys(e).map((i) => e[i])) : h = r.map((i) => e[i]);
421
- const a = s.filter((i, d) => c[d]);
422
- return h.concat(a);
453
+ let d;
454
+ n.saveAllInputs ? (p(Array.isArray(t), () => "saveAllInputs is true, expected inputs to be an array."), d = Object.keys(t).map((a) => t[a])) : d = r.map((a) => t[a]);
455
+ const o = s.filter((a, h) => c[h]);
456
+ return d.concat(o);
423
457
  }
424
458
  return [];
425
459
  }
@@ -428,18 +462,18 @@ class v {
428
462
  * tensor with the provided shape, dtype and values. It always
429
463
  * creates a new data id and writes the values to the underlying backend.
430
464
  */
431
- makeTensor(t, e, s, n) {
432
- if (t == null)
465
+ makeTensor(e, t, s, n) {
466
+ if (e == null)
433
467
  throw new Error("Values passed to engine.makeTensor() are null");
434
468
  s = s || "float32", n = n || this.backend;
435
- let r = t;
436
- s === "string" && Y(t[0]) && (r = t.map((a) => Z(a)));
437
- const c = n.write(r, e, s), h = new A(e, s, c, this.nextTensorId());
438
- if (this.trackTensor(h, n), s === "string") {
439
- const a = this.state.tensorInfo.get(c), i = J(r);
440
- this.state.numBytes += i - a.bytes, a.bytes = i;
469
+ let r = e;
470
+ s === "string" && ee(e[0]) && (r = e.map((o) => te(o)));
471
+ const c = n.write(r, t, s), d = new j(t, s, c, this.nextTensorId());
472
+ if (this.trackTensor(d, n), s === "string") {
473
+ const o = this.state.tensorInfo.get(c), a = se(r);
474
+ this.state.numBytes += a - o.bytes, o.bytes = a;
441
475
  }
442
- return h;
476
+ return d;
443
477
  }
444
478
  /**
445
479
  * Internal method used by backends. Makes a new tensor
@@ -447,9 +481,9 @@ class v {
447
481
  * a new data id, only increments the ref count used in memory tracking.
448
482
  * @deprecated
449
483
  */
450
- makeTensorFromDataId(t, e, s, n) {
484
+ makeTensorFromDataId(e, t, s, n) {
451
485
  s = s || "float32";
452
- const r = { dataId: t, shape: e, dtype: s };
486
+ const r = { dataId: e, shape: t, dtype: s };
453
487
  return this.makeTensorFromTensorInfo(r, n);
454
488
  }
455
489
  /**
@@ -457,69 +491,69 @@ class v {
457
491
  * around an existing data id in TensorInfo. It doesn't create a new data id,
458
492
  * only increments the ref count used in memory tracking.
459
493
  */
460
- makeTensorFromTensorInfo(t, e) {
461
- const { dataId: s, shape: n, dtype: r } = t, c = new A(n, r, s, this.nextTensorId());
462
- if (c.packed = t.packed || !1, c.packed && r !== "int32")
494
+ makeTensorFromTensorInfo(e, t) {
495
+ const { dataId: s, shape: n, dtype: r } = e, c = new j(n, r, s, this.nextTensorId());
496
+ if (c.packed = e.packed || !1, c.packed && r !== "int32")
463
497
  throw new Error("Only int32 tensors can be packed.");
464
- return this.trackTensor(c, e ?? this.backend), c;
498
+ return this.trackTensor(c, t ?? this.backend), c;
465
499
  }
466
- makeVariable(t, e = !0, s, n) {
467
- s = s || this.nextVariableId().toString(), n != null && n !== t.dtype && (t = t.cast(n));
468
- const r = new rt(t, e, s, this.nextTensorId());
500
+ makeVariable(e, t = !0, s, n) {
501
+ s = s || this.nextVariableId().toString(), n != null && n !== e.dtype && (e = e.cast(n));
502
+ const r = new de(e, t, s, this.nextTensorId());
469
503
  if (this.state.registeredVariables[r.name] != null)
470
504
  throw new Error(`Variable with name ${r.name} was already registered`);
471
505
  return this.state.registeredVariables[r.name] = r, this.incRef(r, this.backend), r;
472
506
  }
473
- trackTensor(t, e) {
474
- this.state.numTensors++, t.dtype === "string" && this.state.numStringTensors++;
507
+ trackTensor(e, t) {
508
+ this.state.numTensors++, e.dtype === "string" && this.state.numStringTensors++;
475
509
  let s = 0;
476
- t.dtype !== "complex64" && t.dtype !== "string" && (s = t.size * K(t.dtype)), this.state.numBytes += s, this.state.tensorInfo.has(t.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(t.dataId, {
477
- backend: e || this.backend,
478
- dtype: t.dtype,
479
- shape: t.shape,
510
+ e.dtype !== "complex64" && e.dtype !== "string" && (s = e.size * C(e.dtype)), this.state.numBytes += s, this.state.tensorInfo.has(e.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(e.dataId, {
511
+ backend: t || this.backend,
512
+ dtype: e.dtype,
513
+ shape: e.shape,
480
514
  bytes: s
481
- })), t instanceof Q || this.track(t);
515
+ })), e instanceof ne || this.track(e);
482
516
  }
483
517
  // Track the tensor by dataId and increase the refCount for the dataId in the
484
518
  // backend.
485
519
  // TODO(pyu10055): This is currently used by makeVariable method, to increase
486
520
  // refCount on the backend for the dataId. It can potentially be replaced with
487
521
  // Identity op indead of calling backend directly.
488
- incRef(t, e) {
489
- this.trackTensor(t, e), this.backend.incRef(t.dataId);
522
+ incRef(e, t) {
523
+ this.trackTensor(e, t), this.backend.incRef(e.dataId);
490
524
  }
491
- removeDataId(t, e) {
492
- this.state.tensorInfo.has(t) && this.state.tensorInfo.get(t).backend === e && (this.state.tensorInfo.delete(t), this.state.numDataBuffers--);
525
+ removeDataId(e, t) {
526
+ this.state.tensorInfo.has(e) && this.state.tensorInfo.get(e).backend === t && (this.state.tensorInfo.delete(e), this.state.numDataBuffers--);
493
527
  }
494
- disposeTensor(t) {
495
- if (!this.state.tensorInfo.has(t.dataId))
528
+ disposeTensor(e) {
529
+ if (!this.state.tensorInfo.has(e.dataId))
496
530
  return;
497
- const e = this.state.tensorInfo.get(t.dataId);
498
- if (this.state.numTensors--, t.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= e.bytes), t.dtype !== "complex64" && t.dtype !== "string") {
499
- const s = t.size * K(t.dtype);
531
+ const t = this.state.tensorInfo.get(e.dataId);
532
+ if (this.state.numTensors--, e.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= t.bytes), e.dtype !== "complex64" && e.dtype !== "string") {
533
+ const s = e.size * C(e.dtype);
500
534
  this.state.numBytes -= s;
501
535
  }
502
- e.backend.disposeData(t.dataId) && this.removeDataId(t.dataId, e.backend);
536
+ t.backend.disposeData(e.dataId) && this.removeDataId(e.dataId, t.backend);
503
537
  }
504
538
  disposeVariables() {
505
- for (const t in this.state.registeredVariables) {
506
- const e = this.state.registeredVariables[t];
507
- this.disposeVariable(e);
539
+ for (const e in this.state.registeredVariables) {
540
+ const t = this.state.registeredVariables[e];
541
+ this.disposeVariable(t);
508
542
  }
509
543
  }
510
- disposeVariable(t) {
511
- this.disposeTensor(t), this.state.registeredVariables[t.name] != null && delete this.state.registeredVariables[t.name];
544
+ disposeVariable(e) {
545
+ this.disposeTensor(e), this.state.registeredVariables[e.name] != null && delete this.state.registeredVariables[e.name];
512
546
  }
513
547
  memory() {
514
- const t = this.backend.memory();
515
- return t.numTensors = this.state.numTensors, t.numDataBuffers = this.state.numDataBuffers, t.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (t.unreliable = !0, t.reasons == null && (t.reasons = []), t.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), t;
548
+ const e = this.backend.memory();
549
+ return e.numTensors = this.state.numTensors, e.numDataBuffers = this.state.numDataBuffers, e.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (e.unreliable = !0, e.reasons == null && (e.reasons = []), e.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), e;
516
550
  }
517
- async profile(t) {
551
+ async profile(e) {
518
552
  this.state.profiling = !0;
519
- const e = this.state.numBytes, s = this.state.numTensors;
520
- this.state.activeProfile.kernels = [], this.state.activeProfile.result = await t(), this.state.profiling = !1, this.state.activeProfile.peakBytes = Math.max(
553
+ const t = this.state.numBytes, s = this.state.numTensors;
554
+ this.state.activeProfile.kernels = [], this.state.activeProfile.result = await e(), this.state.profiling = !1, this.state.activeProfile.peakBytes = Math.max(
521
555
  ...this.state.activeProfile.kernels.map((n) => n.totalBytesSnapshot)
522
- ), this.state.activeProfile.newBytes = this.state.numBytes - e, this.state.activeProfile.newTensors = this.state.numTensors - s;
556
+ ), this.state.activeProfile.newBytes = this.state.numBytes - t, this.state.activeProfile.newTensors = this.state.numTensors - s;
523
557
  for (const n of this.state.activeProfile.kernels)
524
558
  n.kernelTimeMs = await n.kernelTimeMs, n.extraInfo = await n.extraInfo;
525
559
  return this.state.activeProfile;
@@ -527,18 +561,18 @@ class v {
527
561
  isTapeOn() {
528
562
  return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
529
563
  }
530
- addTapeNode(t, e, s, n, r, c) {
531
- const h = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: r }, a = $(t);
532
- a != null && (n = a.gradFunc), n != null && (h.gradient = (i) => (i = i.map((d, l) => {
533
- if (d == null) {
534
- const u = s[l], k = tt(u.size, u.dtype);
535
- return this.makeTensor(k, u.shape, u.dtype);
564
+ addTapeNode(e, t, s, n, r, c) {
565
+ const d = { id: this.state.nextTapeNodeId++, kernelName: e, inputs: t, outputs: s, saved: r }, o = z(e);
566
+ o != null && (n = o.gradFunc), n != null && (d.gradient = (a) => (a = a.map((h, u) => {
567
+ if (h == null) {
568
+ const l = s[u], g = oe(l.size, l.dtype);
569
+ return this.makeTensor(g, l.shape, l.dtype);
536
570
  }
537
- return d;
538
- }), n(i.length > 1 ? i : i[0], r, c))), this.state.activeTape.push(h);
571
+ return h;
572
+ }), n(a.length > 1 ? a : a[0], r, c))), this.state.activeTape.push(d);
539
573
  }
540
- keep(t) {
541
- return t.kept = !0, t;
574
+ keep(e) {
575
+ return e.kept = !0, e;
542
576
  }
543
577
  startTape() {
544
578
  this.state.gradientDepth === 0 && (this.state.activeTape = []), this.state.gradientDepth++;
@@ -550,26 +584,26 @@ class v {
550
584
  * Start a scope. Use this with endScope() to achieve the same functionality
551
585
  * as scope() without the need for a function closure.
552
586
  */
553
- startScope(t) {
554
- const e = {
587
+ startScope(e) {
588
+ const t = {
555
589
  track: [],
556
590
  name: "unnamed scope",
557
591
  id: this.state.nextScopeId++
558
592
  };
559
- t && (e.name = t), this.state.scopeStack.push(e), this.state.activeScope = e;
593
+ e && (t.name = e), this.state.scopeStack.push(t), this.state.activeScope = t;
560
594
  }
561
595
  /**
562
596
  * End a scope. Use this with startScope() to achieve the same functionality
563
597
  * as scope() without the need for a function closure.
564
598
  */
565
- endScope(t) {
566
- const e = O(t), s = new Set(e.map((r) => r.id));
599
+ endScope(e) {
600
+ const t = X(e), s = new Set(t.map((r) => r.id));
567
601
  for (let r = 0; r < this.state.activeScope.track.length; r++) {
568
602
  const c = this.state.activeScope.track[r];
569
603
  !c.kept && !s.has(c.id) && c.dispose();
570
604
  }
571
605
  const n = this.state.scopeStack.pop();
572
- this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], e.forEach((r) => {
606
+ this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], t.forEach((r) => {
573
607
  !r.kept && r.scopeId === n?.id && this.track(r);
574
608
  });
575
609
  }
@@ -579,68 +613,68 @@ class v {
579
613
  * was not a function of that `x`. It also takes optional dy to multiply the
580
614
  * gradient, which defaults to `1`.
581
615
  */
582
- gradients(t, e, s, n = !1) {
583
- if (p(e.length > 0, () => "gradients() received an empty list of xs."), s != null && s.dtype !== "float32")
616
+ gradients(e, t, s, n = !1) {
617
+ if (p(t.length > 0, () => "gradients() received an empty list of xs."), s != null && s.dtype !== "float32")
584
618
  throw new Error(`dy must have 'float32' dtype, but has '${s.dtype}'`);
585
619
  const r = this.scopedRun(
586
620
  () => this.startTape(),
587
621
  () => this.endTape(),
588
- () => this.tidy("forward", t)
622
+ () => this.tidy("forward", e)
589
623
  );
590
624
  p(r instanceof B, () => "The result y returned by f() must be a tensor.");
591
- const c = L(this.state.activeTape, e, r);
592
- if (!n && c.length === 0 && e.length > 0)
625
+ const c = Y(this.state.activeTape, t, r);
626
+ if (!n && c.length === 0 && t.length > 0)
593
627
  throw new Error(
594
628
  "Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y."
595
629
  );
596
630
  return this.tidy("backward", () => {
597
- const h = {};
598
- h[r.id] = s ?? at(r.shape), it(
599
- h,
631
+ const d = {};
632
+ d[r.id] = s ?? le(r.shape), he(
633
+ d,
600
634
  c,
601
635
  // Pass the tidy function to avoid circular dep with `tape.ts`.
602
- (i) => this.tidy(i),
636
+ (a) => this.tidy(a),
603
637
  // Pass an add function to avoide a circular dep with `tape.ts`.
604
- ct
638
+ fe
605
639
  );
606
- const a = e.map((i) => h[i.id]);
607
- return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((i) => {
608
- if (i.saved !== void 0)
609
- for (const d of i.saved)
610
- d.dispose();
611
- }), this.state.activeTape = null), { value: r, grads: a };
640
+ const o = t.map((a) => d[a.id]);
641
+ return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((a) => {
642
+ if (a.saved !== void 0)
643
+ for (const h of a.saved)
644
+ h.dispose();
645
+ }), this.state.activeTape = null), { value: r, grads: o };
612
646
  });
613
647
  }
614
- customGrad(t) {
615
- return p(N(t), () => "The f passed in customGrad(f) must be a function."), (...e) => {
648
+ customGrad(e) {
649
+ return p(G(e), () => "The f passed in customGrad(f) must be a function."), (...t) => {
616
650
  p(
617
- e.every((h) => h instanceof B),
651
+ t.every((d) => d instanceof B),
618
652
  () => "The args passed in customGrad(f)(x1, x2,...) must all be tensors"
619
653
  );
620
654
  let s;
621
655
  const n = {};
622
- e.forEach((h, a) => {
623
- n[a] = h;
656
+ t.forEach((d, o) => {
657
+ n[o] = d;
624
658
  });
625
- const r = (h, a) => (s = t(...e, a), p(
659
+ const r = (d, o) => (s = e(...t, o), p(
626
660
  s.value instanceof B,
627
661
  () => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"
628
662
  ), p(
629
- N(s.gradFunc),
663
+ G(s.gradFunc),
630
664
  () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."
631
- ), s.value), c = (h, a) => {
632
- const i = s.gradFunc(h, a), d = Array.isArray(i) ? i : [i];
665
+ ), s.value), c = (d, o) => {
666
+ const a = s.gradFunc(d, o), h = Array.isArray(a) ? a : [a];
633
667
  p(
634
- d.length === e.length,
668
+ h.length === t.length,
635
669
  () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."
636
670
  ), p(
637
- d.every((u) => u instanceof B),
671
+ h.every((l) => l instanceof B),
638
672
  () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors."
639
673
  );
640
- const l = {};
641
- return d.forEach((u, k) => {
642
- l[k] = () => u;
643
- }), l;
674
+ const u = {};
675
+ return h.forEach((l, g) => {
676
+ u[g] = () => l;
677
+ }), u;
644
678
  };
645
679
  return this.runKernelFunc({
646
680
  forwardFunc: r,
@@ -649,18 +683,18 @@ class v {
649
683
  });
650
684
  };
651
685
  }
652
- readSync(t) {
653
- return this.state.tensorInfo.get(t).backend.readSync(t);
686
+ readSync(e) {
687
+ return this.state.tensorInfo.get(e).backend.readSync(e);
654
688
  }
655
- read(t) {
656
- return this.state.tensorInfo.get(t).backend.read(t);
689
+ read(e) {
690
+ return this.state.tensorInfo.get(e).backend.read(e);
657
691
  }
658
- readToGPU(t, e) {
659
- return this.state.tensorInfo.get(t).backend.readToGPU(t, e);
692
+ readToGPU(e, t) {
693
+ return this.state.tensorInfo.get(e).backend.readToGPU(e, t);
660
694
  }
661
- async time(t) {
662
- const e = P(), s = await this.backend.time(t);
663
- return s.wallMs = P() - e, s;
695
+ async time(e) {
696
+ const t = R(), s = await this.backend.time(e);
697
+ return s.wallMs = R() - t, s;
664
698
  }
665
699
  /**
666
700
  * Tracks a Tensor in the current scope to be automatically cleaned up
@@ -668,8 +702,8 @@ class v {
668
702
  *
669
703
  * @param result The Tensor to track in the current scope.
670
704
  */
671
- track(t) {
672
- return this.state.activeScope != null && (t.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(t)), t;
705
+ track(e) {
706
+ return this.state.activeScope != null && (e.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(e)), e;
673
707
  }
674
708
  get registeredVariables() {
675
709
  return this.state.registeredVariables;
@@ -679,38 +713,38 @@ class v {
679
713
  * registered backend factories.
680
714
  */
681
715
  reset() {
682
- this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new C();
683
- for (const t in this.registry)
684
- this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t];
716
+ this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new _();
717
+ for (const e in this.registry)
718
+ this.disposeRegisteredKernels(e), this.registry[e].dispose(), delete this.registry[e];
685
719
  this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
686
720
  }
687
721
  }
688
- function at(o) {
689
- const t = et(st(o), "float32");
690
- return y.makeTensor(t, o, "float32");
722
+ function le(i) {
723
+ const e = re(ie(i), "float32");
724
+ return y.makeTensor(e, i, "float32");
691
725
  }
692
- function ot() {
693
- const o = W();
694
- if (o._tfengine == null) {
695
- const t = new H(o);
696
- o._tfengine = new v(t);
726
+ function ue() {
727
+ const i = U();
728
+ if (i._tfengine == null) {
729
+ const e = new L(i);
730
+ i._tfengine = new I(e);
697
731
  }
698
- return X(o._tfengine.ENV), nt(() => o._tfengine), o._tfengine;
732
+ return ae(i._tfengine.ENV), ce(() => i._tfengine), i._tfengine;
699
733
  }
700
- const y = ot();
701
- function ct(o, t) {
702
- const e = w(o) || w(t), s = { a: o, b: t };
703
- return y.runKernel(e ? "Add16" : U, s);
734
+ const y = ue();
735
+ function fe(i, e) {
736
+ const t = T(i) || T(e), s = { a: i, b: e };
737
+ return y.runKernel(t ? "Add16" : J, s);
704
738
  }
705
739
  export {
706
- v as E,
740
+ I as E,
707
741
  y as a,
708
- it as b,
709
- ct as c,
710
- ot as g,
711
- E as isPackableTensor,
712
- w as isPackedTensor,
713
- z as packTensor,
714
- lt as packingSupported,
715
- ft as unpackTensor
742
+ he as b,
743
+ fe as c,
744
+ ue as g,
745
+ x as isPackableTensor,
746
+ T as isPackedTensor,
747
+ O as packTensor,
748
+ be as packingSupported,
749
+ we as unpackTensor
716
750
  };