@genai-fi/nanogpt 0.4.2 → 0.4.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/NanoGPTModel.js +72 -75
- package/dist/layers/CausalSelfAttention.js +37 -37
- package/dist/ops/appendCache.d.ts +1 -1
- package/dist/ops/appendCache.js +10 -4
- package/dist/ops/cpu/appendCache.d.ts +1 -2
- package/dist/ops/cpu/appendCache.js +15 -20
- package/dist/ops/cpu/attentionMask.js +10 -10
- package/dist/ops/webgl/appendCache.js +14 -13
- package/package.json +1 -1
package/dist/NanoGPTModel.js
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
import { defaultConfig as x } from "./config.js";
|
|
2
2
|
import W from "./layers/TransformerBlock.js";
|
|
3
|
-
import { E as F, D as P, T as q, r as
|
|
4
|
-
import
|
|
3
|
+
import { E as F, D as P, T as q, r as T, p as D } from "./TiedEmbedding-CnJ1bx4q.js";
|
|
4
|
+
import K from "./layers/RoPECache.js";
|
|
5
5
|
import N from "./layers/RMSNorm.js";
|
|
6
6
|
import { estimateParameterCount as R } from "./utilities/parameters.js";
|
|
7
7
|
import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
|
|
8
8
|
import B from "./layers/BaseLayer.js";
|
|
9
|
-
import { o as
|
|
9
|
+
import { o as $, h as E, p as G, E as v, a6 as O, a7 as j, a8 as Q, t as w, a5 as V, f as C } from "./index-C4JCoBvj.js";
|
|
10
10
|
import { r as _ } from "./reshape-Boe4DuIO.js";
|
|
11
|
-
import { r as
|
|
12
|
-
import { e as
|
|
13
|
-
import { g as
|
|
14
|
-
import { s as
|
|
11
|
+
import { r as X } from "./range-9AzeApCc.js";
|
|
12
|
+
import { e as H } from "./tfjs_backend-Cug-PH75.js";
|
|
13
|
+
import { g as J } from "./gather-ZYRWhmXR.js";
|
|
14
|
+
import { s as U } from "./softmax-Cujsg4ay.js";
|
|
15
15
|
/**
|
|
16
16
|
* @license
|
|
17
17
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -28,13 +28,13 @@ import { s as J } from "./softmax-Cujsg4ay.js";
|
|
|
28
28
|
* limitations under the License.
|
|
29
29
|
* =============================================================================
|
|
30
30
|
*/
|
|
31
|
-
function Y(
|
|
32
|
-
let e = E(
|
|
31
|
+
function Y(c, t) {
|
|
32
|
+
let e = E(c, "a", "mod"), o = E(t, "b", "mod");
|
|
33
33
|
[e, o] = G(e, o);
|
|
34
|
-
const
|
|
35
|
-
return v.runKernel(O,
|
|
34
|
+
const i = { a: e, b: o };
|
|
35
|
+
return v.runKernel(O, i);
|
|
36
36
|
}
|
|
37
|
-
const Z = /* @__PURE__ */
|
|
37
|
+
const Z = /* @__PURE__ */ $({ mod_: Y });
|
|
38
38
|
/**
|
|
39
39
|
* @license
|
|
40
40
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -51,17 +51,17 @@ const Z = /* @__PURE__ */ y({ mod_: Y });
|
|
|
51
51
|
* limitations under the License.
|
|
52
52
|
* =============================================================================
|
|
53
53
|
*/
|
|
54
|
-
function tt(
|
|
55
|
-
const
|
|
54
|
+
function tt(c, t, e, o = !1) {
|
|
55
|
+
const i = E(c, "logits", "multinomial"), s = i.size, l = i.rank;
|
|
56
56
|
if (s < 2)
|
|
57
57
|
throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
|
|
58
|
-
if (
|
|
59
|
-
throw new Error(`Rank of probabilities must be 1 or 2, but is ${
|
|
58
|
+
if (l > 2)
|
|
59
|
+
throw new Error(`Rank of probabilities must be 1 or 2, but is ${l}`);
|
|
60
60
|
e = e || Math.random();
|
|
61
|
-
const
|
|
62
|
-
return
|
|
61
|
+
const n = { logits: l === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = v.runKernel(j, n, h);
|
|
62
|
+
return l === 1 ? _(a, [a.size]) : a;
|
|
63
63
|
}
|
|
64
|
-
const M = /* @__PURE__ */
|
|
64
|
+
const M = /* @__PURE__ */ $({ multinomial_: tt });
|
|
65
65
|
/**
|
|
66
66
|
* @license
|
|
67
67
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -78,19 +78,19 @@ const M = /* @__PURE__ */ y({ multinomial_: tt });
|
|
|
78
78
|
* limitations under the License.
|
|
79
79
|
* =============================================================================
|
|
80
80
|
*/
|
|
81
|
-
function et(
|
|
82
|
-
const o = E(
|
|
81
|
+
function et(c, t = 1, e = !0) {
|
|
82
|
+
const o = E(c, "x", "topk");
|
|
83
83
|
if (o.rank === 0)
|
|
84
84
|
throw new Error("topk() expects the input to be of rank 1 or higher");
|
|
85
|
-
const
|
|
85
|
+
const i = o.shape[o.shape.length - 1];
|
|
86
86
|
if (t < 0)
|
|
87
87
|
throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
|
|
88
|
-
if (t >
|
|
89
|
-
throw new Error(`'k' passed to topk() must be <= the last dimension (${
|
|
90
|
-
const s = { x: o },
|
|
91
|
-
return { values:
|
|
88
|
+
if (t > i)
|
|
89
|
+
throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
|
|
90
|
+
const s = { x: o }, l = { k: t, sorted: e }, [r, n] = v.runKernel(Q, s, l);
|
|
91
|
+
return { values: r, indices: n };
|
|
92
92
|
}
|
|
93
|
-
const ot = /* @__PURE__ */
|
|
93
|
+
const ot = /* @__PURE__ */ $({ topk_: et });
|
|
94
94
|
/**
|
|
95
95
|
* @license
|
|
96
96
|
* Copyright 2018 Google LLC
|
|
@@ -100,11 +100,11 @@ const ot = /* @__PURE__ */ y({ topk_: et });
|
|
|
100
100
|
* https://opensource.org/licenses/MIT.
|
|
101
101
|
* =============================================================================
|
|
102
102
|
*/
|
|
103
|
-
function st(
|
|
104
|
-
return new P(
|
|
103
|
+
function st(c) {
|
|
104
|
+
return new P(c);
|
|
105
105
|
}
|
|
106
|
-
function
|
|
107
|
-
return new F(
|
|
106
|
+
function it(c) {
|
|
107
|
+
return new F(c);
|
|
108
108
|
}
|
|
109
109
|
class wt extends B {
|
|
110
110
|
wte;
|
|
@@ -124,12 +124,12 @@ class wt extends B {
|
|
|
124
124
|
vocabSize: this.config.gpt.vocabSize,
|
|
125
125
|
embedDim: this.config.gpt.nEmbed,
|
|
126
126
|
name: "token_embedding"
|
|
127
|
-
}), this.config.gpt.useRope === !1 ? this.wpe =
|
|
127
|
+
}), this.config.gpt.useRope === !1 ? this.wpe = it({
|
|
128
128
|
inputDim: this.config.gpt.blockSize,
|
|
129
129
|
outputDim: this.config.gpt.nEmbed,
|
|
130
130
|
name: "positional_embedding",
|
|
131
|
-
embeddingsInitializer:
|
|
132
|
-
}) : (this.ropeCache = new
|
|
131
|
+
embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
|
|
132
|
+
}) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
|
|
133
133
|
for (let e = 0; e < this.config.gpt.nLayer; e++)
|
|
134
134
|
this.blocks.push(new W(e, this.config));
|
|
135
135
|
this.lnF = new N(this.config, 1e-8, "final_rms_norm");
|
|
@@ -163,12 +163,12 @@ class wt extends B {
|
|
|
163
163
|
}
|
|
164
164
|
inputPhase(t, e, o = !1) {
|
|
165
165
|
return w(() => {
|
|
166
|
-
const
|
|
166
|
+
const i = this.wte.embed(t);
|
|
167
167
|
if (this.config.gpt.useRope === !1) {
|
|
168
|
-
const [, s] = t.shape,
|
|
169
|
-
return this.drop.apply(
|
|
168
|
+
const [, s] = t.shape, l = this.config.gpt.blockSize, r = X(0, s, 1, "int32"), n = Z(V(r, C(e, "int32")), C(l, "int32")), h = this.wpe.apply(n), a = i.add(h);
|
|
169
|
+
return this.drop.apply(a, { training: o });
|
|
170
170
|
} else
|
|
171
|
-
return this.drop.apply(
|
|
171
|
+
return this.drop.apply(i, { training: o });
|
|
172
172
|
});
|
|
173
173
|
}
|
|
174
174
|
setSkipMask(t) {
|
|
@@ -209,67 +209,64 @@ class wt extends B {
|
|
|
209
209
|
return w(() => {
|
|
210
210
|
if (t.length === 0)
|
|
211
211
|
throw new Error("No attentions for rollout");
|
|
212
|
-
const [e, o,
|
|
213
|
-
for (const
|
|
214
|
-
const [
|
|
215
|
-
if (
|
|
212
|
+
const [e, o, i] = t[0].shape;
|
|
213
|
+
for (const n of t) {
|
|
214
|
+
const [h, a, p] = n.shape;
|
|
215
|
+
if (h !== e || a !== o || p !== i)
|
|
216
216
|
throw new Error(
|
|
217
|
-
`Inconsistent attention shapes in rollout: expected [${e},${o},${
|
|
217
|
+
`Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${h},${a},${p}]`
|
|
218
218
|
);
|
|
219
219
|
}
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
r = i.div(i.sum(-1, !0)).matMul(r);
|
|
226
|
-
}
|
|
227
|
-
return r;
|
|
220
|
+
const s = t.map((n) => n.slice([0, 0, 0], [e, o, o])), l = H(o, o).expandDims(0);
|
|
221
|
+
let r = l.tile([e, 1, 1]);
|
|
222
|
+
for (const n of s) {
|
|
223
|
+
const h = n.add(l);
|
|
224
|
+
r = h.div(h.sum(-1, !0)).matMul(r);
|
|
228
225
|
}
|
|
229
|
-
|
|
226
|
+
return r;
|
|
230
227
|
});
|
|
231
228
|
}
|
|
232
|
-
forward(t, e, o = !1,
|
|
229
|
+
forward(t, e, o = !1, i = !1, s) {
|
|
233
230
|
return this.validateInput(t), w(() => {
|
|
234
231
|
this.startMemory();
|
|
235
|
-
const
|
|
236
|
-
let
|
|
237
|
-
const
|
|
232
|
+
const l = s?.[0]?.length ?? 0;
|
|
233
|
+
let r = this.inputPhase(t, l, o);
|
|
234
|
+
const n = [];
|
|
238
235
|
if (s && s.length !== this.blocks.length)
|
|
239
236
|
throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
|
|
240
|
-
for (let
|
|
241
|
-
const u =
|
|
237
|
+
for (let g = 0; g < this.blocks.length; g++) {
|
|
238
|
+
const u = r, m = this.blocks[g], {
|
|
242
239
|
output: b,
|
|
243
240
|
attention: k,
|
|
244
241
|
cache: f
|
|
245
|
-
} = m.call(
|
|
246
|
-
|
|
242
|
+
} = m.call(r, o, i, s ? s[g] : void 0);
|
|
243
|
+
r = b, u.dispose(), i && k && n.push(k), s && f ? (s[g]?.k.dispose(), s[g]?.v.dispose(), s[g] = f) : f && (f.k.dispose(), f.v.dispose());
|
|
247
244
|
}
|
|
245
|
+
let h;
|
|
246
|
+
i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
|
|
247
|
+
const a = this.wte.project(r);
|
|
248
248
|
let p;
|
|
249
|
-
|
|
250
|
-
const l = this.wte.project(a);
|
|
251
|
-
let g;
|
|
252
|
-
return e && (g = this.calculateLoss(l, e)), this.endMemory("Forward"), { logits: l, loss: g, attention: n ? p : void 0 };
|
|
249
|
+
return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
|
|
253
250
|
});
|
|
254
251
|
}
|
|
255
252
|
generate(t, e, o) {
|
|
256
|
-
const
|
|
253
|
+
const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
|
|
257
254
|
return w(() => {
|
|
258
|
-
const
|
|
259
|
-
[0,
|
|
260
|
-
[
|
|
261
|
-
),
|
|
255
|
+
const n = t, h = n.shape[1], a = h <= this.config.gpt.blockSize ? n : n.slice(
|
|
256
|
+
[0, h - this.config.gpt.blockSize],
|
|
257
|
+
[n.shape[0], this.config.gpt.blockSize]
|
|
258
|
+
), p = l ? this.config.gpt.blockSize - a.shape[1] : 0, g = p > 0 ? D(a, [
|
|
262
259
|
[0, 0],
|
|
263
|
-
[0,
|
|
264
|
-
]) :
|
|
260
|
+
[0, p]
|
|
261
|
+
]) : a, { logits: u, attention: m } = this.forward(g, void 0, !1, r, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, y = k.div(i);
|
|
265
262
|
let d;
|
|
266
263
|
if (s) {
|
|
267
|
-
const { values: S, indices: I } = ot(
|
|
268
|
-
d =
|
|
264
|
+
const { values: S, indices: I } = ot(y, s), L = M(S.squeeze([1]), 1);
|
|
265
|
+
d = J(I.squeeze([1]), L, 1);
|
|
269
266
|
} else
|
|
270
|
-
d = M(
|
|
267
|
+
d = M(y.squeeze([1]), 1);
|
|
271
268
|
let z;
|
|
272
|
-
return o?.includeProbabilities && (z =
|
|
269
|
+
return o?.includeProbabilities && (z = U(y.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
|
|
273
270
|
});
|
|
274
271
|
}
|
|
275
272
|
getNumParams() {
|
|
@@ -7,8 +7,8 @@ import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index-C4JCoBv
|
|
|
7
7
|
import { fusedSoftmax as _ } from "../ops/fusedSoftmax.js";
|
|
8
8
|
import { l as W, w as M, d as x } from "../tfjs_backend-Cug-PH75.js";
|
|
9
9
|
import { o as N } from "../ones-Bf3YR48P.js";
|
|
10
|
+
import { v as A } from "../variable-LJT9Ld63.js";
|
|
10
11
|
import { z as q } from "../zeros-dnQxFgAD.js";
|
|
11
|
-
import { v as k } from "../variable-LJT9Ld63.js";
|
|
12
12
|
import { r as C, d as I } from "../dropout-DfDdklfL.js";
|
|
13
13
|
import { r as B } from "../reshape-Boe4DuIO.js";
|
|
14
14
|
import { m as F } from "../mat_mul-415y5Qn2.js";
|
|
@@ -24,15 +24,15 @@ class nt extends T {
|
|
|
24
24
|
projUnits;
|
|
25
25
|
constructor(t, s) {
|
|
26
26
|
super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(N([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
|
|
27
|
-
const
|
|
28
|
-
this.maskInf = M(this.bias,
|
|
27
|
+
const o = q([s.gpt.blockSize, s.gpt.blockSize]), e = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
|
|
28
|
+
this.maskInf = M(this.bias, o, e);
|
|
29
29
|
}
|
|
30
30
|
build() {
|
|
31
|
-
this.cAttn === null && (this.cAttn =
|
|
31
|
+
this.cAttn === null && (this.cAttn = A(
|
|
32
32
|
C([this.config.gpt.nEmbed, this.units], 0, 0.02),
|
|
33
33
|
!0
|
|
34
34
|
//`block_${this.index}_attn_cAttn_kernel`
|
|
35
|
-
)), this.cProj === null && (this.cProj =
|
|
35
|
+
)), this.cProj === null && (this.cProj = A(
|
|
36
36
|
C([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
|
|
37
37
|
!0
|
|
38
38
|
//`block_${this.index}_attn_cProj_kernel`
|
|
@@ -53,74 +53,74 @@ class nt extends T {
|
|
|
53
53
|
t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj ? [this.cProj.clone()] : []);
|
|
54
54
|
}
|
|
55
55
|
loadWeights(t) {
|
|
56
|
-
const s = t.get(`block_${this.index}_cAttn`)?.[0],
|
|
56
|
+
const s = t.get(`block_${this.index}_cAttn`)?.[0], o = t.get(`block_${this.index}_cProj`)?.[0];
|
|
57
57
|
if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
|
|
58
|
-
if (!
|
|
59
|
-
this.cAttn ? this.cAttn.assign(s) : this.cAttn =
|
|
58
|
+
if (!o) throw new Error(`Weights for block_${this.index}_cProj not found`);
|
|
59
|
+
this.cAttn ? this.cAttn.assign(s) : this.cAttn = A(s, !0), this.cProj ? this.cProj.assign(o) : this.cProj = A(o, !0);
|
|
60
60
|
}
|
|
61
|
-
getAttentionScores(t, s,
|
|
62
|
-
const
|
|
63
|
-
return _(
|
|
61
|
+
getAttentionScores(t, s, o, e) {
|
|
62
|
+
const i = P(t, s, this.divisor, this.maskInf);
|
|
63
|
+
return _(i, o ? this.config.gpt.dropout : 0, e);
|
|
64
64
|
}
|
|
65
65
|
// Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
|
|
66
|
-
getAttentionScoresWithPast(t, s,
|
|
67
|
-
const
|
|
68
|
-
return _(
|
|
66
|
+
getAttentionScoresWithPast(t, s, o) {
|
|
67
|
+
const e = P(t, s, this.divisor, void 0, o);
|
|
68
|
+
return _(e, 0, 0);
|
|
69
69
|
}
|
|
70
70
|
getQKV(t) {
|
|
71
71
|
return y(t, this.cAttn, this.config.gpt.nHead);
|
|
72
72
|
}
|
|
73
73
|
getOutputProjection(t) {
|
|
74
|
-
const s = t.shape[0],
|
|
74
|
+
const s = t.shape[0], o = t.shape[2], e = this.config.gpt.nEmbed, i = t.transpose([0, 2, 1, 3]), n = B(i, [s, o, e]);
|
|
75
75
|
return x(n, this.cProj);
|
|
76
76
|
}
|
|
77
|
-
updateCache(t, s, e) {
|
|
78
|
-
const i = this.config.gpt.blockSize,
|
|
77
|
+
updateCache(t, s, o, e) {
|
|
78
|
+
const i = this.config.gpt.blockSize, n = t.shape[2], r = e?.length || 0, a = o ? t : E(t, i, r, e?.k), p = o ? s : E(s, i, r, e?.v);
|
|
79
79
|
return {
|
|
80
|
-
k: S(
|
|
81
|
-
v: S(
|
|
82
|
-
length:
|
|
83
|
-
cumulativeLength: e ? e.cumulativeLength +
|
|
80
|
+
k: S(a),
|
|
81
|
+
v: S(p),
|
|
82
|
+
length: Math.min(r + n, i),
|
|
83
|
+
cumulativeLength: e ? e.cumulativeLength + n : n
|
|
84
84
|
};
|
|
85
85
|
}
|
|
86
|
-
forward(t, s = !1,
|
|
86
|
+
forward(t, s = !1, o, e = !1, i) {
|
|
87
87
|
return $(() => {
|
|
88
88
|
this.startMemory();
|
|
89
|
-
const [n, r, a] = this.getQKV(t), p =
|
|
89
|
+
const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, f = c ? w(r, c, p) : r;
|
|
90
90
|
c && (n.dispose(), r.dispose());
|
|
91
|
-
const g =
|
|
92
|
-
|
|
91
|
+
const g = i ? i.length : 0, d = this.updateCache(f, a, s, i), l = d.k, m = d.v;
|
|
92
|
+
i && (f.dispose(), a.dispose());
|
|
93
93
|
let h;
|
|
94
|
-
g > 0 ? h = this.getAttentionScoresWithPast(u, l, g) : h = this.getAttentionScores(u, l, s,
|
|
94
|
+
g > 0 ? h = this.getAttentionScoresWithPast(u, l, g) : h = this.getAttentionScores(u, l, s, o), u.dispose(), s && l.dispose();
|
|
95
95
|
const b = F(h, m);
|
|
96
|
-
|
|
97
|
-
const
|
|
96
|
+
e || h.dispose(), s && m.dispose();
|
|
97
|
+
const k = this.getOutputProjection(b);
|
|
98
98
|
b.dispose();
|
|
99
|
-
const v =
|
|
100
|
-
return this.endMemory("CausalSelfAttention"), { output:
|
|
99
|
+
const v = e ? h.mean(1) : void 0;
|
|
100
|
+
return this.endMemory("CausalSelfAttention"), { output: k, attention: v, presentKV: s ? void 0 : d };
|
|
101
101
|
});
|
|
102
102
|
}
|
|
103
|
-
call(t, s = !1,
|
|
104
|
-
if (
|
|
103
|
+
call(t, s = !1, o = !1, e) {
|
|
104
|
+
if (e && !this.config.gpt.useRope)
|
|
105
105
|
throw new Error("Cannot use pastKV without RoPE enabled");
|
|
106
|
-
if (s &&
|
|
106
|
+
if (s && e)
|
|
107
107
|
throw new Error("Cannot use pastKV during training");
|
|
108
108
|
if (t.shape.length !== 3)
|
|
109
109
|
throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
|
|
110
110
|
if (t.shape[2] !== this.config.gpt.nEmbed)
|
|
111
111
|
throw new Error(`Input tensor last dimension must be ${this.config.gpt.nEmbed}, got ${t.shape[2]}`);
|
|
112
112
|
this.build();
|
|
113
|
-
const
|
|
113
|
+
const i = Math.random() * 1e9;
|
|
114
114
|
if (s && this.config.layerConfig.checkpointAttention) {
|
|
115
115
|
const r = L(
|
|
116
116
|
// @ts-expect-error Invalid params
|
|
117
117
|
(a, p, c, u) => {
|
|
118
|
-
const f = this.forward(a, !0,
|
|
118
|
+
const f = this.forward(a, !0, i);
|
|
119
119
|
u([a]);
|
|
120
120
|
const g = (d, l) => {
|
|
121
121
|
const [m] = l, h = j().state.activeTape;
|
|
122
122
|
j().state.activeTape = [];
|
|
123
|
-
const b = O((
|
|
123
|
+
const b = O((k, v, R) => this.forward(k, !0, i).output)([m, p, c], d);
|
|
124
124
|
return j().state.activeTape = h, b;
|
|
125
125
|
};
|
|
126
126
|
return { value: f.output, gradFunc: g };
|
|
@@ -132,7 +132,7 @@ class nt extends T {
|
|
|
132
132
|
} else
|
|
133
133
|
return { output: r };
|
|
134
134
|
} else {
|
|
135
|
-
const n = this.forward(t, s, o, e
|
|
135
|
+
const n = this.forward(t, s, i, o, e);
|
|
136
136
|
if (this.config.gpt.dropout > 0) {
|
|
137
137
|
const r = I(n.output, this.config.gpt.dropout);
|
|
138
138
|
return n.output.dispose(), { output: r, attention: n.attention, presentKV: n.presentKV };
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import { Tensor } from '@tensorflow/tfjs-core';
|
|
2
|
-
export declare function appendCache(
|
|
2
|
+
export declare function appendCache(item: Tensor, maxSize: number, pastLen: number, cache?: Tensor): Tensor;
|
package/dist/ops/appendCache.js
CHANGED
|
@@ -1,9 +1,15 @@
|
|
|
1
|
-
import { e as
|
|
1
|
+
import { e as a } from "../index-C4JCoBvj.js";
|
|
2
2
|
import "./cpu/appendCache.js";
|
|
3
3
|
import "./webgl/appendCache.js";
|
|
4
|
-
|
|
5
|
-
|
|
4
|
+
import { z as s } from "../zeros-dnQxFgAD.js";
|
|
5
|
+
import { c } from "../concat-CuRsVY-K.js";
|
|
6
|
+
function i(r, p, n, o) {
|
|
7
|
+
if (!o) {
|
|
8
|
+
const e = r.shape[2];
|
|
9
|
+
return c([r, s([r.shape[0], r.shape[1], p - e, r.shape[3]])], 2);
|
|
10
|
+
}
|
|
11
|
+
return a().runKernel("AppendCache", { cache: o, item: r }, { maxSize: p, pastLen: n });
|
|
6
12
|
}
|
|
7
13
|
export {
|
|
8
|
-
|
|
14
|
+
i as appendCache
|
|
9
15
|
};
|
|
@@ -1,2 +1 @@
|
|
|
1
|
-
|
|
2
|
-
export declare function appendCache(cache: Tensor, item: Tensor, maxSize: number): Tensor;
|
|
1
|
+
export {};
|
|
@@ -1,28 +1,23 @@
|
|
|
1
|
-
import { r as
|
|
2
|
-
import { c as
|
|
3
|
-
function
|
|
4
|
-
const { cache:
|
|
5
|
-
if (
|
|
6
|
-
const
|
|
7
|
-
return
|
|
1
|
+
import { r as d } from "../../index-C4JCoBvj.js";
|
|
2
|
+
import { c as h } from "../../concat-CuRsVY-K.js";
|
|
3
|
+
function u(p) {
|
|
4
|
+
const { cache: n, item: s } = p.inputs, { maxSize: r, pastLen: c } = p.attrs, t = n.shape[0], o = n.shape[1], a = n.shape[3], e = s.shape[2];
|
|
5
|
+
if (c + e <= r) {
|
|
6
|
+
const f = n.slice([0, 0, 0, 0], [t, o, c, a]), m = n.slice([0, 0, c + e, 0], [t, o, r - c - e, a]), i = e < e ? s.slice([0, 0, 0, 0], [t, o, e, a]) : s, k = h([f, i, m], 2);
|
|
7
|
+
return f.dispose(), m.dispose(), i !== s && i.dispose(), k;
|
|
8
8
|
}
|
|
9
|
-
|
|
9
|
+
const l = n.slice([0, 0, e, 0], [t, o, r - e, a]), C = h([l, s], 2);
|
|
10
|
+
return l.dispose(), C;
|
|
10
11
|
}
|
|
11
|
-
const
|
|
12
|
+
const w = {
|
|
12
13
|
kernelName: "AppendCache",
|
|
13
14
|
backendName: "cpu",
|
|
14
|
-
kernelFunc:
|
|
15
|
+
kernelFunc: u
|
|
15
16
|
};
|
|
16
|
-
|
|
17
|
-
const
|
|
17
|
+
d(w);
|
|
18
|
+
const N = {
|
|
18
19
|
kernelName: "AppendCache",
|
|
19
20
|
backendName: "tensorflow",
|
|
20
|
-
kernelFunc:
|
|
21
|
-
};
|
|
22
|
-
a(C);
|
|
23
|
-
function N(n, c, t) {
|
|
24
|
-
return m().runKernel("AppendCache", { cache: n, item: c }, { maxSize: t });
|
|
25
|
-
}
|
|
26
|
-
export {
|
|
27
|
-
N as appendCache
|
|
21
|
+
kernelFunc: u
|
|
28
22
|
};
|
|
23
|
+
d(N);
|
|
@@ -1,22 +1,22 @@
|
|
|
1
|
-
import { r as o, f as
|
|
2
|
-
import { m as
|
|
1
|
+
import { r as o, f as k } from "../../index-C4JCoBvj.js";
|
|
2
|
+
import { m as d } from "../../mat_mul-415y5Qn2.js";
|
|
3
3
|
function r(t) {
|
|
4
|
-
const { q: e, k:
|
|
5
|
-
if (
|
|
6
|
-
const
|
|
7
|
-
return a.add(
|
|
4
|
+
const { q: e, k: n, mask: s } = t.inputs, { divisor: c } = t.attrs, m = e.shape[2], i = n.shape[2], a = d(e, n, !1, !0).mul(k(c));
|
|
5
|
+
if (s) {
|
|
6
|
+
const l = s.slice([0, 0], [m, i]).expandDims(0).expandDims(0);
|
|
7
|
+
return a.add(l);
|
|
8
8
|
}
|
|
9
9
|
return a;
|
|
10
10
|
}
|
|
11
|
-
const
|
|
11
|
+
const u = {
|
|
12
12
|
kernelName: "AttentionMask",
|
|
13
13
|
backendName: "cpu",
|
|
14
14
|
kernelFunc: r
|
|
15
15
|
};
|
|
16
|
-
o(
|
|
17
|
-
const
|
|
16
|
+
o(u);
|
|
17
|
+
const f = {
|
|
18
18
|
kernelName: "AttentionMask",
|
|
19
19
|
backendName: "tensorflow",
|
|
20
20
|
kernelFunc: r
|
|
21
21
|
};
|
|
22
|
-
o(
|
|
22
|
+
o(f);
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
import { r as
|
|
1
|
+
import { r as p } from "../../index-C4JCoBvj.js";
|
|
2
2
|
class m {
|
|
3
3
|
variableNames = ["cache", "item"];
|
|
4
4
|
outputShape;
|
|
5
5
|
userCode;
|
|
6
6
|
customUniforms = [{ name: "cacheT", type: "int" }];
|
|
7
|
-
constructor(t, a,
|
|
8
|
-
const
|
|
9
|
-
this.outputShape = [t, a,
|
|
7
|
+
constructor(t, a, n, o, c) {
|
|
8
|
+
const s = Math.min(n + 1, c);
|
|
9
|
+
this.outputShape = [t, a, s, o], this.userCode = `
|
|
10
10
|
void main() {
|
|
11
11
|
ivec4 coords = getOutputCoords(); // [b, h, t, d]
|
|
12
12
|
int b = coords.x;
|
|
@@ -15,7 +15,7 @@ class m {
|
|
|
15
15
|
int d = coords.w;
|
|
16
16
|
|
|
17
17
|
int itemT = 1;
|
|
18
|
-
int maxSize = ${
|
|
18
|
+
int maxSize = ${c};
|
|
19
19
|
int totalT = cacheT + itemT;
|
|
20
20
|
int start = totalT >= maxSize ? 1 : 0;
|
|
21
21
|
|
|
@@ -23,21 +23,22 @@ class m {
|
|
|
23
23
|
float val = 0.0;
|
|
24
24
|
if (srcT < cacheT) {
|
|
25
25
|
val = getCache(b, h, srcT, d);
|
|
26
|
-
} else {
|
|
26
|
+
} else if (srcT == cacheT) {
|
|
27
27
|
val = getItem(b, h, 0, d);
|
|
28
|
-
}
|
|
28
|
+
} else {
|
|
29
|
+
val = 0.0;}
|
|
29
30
|
setOutput(val);
|
|
30
31
|
}
|
|
31
32
|
`;
|
|
32
33
|
}
|
|
33
34
|
}
|
|
34
|
-
function
|
|
35
|
-
const { cache: t, item: a } = e.inputs, { maxSize: o } = e.attrs,
|
|
36
|
-
return
|
|
35
|
+
function d(e) {
|
|
36
|
+
const { cache: t, item: a } = e.inputs, { maxSize: n, pastLen: o } = e.attrs, c = e.backend, s = t.shape[0], r = t.shape[2], i = t.shape[1], h = new m(s, i, r, a.shape[3], n);
|
|
37
|
+
return c.runWebGLProgram(h, [t, a], "float32", [[o]]);
|
|
37
38
|
}
|
|
38
|
-
const
|
|
39
|
+
const l = {
|
|
39
40
|
kernelName: "AppendCache",
|
|
40
41
|
backendName: "webgl",
|
|
41
|
-
kernelFunc:
|
|
42
|
+
kernelFunc: d
|
|
42
43
|
};
|
|
43
|
-
|
|
44
|
+
p(l);
|