@genai-fi/nanogpt 0.10.2 → 0.10.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +11761 -171
- package/dist/{RealDiv-zz7FpkKX.js → RealDiv-KAPDe8zB.js} +23 -25
- package/dist/Reshape-BYkmUnAv.js +14 -0
- package/dist/{Reshape-CHdUjC72.js → Reshape-Zt6eb7yh.js} +18 -20
- package/dist/TeachableLLM.js +10 -11
- package/dist/{axis_util-BsIr9ZNu.js → axis_util-BaG7mf5A.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-B1XRLuq9.js → backend_util-RCe-rHaj.js} +72 -73
- package/dist/{backend_webgpu-CqpfEImu.js → backend_webgpu-DE3ACOLx.js} +45 -47
- package/dist/broadcast_to-B3eYlZm7.js +28 -0
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/matMulGelu.js +7 -11
- package/dist/checks/normRMS.js +9 -9
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/packUnpack.js +2 -2
- package/dist/checks/qkv.js +12 -13
- package/dist/checks/rope.js +2 -2
- package/dist/clip_by_value-BnO7-a88.js +12 -0
- package/dist/complex-DjxcVmoX.js +11 -0
- package/dist/concat-BV8bt5H-.js +17 -0
- package/dist/{concat_util-iBYIyuQe.js → concat_util-DpW8mL_l.js} +1 -1
- package/dist/{dataset-D2P7rHAw.js → dataset-BcwmTGYc.js} +137 -139
- package/dist/dropout-BcvN9JYi.js +92 -0
- package/dist/expand_dims-DT4tEPwA.js +11 -0
- package/dist/{exports_initializers-CZSUJoVE.js → exports_initializers-Hta_rEnm.js} +1 -1
- package/dist/floor-D5QdR_le.js +9 -0
- package/dist/gather-D3JcZUaI.js +9 -0
- package/dist/{gelu-Bmhopi0J.js → gelu-CjNPL4OH.js} +10 -11
- package/dist/{gpgpu_math-DsCcikas.js → gpgpu_math-DAOmgtXR.js} +841 -1015
- package/dist/{index-DRyE072i.js → index-BwexR4lA.js} +262 -263
- package/dist/index-DOvlwCh-.js +3520 -0
- package/dist/{kernel_funcs_utils-CWfOAPGO.js → kernel_funcs_utils-CCzYdUZg.js} +130 -132
- package/dist/layers/BaseLayer.js +15 -16
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +4 -4
- package/dist/layers/PositionEmbedding.js +7 -7
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +9 -9
- package/dist/layers/TiedEmbedding.js +6 -6
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +13 -14
- package/dist/log_sum_exp-ngO0-4pK.js +39 -0
- package/dist/main.js +49 -50
- package/dist/{matMul16-fEAJ4smh.js → matMul16-BWRSOCWB.js} +14 -15
- package/dist/matMulGelu-CzfgT6Wq.js +163 -0
- package/dist/mat_mul-SjpJRLyL.js +11 -0
- package/dist/mod-AnXEvvpo.js +11 -0
- package/dist/models/NanoGPTV1.js +2 -2
- package/dist/models/model.js +13 -14
- package/dist/ones-D2rT0xk2.js +14 -0
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/add16.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/concat16.js +2 -2
- package/dist/ops/cpu/adamAdjust.js +13 -14
- package/dist/ops/cpu/adamMoments.js +6 -7
- package/dist/ops/cpu/appendCache.js +7 -8
- package/dist/ops/cpu/attentionMask.js +7 -7
- package/dist/ops/cpu/fusedSoftmax.js +10 -11
- package/dist/ops/cpu/gatherSub.js +9 -10
- package/dist/ops/cpu/gelu.js +9 -10
- package/dist/ops/cpu/matMul16.js +6 -7
- package/dist/ops/cpu/matMulGelu.js +5 -6
- package/dist/ops/cpu/matMulMul.js +3 -4
- package/dist/ops/cpu/mulDropout.js +3 -4
- package/dist/ops/cpu/normRMS.js +10 -11
- package/dist/ops/cpu/qkv.js +8 -9
- package/dist/ops/cpu/rope.js +5 -6
- package/dist/ops/cpu/scatterSub.js +17 -19
- package/dist/ops/dot16.js +2 -2
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/add16.js +11 -12
- package/dist/ops/grads/attentionMask.js +5 -6
- package/dist/ops/grads/gelu.js +3 -4
- package/dist/ops/grads/matMul16.js +4 -5
- package/dist/ops/grads/matMulGelu.js +9 -10
- package/dist/ops/grads/normRMS.js +7 -8
- package/dist/ops/grads/pack16.js +4 -5
- package/dist/ops/grads/qkv.js +17 -19
- package/dist/ops/grads/rope.js +3 -5
- package/dist/ops/grads/softmax16.js +3 -4
- package/dist/ops/grads/unpack16.js +3 -4
- package/dist/ops/grads/utils.d.ts +1 -0
- package/dist/ops/grads/utils.js +8 -4
- package/dist/ops/matMul16.js +3 -3
- package/dist/ops/matMulGelu.js +2 -2
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mul16.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/pack16.js +3 -4
- package/dist/ops/qkv.js +4 -8
- package/dist/ops/reshape16.js +14 -16
- package/dist/ops/rope.d.ts +1 -1
- package/dist/ops/rope.js +3 -8
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/slice16.js +2 -2
- package/dist/ops/softmax16.js +5 -8
- package/dist/ops/sub16.js +1 -1
- package/dist/ops/sum16.js +2 -2
- package/dist/ops/transpose16.js +23 -24
- package/dist/ops/unpack16.js +2 -2
- package/dist/ops/webgl/adamAdjust.js +2 -3
- package/dist/ops/webgl/adamMoments.js +1 -2
- package/dist/ops/webgl/appendCache.js +1 -2
- package/dist/ops/webgl/attentionMask.js +4 -5
- package/dist/ops/webgl/fusedSoftmax.js +4 -6
- package/dist/ops/webgl/gatherSub.js +6 -7
- package/dist/ops/webgl/gelu.js +2 -3
- package/dist/ops/webgl/log.js +11 -12
- package/dist/ops/webgl/matMul16.js +10 -11
- package/dist/ops/webgl/matMulGelu.js +7 -111
- package/dist/ops/webgl/matMulMul.js +9 -10
- package/dist/ops/webgl/mulDropout.js +8 -9
- package/dist/ops/webgl/normRMS.js +2 -3
- package/dist/ops/webgl/qkv.js +5 -6
- package/dist/ops/webgl/rope.js +7 -8
- package/dist/ops/webgl/scatterSub.js +5 -6
- package/dist/ops/webgpu/adamAdjust.js +10 -12
- package/dist/ops/webgpu/adamMoments.js +8 -10
- package/dist/ops/webgpu/add16.js +8 -9
- package/dist/ops/webgpu/appendCache.js +23 -25
- package/dist/ops/webgpu/attentionMask.js +8 -10
- package/dist/ops/webgpu/attentionMask32_program.js +2 -2
- package/dist/ops/webgpu/concat16.js +12 -14
- package/dist/ops/webgpu/gatherSub.js +11 -13
- package/dist/ops/webgpu/gelu.js +28 -29
- package/dist/ops/webgpu/matMul16.js +26 -28
- package/dist/ops/webgpu/matMul16_program.js +4 -5
- package/dist/ops/webgpu/mul16.js +9 -10
- package/dist/ops/webgpu/normRMS.js +15 -17
- package/dist/ops/webgpu/normRMSGrad.js +21 -28
- package/dist/ops/webgpu/pack16.js +12 -13
- package/dist/ops/webgpu/pack16_program.js +2 -2
- package/dist/ops/webgpu/qkv.js +16 -18
- package/dist/ops/webgpu/rope.js +25 -27
- package/dist/ops/webgpu/scatterSub.js +7 -9
- package/dist/ops/webgpu/slice16.js +21 -23
- package/dist/ops/webgpu/softmax16.js +17 -19
- package/dist/ops/webgpu/softmax16_program.js +2 -2
- package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
- package/dist/ops/webgpu/softmax16grad.js +7 -8
- package/dist/ops/webgpu/sub16.js +7 -8
- package/dist/ops/webgpu/sum16.js +18 -20
- package/dist/ops/webgpu/transpose16.js +19 -20
- package/dist/ops/webgpu/transpose16_program.js +2 -2
- package/dist/ops/webgpu/transpose16_shared_program.js +11 -12
- package/dist/ops/webgpu/unpack16.js +3 -4
- package/dist/ops/webgpu/utils/binary_op.js +7 -8
- package/dist/ops/webgpu/utils/reductions.js +14 -22
- package/dist/ops-B5yanEdW.js +476 -0
- package/dist/pack16-nQ6JaLo-.js +39 -0
- package/dist/patches/webgpu_backend.js +19 -20
- package/dist/patches/webgpu_base.js +1 -1
- package/dist/patches/webgpu_program.js +21 -22
- package/dist/{random_width-BVV9HveY.js → random_width-or-CEftb.js} +2506 -2761
- package/dist/range-BklejeeW.js +10 -0
- package/dist/relu-CP0ZcxWO.js +9 -0
- package/dist/reshape-ByE68wS9.js +9 -0
- package/dist/resize_nearest_neighbor-B19mCEg2.js +175 -0
- package/dist/rope-Ir4mTyD1.js +24 -0
- package/dist/{scatter_nd_util-C7zXRT_h.js → scatter_nd_util-lvSiX8q4.js} +1 -1
- package/dist/selu_util-kbhpTdYD.js +44 -0
- package/dist/{shared-CHhxz-O5.js → shared-DT1TkE6w.js} +1 -1
- package/dist/{shared-D2NP_CpY.js → shared-dntlHIDQ.js} +343 -345
- package/dist/slice-BfEGSH82.js +12 -0
- package/dist/{slice_util-DyjSAD0u.js → slice_util-uTKwiEpW.js} +1 -1
- package/dist/{softmax-C9JQEtnO.js → softmax-CA5jFsLR.js} +4 -5
- package/dist/split-CVLc0w--.js +9 -0
- package/dist/squeeze-C7Z2srUo.js +10 -0
- package/dist/stack-Cf4n9h0N.js +11 -0
- package/dist/step-CINUs5QB.js +261 -0
- package/dist/sum-DWAtNGez.js +11 -0
- package/dist/tensor-DJoc7gJU.js +8 -0
- package/dist/tensor1d-D11P_7Dp.js +11 -0
- package/dist/{tensor2d-CSB4KOb0.js → tensor2d-Bs9wZRc7.js} +6 -7
- package/dist/{tensor4d-D7bLqGqz.js → tensor4d-BARPdTaS.js} +6 -7
- package/dist/{tfjs_backend-CNkSTL0c.js → tfjs_backend-y1cvNhLA.js} +255 -264
- package/dist/tile-mbfagpsB.js +11 -0
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +2 -2
- package/dist/training/sparseCrossEntropy.js +5 -5
- package/dist/transpose-ClWiBS_b.js +36 -0
- package/dist/unsorted_segment_sum-BDDhB_E6.js +277 -0
- package/dist/utilities/dummy.js +3 -3
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/packed.d.ts +1 -4
- package/dist/utilities/packed.js +10 -745
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/sentences.js +5 -5
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-DzfrwYuP.js → variable-WawDEaAb.js} +1 -1
- package/dist/{webgpu_program-DzaQiqel.js → webgpu_program-DuOXPQol.js} +178 -172
- package/dist/{webgpu_util-0_ubCEHJ.js → webgpu_util-RxEF33Rj.js} +34 -35
- package/dist/zeros-KnWaWf-X.js +13 -0
- package/dist/zeros_like-DvE73F4e.js +721 -0
- package/package.json +4 -2
- package/dist/Reshape-CDVLyVfz.js +0 -16
- package/dist/broadcast_to-B0ChcDaz.js +0 -30
- package/dist/complex-BBiRlsVq.js +0 -13
- package/dist/concat-DmBLPVGC.js +0 -19
- package/dist/dropout-B1x1kYMa.js +0 -99
- package/dist/expand_dims-ouvfxQ1n.js +0 -13
- package/dist/gather-CH9sdacz.js +0 -10
- package/dist/index-D6Q1lPZO.js +0 -2157
- package/dist/log_sum_exp-D3ftBNY5.js +0 -41
- package/dist/mat_mul-C59XWcJd.js +0 -12
- package/dist/mod-DESSvHIU.js +0 -12
- package/dist/mulmat_packed_gpu-Coh6qbJk.js +0 -55
- package/dist/ones-jU9jlQvM.js +0 -15
- package/dist/ops-BFDtP6th.js +0 -645
- package/dist/pack16-CmVZs6af.js +0 -41
- package/dist/patches/PackedTensor.d.ts +0 -12
- package/dist/patches/PackedTensor.js +0 -11
- package/dist/patches/engine.d.ts +0 -261
- package/dist/patches/engine.js +0 -12
- package/dist/patches/tape.d.ts +0 -12
- package/dist/patches/tape.js +0 -5
- package/dist/range-ZZZD60Fx.js +0 -11
- package/dist/reciprocal-CrYlsAGD.js +0 -10
- package/dist/register_all_kernels-nvj2k7OC.js +0 -12307
- package/dist/relu-BYDneVPn.js +0 -10
- package/dist/reshape-CaPQzFvz.js +0 -10
- package/dist/rope-s4W2XO9B.js +0 -32
- package/dist/selu_util-BGPXmd4B.js +0 -303
- package/dist/sin-Djs4aQiu.js +0 -16
- package/dist/slice-DvovR5wq.js +0 -13
- package/dist/split-DBck65sX.js +0 -10
- package/dist/squeeze-C00Ipm_7.js +0 -11
- package/dist/stack-ChnHwRpX.js +0 -13
- package/dist/sum-ywRJj3Zr.js +0 -12
- package/dist/tensor-0r5yOo2R.js +0 -8
- package/dist/tensor-CzmOBsdf.js +0 -909
- package/dist/tensor1d-BlUT89BP.js +0 -12
- package/dist/tensor_util-DfwaWayG.js +0 -523
- package/dist/tile-CR074jmp.js +0 -13
- package/dist/transpose-DH4gmHvu.js +0 -38
- package/dist/zeros-DBFVbpv5.js +0 -14
package/dist/utilities/packed.js
CHANGED
|
@@ -1,750 +1,15 @@
|
|
|
1
|
-
import {
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
import { b as Q, E as L, a as p, o as ee, q as te, r as se, t as C, V as ne, u as G, T as B, v as R, m as re, s as ie, w as ae, f as oe, x as ce } from "../tensor-CzmOBsdf.js";
|
|
5
|
-
import { PackableTensor as j, PackableVariable as de } from "../patches/PackedTensor.js";
|
|
6
|
-
function be() {
|
|
7
|
-
return y.backendName === "webgpu";
|
|
1
|
+
import { e as n } from "../index-DOvlwCh-.js";
|
|
2
|
+
function o() {
|
|
3
|
+
return n().backendName === "webgpu";
|
|
8
4
|
}
|
|
9
|
-
function
|
|
10
|
-
return
|
|
5
|
+
function r(e) {
|
|
6
|
+
return e.dtype === "packedF16";
|
|
11
7
|
}
|
|
12
|
-
function
|
|
13
|
-
return
|
|
14
|
-
}
|
|
15
|
-
function O(i) {
|
|
16
|
-
if (x(i)) {
|
|
17
|
-
if (i.dtype !== "int32")
|
|
18
|
-
throw new Error("packTensor: only int32 tensors can be packed.");
|
|
19
|
-
return i.packed = !0, i;
|
|
20
|
-
} else
|
|
21
|
-
throw console.error("Tensor:", i), new Error("Tensor is not packable");
|
|
22
|
-
}
|
|
23
|
-
function we(i) {
|
|
24
|
-
if (x(i)) {
|
|
25
|
-
if (i.dtype !== "float32")
|
|
26
|
-
throw new Error("unpackTensor: only float32 tensors can be unpacked.");
|
|
27
|
-
i.packed = !1;
|
|
28
|
-
}
|
|
29
|
-
return i;
|
|
30
|
-
}
|
|
31
|
-
function he(i, e, t, s) {
|
|
32
|
-
for (let n = e.length - 1; n >= 0; n--) {
|
|
33
|
-
const r = e[n], c = [];
|
|
34
|
-
if (r.outputs.forEach((o) => {
|
|
35
|
-
const a = i[o.id];
|
|
36
|
-
a != null ? c.push(a) : c.push(null);
|
|
37
|
-
}), r.gradient == null)
|
|
38
|
-
throw new Error(`Cannot compute gradient: gradient function not found for ${r.kernelName}.`);
|
|
39
|
-
const d = r.gradient(c);
|
|
40
|
-
for (const o in r.inputs) {
|
|
41
|
-
if (!(o in d))
|
|
42
|
-
throw new Error(
|
|
43
|
-
`Cannot backprop through input ${o}. Available gradients found: ${Object.keys(d)}.`
|
|
44
|
-
);
|
|
45
|
-
const a = t(() => d[o]()), h = T(a);
|
|
46
|
-
if (a.dtype !== "float32" && (!h || a.dtype !== "int32"))
|
|
47
|
-
throw new Error(
|
|
48
|
-
`Error in gradient for op ${r.kernelName}. The gradient of input ${o} must have 'float32' dtype, but has '${a.dtype}'`
|
|
49
|
-
);
|
|
50
|
-
const u = r.inputs[o];
|
|
51
|
-
if (!Q(a.shape, u.shape))
|
|
52
|
-
throw new Error(
|
|
53
|
-
`Error in gradient for op ${r.kernelName}. The gradient of input '${o}' has shape '${a.shape}', which does not match the shape of the input '${u.shape}'`
|
|
54
|
-
);
|
|
55
|
-
if (i[u.id] == null)
|
|
56
|
-
i[u.id] = a;
|
|
57
|
-
else {
|
|
58
|
-
const l = i[u.id];
|
|
59
|
-
i[u.id] = s(l, a), l.dispose();
|
|
60
|
-
}
|
|
61
|
-
}
|
|
62
|
-
}
|
|
63
|
-
}
|
|
64
|
-
let S;
|
|
65
|
-
function U() {
|
|
66
|
-
if (S == null) {
|
|
67
|
-
let i;
|
|
68
|
-
if (typeof window < "u")
|
|
69
|
-
i = window;
|
|
70
|
-
else if (typeof $ < "u")
|
|
71
|
-
i = $;
|
|
72
|
-
else if (typeof K < "u")
|
|
73
|
-
i = K;
|
|
74
|
-
else if (typeof self < "u")
|
|
75
|
-
i = self;
|
|
76
|
-
else
|
|
77
|
-
throw new Error("Could not find a global object");
|
|
78
|
-
S = i;
|
|
79
|
-
}
|
|
80
|
-
return S;
|
|
81
|
-
}
|
|
82
|
-
const E = {
|
|
83
|
-
engine: null
|
|
84
|
-
}, D = U();
|
|
85
|
-
if (D._tfengine)
|
|
86
|
-
throw new Error("TensorFlow engine already initialized before patching.");
|
|
87
|
-
Object.defineProperty(D, "_tfengine", {
|
|
88
|
-
get: () => {
|
|
89
|
-
if (E.engine == null) {
|
|
90
|
-
const i = new L(D);
|
|
91
|
-
E.engine = new I(i);
|
|
92
|
-
}
|
|
93
|
-
return E.engine;
|
|
94
|
-
}
|
|
95
|
-
});
|
|
96
|
-
function F(i) {
|
|
97
|
-
return i.kernelName != null;
|
|
98
|
-
}
|
|
99
|
-
class _ {
|
|
100
|
-
// Public since optimizers will use it.
|
|
101
|
-
registeredVariables = {};
|
|
102
|
-
nextTapeNodeId = 0;
|
|
103
|
-
numBytes = 0;
|
|
104
|
-
numTensors = 0;
|
|
105
|
-
numStringTensors = 0;
|
|
106
|
-
numDataBuffers = 0;
|
|
107
|
-
activeTape;
|
|
108
|
-
// Number of nested tf.grad() statements when computing higher-order
|
|
109
|
-
// gradients. E.g. `1` for first-order gradients and `2` for second-order
|
|
110
|
-
// gradients. Used to track if the tape should be removed after a backprop.
|
|
111
|
-
gradientDepth = 0;
|
|
112
|
-
// Number of nested kernel calls. When kernel depth is greater than 1, we turn
|
|
113
|
-
// off the tape.
|
|
114
|
-
kernelDepth = 0;
|
|
115
|
-
// Keep Tensors that parallel the tapes.
|
|
116
|
-
activeScope;
|
|
117
|
-
scopeStack = [];
|
|
118
|
-
/**
|
|
119
|
-
* Keeps track of the number of data moves during a kernel execution. We
|
|
120
|
-
* maintain a stack since kernels can call other kernels, recursively.
|
|
121
|
-
*/
|
|
122
|
-
numDataMovesStack = [];
|
|
123
|
-
nextScopeId = 0;
|
|
124
|
-
tensorInfo = /* @__PURE__ */ new WeakMap();
|
|
125
|
-
profiling = !1;
|
|
126
|
-
activeProfile = {
|
|
127
|
-
newBytes: 0,
|
|
128
|
-
newTensors: 0,
|
|
129
|
-
peakBytes: 0,
|
|
130
|
-
kernels: [],
|
|
131
|
-
result: null,
|
|
132
|
-
get kernelNames() {
|
|
133
|
-
return Array.from(new Set(this.kernels.map((e) => e.name)));
|
|
134
|
-
}
|
|
135
|
-
};
|
|
136
|
-
dispose() {
|
|
137
|
-
for (const e in this.registeredVariables)
|
|
138
|
-
this.registeredVariables[e].dispose();
|
|
139
|
-
}
|
|
140
|
-
}
|
|
141
|
-
class I {
|
|
142
|
-
constructor(e) {
|
|
143
|
-
this.ENV = e, this.state = new _(), console.log("GenAI Patched Engine Initialized");
|
|
144
|
-
}
|
|
145
|
-
version = "GENAI_PATCHED_ENGINE";
|
|
146
|
-
state;
|
|
147
|
-
backendName;
|
|
148
|
-
registry = {};
|
|
149
|
-
registryFactory = {};
|
|
150
|
-
profiler;
|
|
151
|
-
backendInstance = null;
|
|
152
|
-
pendingBackendInit;
|
|
153
|
-
pendingBackendInitId = 0;
|
|
154
|
-
async ready() {
|
|
155
|
-
if (this.pendingBackendInit != null)
|
|
156
|
-
return this.pendingBackendInit.then(() => {
|
|
157
|
-
});
|
|
158
|
-
if (this.backendInstance != null)
|
|
159
|
-
return;
|
|
160
|
-
const e = this.getSortedBackends();
|
|
161
|
-
for (let t = 0; t < e.length; t++) {
|
|
162
|
-
const s = e[t];
|
|
163
|
-
if (await this.initializeBackend(s).success) {
|
|
164
|
-
await this.setBackend(s);
|
|
165
|
-
return;
|
|
166
|
-
}
|
|
167
|
-
}
|
|
168
|
-
throw new Error("Could not initialize any backends, all backend initializations failed.");
|
|
169
|
-
}
|
|
170
|
-
get backend() {
|
|
171
|
-
if (this.pendingBackendInit != null)
|
|
172
|
-
throw new Error(
|
|
173
|
-
`Backend '${this.backendName}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
|
|
174
|
-
);
|
|
175
|
-
if (this.backendInstance == null) {
|
|
176
|
-
const { name: e, asyncInit: t } = this.initializeBackendsAndReturnBest();
|
|
177
|
-
if (t)
|
|
178
|
-
throw new Error(
|
|
179
|
-
`The highest priority backend '${e}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
|
|
180
|
-
);
|
|
181
|
-
this.setBackend(e);
|
|
182
|
-
}
|
|
183
|
-
return this.backendInstance;
|
|
184
|
-
}
|
|
185
|
-
backendNames() {
|
|
186
|
-
return Object.keys(this.registryFactory);
|
|
187
|
-
}
|
|
188
|
-
findBackend(e) {
|
|
189
|
-
if (!(e in this.registry))
|
|
190
|
-
if (e in this.registryFactory) {
|
|
191
|
-
const { asyncInit: t } = this.initializeBackend(e);
|
|
192
|
-
if (t)
|
|
193
|
-
return null;
|
|
194
|
-
} else
|
|
195
|
-
return null;
|
|
196
|
-
return this.registry[e];
|
|
197
|
-
}
|
|
198
|
-
findBackendFactory(e) {
|
|
199
|
-
return e in this.registryFactory ? this.registryFactory[e].factory : null;
|
|
200
|
-
}
|
|
201
|
-
registerBackend(e, t, s = 1) {
|
|
202
|
-
return e in this.registryFactory ? (w(`${e} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[e] = { factory: t, priority: s }, console.log("Registered backend", e), !0);
|
|
203
|
-
}
|
|
204
|
-
async setBackend(e) {
|
|
205
|
-
if (this.registryFactory[e] == null)
|
|
206
|
-
throw new Error(`Backend name '${e}' not found in registry`);
|
|
207
|
-
if (this.backendName = e, this.registry[e] == null) {
|
|
208
|
-
this.backendInstance = null;
|
|
209
|
-
const { success: t, asyncInit: s } = this.initializeBackend(e);
|
|
210
|
-
if (!(s ? await t : t))
|
|
211
|
-
return !1;
|
|
212
|
-
}
|
|
213
|
-
return this.backendInstance = this.registry[e], this.setupRegisteredKernels(), this.profiler = new W(this.backendInstance), !0;
|
|
214
|
-
}
|
|
215
|
-
setupRegisteredKernels() {
|
|
216
|
-
N(this.backendName).forEach((t) => {
|
|
217
|
-
t.setupFunc != null && t.setupFunc(this.backendInstance);
|
|
218
|
-
});
|
|
219
|
-
}
|
|
220
|
-
disposeRegisteredKernels(e) {
|
|
221
|
-
N(e).forEach((s) => {
|
|
222
|
-
s.disposeFunc != null && s.disposeFunc(this.registry[e]);
|
|
223
|
-
});
|
|
224
|
-
}
|
|
225
|
-
/**
|
|
226
|
-
* Initializes a backend by looking up the backend name in the factory
|
|
227
|
-
* registry and calling the factory method. Returns a boolean representing
|
|
228
|
-
* whether the initialization of the backend succeeded. Throws an error if
|
|
229
|
-
* there is no backend in the factory registry.
|
|
230
|
-
*/
|
|
231
|
-
initializeBackend(e) {
|
|
232
|
-
const t = this.registryFactory[e];
|
|
233
|
-
if (t == null)
|
|
234
|
-
throw new Error(`Cannot initialize backend ${e}, no registration found.`);
|
|
235
|
-
try {
|
|
236
|
-
const s = t.factory();
|
|
237
|
-
if (s && !(s instanceof H) && typeof s.then == "function") {
|
|
238
|
-
const n = ++this.pendingBackendInitId, r = s.then((c) => n < this.pendingBackendInitId ? !1 : (this.registry[e] = c, this.pendingBackendInit = null, !0)).catch((c) => (n < this.pendingBackendInitId || (this.pendingBackendInit = null, w(`Initialization of backend ${e} failed`), w(c.stack || c.message)), !1));
|
|
239
|
-
return this.pendingBackendInit = r, { success: r, asyncInit: !0 };
|
|
240
|
-
} else
|
|
241
|
-
return this.registry[e] = s, { success: !0, asyncInit: !1 };
|
|
242
|
-
} catch (s) {
|
|
243
|
-
return w(`Initialization of backend ${e} failed`), w(s.stack || s.message), { success: !1, asyncInit: !1 };
|
|
244
|
-
}
|
|
245
|
-
}
|
|
246
|
-
removeBackend(e) {
|
|
247
|
-
if (!(e in this.registryFactory))
|
|
248
|
-
throw new Error(`${e} backend not found in registry`);
|
|
249
|
-
this.backendName === e && this.pendingBackendInit != null && this.pendingBackendInitId++, e in this.registry && (this.disposeRegisteredKernels(e), this.registry[e].dispose(), delete this.registry[e]), delete this.registryFactory[e], this.backendName === e && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null);
|
|
250
|
-
}
|
|
251
|
-
getSortedBackends() {
|
|
252
|
-
if (Object.keys(this.registryFactory).length === 0)
|
|
253
|
-
throw new Error("No backend found in registry.");
|
|
254
|
-
return Object.keys(this.registryFactory).sort((e, t) => this.registryFactory[t].priority - this.registryFactory[e].priority);
|
|
255
|
-
}
|
|
256
|
-
initializeBackendsAndReturnBest() {
|
|
257
|
-
const e = this.getSortedBackends();
|
|
258
|
-
for (let t = 0; t < e.length; t++) {
|
|
259
|
-
const s = e[t], { success: n, asyncInit: r } = this.initializeBackend(s);
|
|
260
|
-
if (r || n)
|
|
261
|
-
return { name: s, asyncInit: r };
|
|
262
|
-
}
|
|
263
|
-
throw new Error("Could not initialize any backends, all backend initializations failed.");
|
|
264
|
-
}
|
|
265
|
-
moveData(e, t) {
|
|
266
|
-
const s = this.state.tensorInfo.get(t);
|
|
267
|
-
s || console.warn("Tried to move data that does not exist", this.state, t);
|
|
268
|
-
const n = s.backend, r = this.readSync(t), c = n.refCount(t);
|
|
269
|
-
n.disposeData(t, !0), s.backend = e, e.move(t, r, s.shape, s.dtype, c), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
|
|
270
|
-
}
|
|
271
|
-
tidy(e, t) {
|
|
272
|
-
let s = null;
|
|
273
|
-
if (t == null) {
|
|
274
|
-
if (typeof e != "function")
|
|
275
|
-
throw new Error("Please provide a function to tidy()");
|
|
276
|
-
t = e;
|
|
277
|
-
} else {
|
|
278
|
-
if (typeof e != "string" && !(e instanceof String))
|
|
279
|
-
throw new Error("When calling with two arguments, the first argument to tidy() must be a string");
|
|
280
|
-
if (typeof t != "function")
|
|
281
|
-
throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function");
|
|
282
|
-
s = e;
|
|
283
|
-
}
|
|
284
|
-
let n;
|
|
285
|
-
return this.scopedRun(
|
|
286
|
-
() => this.startScope(s),
|
|
287
|
-
() => this.endScope(n),
|
|
288
|
-
() => (n = t(), n instanceof Promise && console.error("Cannot return a Promise inside of tidy."), n)
|
|
289
|
-
);
|
|
290
|
-
}
|
|
291
|
-
scopedRun(e, t, s) {
|
|
292
|
-
e();
|
|
293
|
-
try {
|
|
294
|
-
const n = s();
|
|
295
|
-
return t(), n;
|
|
296
|
-
} catch (n) {
|
|
297
|
-
throw t(), n;
|
|
298
|
-
}
|
|
299
|
-
}
|
|
300
|
-
static nextTensorId = 0;
|
|
301
|
-
nextTensorId() {
|
|
302
|
-
return I.nextTensorId++;
|
|
303
|
-
}
|
|
304
|
-
static nextVariableId = 0;
|
|
305
|
-
nextVariableId() {
|
|
306
|
-
return I.nextVariableId++;
|
|
307
|
-
}
|
|
308
|
-
/**
|
|
309
|
-
* This method is called instead of the public-facing tensor.clone() when
|
|
310
|
-
* saving a tensor for backwards pass. It makes sure to add the clone
|
|
311
|
-
* operation to the tape regardless of being called inside a kernel
|
|
312
|
-
* execution.
|
|
313
|
-
*/
|
|
314
|
-
clone(e) {
|
|
315
|
-
const s = T(e) ? O(y.runKernel(P, { x: e })) : y.runKernel(P, { x: e }), n = { x: e }, r = (d) => ({
|
|
316
|
-
x: () => {
|
|
317
|
-
const o = "float32", a = { x: d }, h = { dtype: o }, u = T(e), l = y.runKernel(
|
|
318
|
-
Z,
|
|
319
|
-
a,
|
|
320
|
-
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
321
|
-
h
|
|
322
|
-
);
|
|
323
|
-
return u && O(l), l;
|
|
324
|
-
}
|
|
325
|
-
}), c = [];
|
|
326
|
-
return this.addTapeNode(this.state.activeScope.name, n, [s], r, c, {}), s;
|
|
327
|
-
}
|
|
328
|
-
/**
|
|
329
|
-
* Execute a kernel with the given name and return the output tensor.
|
|
330
|
-
*
|
|
331
|
-
* @param kernelName The name of the kernel to execute.
|
|
332
|
-
* @param inputs A map of input names to tensors.
|
|
333
|
-
* @param attrs A map of attribute names to their values. An attribute is a
|
|
334
|
-
* primitive (non-tensor) input to the kernel.
|
|
335
|
-
* @param inputsToSave A list of tensors, inputs to save for the backprop
|
|
336
|
-
* computation.
|
|
337
|
-
* @param outputsToSave A list of booleans, specifying which output to save
|
|
338
|
-
* for the backprop computation. These are booleans since the output
|
|
339
|
-
* tensors are not visible to the user.
|
|
340
|
-
*/
|
|
341
|
-
runKernel(e, t, s) {
|
|
342
|
-
if (this.backendName == null && this.backend, !(A(e, this.backendName) != null))
|
|
343
|
-
throw new Error(`Kernel '${e}' not registered for backend '${this.backendName}'`);
|
|
344
|
-
return this.runKernelFunc({ kernelName: e, inputs: t, attrs: s });
|
|
345
|
-
}
|
|
346
|
-
shouldCheckForMemLeaks() {
|
|
347
|
-
return this.ENV.getBool("IS_TEST");
|
|
348
|
-
}
|
|
349
|
-
checkKernelForMemLeak(e, t, s) {
|
|
350
|
-
const n = this.backend.numDataIds();
|
|
351
|
-
let r = 0;
|
|
352
|
-
s.forEach((o) => {
|
|
353
|
-
r += o.dtype === "complex64" ? 3 : 1;
|
|
354
|
-
});
|
|
355
|
-
const c = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], d = n - t - r - c;
|
|
356
|
-
if (d > 0)
|
|
357
|
-
throw new Error(
|
|
358
|
-
`Backend '${this.backendName}' has an internal memory leak (${d} data ids) after running '${e}'`
|
|
359
|
-
);
|
|
360
|
-
}
|
|
361
|
-
/**
|
|
362
|
-
* Internal helper method to execute a kernel Func
|
|
363
|
-
*
|
|
364
|
-
* Use `runKernel` to execute kernels from outside of engine.
|
|
365
|
-
*/
|
|
366
|
-
runKernelFunc(e) {
|
|
367
|
-
let t, s = [];
|
|
368
|
-
const n = this.isTapeOn(), r = this.state.numBytes, c = this.state.numTensors;
|
|
369
|
-
this.shouldCheckForMemLeaks() && this.state.numDataMovesStack.push(0);
|
|
370
|
-
let d;
|
|
371
|
-
this.backendName == null && this.backend;
|
|
372
|
-
let o;
|
|
373
|
-
const a = F(e) ? e.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
|
|
374
|
-
if (F(e)) {
|
|
375
|
-
const { kernelName: f, inputs: v, attrs: m } = e;
|
|
376
|
-
this.backendName == null && this.backend;
|
|
377
|
-
const k = A(f, this.backendName);
|
|
378
|
-
p(
|
|
379
|
-
k != null,
|
|
380
|
-
() => `Cannot find registered kernel '${f}' for backend '${this.backendName}'`
|
|
381
|
-
), d = () => {
|
|
382
|
-
const q = this.backend.numDataIds();
|
|
383
|
-
o = k.kernelFunc({ inputs: v, attrs: m, backend: this.backend });
|
|
384
|
-
const M = Array.isArray(o) ? o : [o];
|
|
385
|
-
this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(f, q, M);
|
|
386
|
-
const V = M.map((b) => b.rank != null ? b : this.makeTensorFromTensorInfo(b));
|
|
387
|
-
if (n) {
|
|
388
|
-
const b = this.getTensorsForGradient(f, v, V);
|
|
389
|
-
s = this.saveTensorsForBackwardMode(b ?? []);
|
|
390
|
-
}
|
|
391
|
-
return V;
|
|
392
|
-
};
|
|
393
|
-
} else {
|
|
394
|
-
const { forwardFunc: f } = e, v = (m) => {
|
|
395
|
-
n && (s = m.map((k) => this.keep(this.clone(k))));
|
|
396
|
-
};
|
|
397
|
-
d = () => {
|
|
398
|
-
const m = this.backend.numDataIds();
|
|
399
|
-
o = this.tidy(() => f(this.backend, v));
|
|
400
|
-
const k = Array.isArray(o) ? o : [o];
|
|
401
|
-
return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(a, m, k), k;
|
|
402
|
-
};
|
|
403
|
-
}
|
|
404
|
-
const { inputs: h, attrs: u } = e, l = F(e) ? null : e.backwardsFunc;
|
|
405
|
-
let g;
|
|
406
|
-
return this.scopedRun(
|
|
407
|
-
// Stop recording to a tape when running a kernel.
|
|
408
|
-
() => this.state.kernelDepth++,
|
|
409
|
-
() => this.state.kernelDepth--,
|
|
410
|
-
() => {
|
|
411
|
-
!this.ENV.getBool("DEBUG") && !this.state.profiling ? t = d() : (g = this.profiler.profileKernel(a, h, () => d()), this.ENV.getBool("DEBUG") && this.profiler.logKernelProfile(g), t = g.outputs);
|
|
412
|
-
}
|
|
413
|
-
), n && this.addTapeNode(
|
|
414
|
-
a,
|
|
415
|
-
h,
|
|
416
|
-
t,
|
|
417
|
-
l,
|
|
418
|
-
s,
|
|
419
|
-
u ?? {}
|
|
420
|
-
), this.state.profiling && this.state.activeProfile.kernels.push({
|
|
421
|
-
name: a,
|
|
422
|
-
bytesAdded: this.state.numBytes - r,
|
|
423
|
-
totalBytesSnapshot: this.state.numBytes,
|
|
424
|
-
tensorsAdded: this.state.numTensors - c,
|
|
425
|
-
totalTensorsSnapshot: this.state.numTensors,
|
|
426
|
-
inputShapes: Object.keys(h).map(
|
|
427
|
-
(f) => h[f] != null ? h[f].shape : null
|
|
428
|
-
),
|
|
429
|
-
outputShapes: t.map((f) => f.shape),
|
|
430
|
-
kernelTimeMs: g.timeMs,
|
|
431
|
-
extraInfo: g.extraInfo
|
|
432
|
-
}), Array.isArray(o) ? t : t[0];
|
|
433
|
-
}
|
|
434
|
-
/**
|
|
435
|
-
* Saves tensors used in forward mode for use in backward mode.
|
|
436
|
-
*
|
|
437
|
-
* @param tensors the list of tensors to save.
|
|
438
|
-
*/
|
|
439
|
-
saveTensorsForBackwardMode(e) {
|
|
440
|
-
return e.map((s) => this.keep(this.clone(s)));
|
|
441
|
-
}
|
|
442
|
-
/**
|
|
443
|
-
* Returns a list of tensors to save for a given gradient calculation.
|
|
444
|
-
*
|
|
445
|
-
* @param kernelName name of kernel to look up gradient for.
|
|
446
|
-
* @param inputs a map of input tensors.
|
|
447
|
-
* @param outputs an array of output tensors from forward mode of kernel.
|
|
448
|
-
*/
|
|
449
|
-
getTensorsForGradient(e, t, s) {
|
|
450
|
-
const n = z(e);
|
|
451
|
-
if (n != null) {
|
|
452
|
-
const r = n.inputsToSave || [], c = n.outputsToSave || [];
|
|
453
|
-
let d;
|
|
454
|
-
n.saveAllInputs ? (p(Array.isArray(t), () => "saveAllInputs is true, expected inputs to be an array."), d = Object.keys(t).map((a) => t[a])) : d = r.map((a) => t[a]);
|
|
455
|
-
const o = s.filter((a, h) => c[h]);
|
|
456
|
-
return d.concat(o);
|
|
457
|
-
}
|
|
458
|
-
return [];
|
|
459
|
-
}
|
|
460
|
-
/**
|
|
461
|
-
* Internal method used by public APIs for tensor creation. Makes a new
|
|
462
|
-
* tensor with the provided shape, dtype and values. It always
|
|
463
|
-
* creates a new data id and writes the values to the underlying backend.
|
|
464
|
-
*/
|
|
465
|
-
makeTensor(e, t, s, n) {
|
|
466
|
-
if (e == null)
|
|
467
|
-
throw new Error("Values passed to engine.makeTensor() are null");
|
|
468
|
-
s = s || "float32", n = n || this.backend;
|
|
469
|
-
let r = e;
|
|
470
|
-
s === "string" && ee(e[0]) && (r = e.map((o) => te(o)));
|
|
471
|
-
const c = n.write(r, t, s), d = new j(t, s, c, this.nextTensorId());
|
|
472
|
-
if (this.trackTensor(d, n), s === "string") {
|
|
473
|
-
const o = this.state.tensorInfo.get(c), a = se(r);
|
|
474
|
-
this.state.numBytes += a - o.bytes, o.bytes = a;
|
|
475
|
-
}
|
|
476
|
-
return d;
|
|
477
|
-
}
|
|
478
|
-
/**
|
|
479
|
-
* Internal method used by backends. Makes a new tensor
|
|
480
|
-
* that is a wrapper around an existing data id. It doesn't create
|
|
481
|
-
* a new data id, only increments the ref count used in memory tracking.
|
|
482
|
-
* @deprecated
|
|
483
|
-
*/
|
|
484
|
-
makeTensorFromDataId(e, t, s, n) {
|
|
485
|
-
s = s || "float32";
|
|
486
|
-
const r = { dataId: e, shape: t, dtype: s };
|
|
487
|
-
return this.makeTensorFromTensorInfo(r, n);
|
|
488
|
-
}
|
|
489
|
-
/**
|
|
490
|
-
* Internal method used by backends. Makes a new tensor that is a wrapper
|
|
491
|
-
* around an existing data id in TensorInfo. It doesn't create a new data id,
|
|
492
|
-
* only increments the ref count used in memory tracking.
|
|
493
|
-
*/
|
|
494
|
-
makeTensorFromTensorInfo(e, t) {
|
|
495
|
-
const { dataId: s, shape: n, dtype: r } = e, c = new j(n, r, s, this.nextTensorId());
|
|
496
|
-
if (c.packed = e.packed || !1, c.packed && r !== "int32")
|
|
497
|
-
throw new Error("Only int32 tensors can be packed.");
|
|
498
|
-
return this.trackTensor(c, t ?? this.backend), c;
|
|
499
|
-
}
|
|
500
|
-
makeVariable(e, t = !0, s, n) {
|
|
501
|
-
s = s || this.nextVariableId().toString(), n != null && n !== e.dtype && (e = e.cast(n));
|
|
502
|
-
const r = new de(e, t, s, this.nextTensorId());
|
|
503
|
-
if (this.state.registeredVariables[r.name] != null)
|
|
504
|
-
throw new Error(`Variable with name ${r.name} was already registered`);
|
|
505
|
-
return this.state.registeredVariables[r.name] = r, this.incRef(r, this.backend), r;
|
|
506
|
-
}
|
|
507
|
-
trackTensor(e, t) {
|
|
508
|
-
this.state.numTensors++, e.dtype === "string" && this.state.numStringTensors++;
|
|
509
|
-
let s = 0;
|
|
510
|
-
e.dtype !== "complex64" && e.dtype !== "string" && (s = e.size * C(e.dtype)), this.state.numBytes += s, this.state.tensorInfo.has(e.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(e.dataId, {
|
|
511
|
-
backend: t || this.backend,
|
|
512
|
-
dtype: e.dtype,
|
|
513
|
-
shape: e.shape,
|
|
514
|
-
bytes: s
|
|
515
|
-
})), e instanceof ne || this.track(e);
|
|
516
|
-
}
|
|
517
|
-
// Track the tensor by dataId and increase the refCount for the dataId in the
|
|
518
|
-
// backend.
|
|
519
|
-
// TODO(pyu10055): This is currently used by makeVariable method, to increase
|
|
520
|
-
// refCount on the backend for the dataId. It can potentially be replaced with
|
|
521
|
-
// Identity op indead of calling backend directly.
|
|
522
|
-
incRef(e, t) {
|
|
523
|
-
this.trackTensor(e, t), this.backend.incRef(e.dataId);
|
|
524
|
-
}
|
|
525
|
-
removeDataId(e, t) {
|
|
526
|
-
this.state.tensorInfo.has(e) && this.state.tensorInfo.get(e).backend === t && (this.state.tensorInfo.delete(e), this.state.numDataBuffers--);
|
|
527
|
-
}
|
|
528
|
-
disposeTensor(e) {
|
|
529
|
-
if (!this.state.tensorInfo.has(e.dataId))
|
|
530
|
-
return;
|
|
531
|
-
const t = this.state.tensorInfo.get(e.dataId);
|
|
532
|
-
if (this.state.numTensors--, e.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= t.bytes), e.dtype !== "complex64" && e.dtype !== "string") {
|
|
533
|
-
const s = e.size * C(e.dtype);
|
|
534
|
-
this.state.numBytes -= s;
|
|
535
|
-
}
|
|
536
|
-
t.backend.disposeData(e.dataId) && this.removeDataId(e.dataId, t.backend);
|
|
537
|
-
}
|
|
538
|
-
disposeVariables() {
|
|
539
|
-
for (const e in this.state.registeredVariables) {
|
|
540
|
-
const t = this.state.registeredVariables[e];
|
|
541
|
-
this.disposeVariable(t);
|
|
542
|
-
}
|
|
543
|
-
}
|
|
544
|
-
disposeVariable(e) {
|
|
545
|
-
this.disposeTensor(e), this.state.registeredVariables[e.name] != null && delete this.state.registeredVariables[e.name];
|
|
546
|
-
}
|
|
547
|
-
memory() {
|
|
548
|
-
const e = this.backend.memory();
|
|
549
|
-
return e.numTensors = this.state.numTensors, e.numDataBuffers = this.state.numDataBuffers, e.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (e.unreliable = !0, e.reasons == null && (e.reasons = []), e.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), e;
|
|
550
|
-
}
|
|
551
|
-
async profile(e) {
|
|
552
|
-
this.state.profiling = !0;
|
|
553
|
-
const t = this.state.numBytes, s = this.state.numTensors;
|
|
554
|
-
this.state.activeProfile.kernels = [], this.state.activeProfile.result = await e(), this.state.profiling = !1, this.state.activeProfile.peakBytes = Math.max(
|
|
555
|
-
...this.state.activeProfile.kernels.map((n) => n.totalBytesSnapshot)
|
|
556
|
-
), this.state.activeProfile.newBytes = this.state.numBytes - t, this.state.activeProfile.newTensors = this.state.numTensors - s;
|
|
557
|
-
for (const n of this.state.activeProfile.kernels)
|
|
558
|
-
n.kernelTimeMs = await n.kernelTimeMs, n.extraInfo = await n.extraInfo;
|
|
559
|
-
return this.state.activeProfile;
|
|
560
|
-
}
|
|
561
|
-
isTapeOn() {
|
|
562
|
-
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
|
563
|
-
}
|
|
564
|
-
addTapeNode(e, t, s, n, r, c) {
|
|
565
|
-
const d = { id: this.state.nextTapeNodeId++, kernelName: e, inputs: t, outputs: s, saved: r }, o = z(e);
|
|
566
|
-
o != null && (n = o.gradFunc), n != null && (d.gradient = (a) => (a = a.map((h, u) => {
|
|
567
|
-
if (h == null) {
|
|
568
|
-
const l = s[u], g = oe(l.size, l.dtype);
|
|
569
|
-
return this.makeTensor(g, l.shape, l.dtype);
|
|
570
|
-
}
|
|
571
|
-
return h;
|
|
572
|
-
}), n(a.length > 1 ? a : a[0], r, c))), this.state.activeTape.push(d);
|
|
573
|
-
}
|
|
574
|
-
keep(e) {
|
|
575
|
-
return e.kept = !0, e;
|
|
576
|
-
}
|
|
577
|
-
startTape() {
|
|
578
|
-
this.state.gradientDepth === 0 && (this.state.activeTape = []), this.state.gradientDepth++;
|
|
579
|
-
}
|
|
580
|
-
endTape() {
|
|
581
|
-
this.state.gradientDepth--;
|
|
582
|
-
}
|
|
583
|
-
/**
|
|
584
|
-
* Start a scope. Use this with endScope() to achieve the same functionality
|
|
585
|
-
* as scope() without the need for a function closure.
|
|
586
|
-
*/
|
|
587
|
-
startScope(e) {
|
|
588
|
-
const t = {
|
|
589
|
-
track: [],
|
|
590
|
-
name: "unnamed scope",
|
|
591
|
-
id: this.state.nextScopeId++
|
|
592
|
-
};
|
|
593
|
-
e && (t.name = e), this.state.scopeStack.push(t), this.state.activeScope = t;
|
|
594
|
-
}
|
|
595
|
-
/**
|
|
596
|
-
* End a scope. Use this with startScope() to achieve the same functionality
|
|
597
|
-
* as scope() without the need for a function closure.
|
|
598
|
-
*/
|
|
599
|
-
endScope(e) {
|
|
600
|
-
const t = X(e), s = new Set(t.map((r) => r.id));
|
|
601
|
-
for (let r = 0; r < this.state.activeScope.track.length; r++) {
|
|
602
|
-
const c = this.state.activeScope.track[r];
|
|
603
|
-
!c.kept && !s.has(c.id) && c.dispose();
|
|
604
|
-
}
|
|
605
|
-
const n = this.state.scopeStack.pop();
|
|
606
|
-
this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], t.forEach((r) => {
|
|
607
|
-
!r.kept && r.scopeId === n?.id && this.track(r);
|
|
608
|
-
});
|
|
609
|
-
}
|
|
610
|
-
/**
|
|
611
|
-
* Returns gradients of `f` with respect to each of the `xs`. The gradients
|
|
612
|
-
* returned are of the same length as `xs`, but some might be null if `f`
|
|
613
|
-
* was not a function of that `x`. It also takes optional dy to multiply the
|
|
614
|
-
* gradient, which defaults to `1`.
|
|
615
|
-
*/
|
|
616
|
-
gradients(e, t, s, n = !1) {
|
|
617
|
-
if (p(t.length > 0, () => "gradients() received an empty list of xs."), s != null && s.dtype !== "float32")
|
|
618
|
-
throw new Error(`dy must have 'float32' dtype, but has '${s.dtype}'`);
|
|
619
|
-
const r = this.scopedRun(
|
|
620
|
-
() => this.startTape(),
|
|
621
|
-
() => this.endTape(),
|
|
622
|
-
() => this.tidy("forward", e)
|
|
623
|
-
);
|
|
624
|
-
p(r instanceof B, () => "The result y returned by f() must be a tensor.");
|
|
625
|
-
const c = Y(this.state.activeTape, t, r);
|
|
626
|
-
if (!n && c.length === 0 && t.length > 0)
|
|
627
|
-
throw new Error(
|
|
628
|
-
"Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y."
|
|
629
|
-
);
|
|
630
|
-
return this.tidy("backward", () => {
|
|
631
|
-
const d = {};
|
|
632
|
-
d[r.id] = s ?? le(r.shape), he(
|
|
633
|
-
d,
|
|
634
|
-
c,
|
|
635
|
-
// Pass the tidy function to avoid circular dep with `tape.ts`.
|
|
636
|
-
(a) => this.tidy(a),
|
|
637
|
-
// Pass an add function to avoide a circular dep with `tape.ts`.
|
|
638
|
-
fe
|
|
639
|
-
);
|
|
640
|
-
const o = t.map((a) => d[a.id]);
|
|
641
|
-
return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((a) => {
|
|
642
|
-
if (a.saved !== void 0)
|
|
643
|
-
for (const h of a.saved)
|
|
644
|
-
h.dispose();
|
|
645
|
-
}), this.state.activeTape = null), { value: r, grads: o };
|
|
646
|
-
});
|
|
647
|
-
}
|
|
648
|
-
customGrad(e) {
|
|
649
|
-
return p(G(e), () => "The f passed in customGrad(f) must be a function."), (...t) => {
|
|
650
|
-
p(
|
|
651
|
-
t.every((d) => d instanceof B),
|
|
652
|
-
() => "The args passed in customGrad(f)(x1, x2,...) must all be tensors"
|
|
653
|
-
);
|
|
654
|
-
let s;
|
|
655
|
-
const n = {};
|
|
656
|
-
t.forEach((d, o) => {
|
|
657
|
-
n[o] = d;
|
|
658
|
-
});
|
|
659
|
-
const r = (d, o) => (s = e(...t, o), p(
|
|
660
|
-
s.value instanceof B,
|
|
661
|
-
() => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"
|
|
662
|
-
), p(
|
|
663
|
-
G(s.gradFunc),
|
|
664
|
-
() => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."
|
|
665
|
-
), s.value), c = (d, o) => {
|
|
666
|
-
const a = s.gradFunc(d, o), h = Array.isArray(a) ? a : [a];
|
|
667
|
-
p(
|
|
668
|
-
h.length === t.length,
|
|
669
|
-
() => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."
|
|
670
|
-
), p(
|
|
671
|
-
h.every((l) => l instanceof B),
|
|
672
|
-
() => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors."
|
|
673
|
-
);
|
|
674
|
-
const u = {};
|
|
675
|
-
return h.forEach((l, g) => {
|
|
676
|
-
u[g] = () => l;
|
|
677
|
-
}), u;
|
|
678
|
-
};
|
|
679
|
-
return this.runKernelFunc({
|
|
680
|
-
forwardFunc: r,
|
|
681
|
-
backwardsFunc: c,
|
|
682
|
-
inputs: n
|
|
683
|
-
});
|
|
684
|
-
};
|
|
685
|
-
}
|
|
686
|
-
readSync(e) {
|
|
687
|
-
return this.state.tensorInfo.get(e).backend.readSync(e);
|
|
688
|
-
}
|
|
689
|
-
read(e) {
|
|
690
|
-
return this.state.tensorInfo.get(e).backend.read(e);
|
|
691
|
-
}
|
|
692
|
-
readToGPU(e, t) {
|
|
693
|
-
return this.state.tensorInfo.get(e).backend.readToGPU(e, t);
|
|
694
|
-
}
|
|
695
|
-
async time(e) {
|
|
696
|
-
const t = R(), s = await this.backend.time(e);
|
|
697
|
-
return s.wallMs = R() - t, s;
|
|
698
|
-
}
|
|
699
|
-
/**
|
|
700
|
-
* Tracks a Tensor in the current scope to be automatically cleaned up
|
|
701
|
-
* when the current scope ends, and returns the value.
|
|
702
|
-
*
|
|
703
|
-
* @param result The Tensor to track in the current scope.
|
|
704
|
-
*/
|
|
705
|
-
track(e) {
|
|
706
|
-
return this.state.activeScope != null && (e.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(e)), e;
|
|
707
|
-
}
|
|
708
|
-
get registeredVariables() {
|
|
709
|
-
return this.state.registeredVariables;
|
|
710
|
-
}
|
|
711
|
-
/**
|
|
712
|
-
* Resets the engine state. Removes all backends but does not remove
|
|
713
|
-
* registered backend factories.
|
|
714
|
-
*/
|
|
715
|
-
reset() {
|
|
716
|
-
this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new _();
|
|
717
|
-
for (const e in this.registry)
|
|
718
|
-
this.disposeRegisteredKernels(e), this.registry[e].dispose(), delete this.registry[e];
|
|
719
|
-
this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
|
|
720
|
-
}
|
|
721
|
-
}
|
|
722
|
-
function le(i) {
|
|
723
|
-
const e = re(ie(i), "float32");
|
|
724
|
-
return y.makeTensor(e, i, "float32");
|
|
725
|
-
}
|
|
726
|
-
function ue() {
|
|
727
|
-
const i = U();
|
|
728
|
-
if (i._tfengine == null) {
|
|
729
|
-
const e = new L(i);
|
|
730
|
-
i._tfengine = new I(e);
|
|
731
|
-
}
|
|
732
|
-
return ae(i._tfengine.ENV), ce(() => i._tfengine), i._tfengine;
|
|
733
|
-
}
|
|
734
|
-
const y = ue();
|
|
735
|
-
function fe(i, e) {
|
|
736
|
-
const t = T(i) || T(e), s = { a: i, b: e };
|
|
737
|
-
return y.runKernel(t ? "Add16" : J, s);
|
|
8
|
+
function a(e) {
|
|
9
|
+
return r(e);
|
|
738
10
|
}
|
|
739
11
|
export {
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
fe as c,
|
|
744
|
-
ue as g,
|
|
745
|
-
x as isPackableTensor,
|
|
746
|
-
T as isPackedTensor,
|
|
747
|
-
O as packTensor,
|
|
748
|
-
be as packingSupported,
|
|
749
|
-
we as unpackTensor
|
|
12
|
+
r as isPackableTensor,
|
|
13
|
+
a as isPackedTensor,
|
|
14
|
+
o as packingSupported
|
|
750
15
|
};
|