@genai-fi/nanogpt 0.4.5 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/dist/BaseLayer-BhrMN8JO.js +135 -0
  2. package/dist/Generator.js +44 -41
  3. package/dist/NanoGPTModel.d.ts +12 -16
  4. package/dist/NanoGPTModel.js +128 -138
  5. package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
  6. package/dist/TeachableLLM.js +1 -1
  7. package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
  8. package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
  9. package/dist/broadcast_to-CMlkG8NS.js +44 -0
  10. package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
  11. package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
  12. package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
  13. package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
  14. package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
  15. package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
  16. package/dist/layers/BaseLayer.d.ts +28 -4
  17. package/dist/layers/BaseLayer.js +3 -16
  18. package/dist/layers/CausalSelfAttention.d.ts +22 -24
  19. package/dist/layers/CausalSelfAttention.js +73 -128
  20. package/dist/layers/MLP.d.ts +8 -15
  21. package/dist/layers/MLP.js +43 -81
  22. package/dist/layers/RMSNorm.d.ts +5 -10
  23. package/dist/layers/RMSNorm.js +13 -29
  24. package/dist/layers/RoPECache.js +14 -12
  25. package/dist/layers/TiedEmbedding.d.ts +6 -16
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.d.ts +12 -16
  28. package/dist/layers/TransformerBlock.js +20 -41
  29. package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
  30. package/dist/main.js +1 -1
  31. package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
  32. package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
  33. package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
  34. package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
  35. package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
  36. package/dist/ops/appendCache.js +4 -4
  37. package/dist/ops/attentionMask.d.ts +1 -1
  38. package/dist/ops/attentionMask.js +4 -4
  39. package/dist/ops/cpu/appendCache.js +2 -2
  40. package/dist/ops/cpu/attentionMask.js +14 -15
  41. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  42. package/dist/ops/cpu/gatherSub.js +5 -5
  43. package/dist/ops/cpu/gelu.js +1 -1
  44. package/dist/ops/cpu/matMulGelu.js +1 -1
  45. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  46. package/dist/ops/cpu/matMulMul.js +17 -0
  47. package/dist/ops/cpu/mulDropout.js +1 -1
  48. package/dist/ops/cpu/normRMS.js +1 -1
  49. package/dist/ops/cpu/qkv.js +3 -3
  50. package/dist/ops/cpu/rope.js +5 -5
  51. package/dist/ops/cpu/scatterSub.js +8 -8
  52. package/dist/ops/fusedSoftmax.js +1 -1
  53. package/dist/ops/gatherSub.js +1 -1
  54. package/dist/ops/gelu.js +1 -1
  55. package/dist/ops/grads/attentionMask.js +13 -9
  56. package/dist/ops/grads/fusedSoftmax.js +12 -9
  57. package/dist/ops/grads/gelu.js +1 -1
  58. package/dist/ops/grads/matMulGelu.js +1 -1
  59. package/dist/ops/grads/normRMS.js +1 -1
  60. package/dist/ops/grads/qkv.js +19 -9
  61. package/dist/ops/grads/rope.js +1 -1
  62. package/dist/ops/matMulGelu.js +1 -1
  63. package/dist/ops/matMulMul.d.ts +2 -0
  64. package/dist/ops/matMulMul.js +9 -0
  65. package/dist/ops/mulDrop.js +1 -1
  66. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  67. package/dist/ops/normRMS.js +1 -1
  68. package/dist/ops/qkv.js +1 -1
  69. package/dist/ops/scatterSub.js +1 -1
  70. package/dist/ops/webgl/appendCache.js +1 -1
  71. package/dist/ops/webgl/attentionMask.js +13 -12
  72. package/dist/ops/webgl/fusedSoftmax.js +43 -40
  73. package/dist/ops/webgl/gatherSub.js +1 -1
  74. package/dist/ops/webgl/gelu.js +2 -2
  75. package/dist/ops/webgl/matMulGelu.js +17 -17
  76. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  77. package/dist/ops/webgl/matMulMul.js +28 -0
  78. package/dist/ops/webgl/mulDropout.js +1 -1
  79. package/dist/ops/webgl/normRMS.js +29 -21
  80. package/dist/ops/webgl/qkv.js +1 -1
  81. package/dist/ops/webgl/rope.js +1 -1
  82. package/dist/ops/webgl/scatterSub.js +1 -1
  83. package/dist/ops-ObfXLHYQ.js +1269 -0
  84. package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
  85. package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
  86. package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
  87. package/dist/slice_util-D-kaD4ZV.js +49 -0
  88. package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
  89. package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
  90. package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
  91. package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
  92. package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
  93. package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
  94. package/dist/tfjs_backend-NucKez4s.js +1010 -0
  95. package/dist/training/AdamExt.js +1 -1
  96. package/dist/training/DatasetBuilder.js +44 -44
  97. package/dist/training/Evaluator.js +6 -6
  98. package/dist/training/FullTrainer.js +1 -1
  99. package/dist/training/Trainer.js +7 -7
  100. package/dist/training/sparseCrossEntropy.js +4 -4
  101. package/dist/utilities/dummy.js +10 -10
  102. package/dist/utilities/generate.js +3 -3
  103. package/dist/utilities/load.js +1 -1
  104. package/dist/utilities/profile.js +1 -1
  105. package/dist/utilities/save.js +10 -8
  106. package/dist/utilities/weights.js +2 -2
  107. package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
  108. package/package.json +1 -1
  109. package/dist/slice_util-BdhYwFY_.js +0 -90
  110. package/dist/tfjs_backend-DuKis_xG.js +0 -2271
  111. package/dist/variable-BJTZ3jOy.js +0 -23
@@ -1,17 +1,16 @@
1
- import { defaultConfig as x } from "./config.js";
2
- import W from "./layers/TransformerBlock.js";
3
- import { E as F, D as P, T as q, r as T, p as D } from "./TiedEmbedding-DznFwzcB.js";
4
- import K from "./layers/RoPECache.js";
5
- import N from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as R } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
8
- import B from "./layers/BaseLayer.js";
9
- import { o as $, h as E, p as G, E as v, a9 as O, aa as j, ab as Q, t as w, a8 as V, f as C } from "./index--6vO-cOz.js";
10
- import { r as _ } from "./reshape-z51Eu-re.js";
11
- import { r as X } from "./range-C_vpUjBu.js";
12
- import { e as H } from "./tfjs_backend-DuKis_xG.js";
13
- import { g as J } from "./gather-C5D8PxwA.js";
14
- import { s as U } from "./softmax-Dsxflvdl.js";
1
+ import { defaultConfig as L } from "./config.js";
2
+ import q from "./layers/TransformerBlock.js";
3
+ import { E as O, D as T, T as K, r as P, p as _ } from "./TiedEmbedding-DsDRvLB0.js";
4
+ import F from "./layers/RoPECache.js";
5
+ import D from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as N } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as R } from "./training/sparseCrossEntropy.js";
8
+ import { B } from "./BaseLayer-BhrMN8JO.js";
9
+ import { o as k, i as m, q as G, E as w, aa as A, ab as V, ac as j, t as b, a9 as W, f as y, F as H } from "./index-iNhkcAEQ.js";
10
+ import { r as $ } from "./reshape-DxTPgnwL.js";
11
+ import { r as J } from "./range-BsFU-SNG.js";
12
+ import { g as Q } from "./gather-Bxe1Qip8.js";
13
+ import { s as U } from "./softmax-BjsptB07.js";
15
14
  /**
16
15
  * @license
17
16
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -28,13 +27,13 @@ import { s as U } from "./softmax-Dsxflvdl.js";
28
27
  * limitations under the License.
29
28
  * =============================================================================
30
29
  */
31
- function Y(c, t) {
32
- let e = E(c, "a", "mod"), o = E(t, "b", "mod");
30
+ function X(h, t) {
31
+ let e = m(h, "a", "mod"), o = m(t, "b", "mod");
33
32
  [e, o] = G(e, o);
34
- const i = { a: e, b: o };
35
- return v.runKernel(O, i);
33
+ const n = { a: e, b: o };
34
+ return w.runKernel(A, n);
36
35
  }
37
- const Z = /* @__PURE__ */ $({ mod_: Y });
36
+ const Y = /* @__PURE__ */ k({ mod_: X });
38
37
  /**
39
38
  * @license
40
39
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -51,17 +50,17 @@ const Z = /* @__PURE__ */ $({ mod_: Y });
51
50
  * limitations under the License.
52
51
  * =============================================================================
53
52
  */
54
- function tt(c, t, e, o = !1) {
55
- const i = E(c, "logits", "multinomial"), s = i.size, l = i.rank;
53
+ function Z(h, t, e, o = !1) {
54
+ const n = m(h, "logits", "multinomial"), s = n.size, i = n.rank;
56
55
  if (s < 2)
57
56
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
58
- if (l > 2)
59
- throw new Error(`Rank of probabilities must be 1 or 2, but is ${l}`);
57
+ if (i > 2)
58
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${i}`);
60
59
  e = e || Math.random();
61
- const n = { logits: l === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = v.runKernel(j, n, h);
62
- return l === 1 ? _(a, [a.size]) : a;
60
+ const a = { logits: i === 1 ? $(n, [1, -1]) : n }, c = { numSamples: t, seed: e, normalized: o }, l = w.runKernel(V, a, c);
61
+ return i === 1 ? $(l, [l.size]) : l;
63
62
  }
64
- const M = /* @__PURE__ */ $({ multinomial_: tt });
63
+ const z = /* @__PURE__ */ k({ multinomial_: Z });
65
64
  /**
66
65
  * @license
67
66
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -78,19 +77,19 @@ const M = /* @__PURE__ */ $({ multinomial_: tt });
78
77
  * limitations under the License.
79
78
  * =============================================================================
80
79
  */
81
- function et(c, t = 1, e = !0) {
82
- const o = E(c, "x", "topk");
80
+ function tt(h, t = 1, e = !0) {
81
+ const o = m(h, "x", "topk");
83
82
  if (o.rank === 0)
84
83
  throw new Error("topk() expects the input to be of rank 1 or higher");
85
- const i = o.shape[o.shape.length - 1];
84
+ const n = o.shape[o.shape.length - 1];
86
85
  if (t < 0)
87
86
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
88
- if (t > i)
89
- throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
90
- const s = { x: o }, l = { k: t, sorted: e }, [r, n] = v.runKernel(Q, s, l);
91
- return { values: r, indices: n };
87
+ if (t > n)
88
+ throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
89
+ const s = { x: o }, i = { k: t, sorted: e }, [r, a] = w.runKernel(j, s, i);
90
+ return { values: r, indices: a };
92
91
  }
93
- const ot = /* @__PURE__ */ $({ topk_: et });
92
+ const et = /* @__PURE__ */ k({ topk_: tt });
94
93
  /**
95
94
  * @license
96
95
  * Copyright 2018 Google LLC
@@ -100,13 +99,13 @@ const ot = /* @__PURE__ */ $({ topk_: et });
100
99
  * https://opensource.org/licenses/MIT.
101
100
  * =============================================================================
102
101
  */
103
- function st(c) {
104
- return new P(c);
102
+ function ot(h) {
103
+ return new T(h);
105
104
  }
106
- function it(c) {
107
- return new F(c);
105
+ function st(h) {
106
+ return new O(h);
108
107
  }
109
- class wt extends B {
108
+ class bt extends B {
110
109
  wte;
111
110
  // Token embeddings
112
111
  wpe;
@@ -120,55 +119,30 @@ class wt extends B {
120
119
  log = [];
121
120
  // Training log
122
121
  constructor(t = {}) {
123
- super({ gpt: { ...x, ...t }, layerConfig: {} }), this.wte = new q({
124
- vocabSize: this.config.gpt.vocabSize,
125
- embedDim: this.config.gpt.nEmbed,
126
- name: "token_embedding"
127
- }), this.config.gpt.useRope === !1 ? this.wpe = it({
122
+ super({ gpt: { ...L, ...t }, layerConfig: {} }), this.wte = new K(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = st({
128
123
  inputDim: this.config.gpt.blockSize,
129
124
  outputDim: this.config.gpt.nEmbed,
130
125
  name: "positional_embedding",
131
- embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
132
- }) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
126
+ embeddingsInitializer: P({ mean: 0, stddev: 0.02 })
127
+ }) : (this.ropeCache = new F(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = ot({ rate: this.config.gpt.dropout }), this.blocks = [];
133
128
  for (let e = 0; e < this.config.gpt.nLayer; e++)
134
- this.blocks.push(new W(e, this.config));
135
- this.lnF = new N(this.config, "final_rms_norm");
129
+ this.blocks.push(new q(e, this.config, this));
130
+ this.lnF = new D(this.config, "final_rms_norm", this);
136
131
  }
137
132
  get checkpointing() {
138
- return this.config.layerConfig.checkpointAttention === !0 || this.config.layerConfig.checkpointMLP === !0;
133
+ return this.config.layerConfig.checkpointing === !0;
139
134
  }
140
135
  set checkpointing(t) {
141
- this.config.layerConfig.checkpointAttention = t, this.config.layerConfig.checkpointMLP = t;
142
- }
143
- get variables() {
144
- return [
145
- //...this.wpe.trainableWeights.map((v) => v.read() as TF.Variable),
146
- ...this.blocks.flatMap((t) => t.variables),
147
- ...this.lnF.trainableWeights.map((t) => t),
148
- ...this.wte.variables
149
- ];
150
- }
151
- saveWeights() {
152
- const t = /* @__PURE__ */ new Map();
153
- t.set("token_embedding", this.wte.getWeights()), this.wpe && t.set("positional_embedding", this.wpe.getWeights());
154
- for (let e = 0; e < this.blocks.length; e++)
155
- this.blocks[e].saveWeights(t);
156
- return t.set("final_rms_norm", this.lnF.getWeights()), t;
157
- }
158
- loadWeights(t) {
159
- this.wte.setWeights(t.get("token_embedding") || []), this.wpe && this.wpe.setWeights(t.get("positional_embedding") || []);
160
- for (let e = 0; e < this.blocks.length; e++)
161
- this.blocks[e].loadWeights(t);
162
- this.lnF.setWeights(t.get("final_rms_norm") || []);
136
+ this.config.layerConfig.checkpointing = t;
163
137
  }
164
138
  inputPhase(t, e, o = !1) {
165
- return w(() => {
166
- const i = this.wte.embed(t);
139
+ return b(() => {
140
+ const n = this.wte.embed(t);
167
141
  if (this.config.gpt.useRope === !1) {
168
- const [, s] = t.shape, l = this.config.gpt.blockSize, r = X(0, s, 1, "int32"), n = Z(V(r, C(e, "int32")), C(l, "int32")), h = this.wpe.apply(n), a = i.add(h);
169
- return this.drop.apply(a, { training: o });
142
+ const [, s] = t.shape, i = this.config.gpt.blockSize, r = J(0, s, 1, "int32"), a = Y(W(r, y(e, "int32")), y(i, "int32")), c = this.wpe.apply(a), l = n.add(c);
143
+ return this.drop.apply(l, { training: o });
170
144
  } else
171
- return this.drop.apply(i, { training: o });
145
+ return this.drop.apply(n, { training: o });
172
146
  });
173
147
  }
174
148
  setSkipMask(t) {
@@ -183,11 +157,6 @@ class wt extends B {
183
157
  for (let e = 0; e < this.blocks.length; e++)
184
158
  this.blocks[e].trainable = t[e];
185
159
  }
186
- set trainable(t) {
187
- for (const e of this.blocks)
188
- e.trainable = t;
189
- this.lnF.trainable = t;
190
- }
191
160
  validateInput(t) {
192
161
  if (t.shape.length !== 2)
193
162
  throw new Error(`Invalid input shape: expected [batch_size, sequence_length], got ${t.shape}`);
@@ -198,84 +167,105 @@ class wt extends B {
198
167
  }
199
168
  calculateLoss(t, e) {
200
169
  try {
201
- return A()(t, e).mean();
170
+ return R()(t, e).mean();
202
171
  } catch (o) {
203
172
  throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
204
173
  }
205
174
  }
206
175
  // Attention rollout per Abnar & Zuidema (2020)
207
176
  // Expects list of (B, T, T) attention matrices already averaged over heads.
208
- computeAttentionRollout(t) {
209
- return w(() => {
210
- if (t.length === 0)
211
- throw new Error("No attentions for rollout");
212
- const [e, o, i] = t[0].shape;
213
- for (const n of t) {
214
- const [h, a, p] = n.shape;
215
- if (h !== e || a !== o || p !== i)
216
- throw new Error(
217
- `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${h},${a},${p}]`
218
- );
219
- }
220
- const s = t.map((n) => n.slice([0, 0, 0], [e, o, o])), l = H(o, o).expandDims(0);
221
- let r = l.tile([e, 1, 1]);
222
- for (const n of s) {
223
- const h = n.add(l);
224
- r = h.div(h.sum(-1, !0)).matMul(r);
225
- }
226
- return r;
227
- });
228
- }
229
- forward(t, e, o = !1, i = !1, s) {
230
- return this.validateInput(t), w(() => {
177
+ /*private computeAttentionRollout(attentions: Tensor[]): Tensor {
178
+ return tidy(() => {
179
+ if (attentions.length === 0) {
180
+ throw new Error('No attentions for rollout');
181
+ }
182
+ const [B, Q, K] = attentions[0].shape as number[];
183
+
184
+ // Validate shapes are consistent
185
+ for (const a of attentions) {
186
+ const [b2, q2, k2] = a.shape as number[];
187
+ if (b2 !== B || q2 !== Q || k2 !== K) {
188
+ throw new Error(
189
+ `Inconsistent attention shapes in rollout: expected [${B},${Q},${K}] got [${b2},${q2},${k2}]`
190
+ );
191
+ }
192
+ }
193
+
194
+ // Always slice to [B, Q, Q] for rollout
195
+ const attentionsSliced = attentions.map((att) => att.slice([0, 0, 0], [B, Q, Q]));
196
+
197
+ const ey = eye(Q, Q).expandDims(0); // (1,Q,Q)
198
+ let rollout = ey.tile([B, 1, 1]); // (B,Q,Q)
199
+ for (const att of attentionsSliced) {
200
+ const a = att.add(ey);
201
+ const aNorm = a.div(a.sum(-1, true)); // (B,Q,Q)
202
+ rollout = aNorm.matMul(rollout); // (B,Q,Q)
203
+ }
204
+ return rollout;
205
+ });
206
+ }*/
207
+ forward(t, e, o) {
208
+ return this.validateInput(e), b(() => {
231
209
  this.startMemory();
232
- const l = s?.[0]?.length ?? 0;
233
- let r = this.inputPhase(t, l, o);
234
- const n = [];
235
- if (s && s.length !== this.blocks.length)
236
- throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
237
- for (let g = 0; g < this.blocks.length; g++) {
238
- const u = r, m = this.blocks[g], {
239
- output: b,
240
- attention: k,
241
- cache: f
242
- } = m.call(r, o, i, s ? s[g] : void 0);
243
- r = b, u.dispose(), i && k && n.push(k), s && f ? (s[g]?.k.dispose(), s[g]?.v.dispose(), s[g] = f) : f && (f.k.dispose(), f.v.dispose());
210
+ const n = t.cache?.[0]?.length ?? 0;
211
+ let s = this.inputPhase(e, n, t.training);
212
+ if (t.cache && t.cache.length !== this.blocks.length)
213
+ throw console.error("Cache", t.cache), new Error(
214
+ `Cache length ${t.cache.length} does not match number of blocks ${this.blocks.length}`
215
+ );
216
+ let i;
217
+ for (let c = 0; c < this.blocks.length; c++) {
218
+ const l = this.blocks[c], f = Math.random() * 1e9, p = {
219
+ training: t.training,
220
+ seed: f,
221
+ attentionScores: t.attentionScores,
222
+ pastKV: t.cache ? t.cache[c] : void 0
223
+ }, u = this.config.layerConfig.checkpointing && t.training ? l.callCheckpoint(p, s) : l.call(p, s);
224
+ s.dispose(), s = u, p.attentionScores?.attentionOut && (i = p.attentionScores.attentionOut);
244
225
  }
245
- let h;
246
- i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
247
- const a = this.wte.project(r);
248
- let p;
249
- return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
226
+ s = this.lnF.call(t, s);
227
+ const r = this.wte.project(s);
228
+ s.dispose();
229
+ let a;
230
+ return o && (a = this.calculateLoss(r, o)), this.endMemory("Forward"), t.attentionScores && (t.attentionScores.attentionOut = i ? H(i) : void 0), a ? [r, a] : [r];
250
231
  });
251
232
  }
252
233
  generate(t, e, o) {
253
- const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
254
- return w(() => {
255
- const n = t, h = n.shape[1], a = h <= this.config.gpt.blockSize ? n : n.slice(
256
- [0, h - this.config.gpt.blockSize],
257
- [n.shape[0], this.config.gpt.blockSize]
258
- ), p = l ? this.config.gpt.blockSize - a.shape[1] : 0, g = p > 0 ? D(a, [
234
+ const n = o?.temperature ?? 1, s = o?.topK, i = o?.usePadding ?? !1;
235
+ return b(() => {
236
+ const r = t, a = r.shape[1], c = a <= this.config.gpt.blockSize ? r : r.slice(
237
+ [0, a - this.config.gpt.blockSize],
238
+ [r.shape[0], this.config.gpt.blockSize]
239
+ ), l = i ? this.config.gpt.blockSize - c.shape[1] : 0, f = l > 0 ? _(c, [
259
240
  [0, 0],
260
- [0, p]
261
- ]) : a, { logits: u, attention: m } = this.forward(g, void 0, !1, r, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, y = k.div(i);
262
- let d;
241
+ [0, l]
242
+ ]) : c, p = {
243
+ training: !1,
244
+ attentionScores: o?.attentionScores,
245
+ cache: e
246
+ }, [u] = this.forward(p, f), E = u.shape[1] - 1 - l, C = u.slice([0, E, 0], [u.shape[0], 1, u.shape[2]]), I = p.attentionScores?.attentionOut ? p.attentionScores.attentionOut.slice(
247
+ [0, E, 0],
248
+ [p.attentionScores.attentionOut.shape[0], 1, p.attentionScores.attentionOut.shape[2]]
249
+ ) : void 0;
250
+ u.dispose();
251
+ const d = C.div(n);
252
+ let g;
263
253
  if (s) {
264
- const { values: S, indices: I } = ot(y, s), L = M(S.squeeze([1]), 1);
265
- d = J(I.squeeze([1]), L, 1);
254
+ const { values: v, indices: M } = et(d, s), x = z(v.squeeze([1]), 1);
255
+ g = Q(M.squeeze([1]), x, 1);
266
256
  } else
267
- d = M(y.squeeze([1]), 1);
268
- let z;
269
- return o?.includeProbabilities && (z = U(y.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
257
+ g = z(d.squeeze([1]), 1);
258
+ let S;
259
+ return o?.includeProbabilities && (S = U(d.squeeze([1]))), g = g.reshape([1, 1]), { output: g, attention: I?.squeeze([1]), probabilities: S };
270
260
  });
271
261
  }
272
262
  getNumParams() {
273
- return R(this.config.gpt);
263
+ return N(this.config.gpt);
274
264
  }
275
265
  dispose() {
276
266
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
277
267
  }
278
268
  }
279
269
  export {
280
- wt as default
270
+ bt as default
281
271
  };
@@ -1,5 +1,5 @@
1
- import { ac as f, ad as g, n as p, ae as C, j as x } from "./index--6vO-cOz.js";
2
- import { u as I } from "./gpgpu_math-CUzjlO9A.js";
1
+ import { ad as $, ae as g, p, af as C, k as x } from "./index-iNhkcAEQ.js";
2
+ import { u as I } from "./gpgpu_math-C0zyxKFi.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -17,7 +17,7 @@ import { u as I } from "./gpgpu_math-CUzjlO9A.js";
17
17
  * =============================================================================
18
18
  */
19
19
  function R(t, e, o = "index") {
20
- const s = f(e);
20
+ const s = $(e);
21
21
  return s.map((n, r) => {
22
22
  const i = `int ${t[r]} = ${o} / ${n}`, u = r === s.length - 1 ? `int ${t[r + 1]} = ${o} - ${t[r]} * ${n}` : `index -= ${t[r]} * ${n}`;
23
23
  return `${i}; ${u};`;
@@ -38,7 +38,7 @@ function S(t, e, o = "index") {
38
38
  }).join("");
39
39
  }
40
40
  function F(t) {
41
- const e = f(t).map((o) => o.toString());
41
+ const e = $(t).map((o) => o.toString());
42
42
  return `
43
43
  int getFlatIndex(ivec3 coords) {
44
44
  return coords.x * ${e[0]} + coords.y * ${e[1]} + coords.z;
@@ -82,7 +82,7 @@ function m(t) {
82
82
  function d(t) {
83
83
  return t % 2 === 0;
84
84
  }
85
- function $(t, e) {
85
+ function f(t, e) {
86
86
  if (t = t.slice(-2), e = e.slice(-2), g(t, e) || !t.length || !e.length || t[0] === 0 || t[1] === 0 || e[0] === 0 || e[1] === 0)
87
87
  return !0;
88
88
  if (t.length !== e.length) {
@@ -201,12 +201,12 @@ function D(t, e, o) {
201
201
  * limitations under the License.
202
202
  * =============================================================================
203
203
  */
204
- function k(t) {
204
+ function U(t) {
205
205
  const { inputs: e, backend: o, attrs: s } = t, { x: n } = e, { shape: r } = s, i = o, u = p(n.shape), a = C(r, u), c = p(a);
206
206
  x(u === c, () => `The new shape (${a}) has ${c} elements and the old shape (${n.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
207
207
  const l = i.texData.get(n.dataId);
208
- return l.isPacked && !$(n.shape, a) && !(l.texture !== null && $(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
208
+ return l.isPacked && !f(n.shape, a) && !(l.texture !== null && f(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
209
209
  }
210
210
  export {
211
- k as r
211
+ U as r
212
212
  };
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index--6vO-cOz.js";
14
+ import "./index-iNhkcAEQ.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";