@genai-fi/nanogpt 0.4.2 → 0.4.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,17 @@
1
1
  import { defaultConfig as x } from "./config.js";
2
2
  import W from "./layers/TransformerBlock.js";
3
- import { E as F, D as P, T as q, r as K, p as T } from "./TiedEmbedding-CnJ1bx4q.js";
4
- import D from "./layers/RoPECache.js";
3
+ import { E as F, D as P, T as q, r as T, p as D } from "./TiedEmbedding-CnJ1bx4q.js";
4
+ import K from "./layers/RoPECache.js";
5
5
  import N from "./layers/RMSNorm.js";
6
6
  import { estimateParameterCount as R } from "./utilities/parameters.js";
7
7
  import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
8
8
  import B from "./layers/BaseLayer.js";
9
- import { o as y, h as E, p as G, E as v, a6 as O, a7 as Q, a8 as j, t as w, a5 as U, f as C } from "./index-C4JCoBvj.js";
9
+ import { o as $, h as E, p as G, E as v, a6 as O, a7 as j, a8 as Q, t as w, a5 as V, f as C } from "./index-C4JCoBvj.js";
10
10
  import { r as _ } from "./reshape-Boe4DuIO.js";
11
- import { r as V } from "./range-9AzeApCc.js";
12
- import { e as X } from "./tfjs_backend-Cug-PH75.js";
13
- import { g as H } from "./gather-ZYRWhmXR.js";
14
- import { s as J } from "./softmax-Cujsg4ay.js";
11
+ import { r as X } from "./range-9AzeApCc.js";
12
+ import { e as H } from "./tfjs_backend-Cug-PH75.js";
13
+ import { g as J } from "./gather-ZYRWhmXR.js";
14
+ import { s as U } from "./softmax-Cujsg4ay.js";
15
15
  /**
16
16
  * @license
17
17
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -28,13 +28,13 @@ import { s as J } from "./softmax-Cujsg4ay.js";
28
28
  * limitations under the License.
29
29
  * =============================================================================
30
30
  */
31
- function Y(h, t) {
32
- let e = E(h, "a", "mod"), o = E(t, "b", "mod");
31
+ function Y(c, t) {
32
+ let e = E(c, "a", "mod"), o = E(t, "b", "mod");
33
33
  [e, o] = G(e, o);
34
- const n = { a: e, b: o };
35
- return v.runKernel(O, n);
34
+ const i = { a: e, b: o };
35
+ return v.runKernel(O, i);
36
36
  }
37
- const Z = /* @__PURE__ */ y({ mod_: Y });
37
+ const Z = /* @__PURE__ */ $({ mod_: Y });
38
38
  /**
39
39
  * @license
40
40
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -51,17 +51,17 @@ const Z = /* @__PURE__ */ y({ mod_: Y });
51
51
  * limitations under the License.
52
52
  * =============================================================================
53
53
  */
54
- function tt(h, t, e, o = !1) {
55
- const n = E(h, "logits", "multinomial"), s = n.size, r = n.rank;
54
+ function tt(c, t, e, o = !1) {
55
+ const i = E(c, "logits", "multinomial"), s = i.size, l = i.rank;
56
56
  if (s < 2)
57
57
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
58
- if (r > 2)
59
- throw new Error(`Rank of probabilities must be 1 or 2, but is ${r}`);
58
+ if (l > 2)
59
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${l}`);
60
60
  e = e || Math.random();
61
- const i = { logits: r === 1 ? _(n, [1, -1]) : n }, p = { numSamples: t, seed: e, normalized: o }, l = v.runKernel(Q, i, p);
62
- return r === 1 ? _(l, [l.size]) : l;
61
+ const n = { logits: l === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = v.runKernel(j, n, h);
62
+ return l === 1 ? _(a, [a.size]) : a;
63
63
  }
64
- const M = /* @__PURE__ */ y({ multinomial_: tt });
64
+ const M = /* @__PURE__ */ $({ multinomial_: tt });
65
65
  /**
66
66
  * @license
67
67
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -78,19 +78,19 @@ const M = /* @__PURE__ */ y({ multinomial_: tt });
78
78
  * limitations under the License.
79
79
  * =============================================================================
80
80
  */
81
- function et(h, t = 1, e = !0) {
82
- const o = E(h, "x", "topk");
81
+ function et(c, t = 1, e = !0) {
82
+ const o = E(c, "x", "topk");
83
83
  if (o.rank === 0)
84
84
  throw new Error("topk() expects the input to be of rank 1 or higher");
85
- const n = o.shape[o.shape.length - 1];
85
+ const i = o.shape[o.shape.length - 1];
86
86
  if (t < 0)
87
87
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
88
- if (t > n)
89
- throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
90
- const s = { x: o }, r = { k: t, sorted: e }, [a, i] = v.runKernel(j, s, r);
91
- return { values: a, indices: i };
88
+ if (t > i)
89
+ throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
90
+ const s = { x: o }, l = { k: t, sorted: e }, [r, n] = v.runKernel(Q, s, l);
91
+ return { values: r, indices: n };
92
92
  }
93
- const ot = /* @__PURE__ */ y({ topk_: et });
93
+ const ot = /* @__PURE__ */ $({ topk_: et });
94
94
  /**
95
95
  * @license
96
96
  * Copyright 2018 Google LLC
@@ -100,11 +100,11 @@ const ot = /* @__PURE__ */ y({ topk_: et });
100
100
  * https://opensource.org/licenses/MIT.
101
101
  * =============================================================================
102
102
  */
103
- function st(h) {
104
- return new P(h);
103
+ function st(c) {
104
+ return new P(c);
105
105
  }
106
- function nt(h) {
107
- return new F(h);
106
+ function it(c) {
107
+ return new F(c);
108
108
  }
109
109
  class wt extends B {
110
110
  wte;
@@ -124,12 +124,12 @@ class wt extends B {
124
124
  vocabSize: this.config.gpt.vocabSize,
125
125
  embedDim: this.config.gpt.nEmbed,
126
126
  name: "token_embedding"
127
- }), this.config.gpt.useRope === !1 ? this.wpe = nt({
127
+ }), this.config.gpt.useRope === !1 ? this.wpe = it({
128
128
  inputDim: this.config.gpt.blockSize,
129
129
  outputDim: this.config.gpt.nEmbed,
130
130
  name: "positional_embedding",
131
- embeddingsInitializer: K({ mean: 0, stddev: 0.02 })
132
- }) : (this.ropeCache = new D(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
131
+ embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
132
+ }) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
133
133
  for (let e = 0; e < this.config.gpt.nLayer; e++)
134
134
  this.blocks.push(new W(e, this.config));
135
135
  this.lnF = new N(this.config, 1e-8, "final_rms_norm");
@@ -163,12 +163,12 @@ class wt extends B {
163
163
  }
164
164
  inputPhase(t, e, o = !1) {
165
165
  return w(() => {
166
- const n = this.wte.embed(t);
166
+ const i = this.wte.embed(t);
167
167
  if (this.config.gpt.useRope === !1) {
168
- const [, s] = t.shape, r = this.config.gpt.blockSize, a = V(0, s, 1, "int32"), i = Z(U(a, C(e, "int32")), C(r, "int32")), p = this.wpe.apply(i), l = n.add(p);
169
- return this.drop.apply(l, { training: o });
168
+ const [, s] = t.shape, l = this.config.gpt.blockSize, r = X(0, s, 1, "int32"), n = Z(V(r, C(e, "int32")), C(l, "int32")), h = this.wpe.apply(n), a = i.add(h);
169
+ return this.drop.apply(a, { training: o });
170
170
  } else
171
- return this.drop.apply(n, { training: o });
171
+ return this.drop.apply(i, { training: o });
172
172
  });
173
173
  }
174
174
  setSkipMask(t) {
@@ -209,67 +209,64 @@ class wt extends B {
209
209
  return w(() => {
210
210
  if (t.length === 0)
211
211
  throw new Error("No attentions for rollout");
212
- const [e, o, n] = t[0].shape;
213
- for (const s of t) {
214
- const [r, a, i] = s.shape;
215
- if (r !== e || a !== o || i !== n)
212
+ const [e, o, i] = t[0].shape;
213
+ for (const n of t) {
214
+ const [h, a, p] = n.shape;
215
+ if (h !== e || a !== o || p !== i)
216
216
  throw new Error(
217
- `Inconsistent attention shapes in rollout: expected [${e},${o},${n}] got [${r},${a},${i}]`
217
+ `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${h},${a},${p}]`
218
218
  );
219
219
  }
220
- if (o === n) {
221
- const s = X(n, n).expandDims(0);
222
- let r = s.tile([e, 1, 1]);
223
- for (const a of t) {
224
- const i = a.add(s);
225
- r = i.div(i.sum(-1, !0)).matMul(r);
226
- }
227
- return r;
220
+ const s = t.map((n) => n.slice([0, 0, 0], [e, o, o])), l = H(o, o).expandDims(0);
221
+ let r = l.tile([e, 1, 1]);
222
+ for (const n of s) {
223
+ const h = n.add(l);
224
+ r = h.div(h.sum(-1, !0)).matMul(r);
228
225
  }
229
- throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${o}, K=${n}]`);
226
+ return r;
230
227
  });
231
228
  }
232
- forward(t, e, o = !1, n = !1, s) {
229
+ forward(t, e, o = !1, i = !1, s) {
233
230
  return this.validateInput(t), w(() => {
234
231
  this.startMemory();
235
- const r = s?.[0]?.length ?? 0;
236
- let a = this.inputPhase(t, r, o);
237
- const i = [];
232
+ const l = s?.[0]?.length ?? 0;
233
+ let r = this.inputPhase(t, l, o);
234
+ const n = [];
238
235
  if (s && s.length !== this.blocks.length)
239
236
  throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
240
- for (let c = 0; c < this.blocks.length; c++) {
241
- const u = a, m = this.blocks[c], {
237
+ for (let g = 0; g < this.blocks.length; g++) {
238
+ const u = r, m = this.blocks[g], {
242
239
  output: b,
243
240
  attention: k,
244
241
  cache: f
245
- } = m.call(a, o, n, s ? s[c] : void 0);
246
- a = b, u.dispose(), n && k && i.push(k), s && f ? (s[c]?.k.dispose(), s[c]?.v.dispose(), s[c] = f) : f && (f.k.dispose(), f.v.dispose());
242
+ } = m.call(r, o, i, s ? s[g] : void 0);
243
+ r = b, u.dispose(), i && k && n.push(k), s && f ? (s[g]?.k.dispose(), s[g]?.v.dispose(), s[g] = f) : f && (f.k.dispose(), f.v.dispose());
247
244
  }
245
+ let h;
246
+ i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
247
+ const a = this.wte.project(r);
248
248
  let p;
249
- n && i.length > 0 && (p = this.computeAttentionRollout(i)), a = this.lnF.apply(a);
250
- const l = this.wte.project(a);
251
- let g;
252
- return e && (g = this.calculateLoss(l, e)), this.endMemory("Forward"), { logits: l, loss: g, attention: n ? p : void 0 };
249
+ return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
253
250
  });
254
251
  }
255
252
  generate(t, e, o) {
256
- const n = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, a = o?.includeAttention ?? !1;
253
+ const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
257
254
  return w(() => {
258
- const i = t, p = i.shape[1], l = p <= this.config.gpt.blockSize ? i : i.slice(
259
- [0, p - this.config.gpt.blockSize],
260
- [i.shape[0], this.config.gpt.blockSize]
261
- ), g = r ? this.config.gpt.blockSize - l.shape[1] : 0, c = g > 0 ? T(l, [
255
+ const n = t, h = n.shape[1], a = h <= this.config.gpt.blockSize ? n : n.slice(
256
+ [0, h - this.config.gpt.blockSize],
257
+ [n.shape[0], this.config.gpt.blockSize]
258
+ ), p = l ? this.config.gpt.blockSize - a.shape[1] : 0, g = p > 0 ? D(a, [
262
259
  [0, 0],
263
- [0, g]
264
- ]) : l, { logits: u, attention: m } = this.forward(c, void 0, !1, a, e), b = u.shape[1] - 1 - g, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, $ = k.div(n);
260
+ [0, p]
261
+ ]) : a, { logits: u, attention: m } = this.forward(g, void 0, !1, r, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, y = k.div(i);
265
262
  let d;
266
263
  if (s) {
267
- const { values: S, indices: I } = ot($, s), L = M(S.squeeze([1]), 1);
268
- d = H(I.squeeze([1]), L, 1);
264
+ const { values: S, indices: I } = ot(y, s), L = M(S.squeeze([1]), 1);
265
+ d = J(I.squeeze([1]), L, 1);
269
266
  } else
270
- d = M($.squeeze([1]), 1);
267
+ d = M(y.squeeze([1]), 1);
271
268
  let z;
272
- return o?.includeProbabilities && (z = J($.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
269
+ return o?.includeProbabilities && (z = U(y.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
273
270
  });
274
271
  }
275
272
  getNumParams() {
@@ -7,8 +7,8 @@ import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index-C4JCoBv
7
7
  import { fusedSoftmax as _ } from "../ops/fusedSoftmax.js";
8
8
  import { l as W, w as M, d as x } from "../tfjs_backend-Cug-PH75.js";
9
9
  import { o as N } from "../ones-Bf3YR48P.js";
10
+ import { v as A } from "../variable-LJT9Ld63.js";
10
11
  import { z as q } from "../zeros-dnQxFgAD.js";
11
- import { v as k } from "../variable-LJT9Ld63.js";
12
12
  import { r as C, d as I } from "../dropout-DfDdklfL.js";
13
13
  import { r as B } from "../reshape-Boe4DuIO.js";
14
14
  import { m as F } from "../mat_mul-415y5Qn2.js";
@@ -24,15 +24,15 @@ class nt extends T {
24
24
  projUnits;
25
25
  constructor(t, s) {
26
26
  super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(N([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
27
- const e = q([s.gpt.blockSize, s.gpt.blockSize]), i = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
28
- this.maskInf = M(this.bias, e, i);
27
+ const o = q([s.gpt.blockSize, s.gpt.blockSize]), e = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
28
+ this.maskInf = M(this.bias, o, e);
29
29
  }
30
30
  build() {
31
- this.cAttn === null && (this.cAttn = k(
31
+ this.cAttn === null && (this.cAttn = A(
32
32
  C([this.config.gpt.nEmbed, this.units], 0, 0.02),
33
33
  !0
34
34
  //`block_${this.index}_attn_cAttn_kernel`
35
- )), this.cProj === null && (this.cProj = k(
35
+ )), this.cProj === null && (this.cProj = A(
36
36
  C([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
37
37
  !0
38
38
  //`block_${this.index}_attn_cProj_kernel`
@@ -53,74 +53,74 @@ class nt extends T {
53
53
  t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj ? [this.cProj.clone()] : []);
54
54
  }
55
55
  loadWeights(t) {
56
- const s = t.get(`block_${this.index}_cAttn`)?.[0], e = t.get(`block_${this.index}_cProj`)?.[0];
56
+ const s = t.get(`block_${this.index}_cAttn`)?.[0], o = t.get(`block_${this.index}_cProj`)?.[0];
57
57
  if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
58
- if (!e) throw new Error(`Weights for block_${this.index}_cProj not found`);
59
- this.cAttn ? this.cAttn.assign(s) : this.cAttn = k(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = k(e, !0);
58
+ if (!o) throw new Error(`Weights for block_${this.index}_cProj not found`);
59
+ this.cAttn ? this.cAttn.assign(s) : this.cAttn = A(s, !0), this.cProj ? this.cProj.assign(o) : this.cProj = A(o, !0);
60
60
  }
61
- getAttentionScores(t, s, e, i) {
62
- const o = P(t, s, this.divisor, this.maskInf);
63
- return _(o, e ? this.config.gpt.dropout : 0, i);
61
+ getAttentionScores(t, s, o, e) {
62
+ const i = P(t, s, this.divisor, this.maskInf);
63
+ return _(i, o ? this.config.gpt.dropout : 0, e);
64
64
  }
65
65
  // Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
66
- getAttentionScoresWithPast(t, s, e) {
67
- const i = P(t, s, this.divisor, void 0, e);
68
- return _(i, 0, 0);
66
+ getAttentionScoresWithPast(t, s, o) {
67
+ const e = P(t, s, this.divisor, void 0, o);
68
+ return _(e, 0, 0);
69
69
  }
70
70
  getQKV(t) {
71
71
  return y(t, this.cAttn, this.config.gpt.nHead);
72
72
  }
73
73
  getOutputProjection(t) {
74
- const s = t.shape[0], e = t.shape[2], i = this.config.gpt.nEmbed, o = t.transpose([0, 2, 1, 3]), n = B(o, [s, e, i]);
74
+ const s = t.shape[0], o = t.shape[2], e = this.config.gpt.nEmbed, i = t.transpose([0, 2, 1, 3]), n = B(i, [s, o, e]);
75
75
  return x(n, this.cProj);
76
76
  }
77
- updateCache(t, s, e) {
78
- const i = this.config.gpt.blockSize, o = t.shape[2], n = Math.min(e?.length || 0, i - o), r = e ? E(e.k, t, i) : t, a = e ? E(e.v, s, i) : s;
77
+ updateCache(t, s, o, e) {
78
+ const i = this.config.gpt.blockSize, n = t.shape[2], r = e?.length || 0, a = o ? t : E(t, i, r, e?.k), p = o ? s : E(s, i, r, e?.v);
79
79
  return {
80
- k: S(r),
81
- v: S(a),
82
- length: n + o,
83
- cumulativeLength: e ? e.cumulativeLength + o : o
80
+ k: S(a),
81
+ v: S(p),
82
+ length: Math.min(r + n, i),
83
+ cumulativeLength: e ? e.cumulativeLength + n : n
84
84
  };
85
85
  }
86
- forward(t, s = !1, e, i = !1, o) {
86
+ forward(t, s = !1, o, e = !1, i) {
87
87
  return $(() => {
88
88
  this.startMemory();
89
- const [n, r, a] = this.getQKV(t), p = o ? o.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, f = c ? w(r, c, p) : r;
89
+ const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, f = c ? w(r, c, p) : r;
90
90
  c && (n.dispose(), r.dispose());
91
- const g = o ? o.length : 0, d = this.updateCache(f, a, o), l = d.k, m = d.v;
92
- o && (f.dispose(), a.dispose());
91
+ const g = i ? i.length : 0, d = this.updateCache(f, a, s, i), l = d.k, m = d.v;
92
+ i && (f.dispose(), a.dispose());
93
93
  let h;
94
- g > 0 ? h = this.getAttentionScoresWithPast(u, l, g) : h = this.getAttentionScores(u, l, s, e), u.dispose(), s && l.dispose();
94
+ g > 0 ? h = this.getAttentionScoresWithPast(u, l, g) : h = this.getAttentionScores(u, l, s, o), u.dispose(), s && l.dispose();
95
95
  const b = F(h, m);
96
- i || h.dispose(), s && m.dispose();
97
- const A = this.getOutputProjection(b);
96
+ e || h.dispose(), s && m.dispose();
97
+ const k = this.getOutputProjection(b);
98
98
  b.dispose();
99
- const v = i ? h.mean(1) : void 0;
100
- return this.endMemory("CausalSelfAttention"), { output: A, attention: v, presentKV: s ? void 0 : d };
99
+ const v = e ? h.mean(1) : void 0;
100
+ return this.endMemory("CausalSelfAttention"), { output: k, attention: v, presentKV: s ? void 0 : d };
101
101
  });
102
102
  }
103
- call(t, s = !1, e = !1, i) {
104
- if (i && !this.config.gpt.useRope)
103
+ call(t, s = !1, o = !1, e) {
104
+ if (e && !this.config.gpt.useRope)
105
105
  throw new Error("Cannot use pastKV without RoPE enabled");
106
- if (s && i)
106
+ if (s && e)
107
107
  throw new Error("Cannot use pastKV during training");
108
108
  if (t.shape.length !== 3)
109
109
  throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
110
110
  if (t.shape[2] !== this.config.gpt.nEmbed)
111
111
  throw new Error(`Input tensor last dimension must be ${this.config.gpt.nEmbed}, got ${t.shape[2]}`);
112
112
  this.build();
113
- const o = Math.random() * 1e9;
113
+ const i = Math.random() * 1e9;
114
114
  if (s && this.config.layerConfig.checkpointAttention) {
115
115
  const r = L(
116
116
  // @ts-expect-error Invalid params
117
117
  (a, p, c, u) => {
118
- const f = this.forward(a, !0, o);
118
+ const f = this.forward(a, !0, i);
119
119
  u([a]);
120
120
  const g = (d, l) => {
121
121
  const [m] = l, h = j().state.activeTape;
122
122
  j().state.activeTape = [];
123
- const b = O((A, v, R) => this.forward(A, !0, o).output)([m, p, c], d);
123
+ const b = O((k, v, R) => this.forward(k, !0, i).output)([m, p, c], d);
124
124
  return j().state.activeTape = h, b;
125
125
  };
126
126
  return { value: f.output, gradFunc: g };
@@ -132,7 +132,7 @@ class nt extends T {
132
132
  } else
133
133
  return { output: r };
134
134
  } else {
135
- const n = this.forward(t, s, o, e, i);
135
+ const n = this.forward(t, s, i, o, e);
136
136
  if (this.config.gpt.dropout > 0) {
137
137
  const r = I(n.output, this.config.gpt.dropout);
138
138
  return n.output.dispose(), { output: r, attention: n.attention, presentKV: n.presentKV };
@@ -1,2 +1,2 @@
1
1
  import { Tensor } from '@tensorflow/tfjs-core';
2
- export declare function appendCache(cache: Tensor, item: Tensor, maxSize: number): Tensor;
2
+ export declare function appendCache(item: Tensor, maxSize: number, pastLen: number, cache?: Tensor): Tensor;
@@ -1,9 +1,15 @@
1
- import { e as p } from "../index-C4JCoBvj.js";
1
+ import { e as a } from "../index-C4JCoBvj.js";
2
2
  import "./cpu/appendCache.js";
3
3
  import "./webgl/appendCache.js";
4
- function a(e, n, r) {
5
- return p().runKernel("AppendCache", { cache: e, item: n }, { maxSize: r });
4
+ import { z as s } from "../zeros-dnQxFgAD.js";
5
+ import { c } from "../concat-CuRsVY-K.js";
6
+ function i(r, p, n, o) {
7
+ if (!o) {
8
+ const e = r.shape[2];
9
+ return c([r, s([r.shape[0], r.shape[1], p - e, r.shape[3]])], 2);
10
+ }
11
+ return a().runKernel("AppendCache", { cache: o, item: r }, { maxSize: p, pastLen: n });
6
12
  }
7
13
  export {
8
- a as appendCache
14
+ i as appendCache
9
15
  };
@@ -1,2 +1 @@
1
- import { Tensor } from '@tensorflow/tfjs-core';
2
- export declare function appendCache(cache: Tensor, item: Tensor, maxSize: number): Tensor;
1
+ export {};
@@ -1,28 +1,23 @@
1
- import { r as a, e as m } from "../../index-C4JCoBvj.js";
2
- import { c as d } from "../../concat-CuRsVY-K.js";
3
- function r(n) {
4
- const { cache: c, item: t } = n.inputs, { maxSize: o } = n.attrs, e = d([c, t], 2), s = e.shape[2];
5
- if (s > o) {
6
- const p = s - o, i = e.shape[0], l = e.shape[1], h = e.shape[3], u = e.slice([0, 0, p, 0], [i, l, o, h]);
7
- return e.dispose(), u;
1
+ import { r as d } from "../../index-C4JCoBvj.js";
2
+ import { c as h } from "../../concat-CuRsVY-K.js";
3
+ function u(p) {
4
+ const { cache: n, item: s } = p.inputs, { maxSize: r, pastLen: c } = p.attrs, t = n.shape[0], o = n.shape[1], a = n.shape[3], e = s.shape[2];
5
+ if (c + e <= r) {
6
+ const f = n.slice([0, 0, 0, 0], [t, o, c, a]), m = n.slice([0, 0, c + e, 0], [t, o, r - c - e, a]), i = e < e ? s.slice([0, 0, 0, 0], [t, o, e, a]) : s, k = h([f, i, m], 2);
7
+ return f.dispose(), m.dispose(), i !== s && i.dispose(), k;
8
8
  }
9
- return e;
9
+ const l = n.slice([0, 0, e, 0], [t, o, r - e, a]), C = h([l, s], 2);
10
+ return l.dispose(), C;
10
11
  }
11
- const f = {
12
+ const w = {
12
13
  kernelName: "AppendCache",
13
14
  backendName: "cpu",
14
- kernelFunc: r
15
+ kernelFunc: u
15
16
  };
16
- a(f);
17
- const C = {
17
+ d(w);
18
+ const N = {
18
19
  kernelName: "AppendCache",
19
20
  backendName: "tensorflow",
20
- kernelFunc: r
21
- };
22
- a(C);
23
- function N(n, c, t) {
24
- return m().runKernel("AppendCache", { cache: n, item: c }, { maxSize: t });
25
- }
26
- export {
27
- N as appendCache
21
+ kernelFunc: u
28
22
  };
23
+ d(N);
@@ -1,22 +1,22 @@
1
- import { r as o, f as l } from "../../index-C4JCoBvj.js";
2
- import { m as k } from "../../mat_mul-415y5Qn2.js";
1
+ import { r as o, f as k } from "../../index-C4JCoBvj.js";
2
+ import { m as d } from "../../mat_mul-415y5Qn2.js";
3
3
  function r(t) {
4
- const { q: e, k: c, mask: n } = t.inputs, { divisor: m } = t.attrs, s = e.shape[2], a = k(e, c, !1, !0).mul(l(m));
5
- if (n) {
6
- const i = n.slice([0, 0], [s, s]).expandDims(0).expandDims(0);
7
- return a.add(i);
4
+ const { q: e, k: n, mask: s } = t.inputs, { divisor: c } = t.attrs, m = e.shape[2], i = n.shape[2], a = d(e, n, !1, !0).mul(k(c));
5
+ if (s) {
6
+ const l = s.slice([0, 0], [m, i]).expandDims(0).expandDims(0);
7
+ return a.add(l);
8
8
  }
9
9
  return a;
10
10
  }
11
- const d = {
11
+ const u = {
12
12
  kernelName: "AttentionMask",
13
13
  backendName: "cpu",
14
14
  kernelFunc: r
15
15
  };
16
- o(d);
17
- const u = {
16
+ o(u);
17
+ const f = {
18
18
  kernelName: "AttentionMask",
19
19
  backendName: "tensorflow",
20
20
  kernelFunc: r
21
21
  };
22
- o(u);
22
+ o(f);
@@ -1,12 +1,12 @@
1
- import { r as h } from "../../index-C4JCoBvj.js";
1
+ import { r as p } from "../../index-C4JCoBvj.js";
2
2
  class m {
3
3
  variableNames = ["cache", "item"];
4
4
  outputShape;
5
5
  userCode;
6
6
  customUniforms = [{ name: "cacheT", type: "int" }];
7
- constructor(t, a, o, s, n) {
8
- const c = Math.min(o + 1, n);
9
- this.outputShape = [t, a, c, s], this.userCode = `
7
+ constructor(t, a, n, o, c) {
8
+ const s = Math.min(n + 1, c);
9
+ this.outputShape = [t, a, s, o], this.userCode = `
10
10
  void main() {
11
11
  ivec4 coords = getOutputCoords(); // [b, h, t, d]
12
12
  int b = coords.x;
@@ -15,7 +15,7 @@ class m {
15
15
  int d = coords.w;
16
16
 
17
17
  int itemT = 1;
18
- int maxSize = ${n};
18
+ int maxSize = ${c};
19
19
  int totalT = cacheT + itemT;
20
20
  int start = totalT >= maxSize ? 1 : 0;
21
21
 
@@ -23,21 +23,22 @@ class m {
23
23
  float val = 0.0;
24
24
  if (srcT < cacheT) {
25
25
  val = getCache(b, h, srcT, d);
26
- } else {
26
+ } else if (srcT == cacheT) {
27
27
  val = getItem(b, h, 0, d);
28
- }
28
+ } else {
29
+ val = 0.0;}
29
30
  setOutput(val);
30
31
  }
31
32
  `;
32
33
  }
33
34
  }
34
- function p(e) {
35
- const { cache: t, item: a } = e.inputs, { maxSize: o } = e.attrs, s = e.backend, n = t.shape[0], c = t.shape[2], r = t.shape[1], i = new m(n, r, c, a.shape[3], o);
36
- return s.runWebGLProgram(i, [t, a], "float32", [[c]]);
35
+ function d(e) {
36
+ const { cache: t, item: a } = e.inputs, { maxSize: n, pastLen: o } = e.attrs, c = e.backend, s = t.shape[0], r = t.shape[2], i = t.shape[1], h = new m(s, i, r, a.shape[3], n);
37
+ return c.runWebGLProgram(h, [t, a], "float32", [[o]]);
37
38
  }
38
- const d = {
39
+ const l = {
39
40
  kernelName: "AppendCache",
40
41
  backendName: "webgl",
41
- kernelFunc: p
42
+ kernelFunc: d
42
43
  };
43
- h(d);
44
+ p(l);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.4.2",
3
+ "version": "0.4.3",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",