@genai-fi/nanogpt 0.10.1 → 0.10.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +14 -14
- package/dist/{RealDiv-DgA3z9oO.js → RealDiv-zz7FpkKX.js} +17 -17
- package/dist/{Reshape-CF6odzV4.js → Reshape-CDVLyVfz.js} +3 -3
- package/dist/{Reshape-_kILl6tK.js → Reshape-CHdUjC72.js} +4 -4
- package/dist/TeachableLLM.js +8 -8
- package/dist/{axis_util-BvHEw88j.js → axis_util-BsIr9ZNu.js} +1 -1
- package/dist/backend.js +2 -2
- package/dist/{backend_util-D-rUb2ty.js → backend_util-B1XRLuq9.js} +31 -31
- package/dist/{backend_webgpu-B0u2ndUn.js → backend_webgpu-CqpfEImu.js} +5 -5
- package/dist/{broadcast_to-CwF7XIeu.js → broadcast_to-B0ChcDaz.js} +4 -4
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/matMulGelu.js +5 -5
- package/dist/checks/normRMS.js +4 -4
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.js +2 -2
- package/dist/checks/qkv.js +3 -3
- package/dist/checks/rope.js +2 -2
- package/dist/{complex-CSlYz-2T.js → complex-BBiRlsVq.js} +3 -3
- package/dist/{concat-BHlIJeyT.js → concat-DmBLPVGC.js} +3 -3
- package/dist/{concat_util-DcJk7YHS.js → concat_util-iBYIyuQe.js} +1 -1
- package/dist/{dataset-0xP8GjwI.js → dataset-D2P7rHAw.js} +5 -5
- package/dist/{dropout-C1pM3f11.js → dropout-B1x1kYMa.js} +3 -3
- package/dist/{expand_dims-BPG4fwBP.js → expand_dims-ouvfxQ1n.js} +3 -3
- package/dist/{exports_initializers-xuidcwI4.js → exports_initializers-CZSUJoVE.js} +1 -1
- package/dist/{gather-DykLGqmW.js → gather-CH9sdacz.js} +2 -2
- package/dist/{gelu-CNLFZWea.js → gelu-Bmhopi0J.js} +2 -2
- package/dist/{gpgpu_math-DDVJCn6-.js → gpgpu_math-DsCcikas.js} +3 -3
- package/dist/{index-ZyQhjEPo.js → index-D6Q1lPZO.js} +55 -55
- package/dist/{index-CjOj7j-u.js → index-DRyE072i.js} +15 -15
- package/dist/{kernel_funcs_utils-Dg_-E44D.js → kernel_funcs_utils-CWfOAPGO.js} +9 -9
- package/dist/layers/BaseLayer.js +10 -10
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +4 -4
- package/dist/layers/PositionEmbedding.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +6 -6
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +8 -8
- package/dist/{log_sum_exp-DWI-76TI.js → log_sum_exp-D3ftBNY5.js} +6 -6
- package/dist/main.js +8 -8
- package/dist/{matMul16--R5hOwDG.js → matMul16-fEAJ4smh.js} +4 -4
- package/dist/{mat_mul-DeAh4uTH.js → mat_mul-C59XWcJd.js} +2 -2
- package/dist/{mod-Gt1rMB4n.js → mod-DESSvHIU.js} +2 -2
- package/dist/models/NanoGPTV1.js +2 -2
- package/dist/models/model.js +8 -8
- package/dist/{mulmat_packed_gpu-BMFhLwta.js → mulmat_packed_gpu-Coh6qbJk.js} +1 -1
- package/dist/{ones-CAMiP4I2.js → ones-jU9jlQvM.js} +4 -4
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/add16.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/concat16.js +2 -2
- package/dist/ops/cpu/adamAdjust.js +2 -2
- package/dist/ops/cpu/adamMoments.js +3 -3
- package/dist/ops/cpu/appendCache.js +3 -3
- package/dist/ops/cpu/attentionMask.js +6 -6
- package/dist/ops/cpu/fusedSoftmax.js +3 -3
- package/dist/ops/cpu/gatherSub.js +4 -4
- package/dist/ops/cpu/gelu.js +2 -2
- package/dist/ops/cpu/matMul16.js +3 -3
- package/dist/ops/cpu/matMulGelu.js +4 -4
- package/dist/ops/cpu/matMulMul.js +2 -2
- package/dist/ops/cpu/mulDropout.js +2 -2
- package/dist/ops/cpu/normRMS.js +2 -2
- package/dist/ops/cpu/qkv.js +4 -4
- package/dist/ops/cpu/rope.js +6 -6
- package/dist/ops/cpu/scatterSub.js +7 -7
- package/dist/ops/dot16.js +2 -2
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/add16.js +2 -2
- package/dist/ops/grads/attentionMask.js +3 -3
- package/dist/ops/grads/gelu.js +3 -3
- package/dist/ops/grads/matMul16.js +4 -4
- package/dist/ops/grads/matMulGelu.js +2 -2
- package/dist/ops/grads/normRMS.js +2 -2
- package/dist/ops/grads/pack16.js +4 -4
- package/dist/ops/grads/qkv.js +4 -4
- package/dist/ops/grads/rope.js +3 -3
- package/dist/ops/grads/softmax16.js +2 -2
- package/dist/ops/grads/unpack16.js +3 -3
- package/dist/ops/matMul16.js +3 -3
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mul16.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/pack16.js +2 -2
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/reshape16.js +3 -3
- package/dist/ops/rope.js +5 -5
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.js +2 -2
- package/dist/ops/softmax16.js +1 -1
- package/dist/ops/sub16.js +1 -1
- package/dist/ops/sum16.js +2 -2
- package/dist/ops/transpose16.js +4 -4
- package/dist/ops/unpack16.js +2 -2
- package/dist/ops/webgl/adamAdjust.js +3 -3
- package/dist/ops/webgl/adamMoments.js +2 -2
- package/dist/ops/webgl/appendCache.js +2 -2
- package/dist/ops/webgl/attentionMask.js +2 -2
- package/dist/ops/webgl/fusedSoftmax.js +6 -6
- package/dist/ops/webgl/gatherSub.js +2 -2
- package/dist/ops/webgl/gelu.js +3 -3
- package/dist/ops/webgl/log.js +4 -4
- package/dist/ops/webgl/matMul16.js +5 -5
- package/dist/ops/webgl/matMulGelu.js +6 -6
- package/dist/ops/webgl/matMulMul.js +2 -2
- package/dist/ops/webgl/mulDropout.js +2 -2
- package/dist/ops/webgl/normRMS.js +3 -3
- package/dist/ops/webgl/qkv.js +2 -2
- package/dist/ops/webgl/rope.js +2 -2
- package/dist/ops/webgl/scatterSub.js +2 -2
- package/dist/ops/webgpu/adamAdjust.js +5 -5
- package/dist/ops/webgpu/adamMoments.js +5 -5
- package/dist/ops/webgpu/add16.js +2 -2
- package/dist/ops/webgpu/appendCache.js +5 -5
- package/dist/ops/webgpu/attentionMask.js +4 -4
- package/dist/ops/webgpu/attentionMask32_program.js +2 -2
- package/dist/ops/webgpu/concat16.js +7 -7
- package/dist/ops/webgpu/gatherSub.js +5 -5
- package/dist/ops/webgpu/gelu.js +4 -4
- package/dist/ops/webgpu/matMul16.js +6 -6
- package/dist/ops/webgpu/matMul16_program.js +3 -3
- package/dist/ops/webgpu/mul16.js +2 -2
- package/dist/ops/webgpu/normRMS.js +4 -4
- package/dist/ops/webgpu/normRMSGrad.js +6 -6
- package/dist/ops/webgpu/pack16.js +2 -2
- package/dist/ops/webgpu/pack16_program.js +2 -2
- package/dist/ops/webgpu/qkv.js +4 -4
- package/dist/ops/webgpu/rope.js +5 -5
- package/dist/ops/webgpu/scatterSub.js +5 -5
- package/dist/ops/webgpu/slice16.js +6 -6
- package/dist/ops/webgpu/softmax16.js +4 -4
- package/dist/ops/webgpu/softmax16_program.js +2 -2
- package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
- package/dist/ops/webgpu/softmax16grad.js +2 -2
- package/dist/ops/webgpu/sub16.js +2 -2
- package/dist/ops/webgpu/sum16.js +5 -5
- package/dist/ops/webgpu/transpose16.js +3 -3
- package/dist/ops/webgpu/transpose16_program.js +2 -2
- package/dist/ops/webgpu/transpose16_shared_program.js +4 -4
- package/dist/ops/webgpu/unpack16.js +4 -4
- package/dist/ops/webgpu/utils/binary_op.js +4 -4
- package/dist/ops/webgpu/utils/reductions.js +5 -5
- package/dist/{ops-CNI3TwqM.js → ops-BFDtP6th.js} +24 -24
- package/dist/{pack16-CFUqumar.js → pack16-CmVZs6af.js} +3 -3
- package/dist/patches/PackedTensor.js +1 -1
- package/dist/patches/engine.js +7 -5
- package/dist/patches/tape.js +1 -1
- package/dist/patches/webgpu_backend.js +5 -5
- package/dist/patches/webgpu_base.js +1 -1
- package/dist/patches/webgpu_program.js +3 -3
- package/dist/{random_width-DY6Kk2Dl.js → random_width-BVV9HveY.js} +31 -31
- package/dist/{range-BMS52eQi.js → range-ZZZD60Fx.js} +2 -2
- package/dist/{reciprocal-CTmshQ9J.js → reciprocal-CrYlsAGD.js} +2 -2
- package/dist/{register_all_kernels-Bwu1PTuU.js → register_all_kernels-nvj2k7OC.js} +41 -41
- package/dist/{relu-yZ2-7WxU.js → relu-BYDneVPn.js} +2 -2
- package/dist/{reshape-DevtBWtf.js → reshape-CaPQzFvz.js} +2 -2
- package/dist/{rope-B5UUMsPi.js → rope-s4W2XO9B.js} +5 -5
- package/dist/{scatter_nd_util-5EL-8VAQ.js → scatter_nd_util-C7zXRT_h.js} +1 -1
- package/dist/{selu_util-D1w6yyTO.js → selu_util-BGPXmd4B.js} +16 -16
- package/dist/{shared-BRksrJb3.js → shared-CHhxz-O5.js} +1 -1
- package/dist/{shared-BuAXb4CI.js → shared-D2NP_CpY.js} +8 -8
- package/dist/{sin-BGfy2HZo.js → sin-Djs4aQiu.js} +2 -2
- package/dist/{slice-D_gkkqZK.js → slice-DvovR5wq.js} +2 -2
- package/dist/{slice_util-DtEldBfK.js → slice_util-DyjSAD0u.js} +1 -1
- package/dist/{softmax-ZHVebtR1.js → softmax-C9JQEtnO.js} +2 -2
- package/dist/{split-DrfihRpZ.js → split-DBck65sX.js} +2 -2
- package/dist/{squeeze-DZEpeblb.js → squeeze-C00Ipm_7.js} +3 -3
- package/dist/{stack-yOIAalTq.js → stack-ChnHwRpX.js} +3 -3
- package/dist/{sum-_fzj5ZTB.js → sum-ywRJj3Zr.js} +2 -2
- package/dist/{tensor-f35l8Odg.js → tensor-0r5yOo2R.js} +1 -1
- package/dist/{tensor-DdQUJZlz.js → tensor-CzmOBsdf.js} +21 -21
- package/dist/{tensor1d-CeZuc-Rv.js → tensor1d-BlUT89BP.js} +2 -2
- package/dist/{tensor2d-G4Ys2GxX.js → tensor2d-CSB4KOb0.js} +2 -2
- package/dist/{tensor4d-B8roDgtc.js → tensor4d-D7bLqGqz.js} +2 -2
- package/dist/{tensor_util-DV-FP5Q3.js → tensor_util-DfwaWayG.js} +12 -12
- package/dist/{tfjs_backend-kNyO5L2d.js → tfjs_backend-CNkSTL0c.js} +38 -38
- package/dist/{tile-BzyEiF-F.js → tile-CR074jmp.js} +3 -3
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +2 -2
- package/dist/training/sparseCrossEntropy.js +3 -3
- package/dist/{transpose-DKELTqhe.js → transpose-DH4gmHvu.js} +4 -4
- package/dist/utilities/dummy.js +3 -3
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.js +338 -304
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.js +5 -5
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-Bhn5bHYv.js → variable-DzfrwYuP.js} +1 -1
- package/dist/{webgpu_program-Cigz-7RF.js → webgpu_program-DzaQiqel.js} +2 -2
- package/dist/{webgpu_util-BBCnKm2X.js → webgpu_util-0_ubCEHJ.js} +2 -2
- package/dist/{zeros-2gldETuK.js → zeros-DBFVbpv5.js} +3 -3
- package/package.json +1 -1
package/dist/utilities/packed.js
CHANGED
|
@@ -1,68 +1,102 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
|
|
1
|
+
import { g as $ } from "../index-D5v913EJ.js";
|
|
2
|
+
import { p as K } from "../index-xuotMAFm.js";
|
|
3
|
+
import { w, o as W, p as N, K as H, I as P, q as A, s as z, t as X, v as Y, x as Z, A as J } from "../tensor_util-DfwaWayG.js";
|
|
4
|
+
import { b as Q, E as L, a as p, o as ee, q as te, r as se, t as C, V as ne, u as G, T as B, v as R, m as re, s as ie, w as ae, f as oe, x as ce } from "../tensor-CzmOBsdf.js";
|
|
5
|
+
import { PackableTensor as j, PackableVariable as de } from "../patches/PackedTensor.js";
|
|
6
|
+
function be() {
|
|
5
7
|
return y.backendName === "webgpu";
|
|
6
8
|
}
|
|
7
|
-
function
|
|
8
|
-
return
|
|
9
|
+
function x(i) {
|
|
10
|
+
return i.packed !== void 0;
|
|
9
11
|
}
|
|
10
|
-
function
|
|
11
|
-
return
|
|
12
|
+
function T(i) {
|
|
13
|
+
return x(i) && i.packed;
|
|
12
14
|
}
|
|
13
|
-
function
|
|
14
|
-
if (
|
|
15
|
-
if (
|
|
15
|
+
function O(i) {
|
|
16
|
+
if (x(i)) {
|
|
17
|
+
if (i.dtype !== "int32")
|
|
16
18
|
throw new Error("packTensor: only int32 tensors can be packed.");
|
|
17
|
-
return
|
|
19
|
+
return i.packed = !0, i;
|
|
18
20
|
} else
|
|
19
|
-
throw console.error("Tensor:",
|
|
21
|
+
throw console.error("Tensor:", i), new Error("Tensor is not packable");
|
|
20
22
|
}
|
|
21
|
-
function
|
|
22
|
-
if (
|
|
23
|
-
if (
|
|
23
|
+
function we(i) {
|
|
24
|
+
if (x(i)) {
|
|
25
|
+
if (i.dtype !== "float32")
|
|
24
26
|
throw new Error("unpackTensor: only float32 tensors can be unpacked.");
|
|
25
|
-
|
|
27
|
+
i.packed = !1;
|
|
26
28
|
}
|
|
27
|
-
return
|
|
29
|
+
return i;
|
|
28
30
|
}
|
|
29
|
-
function
|
|
30
|
-
for (let n =
|
|
31
|
-
const r =
|
|
32
|
-
if (r.outputs.forEach((
|
|
33
|
-
const
|
|
34
|
-
|
|
31
|
+
function he(i, e, t, s) {
|
|
32
|
+
for (let n = e.length - 1; n >= 0; n--) {
|
|
33
|
+
const r = e[n], c = [];
|
|
34
|
+
if (r.outputs.forEach((o) => {
|
|
35
|
+
const a = i[o.id];
|
|
36
|
+
a != null ? c.push(a) : c.push(null);
|
|
35
37
|
}), r.gradient == null)
|
|
36
38
|
throw new Error(`Cannot compute gradient: gradient function not found for ${r.kernelName}.`);
|
|
37
|
-
const
|
|
38
|
-
for (const
|
|
39
|
-
if (!(
|
|
39
|
+
const d = r.gradient(c);
|
|
40
|
+
for (const o in r.inputs) {
|
|
41
|
+
if (!(o in d))
|
|
40
42
|
throw new Error(
|
|
41
|
-
`Cannot backprop through input ${
|
|
43
|
+
`Cannot backprop through input ${o}. Available gradients found: ${Object.keys(d)}.`
|
|
42
44
|
);
|
|
43
|
-
const
|
|
44
|
-
if (
|
|
45
|
+
const a = t(() => d[o]()), h = T(a);
|
|
46
|
+
if (a.dtype !== "float32" && (!h || a.dtype !== "int32"))
|
|
45
47
|
throw new Error(
|
|
46
|
-
`Error in gradient for op ${r.kernelName}. The gradient of input ${
|
|
48
|
+
`Error in gradient for op ${r.kernelName}. The gradient of input ${o} must have 'float32' dtype, but has '${a.dtype}'`
|
|
47
49
|
);
|
|
48
|
-
const
|
|
49
|
-
if (!
|
|
50
|
+
const u = r.inputs[o];
|
|
51
|
+
if (!Q(a.shape, u.shape))
|
|
50
52
|
throw new Error(
|
|
51
|
-
`Error in gradient for op ${r.kernelName}. The gradient of input '${
|
|
53
|
+
`Error in gradient for op ${r.kernelName}. The gradient of input '${o}' has shape '${a.shape}', which does not match the shape of the input '${u.shape}'`
|
|
52
54
|
);
|
|
53
|
-
if (
|
|
54
|
-
|
|
55
|
+
if (i[u.id] == null)
|
|
56
|
+
i[u.id] = a;
|
|
55
57
|
else {
|
|
56
|
-
const
|
|
57
|
-
|
|
58
|
+
const l = i[u.id];
|
|
59
|
+
i[u.id] = s(l, a), l.dispose();
|
|
58
60
|
}
|
|
59
61
|
}
|
|
60
62
|
}
|
|
61
63
|
}
|
|
62
|
-
|
|
63
|
-
|
|
64
|
+
let S;
|
|
65
|
+
function U() {
|
|
66
|
+
if (S == null) {
|
|
67
|
+
let i;
|
|
68
|
+
if (typeof window < "u")
|
|
69
|
+
i = window;
|
|
70
|
+
else if (typeof $ < "u")
|
|
71
|
+
i = $;
|
|
72
|
+
else if (typeof K < "u")
|
|
73
|
+
i = K;
|
|
74
|
+
else if (typeof self < "u")
|
|
75
|
+
i = self;
|
|
76
|
+
else
|
|
77
|
+
throw new Error("Could not find a global object");
|
|
78
|
+
S = i;
|
|
79
|
+
}
|
|
80
|
+
return S;
|
|
64
81
|
}
|
|
65
|
-
|
|
82
|
+
const E = {
|
|
83
|
+
engine: null
|
|
84
|
+
}, D = U();
|
|
85
|
+
if (D._tfengine)
|
|
86
|
+
throw new Error("TensorFlow engine already initialized before patching.");
|
|
87
|
+
Object.defineProperty(D, "_tfengine", {
|
|
88
|
+
get: () => {
|
|
89
|
+
if (E.engine == null) {
|
|
90
|
+
const i = new L(D);
|
|
91
|
+
E.engine = new I(i);
|
|
92
|
+
}
|
|
93
|
+
return E.engine;
|
|
94
|
+
}
|
|
95
|
+
});
|
|
96
|
+
function F(i) {
|
|
97
|
+
return i.kernelName != null;
|
|
98
|
+
}
|
|
99
|
+
class _ {
|
|
66
100
|
// Public since optimizers will use it.
|
|
67
101
|
registeredVariables = {};
|
|
68
102
|
nextTapeNodeId = 0;
|
|
@@ -96,17 +130,17 @@ class C {
|
|
|
96
130
|
kernels: [],
|
|
97
131
|
result: null,
|
|
98
132
|
get kernelNames() {
|
|
99
|
-
return Array.from(new Set(this.kernels.map((
|
|
133
|
+
return Array.from(new Set(this.kernels.map((e) => e.name)));
|
|
100
134
|
}
|
|
101
135
|
};
|
|
102
136
|
dispose() {
|
|
103
|
-
for (const
|
|
104
|
-
this.registeredVariables[
|
|
137
|
+
for (const e in this.registeredVariables)
|
|
138
|
+
this.registeredVariables[e].dispose();
|
|
105
139
|
}
|
|
106
140
|
}
|
|
107
|
-
class
|
|
108
|
-
constructor(
|
|
109
|
-
this.ENV =
|
|
141
|
+
class I {
|
|
142
|
+
constructor(e) {
|
|
143
|
+
this.ENV = e, this.state = new _(), console.log("GenAI Patched Engine Initialized");
|
|
110
144
|
}
|
|
111
145
|
version = "GENAI_PATCHED_ENGINE";
|
|
112
146
|
state;
|
|
@@ -123,9 +157,9 @@ class v {
|
|
|
123
157
|
});
|
|
124
158
|
if (this.backendInstance != null)
|
|
125
159
|
return;
|
|
126
|
-
const
|
|
127
|
-
for (let
|
|
128
|
-
const s = t
|
|
160
|
+
const e = this.getSortedBackends();
|
|
161
|
+
for (let t = 0; t < e.length; t++) {
|
|
162
|
+
const s = e[t];
|
|
129
163
|
if (await this.initializeBackend(s).success) {
|
|
130
164
|
await this.setBackend(s);
|
|
131
165
|
return;
|
|
@@ -139,53 +173,53 @@ class v {
|
|
|
139
173
|
`Backend '${this.backendName}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
|
|
140
174
|
);
|
|
141
175
|
if (this.backendInstance == null) {
|
|
142
|
-
const { name:
|
|
143
|
-
if (
|
|
176
|
+
const { name: e, asyncInit: t } = this.initializeBackendsAndReturnBest();
|
|
177
|
+
if (t)
|
|
144
178
|
throw new Error(
|
|
145
|
-
`The highest priority backend '${
|
|
179
|
+
`The highest priority backend '${e}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
|
|
146
180
|
);
|
|
147
|
-
this.setBackend(
|
|
181
|
+
this.setBackend(e);
|
|
148
182
|
}
|
|
149
183
|
return this.backendInstance;
|
|
150
184
|
}
|
|
151
185
|
backendNames() {
|
|
152
186
|
return Object.keys(this.registryFactory);
|
|
153
187
|
}
|
|
154
|
-
findBackend(
|
|
155
|
-
if (!(
|
|
156
|
-
if (
|
|
157
|
-
const { asyncInit:
|
|
158
|
-
if (
|
|
188
|
+
findBackend(e) {
|
|
189
|
+
if (!(e in this.registry))
|
|
190
|
+
if (e in this.registryFactory) {
|
|
191
|
+
const { asyncInit: t } = this.initializeBackend(e);
|
|
192
|
+
if (t)
|
|
159
193
|
return null;
|
|
160
194
|
} else
|
|
161
195
|
return null;
|
|
162
|
-
return this.registry[
|
|
196
|
+
return this.registry[e];
|
|
163
197
|
}
|
|
164
|
-
findBackendFactory(
|
|
165
|
-
return
|
|
198
|
+
findBackendFactory(e) {
|
|
199
|
+
return e in this.registryFactory ? this.registryFactory[e].factory : null;
|
|
166
200
|
}
|
|
167
|
-
registerBackend(
|
|
168
|
-
return
|
|
201
|
+
registerBackend(e, t, s = 1) {
|
|
202
|
+
return e in this.registryFactory ? (w(`${e} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[e] = { factory: t, priority: s }, console.log("Registered backend", e), !0);
|
|
169
203
|
}
|
|
170
|
-
async setBackend(
|
|
171
|
-
if (this.registryFactory[
|
|
172
|
-
throw new Error(`Backend name '${
|
|
173
|
-
if (this.backendName =
|
|
204
|
+
async setBackend(e) {
|
|
205
|
+
if (this.registryFactory[e] == null)
|
|
206
|
+
throw new Error(`Backend name '${e}' not found in registry`);
|
|
207
|
+
if (this.backendName = e, this.registry[e] == null) {
|
|
174
208
|
this.backendInstance = null;
|
|
175
|
-
const { success:
|
|
176
|
-
if (!(s ? await
|
|
209
|
+
const { success: t, asyncInit: s } = this.initializeBackend(e);
|
|
210
|
+
if (!(s ? await t : t))
|
|
177
211
|
return !1;
|
|
178
212
|
}
|
|
179
|
-
return this.backendInstance = this.registry[
|
|
213
|
+
return this.backendInstance = this.registry[e], this.setupRegisteredKernels(), this.profiler = new W(this.backendInstance), !0;
|
|
180
214
|
}
|
|
181
215
|
setupRegisteredKernels() {
|
|
182
|
-
|
|
183
|
-
|
|
216
|
+
N(this.backendName).forEach((t) => {
|
|
217
|
+
t.setupFunc != null && t.setupFunc(this.backendInstance);
|
|
184
218
|
});
|
|
185
219
|
}
|
|
186
|
-
disposeRegisteredKernels(
|
|
187
|
-
|
|
188
|
-
s.disposeFunc != null && s.disposeFunc(this.registry[
|
|
220
|
+
disposeRegisteredKernels(e) {
|
|
221
|
+
N(e).forEach((s) => {
|
|
222
|
+
s.disposeFunc != null && s.disposeFunc(this.registry[e]);
|
|
189
223
|
});
|
|
190
224
|
}
|
|
191
225
|
/**
|
|
@@ -194,82 +228,82 @@ class v {
|
|
|
194
228
|
* whether the initialization of the backend succeeded. Throws an error if
|
|
195
229
|
* there is no backend in the factory registry.
|
|
196
230
|
*/
|
|
197
|
-
initializeBackend(
|
|
198
|
-
const
|
|
199
|
-
if (
|
|
200
|
-
throw new Error(`Cannot initialize backend ${
|
|
231
|
+
initializeBackend(e) {
|
|
232
|
+
const t = this.registryFactory[e];
|
|
233
|
+
if (t == null)
|
|
234
|
+
throw new Error(`Cannot initialize backend ${e}, no registration found.`);
|
|
201
235
|
try {
|
|
202
|
-
const s =
|
|
203
|
-
if (s && !(s instanceof
|
|
204
|
-
const n = ++this.pendingBackendInitId, r = s.then((c) => n < this.pendingBackendInitId ? !1 : (this.registry[
|
|
236
|
+
const s = t.factory();
|
|
237
|
+
if (s && !(s instanceof H) && typeof s.then == "function") {
|
|
238
|
+
const n = ++this.pendingBackendInitId, r = s.then((c) => n < this.pendingBackendInitId ? !1 : (this.registry[e] = c, this.pendingBackendInit = null, !0)).catch((c) => (n < this.pendingBackendInitId || (this.pendingBackendInit = null, w(`Initialization of backend ${e} failed`), w(c.stack || c.message)), !1));
|
|
205
239
|
return this.pendingBackendInit = r, { success: r, asyncInit: !0 };
|
|
206
240
|
} else
|
|
207
|
-
return this.registry[
|
|
241
|
+
return this.registry[e] = s, { success: !0, asyncInit: !1 };
|
|
208
242
|
} catch (s) {
|
|
209
|
-
return
|
|
243
|
+
return w(`Initialization of backend ${e} failed`), w(s.stack || s.message), { success: !1, asyncInit: !1 };
|
|
210
244
|
}
|
|
211
245
|
}
|
|
212
|
-
removeBackend(
|
|
213
|
-
if (!(
|
|
214
|
-
throw new Error(`${
|
|
215
|
-
this.backendName ===
|
|
246
|
+
removeBackend(e) {
|
|
247
|
+
if (!(e in this.registryFactory))
|
|
248
|
+
throw new Error(`${e} backend not found in registry`);
|
|
249
|
+
this.backendName === e && this.pendingBackendInit != null && this.pendingBackendInitId++, e in this.registry && (this.disposeRegisteredKernels(e), this.registry[e].dispose(), delete this.registry[e]), delete this.registryFactory[e], this.backendName === e && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null);
|
|
216
250
|
}
|
|
217
251
|
getSortedBackends() {
|
|
218
252
|
if (Object.keys(this.registryFactory).length === 0)
|
|
219
253
|
throw new Error("No backend found in registry.");
|
|
220
|
-
return Object.keys(this.registryFactory).sort((
|
|
254
|
+
return Object.keys(this.registryFactory).sort((e, t) => this.registryFactory[t].priority - this.registryFactory[e].priority);
|
|
221
255
|
}
|
|
222
256
|
initializeBackendsAndReturnBest() {
|
|
223
|
-
const
|
|
224
|
-
for (let
|
|
225
|
-
const s = t
|
|
257
|
+
const e = this.getSortedBackends();
|
|
258
|
+
for (let t = 0; t < e.length; t++) {
|
|
259
|
+
const s = e[t], { success: n, asyncInit: r } = this.initializeBackend(s);
|
|
226
260
|
if (r || n)
|
|
227
261
|
return { name: s, asyncInit: r };
|
|
228
262
|
}
|
|
229
263
|
throw new Error("Could not initialize any backends, all backend initializations failed.");
|
|
230
264
|
}
|
|
231
|
-
moveData(
|
|
232
|
-
const s = this.state.tensorInfo.get(
|
|
233
|
-
s || console.warn("Tried to move data that does not exist", this.state,
|
|
234
|
-
const n = s.backend, r = this.readSync(
|
|
235
|
-
n.disposeData(
|
|
265
|
+
moveData(e, t) {
|
|
266
|
+
const s = this.state.tensorInfo.get(t);
|
|
267
|
+
s || console.warn("Tried to move data that does not exist", this.state, t);
|
|
268
|
+
const n = s.backend, r = this.readSync(t), c = n.refCount(t);
|
|
269
|
+
n.disposeData(t, !0), s.backend = e, e.move(t, r, s.shape, s.dtype, c), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
|
|
236
270
|
}
|
|
237
|
-
tidy(
|
|
271
|
+
tidy(e, t) {
|
|
238
272
|
let s = null;
|
|
239
|
-
if (
|
|
240
|
-
if (typeof
|
|
273
|
+
if (t == null) {
|
|
274
|
+
if (typeof e != "function")
|
|
241
275
|
throw new Error("Please provide a function to tidy()");
|
|
242
|
-
|
|
276
|
+
t = e;
|
|
243
277
|
} else {
|
|
244
|
-
if (typeof
|
|
278
|
+
if (typeof e != "string" && !(e instanceof String))
|
|
245
279
|
throw new Error("When calling with two arguments, the first argument to tidy() must be a string");
|
|
246
|
-
if (typeof
|
|
280
|
+
if (typeof t != "function")
|
|
247
281
|
throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function");
|
|
248
|
-
s =
|
|
282
|
+
s = e;
|
|
249
283
|
}
|
|
250
284
|
let n;
|
|
251
285
|
return this.scopedRun(
|
|
252
286
|
() => this.startScope(s),
|
|
253
287
|
() => this.endScope(n),
|
|
254
|
-
() => (n =
|
|
288
|
+
() => (n = t(), n instanceof Promise && console.error("Cannot return a Promise inside of tidy."), n)
|
|
255
289
|
);
|
|
256
290
|
}
|
|
257
|
-
scopedRun(
|
|
258
|
-
|
|
291
|
+
scopedRun(e, t, s) {
|
|
292
|
+
e();
|
|
259
293
|
try {
|
|
260
294
|
const n = s();
|
|
261
|
-
return
|
|
295
|
+
return t(), n;
|
|
262
296
|
} catch (n) {
|
|
263
|
-
throw
|
|
297
|
+
throw t(), n;
|
|
264
298
|
}
|
|
265
299
|
}
|
|
266
300
|
static nextTensorId = 0;
|
|
267
301
|
nextTensorId() {
|
|
268
|
-
return
|
|
302
|
+
return I.nextTensorId++;
|
|
269
303
|
}
|
|
270
304
|
static nextVariableId = 0;
|
|
271
305
|
nextVariableId() {
|
|
272
|
-
return
|
|
306
|
+
return I.nextVariableId++;
|
|
273
307
|
}
|
|
274
308
|
/**
|
|
275
309
|
* This method is called instead of the public-facing tensor.clone() when
|
|
@@ -277,16 +311,16 @@ class v {
|
|
|
277
311
|
* operation to the tape regardless of being called inside a kernel
|
|
278
312
|
* execution.
|
|
279
313
|
*/
|
|
280
|
-
clone(
|
|
281
|
-
const s =
|
|
314
|
+
clone(e) {
|
|
315
|
+
const s = T(e) ? O(y.runKernel(P, { x: e })) : y.runKernel(P, { x: e }), n = { x: e }, r = (d) => ({
|
|
282
316
|
x: () => {
|
|
283
|
-
const
|
|
284
|
-
|
|
285
|
-
|
|
317
|
+
const o = "float32", a = { x: d }, h = { dtype: o }, u = T(e), l = y.runKernel(
|
|
318
|
+
Z,
|
|
319
|
+
a,
|
|
286
320
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
287
|
-
|
|
321
|
+
h
|
|
288
322
|
);
|
|
289
|
-
return
|
|
323
|
+
return u && O(l), l;
|
|
290
324
|
}
|
|
291
325
|
}), c = [];
|
|
292
326
|
return this.addTapeNode(this.state.activeScope.name, n, [s], r, c, {}), s;
|
|
@@ -304,24 +338,24 @@ class v {
|
|
|
304
338
|
* for the backprop computation. These are booleans since the output
|
|
305
339
|
* tensors are not visible to the user.
|
|
306
340
|
*/
|
|
307
|
-
runKernel(
|
|
308
|
-
if (this.backendName == null && this.backend, !(
|
|
309
|
-
throw new Error(`Kernel '${
|
|
310
|
-
return this.runKernelFunc({ kernelName:
|
|
341
|
+
runKernel(e, t, s) {
|
|
342
|
+
if (this.backendName == null && this.backend, !(A(e, this.backendName) != null))
|
|
343
|
+
throw new Error(`Kernel '${e}' not registered for backend '${this.backendName}'`);
|
|
344
|
+
return this.runKernelFunc({ kernelName: e, inputs: t, attrs: s });
|
|
311
345
|
}
|
|
312
346
|
shouldCheckForMemLeaks() {
|
|
313
347
|
return this.ENV.getBool("IS_TEST");
|
|
314
348
|
}
|
|
315
|
-
checkKernelForMemLeak(
|
|
349
|
+
checkKernelForMemLeak(e, t, s) {
|
|
316
350
|
const n = this.backend.numDataIds();
|
|
317
351
|
let r = 0;
|
|
318
|
-
s.forEach((
|
|
319
|
-
r +=
|
|
352
|
+
s.forEach((o) => {
|
|
353
|
+
r += o.dtype === "complex64" ? 3 : 1;
|
|
320
354
|
});
|
|
321
|
-
const c = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1],
|
|
322
|
-
if (
|
|
355
|
+
const c = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], d = n - t - r - c;
|
|
356
|
+
if (d > 0)
|
|
323
357
|
throw new Error(
|
|
324
|
-
`Backend '${this.backendName}' has an internal memory leak (${
|
|
358
|
+
`Backend '${this.backendName}' has an internal memory leak (${d} data ids) after running '${e}'`
|
|
325
359
|
);
|
|
326
360
|
}
|
|
327
361
|
/**
|
|
@@ -329,81 +363,81 @@ class v {
|
|
|
329
363
|
*
|
|
330
364
|
* Use `runKernel` to execute kernels from outside of engine.
|
|
331
365
|
*/
|
|
332
|
-
runKernelFunc(
|
|
333
|
-
let
|
|
366
|
+
runKernelFunc(e) {
|
|
367
|
+
let t, s = [];
|
|
334
368
|
const n = this.isTapeOn(), r = this.state.numBytes, c = this.state.numTensors;
|
|
335
369
|
this.shouldCheckForMemLeaks() && this.state.numDataMovesStack.push(0);
|
|
336
|
-
let
|
|
370
|
+
let d;
|
|
337
371
|
this.backendName == null && this.backend;
|
|
338
|
-
let
|
|
339
|
-
const
|
|
340
|
-
if (
|
|
341
|
-
const { kernelName: f, inputs:
|
|
372
|
+
let o;
|
|
373
|
+
const a = F(e) ? e.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
|
|
374
|
+
if (F(e)) {
|
|
375
|
+
const { kernelName: f, inputs: v, attrs: m } = e;
|
|
342
376
|
this.backendName == null && this.backend;
|
|
343
|
-
const
|
|
377
|
+
const k = A(f, this.backendName);
|
|
344
378
|
p(
|
|
345
|
-
|
|
379
|
+
k != null,
|
|
346
380
|
() => `Cannot find registered kernel '${f}' for backend '${this.backendName}'`
|
|
347
|
-
),
|
|
348
|
-
const
|
|
349
|
-
|
|
350
|
-
const
|
|
351
|
-
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(f,
|
|
352
|
-
const
|
|
381
|
+
), d = () => {
|
|
382
|
+
const q = this.backend.numDataIds();
|
|
383
|
+
o = k.kernelFunc({ inputs: v, attrs: m, backend: this.backend });
|
|
384
|
+
const M = Array.isArray(o) ? o : [o];
|
|
385
|
+
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(f, q, M);
|
|
386
|
+
const V = M.map((b) => b.rank != null ? b : this.makeTensorFromTensorInfo(b));
|
|
353
387
|
if (n) {
|
|
354
|
-
const b = this.getTensorsForGradient(f,
|
|
388
|
+
const b = this.getTensorsForGradient(f, v, V);
|
|
355
389
|
s = this.saveTensorsForBackwardMode(b ?? []);
|
|
356
390
|
}
|
|
357
|
-
return
|
|
391
|
+
return V;
|
|
358
392
|
};
|
|
359
393
|
} else {
|
|
360
|
-
const { forwardFunc: f } =
|
|
361
|
-
n && (s = m.map((
|
|
394
|
+
const { forwardFunc: f } = e, v = (m) => {
|
|
395
|
+
n && (s = m.map((k) => this.keep(this.clone(k))));
|
|
362
396
|
};
|
|
363
|
-
|
|
397
|
+
d = () => {
|
|
364
398
|
const m = this.backend.numDataIds();
|
|
365
|
-
|
|
366
|
-
const
|
|
367
|
-
return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(
|
|
399
|
+
o = this.tidy(() => f(this.backend, v));
|
|
400
|
+
const k = Array.isArray(o) ? o : [o];
|
|
401
|
+
return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(a, m, k), k;
|
|
368
402
|
};
|
|
369
403
|
}
|
|
370
|
-
const { inputs:
|
|
371
|
-
let
|
|
404
|
+
const { inputs: h, attrs: u } = e, l = F(e) ? null : e.backwardsFunc;
|
|
405
|
+
let g;
|
|
372
406
|
return this.scopedRun(
|
|
373
407
|
// Stop recording to a tape when running a kernel.
|
|
374
408
|
() => this.state.kernelDepth++,
|
|
375
409
|
() => this.state.kernelDepth--,
|
|
376
410
|
() => {
|
|
377
|
-
!this.ENV.getBool("DEBUG") && !this.state.profiling ?
|
|
411
|
+
!this.ENV.getBool("DEBUG") && !this.state.profiling ? t = d() : (g = this.profiler.profileKernel(a, h, () => d()), this.ENV.getBool("DEBUG") && this.profiler.logKernelProfile(g), t = g.outputs);
|
|
378
412
|
}
|
|
379
413
|
), n && this.addTapeNode(
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
414
|
+
a,
|
|
415
|
+
h,
|
|
416
|
+
t,
|
|
417
|
+
l,
|
|
384
418
|
s,
|
|
385
|
-
|
|
419
|
+
u ?? {}
|
|
386
420
|
), this.state.profiling && this.state.activeProfile.kernels.push({
|
|
387
|
-
name:
|
|
421
|
+
name: a,
|
|
388
422
|
bytesAdded: this.state.numBytes - r,
|
|
389
423
|
totalBytesSnapshot: this.state.numBytes,
|
|
390
424
|
tensorsAdded: this.state.numTensors - c,
|
|
391
425
|
totalTensorsSnapshot: this.state.numTensors,
|
|
392
|
-
inputShapes: Object.keys(
|
|
393
|
-
(f) =>
|
|
426
|
+
inputShapes: Object.keys(h).map(
|
|
427
|
+
(f) => h[f] != null ? h[f].shape : null
|
|
394
428
|
),
|
|
395
|
-
outputShapes:
|
|
396
|
-
kernelTimeMs:
|
|
397
|
-
extraInfo:
|
|
398
|
-
}), Array.isArray(
|
|
429
|
+
outputShapes: t.map((f) => f.shape),
|
|
430
|
+
kernelTimeMs: g.timeMs,
|
|
431
|
+
extraInfo: g.extraInfo
|
|
432
|
+
}), Array.isArray(o) ? t : t[0];
|
|
399
433
|
}
|
|
400
434
|
/**
|
|
401
435
|
* Saves tensors used in forward mode for use in backward mode.
|
|
402
436
|
*
|
|
403
437
|
* @param tensors the list of tensors to save.
|
|
404
438
|
*/
|
|
405
|
-
saveTensorsForBackwardMode(
|
|
406
|
-
return
|
|
439
|
+
saveTensorsForBackwardMode(e) {
|
|
440
|
+
return e.map((s) => this.keep(this.clone(s)));
|
|
407
441
|
}
|
|
408
442
|
/**
|
|
409
443
|
* Returns a list of tensors to save for a given gradient calculation.
|
|
@@ -412,14 +446,14 @@ class v {
|
|
|
412
446
|
* @param inputs a map of input tensors.
|
|
413
447
|
* @param outputs an array of output tensors from forward mode of kernel.
|
|
414
448
|
*/
|
|
415
|
-
getTensorsForGradient(
|
|
416
|
-
const n =
|
|
449
|
+
getTensorsForGradient(e, t, s) {
|
|
450
|
+
const n = z(e);
|
|
417
451
|
if (n != null) {
|
|
418
452
|
const r = n.inputsToSave || [], c = n.outputsToSave || [];
|
|
419
|
-
let
|
|
420
|
-
n.saveAllInputs ? (p(Array.isArray(
|
|
421
|
-
const
|
|
422
|
-
return
|
|
453
|
+
let d;
|
|
454
|
+
n.saveAllInputs ? (p(Array.isArray(t), () => "saveAllInputs is true, expected inputs to be an array."), d = Object.keys(t).map((a) => t[a])) : d = r.map((a) => t[a]);
|
|
455
|
+
const o = s.filter((a, h) => c[h]);
|
|
456
|
+
return d.concat(o);
|
|
423
457
|
}
|
|
424
458
|
return [];
|
|
425
459
|
}
|
|
@@ -428,18 +462,18 @@ class v {
|
|
|
428
462
|
* tensor with the provided shape, dtype and values. It always
|
|
429
463
|
* creates a new data id and writes the values to the underlying backend.
|
|
430
464
|
*/
|
|
431
|
-
makeTensor(
|
|
432
|
-
if (
|
|
465
|
+
makeTensor(e, t, s, n) {
|
|
466
|
+
if (e == null)
|
|
433
467
|
throw new Error("Values passed to engine.makeTensor() are null");
|
|
434
468
|
s = s || "float32", n = n || this.backend;
|
|
435
|
-
let r =
|
|
436
|
-
s === "string" &&
|
|
437
|
-
const c = n.write(r,
|
|
438
|
-
if (this.trackTensor(
|
|
439
|
-
const
|
|
440
|
-
this.state.numBytes +=
|
|
469
|
+
let r = e;
|
|
470
|
+
s === "string" && ee(e[0]) && (r = e.map((o) => te(o)));
|
|
471
|
+
const c = n.write(r, t, s), d = new j(t, s, c, this.nextTensorId());
|
|
472
|
+
if (this.trackTensor(d, n), s === "string") {
|
|
473
|
+
const o = this.state.tensorInfo.get(c), a = se(r);
|
|
474
|
+
this.state.numBytes += a - o.bytes, o.bytes = a;
|
|
441
475
|
}
|
|
442
|
-
return
|
|
476
|
+
return d;
|
|
443
477
|
}
|
|
444
478
|
/**
|
|
445
479
|
* Internal method used by backends. Makes a new tensor
|
|
@@ -447,9 +481,9 @@ class v {
|
|
|
447
481
|
* a new data id, only increments the ref count used in memory tracking.
|
|
448
482
|
* @deprecated
|
|
449
483
|
*/
|
|
450
|
-
makeTensorFromDataId(
|
|
484
|
+
makeTensorFromDataId(e, t, s, n) {
|
|
451
485
|
s = s || "float32";
|
|
452
|
-
const r = { dataId:
|
|
486
|
+
const r = { dataId: e, shape: t, dtype: s };
|
|
453
487
|
return this.makeTensorFromTensorInfo(r, n);
|
|
454
488
|
}
|
|
455
489
|
/**
|
|
@@ -457,69 +491,69 @@ class v {
|
|
|
457
491
|
* around an existing data id in TensorInfo. It doesn't create a new data id,
|
|
458
492
|
* only increments the ref count used in memory tracking.
|
|
459
493
|
*/
|
|
460
|
-
makeTensorFromTensorInfo(
|
|
461
|
-
const { dataId: s, shape: n, dtype: r } =
|
|
462
|
-
if (c.packed =
|
|
494
|
+
makeTensorFromTensorInfo(e, t) {
|
|
495
|
+
const { dataId: s, shape: n, dtype: r } = e, c = new j(n, r, s, this.nextTensorId());
|
|
496
|
+
if (c.packed = e.packed || !1, c.packed && r !== "int32")
|
|
463
497
|
throw new Error("Only int32 tensors can be packed.");
|
|
464
|
-
return this.trackTensor(c,
|
|
498
|
+
return this.trackTensor(c, t ?? this.backend), c;
|
|
465
499
|
}
|
|
466
|
-
makeVariable(
|
|
467
|
-
s = s || this.nextVariableId().toString(), n != null && n !==
|
|
468
|
-
const r = new
|
|
500
|
+
makeVariable(e, t = !0, s, n) {
|
|
501
|
+
s = s || this.nextVariableId().toString(), n != null && n !== e.dtype && (e = e.cast(n));
|
|
502
|
+
const r = new de(e, t, s, this.nextTensorId());
|
|
469
503
|
if (this.state.registeredVariables[r.name] != null)
|
|
470
504
|
throw new Error(`Variable with name ${r.name} was already registered`);
|
|
471
505
|
return this.state.registeredVariables[r.name] = r, this.incRef(r, this.backend), r;
|
|
472
506
|
}
|
|
473
|
-
trackTensor(
|
|
474
|
-
this.state.numTensors++,
|
|
507
|
+
trackTensor(e, t) {
|
|
508
|
+
this.state.numTensors++, e.dtype === "string" && this.state.numStringTensors++;
|
|
475
509
|
let s = 0;
|
|
476
|
-
|
|
477
|
-
backend:
|
|
478
|
-
dtype:
|
|
479
|
-
shape:
|
|
510
|
+
e.dtype !== "complex64" && e.dtype !== "string" && (s = e.size * C(e.dtype)), this.state.numBytes += s, this.state.tensorInfo.has(e.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(e.dataId, {
|
|
511
|
+
backend: t || this.backend,
|
|
512
|
+
dtype: e.dtype,
|
|
513
|
+
shape: e.shape,
|
|
480
514
|
bytes: s
|
|
481
|
-
})),
|
|
515
|
+
})), e instanceof ne || this.track(e);
|
|
482
516
|
}
|
|
483
517
|
// Track the tensor by dataId and increase the refCount for the dataId in the
|
|
484
518
|
// backend.
|
|
485
519
|
// TODO(pyu10055): This is currently used by makeVariable method, to increase
|
|
486
520
|
// refCount on the backend for the dataId. It can potentially be replaced with
|
|
487
521
|
// Identity op indead of calling backend directly.
|
|
488
|
-
incRef(
|
|
489
|
-
this.trackTensor(
|
|
522
|
+
incRef(e, t) {
|
|
523
|
+
this.trackTensor(e, t), this.backend.incRef(e.dataId);
|
|
490
524
|
}
|
|
491
|
-
removeDataId(
|
|
492
|
-
this.state.tensorInfo.has(
|
|
525
|
+
removeDataId(e, t) {
|
|
526
|
+
this.state.tensorInfo.has(e) && this.state.tensorInfo.get(e).backend === t && (this.state.tensorInfo.delete(e), this.state.numDataBuffers--);
|
|
493
527
|
}
|
|
494
|
-
disposeTensor(
|
|
495
|
-
if (!this.state.tensorInfo.has(
|
|
528
|
+
disposeTensor(e) {
|
|
529
|
+
if (!this.state.tensorInfo.has(e.dataId))
|
|
496
530
|
return;
|
|
497
|
-
const
|
|
498
|
-
if (this.state.numTensors--,
|
|
499
|
-
const s =
|
|
531
|
+
const t = this.state.tensorInfo.get(e.dataId);
|
|
532
|
+
if (this.state.numTensors--, e.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= t.bytes), e.dtype !== "complex64" && e.dtype !== "string") {
|
|
533
|
+
const s = e.size * C(e.dtype);
|
|
500
534
|
this.state.numBytes -= s;
|
|
501
535
|
}
|
|
502
|
-
|
|
536
|
+
t.backend.disposeData(e.dataId) && this.removeDataId(e.dataId, t.backend);
|
|
503
537
|
}
|
|
504
538
|
disposeVariables() {
|
|
505
|
-
for (const
|
|
506
|
-
const
|
|
507
|
-
this.disposeVariable(
|
|
539
|
+
for (const e in this.state.registeredVariables) {
|
|
540
|
+
const t = this.state.registeredVariables[e];
|
|
541
|
+
this.disposeVariable(t);
|
|
508
542
|
}
|
|
509
543
|
}
|
|
510
|
-
disposeVariable(
|
|
511
|
-
this.disposeTensor(
|
|
544
|
+
disposeVariable(e) {
|
|
545
|
+
this.disposeTensor(e), this.state.registeredVariables[e.name] != null && delete this.state.registeredVariables[e.name];
|
|
512
546
|
}
|
|
513
547
|
memory() {
|
|
514
|
-
const
|
|
515
|
-
return
|
|
548
|
+
const e = this.backend.memory();
|
|
549
|
+
return e.numTensors = this.state.numTensors, e.numDataBuffers = this.state.numDataBuffers, e.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (e.unreliable = !0, e.reasons == null && (e.reasons = []), e.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), e;
|
|
516
550
|
}
|
|
517
|
-
async profile(
|
|
551
|
+
async profile(e) {
|
|
518
552
|
this.state.profiling = !0;
|
|
519
|
-
const
|
|
520
|
-
this.state.activeProfile.kernels = [], this.state.activeProfile.result = await
|
|
553
|
+
const t = this.state.numBytes, s = this.state.numTensors;
|
|
554
|
+
this.state.activeProfile.kernels = [], this.state.activeProfile.result = await e(), this.state.profiling = !1, this.state.activeProfile.peakBytes = Math.max(
|
|
521
555
|
...this.state.activeProfile.kernels.map((n) => n.totalBytesSnapshot)
|
|
522
|
-
), this.state.activeProfile.newBytes = this.state.numBytes -
|
|
556
|
+
), this.state.activeProfile.newBytes = this.state.numBytes - t, this.state.activeProfile.newTensors = this.state.numTensors - s;
|
|
523
557
|
for (const n of this.state.activeProfile.kernels)
|
|
524
558
|
n.kernelTimeMs = await n.kernelTimeMs, n.extraInfo = await n.extraInfo;
|
|
525
559
|
return this.state.activeProfile;
|
|
@@ -527,18 +561,18 @@ class v {
|
|
|
527
561
|
isTapeOn() {
|
|
528
562
|
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
|
529
563
|
}
|
|
530
|
-
addTapeNode(
|
|
531
|
-
const
|
|
532
|
-
|
|
533
|
-
if (
|
|
534
|
-
const
|
|
535
|
-
return this.makeTensor(
|
|
564
|
+
addTapeNode(e, t, s, n, r, c) {
|
|
565
|
+
const d = { id: this.state.nextTapeNodeId++, kernelName: e, inputs: t, outputs: s, saved: r }, o = z(e);
|
|
566
|
+
o != null && (n = o.gradFunc), n != null && (d.gradient = (a) => (a = a.map((h, u) => {
|
|
567
|
+
if (h == null) {
|
|
568
|
+
const l = s[u], g = oe(l.size, l.dtype);
|
|
569
|
+
return this.makeTensor(g, l.shape, l.dtype);
|
|
536
570
|
}
|
|
537
|
-
return
|
|
538
|
-
}), n(
|
|
571
|
+
return h;
|
|
572
|
+
}), n(a.length > 1 ? a : a[0], r, c))), this.state.activeTape.push(d);
|
|
539
573
|
}
|
|
540
|
-
keep(
|
|
541
|
-
return
|
|
574
|
+
keep(e) {
|
|
575
|
+
return e.kept = !0, e;
|
|
542
576
|
}
|
|
543
577
|
startTape() {
|
|
544
578
|
this.state.gradientDepth === 0 && (this.state.activeTape = []), this.state.gradientDepth++;
|
|
@@ -550,26 +584,26 @@ class v {
|
|
|
550
584
|
* Start a scope. Use this with endScope() to achieve the same functionality
|
|
551
585
|
* as scope() without the need for a function closure.
|
|
552
586
|
*/
|
|
553
|
-
startScope(
|
|
554
|
-
const
|
|
587
|
+
startScope(e) {
|
|
588
|
+
const t = {
|
|
555
589
|
track: [],
|
|
556
590
|
name: "unnamed scope",
|
|
557
591
|
id: this.state.nextScopeId++
|
|
558
592
|
};
|
|
559
|
-
|
|
593
|
+
e && (t.name = e), this.state.scopeStack.push(t), this.state.activeScope = t;
|
|
560
594
|
}
|
|
561
595
|
/**
|
|
562
596
|
* End a scope. Use this with startScope() to achieve the same functionality
|
|
563
597
|
* as scope() without the need for a function closure.
|
|
564
598
|
*/
|
|
565
|
-
endScope(
|
|
566
|
-
const
|
|
599
|
+
endScope(e) {
|
|
600
|
+
const t = X(e), s = new Set(t.map((r) => r.id));
|
|
567
601
|
for (let r = 0; r < this.state.activeScope.track.length; r++) {
|
|
568
602
|
const c = this.state.activeScope.track[r];
|
|
569
603
|
!c.kept && !s.has(c.id) && c.dispose();
|
|
570
604
|
}
|
|
571
605
|
const n = this.state.scopeStack.pop();
|
|
572
|
-
this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1],
|
|
606
|
+
this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], t.forEach((r) => {
|
|
573
607
|
!r.kept && r.scopeId === n?.id && this.track(r);
|
|
574
608
|
});
|
|
575
609
|
}
|
|
@@ -579,68 +613,68 @@ class v {
|
|
|
579
613
|
* was not a function of that `x`. It also takes optional dy to multiply the
|
|
580
614
|
* gradient, which defaults to `1`.
|
|
581
615
|
*/
|
|
582
|
-
gradients(
|
|
583
|
-
if (p(
|
|
616
|
+
gradients(e, t, s, n = !1) {
|
|
617
|
+
if (p(t.length > 0, () => "gradients() received an empty list of xs."), s != null && s.dtype !== "float32")
|
|
584
618
|
throw new Error(`dy must have 'float32' dtype, but has '${s.dtype}'`);
|
|
585
619
|
const r = this.scopedRun(
|
|
586
620
|
() => this.startTape(),
|
|
587
621
|
() => this.endTape(),
|
|
588
|
-
() => this.tidy("forward",
|
|
622
|
+
() => this.tidy("forward", e)
|
|
589
623
|
);
|
|
590
624
|
p(r instanceof B, () => "The result y returned by f() must be a tensor.");
|
|
591
|
-
const c =
|
|
592
|
-
if (!n && c.length === 0 &&
|
|
625
|
+
const c = Y(this.state.activeTape, t, r);
|
|
626
|
+
if (!n && c.length === 0 && t.length > 0)
|
|
593
627
|
throw new Error(
|
|
594
628
|
"Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y."
|
|
595
629
|
);
|
|
596
630
|
return this.tidy("backward", () => {
|
|
597
|
-
const
|
|
598
|
-
|
|
599
|
-
|
|
631
|
+
const d = {};
|
|
632
|
+
d[r.id] = s ?? le(r.shape), he(
|
|
633
|
+
d,
|
|
600
634
|
c,
|
|
601
635
|
// Pass the tidy function to avoid circular dep with `tape.ts`.
|
|
602
|
-
(
|
|
636
|
+
(a) => this.tidy(a),
|
|
603
637
|
// Pass an add function to avoide a circular dep with `tape.ts`.
|
|
604
|
-
|
|
638
|
+
fe
|
|
605
639
|
);
|
|
606
|
-
const
|
|
607
|
-
return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((
|
|
608
|
-
if (
|
|
609
|
-
for (const
|
|
610
|
-
|
|
611
|
-
}), this.state.activeTape = null), { value: r, grads:
|
|
640
|
+
const o = t.map((a) => d[a.id]);
|
|
641
|
+
return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((a) => {
|
|
642
|
+
if (a.saved !== void 0)
|
|
643
|
+
for (const h of a.saved)
|
|
644
|
+
h.dispose();
|
|
645
|
+
}), this.state.activeTape = null), { value: r, grads: o };
|
|
612
646
|
});
|
|
613
647
|
}
|
|
614
|
-
customGrad(
|
|
615
|
-
return p(
|
|
648
|
+
customGrad(e) {
|
|
649
|
+
return p(G(e), () => "The f passed in customGrad(f) must be a function."), (...t) => {
|
|
616
650
|
p(
|
|
617
|
-
|
|
651
|
+
t.every((d) => d instanceof B),
|
|
618
652
|
() => "The args passed in customGrad(f)(x1, x2,...) must all be tensors"
|
|
619
653
|
);
|
|
620
654
|
let s;
|
|
621
655
|
const n = {};
|
|
622
|
-
|
|
623
|
-
n[
|
|
656
|
+
t.forEach((d, o) => {
|
|
657
|
+
n[o] = d;
|
|
624
658
|
});
|
|
625
|
-
const r = (
|
|
659
|
+
const r = (d, o) => (s = e(...t, o), p(
|
|
626
660
|
s.value instanceof B,
|
|
627
661
|
() => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"
|
|
628
662
|
), p(
|
|
629
|
-
|
|
663
|
+
G(s.gradFunc),
|
|
630
664
|
() => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."
|
|
631
|
-
), s.value), c = (
|
|
632
|
-
const
|
|
665
|
+
), s.value), c = (d, o) => {
|
|
666
|
+
const a = s.gradFunc(d, o), h = Array.isArray(a) ? a : [a];
|
|
633
667
|
p(
|
|
634
|
-
|
|
668
|
+
h.length === t.length,
|
|
635
669
|
() => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."
|
|
636
670
|
), p(
|
|
637
|
-
|
|
671
|
+
h.every((l) => l instanceof B),
|
|
638
672
|
() => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors."
|
|
639
673
|
);
|
|
640
|
-
const
|
|
641
|
-
return
|
|
642
|
-
|
|
643
|
-
}),
|
|
674
|
+
const u = {};
|
|
675
|
+
return h.forEach((l, g) => {
|
|
676
|
+
u[g] = () => l;
|
|
677
|
+
}), u;
|
|
644
678
|
};
|
|
645
679
|
return this.runKernelFunc({
|
|
646
680
|
forwardFunc: r,
|
|
@@ -649,18 +683,18 @@ class v {
|
|
|
649
683
|
});
|
|
650
684
|
};
|
|
651
685
|
}
|
|
652
|
-
readSync(
|
|
653
|
-
return this.state.tensorInfo.get(
|
|
686
|
+
readSync(e) {
|
|
687
|
+
return this.state.tensorInfo.get(e).backend.readSync(e);
|
|
654
688
|
}
|
|
655
|
-
read(
|
|
656
|
-
return this.state.tensorInfo.get(
|
|
689
|
+
read(e) {
|
|
690
|
+
return this.state.tensorInfo.get(e).backend.read(e);
|
|
657
691
|
}
|
|
658
|
-
readToGPU(
|
|
659
|
-
return this.state.tensorInfo.get(
|
|
692
|
+
readToGPU(e, t) {
|
|
693
|
+
return this.state.tensorInfo.get(e).backend.readToGPU(e, t);
|
|
660
694
|
}
|
|
661
|
-
async time(
|
|
662
|
-
const
|
|
663
|
-
return s.wallMs =
|
|
695
|
+
async time(e) {
|
|
696
|
+
const t = R(), s = await this.backend.time(e);
|
|
697
|
+
return s.wallMs = R() - t, s;
|
|
664
698
|
}
|
|
665
699
|
/**
|
|
666
700
|
* Tracks a Tensor in the current scope to be automatically cleaned up
|
|
@@ -668,8 +702,8 @@ class v {
|
|
|
668
702
|
*
|
|
669
703
|
* @param result The Tensor to track in the current scope.
|
|
670
704
|
*/
|
|
671
|
-
track(
|
|
672
|
-
return this.state.activeScope != null && (
|
|
705
|
+
track(e) {
|
|
706
|
+
return this.state.activeScope != null && (e.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(e)), e;
|
|
673
707
|
}
|
|
674
708
|
get registeredVariables() {
|
|
675
709
|
return this.state.registeredVariables;
|
|
@@ -679,38 +713,38 @@ class v {
|
|
|
679
713
|
* registered backend factories.
|
|
680
714
|
*/
|
|
681
715
|
reset() {
|
|
682
|
-
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new
|
|
683
|
-
for (const
|
|
684
|
-
this.disposeRegisteredKernels(
|
|
716
|
+
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new _();
|
|
717
|
+
for (const e in this.registry)
|
|
718
|
+
this.disposeRegisteredKernels(e), this.registry[e].dispose(), delete this.registry[e];
|
|
685
719
|
this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
|
|
686
720
|
}
|
|
687
721
|
}
|
|
688
|
-
function
|
|
689
|
-
const
|
|
690
|
-
return y.makeTensor(
|
|
722
|
+
function le(i) {
|
|
723
|
+
const e = re(ie(i), "float32");
|
|
724
|
+
return y.makeTensor(e, i, "float32");
|
|
691
725
|
}
|
|
692
|
-
function
|
|
693
|
-
const
|
|
694
|
-
if (
|
|
695
|
-
const
|
|
696
|
-
|
|
726
|
+
function ue() {
|
|
727
|
+
const i = U();
|
|
728
|
+
if (i._tfengine == null) {
|
|
729
|
+
const e = new L(i);
|
|
730
|
+
i._tfengine = new I(e);
|
|
697
731
|
}
|
|
698
|
-
return
|
|
732
|
+
return ae(i._tfengine.ENV), ce(() => i._tfengine), i._tfengine;
|
|
699
733
|
}
|
|
700
|
-
const y =
|
|
701
|
-
function
|
|
702
|
-
const
|
|
703
|
-
return y.runKernel(
|
|
734
|
+
const y = ue();
|
|
735
|
+
function fe(i, e) {
|
|
736
|
+
const t = T(i) || T(e), s = { a: i, b: e };
|
|
737
|
+
return y.runKernel(t ? "Add16" : J, s);
|
|
704
738
|
}
|
|
705
739
|
export {
|
|
706
|
-
|
|
740
|
+
I as E,
|
|
707
741
|
y as a,
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
742
|
+
he as b,
|
|
743
|
+
fe as c,
|
|
744
|
+
ue as g,
|
|
745
|
+
x as isPackableTensor,
|
|
746
|
+
T as isPackedTensor,
|
|
747
|
+
O as packTensor,
|
|
748
|
+
be as packingSupported,
|
|
749
|
+
we as unpackTensor
|
|
716
750
|
};
|