@genai-fi/nanogpt 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +90 -41
- package/dist/NanoGPTModel.d.ts +1 -0
- package/dist/NanoGPTModel.js +86 -73
- package/dist/{Reshape-BE5rA4rT.js → Reshape-Bt_t7RNz.js} +4 -4
- package/dist/TeachableLLM.js +1 -1
- package/dist/TiedEmbedding-DORsPlNL.js +44 -0
- package/dist/{axis_util-97KkkyRQ.js → axis_util-CVbf1vmL.js} +3 -3
- package/dist/{broadcast_to-CMlkG8NS.js → broadcast_to-BBoMQXbL.js} +4 -4
- package/dist/{concat-Cxbo2sOz.js → concat-BRRtq4S2.js} +1 -1
- package/dist/dataset-ZHEPJmED.js +1226 -0
- package/dist/{dropout-kbDY39Ci.js → dropout-lQm_YyX3.js} +1 -1
- package/dist/{gather-Bxe1Qip8.js → gather-BWyutxwi.js} +3 -3
- package/dist/{gpgpu_math-C0zyxKFi.js → gpgpu_math-Df7gzJWH.js} +1 -1
- package/dist/{index-iNhkcAEQ.js → index-CnHyhpKc.js} +32 -32
- package/dist/{kernel_funcs_utils-C4eIk4fE.js → kernel_funcs_utils-Dqo82NH4.js} +25 -25
- package/dist/layers/BaseLayer.js +114 -3
- package/dist/layers/CausalSelfAttention.js +29 -28
- package/dist/layers/MLP.js +10 -9
- package/dist/layers/RMSNorm.js +12 -11
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +8 -6
- package/dist/layers/TransformerBlock.js +2 -2
- package/dist/{log_sum_exp-CkumwesB.js → log_sum_exp-CRH7Np9v.js} +12 -12
- package/dist/main.js +1 -1
- package/dist/{mat_mul-D0SifYfJ.js → mat_mul-DeGU1U_C.js} +3 -3
- package/dist/{max-CYaAjEEp.js → max-CcnEArWK.js} +3 -3
- package/dist/{moments-B06NlR_V.js → moments-DLTE6-1p.js} +4 -4
- package/dist/{norm-D3676xIo.js → norm-BpWsOapl.js} +5 -5
- package/dist/{ones-BIeFnPHR.js → ones-CDWGzVnm.js} +6 -6
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +27 -27
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +36 -36
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.js +22 -22
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{ops-ObfXLHYQ.js → ops-DzQTmLIl.js} +60 -60
- package/dist/{TiedEmbedding-DsDRvLB0.js → random_width-DI2h9CMs.js} +1215 -1250
- package/dist/{range-BsFU-SNG.js → range-CkOJ7090.js} +1 -1
- package/dist/{reshape-DxTPgnwL.js → reshape-CTIbqjwm.js} +1 -1
- package/dist/{sin-BOX-JVAj.js → sin-HzioENy_.js} +5 -5
- package/dist/{slice_util-D-kaD4ZV.js → slice_util-n4wHKmex.js} +1 -1
- package/dist/{softmax-BjsptB07.js → softmax-DX6qXAbm.js} +2 -2
- package/dist/{split-BCbrzthj.js → split-CVwhL8Oe.js} +3 -3
- package/dist/{stack--cqr9Dgc.js → stack-S2-D2JAQ.js} +1 -1
- package/dist/{sum-B_92TaHD.js → sum-UdfvaNhB.js} +4 -4
- package/dist/{tensor-CfiPXsW4.js → tensor-IZex6Bwp.js} +1 -1
- package/dist/{tensor2d-tSxWdFMH.js → tensor2d-CqtBzOKq.js} +1 -1
- package/dist/{tfjs_backend-NucKez4s.js → tfjs_backend-DX9yVvwk.js} +41 -41
- package/dist/tokeniser/CharTokeniser.js +27 -27
- package/dist/tokeniser/bpe.d.ts +1 -0
- package/dist/tokeniser/bpe.js +38 -35
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.d.ts +4 -1
- package/dist/training/DatasetBuilder.js +49 -1244
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +33 -24
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/tokenParse.js +17 -8
- package/dist/utilities/weights.js +2 -2
- package/dist/variable-BGvK-VN3.js +23 -0
- package/dist/{zeros-NMYTayy7.js → zeros-CYMicyqz.js} +3 -3
- package/package.json +1 -1
- package/dist/BaseLayer-BhrMN8JO.js +0 -135
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { generateText as v } from "../utilities/generate.js";
|
|
2
2
|
import L from "./Trainer.js";
|
|
3
3
|
import x from "./Evaluator.js";
|
|
4
|
-
import { a as h } from "../index-
|
|
4
|
+
import { a as h } from "../index-CnHyhpKc.js";
|
|
5
5
|
const D = {
|
|
6
6
|
desiredLoss: 0.01,
|
|
7
7
|
logInterval: 1,
|
package/dist/training/Trainer.js
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
import { DatasetBuilder as d } from "./DatasetBuilder.js";
|
|
2
|
-
import
|
|
3
|
-
import { t as
|
|
4
|
-
import { m as
|
|
5
|
-
import { m as
|
|
6
|
-
import { m as
|
|
7
|
-
import { z as
|
|
1
|
+
import { DatasetBuilder as h, flattenTokens as d, PAGE_FACTOR as g } from "./DatasetBuilder.js";
|
|
2
|
+
import u from "./AdamExt.js";
|
|
3
|
+
import { t as f, v as y, a as m } from "../index-CnHyhpKc.js";
|
|
4
|
+
import { m as S, n as z } from "../norm-BpWsOapl.js";
|
|
5
|
+
import { m as w, a as T } from "../moments-DLTE6-1p.js";
|
|
6
|
+
import { m as x } from "../max-CcnEArWK.js";
|
|
7
|
+
import { z as p } from "../zeros-CYMicyqz.js";
|
|
8
8
|
class G {
|
|
9
9
|
constructor(t, s, e = 1e-3) {
|
|
10
|
-
this.tokenizer = s, this.model = t, this.learningRate = e, this.resetOptimizer(), this.datasetBuilder = new
|
|
10
|
+
this.tokenizer = s, this.model = t, this.learningRate = e, this.resetOptimizer(), this.datasetBuilder = new h(s, t.config.gpt.blockSize);
|
|
11
11
|
}
|
|
12
12
|
model;
|
|
13
13
|
optimizer;
|
|
@@ -29,7 +29,7 @@ class G {
|
|
|
29
29
|
}
|
|
30
30
|
resetOptimizer(t = { learningRateFactor: 1, beta1: 0.9, beta2: 0.99, epsilon: 1e-8 }) {
|
|
31
31
|
this.optimizer && this.optimizer.dispose();
|
|
32
|
-
const s = new
|
|
32
|
+
const s = new u(
|
|
33
33
|
t.learningRateFactor * this.learningRate,
|
|
34
34
|
t.beta1,
|
|
35
35
|
t.beta2,
|
|
@@ -46,21 +46,21 @@ class G {
|
|
|
46
46
|
printGradients(t) {
|
|
47
47
|
Object.keys(t).forEach((s) => {
|
|
48
48
|
const e = t[s];
|
|
49
|
-
console.log(`${s}:`), console.log(` Shape: ${e.shape}`), console.log(` Mean: ${
|
|
49
|
+
console.log(`${s}:`), console.log(` Shape: ${e.shape}`), console.log(` Mean: ${w(e).dataSync()[0]}`), console.log(` Std: ${T(e).variance.sqrt().dataSync()[0]}`), console.log(` Min: ${S(e).dataSync()[0]}`), console.log(` Max: ${x(e).dataSync()[0]}`), console.log(` Norm: ${z(e).dataSync()[0]}`);
|
|
50
50
|
});
|
|
51
51
|
}
|
|
52
52
|
trainStep(t, s = !1, e = !1) {
|
|
53
|
-
return
|
|
53
|
+
return f(() => {
|
|
54
54
|
this.model.getProfiler()?.startMemory();
|
|
55
|
-
const { xs: a, ys:
|
|
56
|
-
const [
|
|
57
|
-
return
|
|
58
|
-
}, { value:
|
|
59
|
-
return s ? this.model.getProfiler()?.endMemory("Training") : (e && (console.log("-------"), this.printGradients(
|
|
55
|
+
const { xs: a, ys: i } = t, o = () => {
|
|
56
|
+
const [l, c] = this.model.forward({ training: !0 }, a, i);
|
|
57
|
+
return l.dispose(), c;
|
|
58
|
+
}, { value: n, grads: r } = y(o);
|
|
59
|
+
return s ? this.model.getProfiler()?.endMemory("Training") : (e && (console.log("-------"), this.printGradients(r), console.log("-------")), this.optimizer.applyGradients(r), this.model.getProfiler()?.endMemory("Training"), m(r)), n;
|
|
60
60
|
});
|
|
61
61
|
}
|
|
62
62
|
dummyPass() {
|
|
63
|
-
const t =
|
|
63
|
+
const t = p([1, this.model.config.gpt.blockSize], "int32"), s = p([1, this.model.config.gpt.blockSize], "int32");
|
|
64
64
|
try {
|
|
65
65
|
const e = this.trainStep({ xs: t, ys: s }, !0);
|
|
66
66
|
e.dataSync(), e.dispose();
|
|
@@ -75,20 +75,29 @@ class G {
|
|
|
75
75
|
const e = this.trainStep(s, !1, !1);
|
|
76
76
|
return s.xs.dispose(), s.ys.dispose(), t.step++, t.totalSteps++, e.array().then((a) => (t.lastLoss = a, t.losses.push(t.lastLoss), e.dispose(), t.lastLoss));
|
|
77
77
|
} catch (e) {
|
|
78
|
-
throw console.error(`Error processing batch at step ${t.step}:`, e),
|
|
78
|
+
throw console.error(`Error processing batch at step ${t.step}:`, e), m(), e;
|
|
79
79
|
}
|
|
80
80
|
}
|
|
81
81
|
async createTrainValidationSplit(t, s = 32, e = 0.1) {
|
|
82
|
-
const a = await
|
|
83
|
-
|
|
82
|
+
const a = await d(t, this.tokenizer), i = /* @__PURE__ */ new Set();
|
|
83
|
+
if (e > 0) {
|
|
84
|
+
const r = Math.floor(a.length / (this.datasetBuilder.blockSize * g)), l = Math.max(1, Math.floor(r * e));
|
|
85
|
+
for (; i.size < l; ) {
|
|
86
|
+
const c = Math.floor(Math.random() * r);
|
|
87
|
+
i.add(c);
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
const o = await this.datasetBuilder.createTextDataset(a, s, i, !1), n = await this.datasetBuilder.createTextDataset(
|
|
91
|
+
a,
|
|
84
92
|
s,
|
|
85
|
-
|
|
86
|
-
|
|
93
|
+
i,
|
|
94
|
+
!0
|
|
87
95
|
);
|
|
88
|
-
return { trainDataset:
|
|
96
|
+
return { trainDataset: o, validationDataset: n };
|
|
89
97
|
}
|
|
90
98
|
async createDataset(t, s = 32) {
|
|
91
|
-
|
|
99
|
+
const e = await d(t, this.tokenizer);
|
|
100
|
+
return await this.datasetBuilder.createTextDataset(e, s);
|
|
92
101
|
}
|
|
93
102
|
dispose() {
|
|
94
103
|
this.optimizer && this.optimizer.dispose();
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import { gatherSub as L } from "../ops/gatherSub.js";
|
|
2
2
|
import { scatterSub as y } from "../ops/scatterSub.js";
|
|
3
|
-
import { e as u, c as i, z as S, t as f, s as G } from "../index-
|
|
4
|
-
import { s as v } from "../softmax-
|
|
5
|
-
import { m as z } from "../max-
|
|
6
|
-
import { l as k } from "../log_sum_exp-
|
|
3
|
+
import { e as u, c as i, z as S, t as f, s as G } from "../index-CnHyhpKc.js";
|
|
4
|
+
import { s as v } from "../softmax-DX6qXAbm.js";
|
|
5
|
+
import { m as z } from "../max-CcnEArWK.js";
|
|
6
|
+
import { l as k } from "../log_sum_exp-CRH7Np9v.js";
|
|
7
7
|
function F(a, s) {
|
|
8
8
|
return f(() => {
|
|
9
9
|
const e = a.shape[a.shape.length - 1], o = a.shape.slice(0, -1).reduce((d, c) => d * c, 1), p = a.shape.length > 2 ? a.reshape([o, e]) : a, n = s.shape.length > 1 ? s.reshape([o]).cast("int32") : s.cast("int32"), t = z(p, -1, !0), r = G(p, t), h = k(r, -1);
|
package/dist/utilities/dummy.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import "../index-
|
|
2
|
-
import { z as n } from "../zeros-
|
|
1
|
+
import "../index-CnHyhpKc.js";
|
|
2
|
+
import { z as n } from "../zeros-CYMicyqz.js";
|
|
3
3
|
async function c(s) {
|
|
4
4
|
const i = n([1, s.config.gpt.blockSize], "int32"), [t, o] = s.forward({ training: !1 }, i);
|
|
5
5
|
await t.data(), t.dispose(), o && o.dispose(), i.dispose();
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { t as y } from "../index-
|
|
2
|
-
import { t as x } from "../tensor2d-
|
|
3
|
-
import { c as f } from "../concat-
|
|
1
|
+
import { t as y } from "../index-CnHyhpKc.js";
|
|
2
|
+
import { t as x } from "../tensor2d-CqtBzOKq.js";
|
|
3
|
+
import { c as f } from "../concat-BRRtq4S2.js";
|
|
4
4
|
async function A(o, r, a, c, T) {
|
|
5
5
|
if (c <= 0)
|
|
6
6
|
throw new Error("Length must be a positive integer");
|
package/dist/utilities/load.js
CHANGED
|
@@ -3,7 +3,7 @@ import { importWeights as b } from "./weights.js";
|
|
|
3
3
|
import u from "../tokeniser/CharTokeniser.js";
|
|
4
4
|
import F from "../NanoGPTModel.js";
|
|
5
5
|
import { dummyPassAsync as j } from "./dummy.js";
|
|
6
|
-
import { d as T } from "../index-
|
|
6
|
+
import { d as T } from "../index-CnHyhpKc.js";
|
|
7
7
|
import E from "../tokeniser/bpe.js";
|
|
8
8
|
async function A(t) {
|
|
9
9
|
const o = await fetch(t);
|
|
@@ -1,12 +1,21 @@
|
|
|
1
|
-
function
|
|
2
|
-
const r = Array.from(
|
|
3
|
-
let
|
|
4
|
-
for (let
|
|
5
|
-
const
|
|
6
|
-
|
|
1
|
+
function c(l) {
|
|
2
|
+
const r = Array.from(l), s = [], o = new RegExp("(\\p{P}|\\p{S}|\\s)", "gu");
|
|
3
|
+
let t = "";
|
|
4
|
+
for (let e = 0; e < r.length; e++) {
|
|
5
|
+
const n = r[e];
|
|
6
|
+
if (n === " ")
|
|
7
|
+
(r[e + 1] ?? "") !== " " ? (s.push(t), t = n) : t += n;
|
|
8
|
+
else if (n.match(o)) {
|
|
9
|
+
s.push(t);
|
|
10
|
+
let h = n;
|
|
11
|
+
for (; e + 1 < r.length && r[e + 1] === n; )
|
|
12
|
+
h += r[e + 1], e++;
|
|
13
|
+
s.push(h), t = "";
|
|
14
|
+
} else
|
|
15
|
+
t += n;
|
|
7
16
|
}
|
|
8
|
-
return
|
|
17
|
+
return t.length > 0 && s.push(t), s.filter((e) => e.length > 0);
|
|
9
18
|
}
|
|
10
19
|
export {
|
|
11
|
-
|
|
20
|
+
c as default
|
|
12
21
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import "../index-
|
|
2
|
-
import { t as p } from "../tensor-
|
|
1
|
+
import "../index-CnHyhpKc.js";
|
|
2
|
+
import { t as p } from "../tensor-IZex6Bwp.js";
|
|
3
3
|
function h(n) {
|
|
4
4
|
const e = n.reduce((s, o) => s + o.length, 0), a = new Float32Array(e);
|
|
5
5
|
let t = 0;
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import { E as i } from "./index-CnHyhpKc.js";
|
|
2
|
+
/**
|
|
3
|
+
* @license
|
|
4
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
* =============================================================================
|
|
17
|
+
*/
|
|
18
|
+
function m(r, a = !0, e, t) {
|
|
19
|
+
return i.makeVariable(r, a, e, t);
|
|
20
|
+
}
|
|
21
|
+
export {
|
|
22
|
+
m as v
|
|
23
|
+
};
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { o as
|
|
1
|
+
import { o as m, j as r, X as l, E as c, Y as i, n as p, Z as u, q as f } from "./index-CnHyhpKc.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -17,11 +17,11 @@ import { o as l, i as r, X as m, E as c, Y as i, l as p, Z as u, p as f } from "
|
|
|
17
17
|
*/
|
|
18
18
|
function x(a, e) {
|
|
19
19
|
const o = r(a, "real", "complex"), s = r(e, "imag", "complex");
|
|
20
|
-
|
|
20
|
+
l(o.shape, s.shape, `real and imag shapes, ${o.shape} and ${s.shape}, must match in call to tf.complex().`);
|
|
21
21
|
const n = { real: o, imag: s };
|
|
22
22
|
return c.runKernel(i, n);
|
|
23
23
|
}
|
|
24
|
-
const g = /* @__PURE__ */
|
|
24
|
+
const g = /* @__PURE__ */ m({ complex_: x });
|
|
25
25
|
/**
|
|
26
26
|
* @license
|
|
27
27
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
package/package.json
CHANGED
|
@@ -1,135 +0,0 @@
|
|
|
1
|
-
import { E as p, V as v, c as _, e as o, W as V } from "./index-iNhkcAEQ.js";
|
|
2
|
-
/**
|
|
3
|
-
* @license
|
|
4
|
-
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
5
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
* you may not use this file except in compliance with the License.
|
|
7
|
-
* You may obtain a copy of the License at
|
|
8
|
-
*
|
|
9
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
*
|
|
11
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
* See the License for the specific language governing permissions and
|
|
15
|
-
* limitations under the License.
|
|
16
|
-
* =============================================================================
|
|
17
|
-
*/
|
|
18
|
-
function E(l, t = !0, r, e) {
|
|
19
|
-
return p.makeVariable(l, t, r, e);
|
|
20
|
-
}
|
|
21
|
-
class m {
|
|
22
|
-
parent;
|
|
23
|
-
config;
|
|
24
|
-
_variables = /* @__PURE__ */ new Map();
|
|
25
|
-
_trainable = !0;
|
|
26
|
-
children = [];
|
|
27
|
-
constructor(t, r) {
|
|
28
|
-
this.config = t, this.parent = r, this.parent && this.parent.children.push(this);
|
|
29
|
-
}
|
|
30
|
-
getProfiler() {
|
|
31
|
-
return this.config.layerConfig.profiler;
|
|
32
|
-
}
|
|
33
|
-
startMemory() {
|
|
34
|
-
this.config.layerConfig.profiler?.startMemory();
|
|
35
|
-
}
|
|
36
|
-
endMemory(t) {
|
|
37
|
-
this.config.layerConfig.profiler?.endMemory(t);
|
|
38
|
-
}
|
|
39
|
-
addVariable(t, r) {
|
|
40
|
-
this._variables.set(t, r || null);
|
|
41
|
-
}
|
|
42
|
-
get variables() {
|
|
43
|
-
const t = Array.from(this._variables.values()).filter((e) => e !== null), r = this.children.flatMap((e) => e.variables);
|
|
44
|
-
return [...t, ...r];
|
|
45
|
-
}
|
|
46
|
-
get trainableVariables() {
|
|
47
|
-
const t = Array.from(this._variables.values()).filter(
|
|
48
|
-
(e) => e !== null && e.trainable
|
|
49
|
-
), r = this.children.flatMap((e) => e.trainableVariables);
|
|
50
|
-
return [...t, ...r];
|
|
51
|
-
}
|
|
52
|
-
get trainable() {
|
|
53
|
-
return this._trainable;
|
|
54
|
-
}
|
|
55
|
-
set trainable(t) {
|
|
56
|
-
this._trainable = t, this._variables.forEach((r) => {
|
|
57
|
-
r && (r.trainable = t);
|
|
58
|
-
}), this.children.forEach((r) => {
|
|
59
|
-
r.trainable = t;
|
|
60
|
-
});
|
|
61
|
-
}
|
|
62
|
-
getVariable(t) {
|
|
63
|
-
const r = this._variables.get(t);
|
|
64
|
-
if (!r)
|
|
65
|
-
throw new Error(`Variable ${t} not found`);
|
|
66
|
-
return r;
|
|
67
|
-
}
|
|
68
|
-
hasVariable(t) {
|
|
69
|
-
return this._variables.get(t) !== null;
|
|
70
|
-
}
|
|
71
|
-
setVariable(t, r) {
|
|
72
|
-
if (!this._variables.has(t))
|
|
73
|
-
throw new Error(`Variable ${t} not found`);
|
|
74
|
-
this._variables.set(t, r);
|
|
75
|
-
}
|
|
76
|
-
saveWeights(t) {
|
|
77
|
-
this._variables.forEach((r, e) => {
|
|
78
|
-
r && t.set(e, [r.clone()]);
|
|
79
|
-
}), this.children.forEach((r) => {
|
|
80
|
-
r.saveWeights(t);
|
|
81
|
-
});
|
|
82
|
-
}
|
|
83
|
-
loadWeights(t) {
|
|
84
|
-
this._variables.forEach((r, e) => {
|
|
85
|
-
const i = t.get(e)?.[0];
|
|
86
|
-
if (!i)
|
|
87
|
-
throw new Error(`Weights for ${e} not found`);
|
|
88
|
-
r ? r.assign(i) : this._variables.set(e, E(i, this._trainable));
|
|
89
|
-
}), this.children.forEach((r) => {
|
|
90
|
-
r.loadWeights(t);
|
|
91
|
-
});
|
|
92
|
-
}
|
|
93
|
-
dispose() {
|
|
94
|
-
this._variables.forEach((t) => {
|
|
95
|
-
t?.dispose();
|
|
96
|
-
}), this._variables.clear();
|
|
97
|
-
}
|
|
98
|
-
build() {
|
|
99
|
-
}
|
|
100
|
-
dropout(t) {
|
|
101
|
-
return t;
|
|
102
|
-
}
|
|
103
|
-
call(t, ...r) {
|
|
104
|
-
this.build();
|
|
105
|
-
const e = this.forward(t, ...r);
|
|
106
|
-
if (t.training && e instanceof v) {
|
|
107
|
-
const i = this.dropout(e);
|
|
108
|
-
return i !== e && e.dispose(), i;
|
|
109
|
-
} else
|
|
110
|
-
return e;
|
|
111
|
-
}
|
|
112
|
-
callCheckpoint(t, ...r) {
|
|
113
|
-
return this.build(), this.checkpointingFn(t, ...r);
|
|
114
|
-
}
|
|
115
|
-
checkpointingFn(t, ...r) {
|
|
116
|
-
const e = this.trainableVariables, s = _((...a) => {
|
|
117
|
-
const h = a[a.length - 1], n = a.slice(0, r.length), c = this.forward(t, ...n);
|
|
118
|
-
return h(n), { value: c, gradFunc: (f, u) => {
|
|
119
|
-
const b = o().state.activeTape;
|
|
120
|
-
o().state.activeTape = [];
|
|
121
|
-
const d = V((...g) => this.forward(t, ...g.slice(0, n.length)))([...u, ...e], f);
|
|
122
|
-
return o().state.activeTape = b, d;
|
|
123
|
-
} };
|
|
124
|
-
})(...r, ...e);
|
|
125
|
-
if (t.training) {
|
|
126
|
-
const a = this.dropout(s);
|
|
127
|
-
return a !== s && s.dispose(), a;
|
|
128
|
-
} else
|
|
129
|
-
return s;
|
|
130
|
-
}
|
|
131
|
-
}
|
|
132
|
-
export {
|
|
133
|
-
m as B,
|
|
134
|
-
E as v
|
|
135
|
-
};
|