@genai-fi/nanogpt 0.4.5 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/dist/BaseLayer-BhrMN8JO.js +135 -0
  2. package/dist/Generator.js +52 -49
  3. package/dist/NanoGPTModel.d.ts +13 -17
  4. package/dist/NanoGPTModel.js +128 -136
  5. package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
  6. package/dist/TeachableLLM.js +1 -1
  7. package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
  8. package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
  9. package/dist/broadcast_to-CMlkG8NS.js +44 -0
  10. package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
  11. package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
  12. package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
  13. package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
  14. package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
  15. package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
  16. package/dist/layers/BaseLayer.d.ts +28 -4
  17. package/dist/layers/BaseLayer.js +3 -16
  18. package/dist/layers/CausalSelfAttention.d.ts +21 -24
  19. package/dist/layers/CausalSelfAttention.js +73 -128
  20. package/dist/layers/MLP.d.ts +8 -15
  21. package/dist/layers/MLP.js +43 -81
  22. package/dist/layers/RMSNorm.d.ts +5 -10
  23. package/dist/layers/RMSNorm.js +13 -29
  24. package/dist/layers/RoPECache.js +14 -12
  25. package/dist/layers/TiedEmbedding.d.ts +6 -16
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.d.ts +12 -16
  28. package/dist/layers/TransformerBlock.js +20 -41
  29. package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
  30. package/dist/main.js +1 -1
  31. package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
  32. package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
  33. package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
  34. package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
  35. package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
  36. package/dist/ops/appendCache.js +4 -4
  37. package/dist/ops/attentionMask.d.ts +1 -1
  38. package/dist/ops/attentionMask.js +4 -4
  39. package/dist/ops/cpu/appendCache.js +2 -2
  40. package/dist/ops/cpu/attentionMask.js +14 -15
  41. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  42. package/dist/ops/cpu/gatherSub.js +5 -5
  43. package/dist/ops/cpu/gelu.js +1 -1
  44. package/dist/ops/cpu/matMulGelu.js +1 -1
  45. package/dist/ops/cpu/matMulMul.d.ts +1 -0
  46. package/dist/ops/cpu/matMulMul.js +17 -0
  47. package/dist/ops/cpu/mulDropout.js +1 -1
  48. package/dist/ops/cpu/normRMS.js +1 -1
  49. package/dist/ops/cpu/qkv.js +3 -3
  50. package/dist/ops/cpu/rope.js +5 -5
  51. package/dist/ops/cpu/scatterSub.js +8 -8
  52. package/dist/ops/fusedSoftmax.js +1 -1
  53. package/dist/ops/gatherSub.js +1 -1
  54. package/dist/ops/gelu.js +1 -1
  55. package/dist/ops/grads/attentionMask.js +13 -9
  56. package/dist/ops/grads/fusedSoftmax.js +12 -9
  57. package/dist/ops/grads/gelu.js +1 -1
  58. package/dist/ops/grads/matMulGelu.js +1 -1
  59. package/dist/ops/grads/normRMS.js +1 -1
  60. package/dist/ops/grads/qkv.js +19 -9
  61. package/dist/ops/grads/rope.js +1 -1
  62. package/dist/ops/matMulGelu.js +1 -1
  63. package/dist/ops/matMulMul.d.ts +2 -0
  64. package/dist/ops/matMulMul.js +9 -0
  65. package/dist/ops/mulDrop.js +1 -1
  66. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  67. package/dist/ops/normRMS.js +1 -1
  68. package/dist/ops/qkv.js +1 -1
  69. package/dist/ops/scatterSub.js +1 -1
  70. package/dist/ops/webgl/appendCache.js +1 -1
  71. package/dist/ops/webgl/attentionMask.js +13 -12
  72. package/dist/ops/webgl/fusedSoftmax.js +43 -40
  73. package/dist/ops/webgl/gatherSub.js +1 -1
  74. package/dist/ops/webgl/gelu.js +2 -2
  75. package/dist/ops/webgl/matMulGelu.js +17 -17
  76. package/dist/ops/webgl/matMulMul.d.ts +14 -0
  77. package/dist/ops/webgl/matMulMul.js +28 -0
  78. package/dist/ops/webgl/mulDropout.js +1 -1
  79. package/dist/ops/webgl/normRMS.js +29 -21
  80. package/dist/ops/webgl/qkv.js +1 -1
  81. package/dist/ops/webgl/rope.js +1 -1
  82. package/dist/ops/webgl/scatterSub.js +1 -1
  83. package/dist/ops-ObfXLHYQ.js +1269 -0
  84. package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
  85. package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
  86. package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
  87. package/dist/slice_util-D-kaD4ZV.js +49 -0
  88. package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
  89. package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
  90. package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
  91. package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
  92. package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
  93. package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
  94. package/dist/tfjs_backend-NucKez4s.js +1010 -0
  95. package/dist/training/AdamExt.js +1 -1
  96. package/dist/training/DatasetBuilder.js +44 -44
  97. package/dist/training/Evaluator.js +6 -6
  98. package/dist/training/FullTrainer.js +1 -1
  99. package/dist/training/Trainer.js +7 -7
  100. package/dist/training/sparseCrossEntropy.js +4 -4
  101. package/dist/utilities/dummy.js +10 -10
  102. package/dist/utilities/generate.js +3 -3
  103. package/dist/utilities/load.js +1 -1
  104. package/dist/utilities/profile.js +1 -1
  105. package/dist/utilities/save.js +13 -11
  106. package/dist/utilities/weights.js +2 -2
  107. package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
  108. package/package.json +1 -1
  109. package/dist/slice_util-BdhYwFY_.js +0 -90
  110. package/dist/tfjs_backend-DuKis_xG.js +0 -2271
  111. package/dist/variable-BJTZ3jOy.js +0 -23
@@ -1,17 +1,16 @@
1
- import { defaultConfig as x } from "./config.js";
2
- import W from "./layers/TransformerBlock.js";
3
- import { E as F, D as P, T as q, r as T, p as D } from "./TiedEmbedding-DznFwzcB.js";
4
- import K from "./layers/RoPECache.js";
5
- import N from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as R } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
8
- import B from "./layers/BaseLayer.js";
9
- import { o as $, h as E, p as G, E as v, a9 as O, aa as j, ab as Q, t as w, a8 as V, f as C } from "./index--6vO-cOz.js";
10
- import { r as _ } from "./reshape-z51Eu-re.js";
11
- import { r as X } from "./range-C_vpUjBu.js";
12
- import { e as H } from "./tfjs_backend-DuKis_xG.js";
13
- import { g as J } from "./gather-C5D8PxwA.js";
14
- import { s as U } from "./softmax-Dsxflvdl.js";
1
+ import { defaultConfig as L } from "./config.js";
2
+ import v from "./layers/TransformerBlock.js";
3
+ import { E as T, D as q, T as K, r as P, p as _ } from "./TiedEmbedding-DsDRvLB0.js";
4
+ import F from "./layers/RoPECache.js";
5
+ import D from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as O } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
8
+ import { B as R } from "./BaseLayer-BhrMN8JO.js";
9
+ import { o as E, i as d, q as B, E as y, aa as G, ab as V, ac as j, t as w, a9 as A, f as z, F as W } from "./index-iNhkcAEQ.js";
10
+ import { r as C } from "./reshape-DxTPgnwL.js";
11
+ import { r as H } from "./range-BsFU-SNG.js";
12
+ import { g as J } from "./gather-Bxe1Qip8.js";
13
+ import { s as Q } from "./softmax-BjsptB07.js";
15
14
  /**
16
15
  * @license
17
16
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -28,13 +27,13 @@ import { s as U } from "./softmax-Dsxflvdl.js";
28
27
  * limitations under the License.
29
28
  * =============================================================================
30
29
  */
31
- function Y(c, t) {
32
- let e = E(c, "a", "mod"), o = E(t, "b", "mod");
33
- [e, o] = G(e, o);
34
- const i = { a: e, b: o };
35
- return v.runKernel(O, i);
30
+ function U(h, t) {
31
+ let e = d(h, "a", "mod"), o = d(t, "b", "mod");
32
+ [e, o] = B(e, o);
33
+ const n = { a: e, b: o };
34
+ return y.runKernel(G, n);
36
35
  }
37
- const Z = /* @__PURE__ */ $({ mod_: Y });
36
+ const X = /* @__PURE__ */ E({ mod_: U });
38
37
  /**
39
38
  * @license
40
39
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -51,17 +50,17 @@ const Z = /* @__PURE__ */ $({ mod_: Y });
51
50
  * limitations under the License.
52
51
  * =============================================================================
53
52
  */
54
- function tt(c, t, e, o = !1) {
55
- const i = E(c, "logits", "multinomial"), s = i.size, l = i.rank;
53
+ function Y(h, t, e, o = !1) {
54
+ const n = d(h, "logits", "multinomial"), s = n.size, i = n.rank;
56
55
  if (s < 2)
57
56
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
58
- if (l > 2)
59
- throw new Error(`Rank of probabilities must be 1 or 2, but is ${l}`);
57
+ if (i > 2)
58
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${i}`);
60
59
  e = e || Math.random();
61
- const n = { logits: l === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = v.runKernel(j, n, h);
62
- return l === 1 ? _(a, [a.size]) : a;
60
+ const c = { logits: i === 1 ? C(n, [1, -1]) : n }, l = { numSamples: t, seed: e, normalized: o }, a = y.runKernel(V, c, l);
61
+ return i === 1 ? C(a, [a.size]) : a;
63
62
  }
64
- const M = /* @__PURE__ */ $({ multinomial_: tt });
63
+ const I = /* @__PURE__ */ E({ multinomial_: Y });
65
64
  /**
66
65
  * @license
67
66
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -78,19 +77,19 @@ const M = /* @__PURE__ */ $({ multinomial_: tt });
78
77
  * limitations under the License.
79
78
  * =============================================================================
80
79
  */
81
- function et(c, t = 1, e = !0) {
82
- const o = E(c, "x", "topk");
80
+ function Z(h, t = 1, e = !0) {
81
+ const o = d(h, "x", "topk");
83
82
  if (o.rank === 0)
84
83
  throw new Error("topk() expects the input to be of rank 1 or higher");
85
- const i = o.shape[o.shape.length - 1];
84
+ const n = o.shape[o.shape.length - 1];
86
85
  if (t < 0)
87
86
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
88
- if (t > i)
89
- throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
90
- const s = { x: o }, l = { k: t, sorted: e }, [r, n] = v.runKernel(Q, s, l);
91
- return { values: r, indices: n };
87
+ if (t > n)
88
+ throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
89
+ const s = { x: o }, i = { k: t, sorted: e }, [r, c] = y.runKernel(j, s, i);
90
+ return { values: r, indices: c };
92
91
  }
93
- const ot = /* @__PURE__ */ $({ topk_: et });
92
+ const tt = /* @__PURE__ */ E({ topk_: Z });
94
93
  /**
95
94
  * @license
96
95
  * Copyright 2018 Google LLC
@@ -100,13 +99,13 @@ const ot = /* @__PURE__ */ $({ topk_: et });
100
99
  * https://opensource.org/licenses/MIT.
101
100
  * =============================================================================
102
101
  */
103
- function st(c) {
104
- return new P(c);
102
+ function et(h) {
103
+ return new q(h);
105
104
  }
106
- function it(c) {
107
- return new F(c);
105
+ function ot(h) {
106
+ return new T(h);
108
107
  }
109
- class wt extends B {
108
+ class dt extends R {
110
109
  wte;
111
110
  // Token embeddings
112
111
  wpe;
@@ -120,55 +119,30 @@ class wt extends B {
120
119
  log = [];
121
120
  // Training log
122
121
  constructor(t = {}) {
123
- super({ gpt: { ...x, ...t }, layerConfig: {} }), this.wte = new q({
124
- vocabSize: this.config.gpt.vocabSize,
125
- embedDim: this.config.gpt.nEmbed,
126
- name: "token_embedding"
127
- }), this.config.gpt.useRope === !1 ? this.wpe = it({
122
+ super({ gpt: { ...L, ...t }, layerConfig: {} }), this.wte = new K(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = ot({
128
123
  inputDim: this.config.gpt.blockSize,
129
124
  outputDim: this.config.gpt.nEmbed,
130
125
  name: "positional_embedding",
131
- embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
132
- }) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
126
+ embeddingsInitializer: P({ mean: 0, stddev: 0.02 })
127
+ }) : (this.ropeCache = new F(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = et({ rate: this.config.gpt.dropout }), this.blocks = [];
133
128
  for (let e = 0; e < this.config.gpt.nLayer; e++)
134
- this.blocks.push(new W(e, this.config));
135
- this.lnF = new N(this.config, "final_rms_norm");
129
+ this.blocks.push(new v(e, this.config, this));
130
+ this.lnF = new D(this.config, "final_rms_norm", this);
136
131
  }
137
132
  get checkpointing() {
138
- return this.config.layerConfig.checkpointAttention === !0 || this.config.layerConfig.checkpointMLP === !0;
133
+ return this.config.layerConfig.checkpointing === !0;
139
134
  }
140
135
  set checkpointing(t) {
141
- this.config.layerConfig.checkpointAttention = t, this.config.layerConfig.checkpointMLP = t;
142
- }
143
- get variables() {
144
- return [
145
- //...this.wpe.trainableWeights.map((v) => v.read() as TF.Variable),
146
- ...this.blocks.flatMap((t) => t.variables),
147
- ...this.lnF.trainableWeights.map((t) => t),
148
- ...this.wte.variables
149
- ];
150
- }
151
- saveWeights() {
152
- const t = /* @__PURE__ */ new Map();
153
- t.set("token_embedding", this.wte.getWeights()), this.wpe && t.set("positional_embedding", this.wpe.getWeights());
154
- for (let e = 0; e < this.blocks.length; e++)
155
- this.blocks[e].saveWeights(t);
156
- return t.set("final_rms_norm", this.lnF.getWeights()), t;
157
- }
158
- loadWeights(t) {
159
- this.wte.setWeights(t.get("token_embedding") || []), this.wpe && this.wpe.setWeights(t.get("positional_embedding") || []);
160
- for (let e = 0; e < this.blocks.length; e++)
161
- this.blocks[e].loadWeights(t);
162
- this.lnF.setWeights(t.get("final_rms_norm") || []);
136
+ this.config.layerConfig.checkpointing = t;
163
137
  }
164
138
  inputPhase(t, e, o = !1) {
165
139
  return w(() => {
166
- const i = this.wte.embed(t);
140
+ const n = this.wte.embed(t);
167
141
  if (this.config.gpt.useRope === !1) {
168
- const [, s] = t.shape, l = this.config.gpt.blockSize, r = X(0, s, 1, "int32"), n = Z(V(r, C(e, "int32")), C(l, "int32")), h = this.wpe.apply(n), a = i.add(h);
142
+ const [, s] = t.shape, i = this.config.gpt.blockSize, r = H(0, s, 1, "int32"), c = X(A(r, z(e, "int32")), z(i, "int32")), l = this.wpe.apply(c), a = n.add(l);
169
143
  return this.drop.apply(a, { training: o });
170
144
  } else
171
- return this.drop.apply(i, { training: o });
145
+ return this.drop.apply(n, { training: o });
172
146
  });
173
147
  }
174
148
  setSkipMask(t) {
@@ -183,11 +157,6 @@ class wt extends B {
183
157
  for (let e = 0; e < this.blocks.length; e++)
184
158
  this.blocks[e].trainable = t[e];
185
159
  }
186
- set trainable(t) {
187
- for (const e of this.blocks)
188
- e.trainable = t;
189
- this.lnF.trainable = t;
190
- }
191
160
  validateInput(t) {
192
161
  if (t.shape.length !== 2)
193
162
  throw new Error(`Invalid input shape: expected [batch_size, sequence_length], got ${t.shape}`);
@@ -198,84 +167,107 @@ class wt extends B {
198
167
  }
199
168
  calculateLoss(t, e) {
200
169
  try {
201
- return A()(t, e).mean();
170
+ return N()(t, e).mean();
202
171
  } catch (o) {
203
172
  throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
204
173
  }
205
174
  }
206
175
  // Attention rollout per Abnar & Zuidema (2020)
207
176
  // Expects list of (B, T, T) attention matrices already averaged over heads.
208
- computeAttentionRollout(t) {
209
- return w(() => {
210
- if (t.length === 0)
211
- throw new Error("No attentions for rollout");
212
- const [e, o, i] = t[0].shape;
213
- for (const n of t) {
214
- const [h, a, p] = n.shape;
215
- if (h !== e || a !== o || p !== i)
216
- throw new Error(
217
- `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${h},${a},${p}]`
218
- );
219
- }
220
- const s = t.map((n) => n.slice([0, 0, 0], [e, o, o])), l = H(o, o).expandDims(0);
221
- let r = l.tile([e, 1, 1]);
222
- for (const n of s) {
223
- const h = n.add(l);
224
- r = h.div(h.sum(-1, !0)).matMul(r);
225
- }
226
- return r;
227
- });
228
- }
229
- forward(t, e, o = !1, i = !1, s) {
230
- return this.validateInput(t), w(() => {
177
+ /*private computeAttentionRollout(attentions: Tensor[]): Tensor {
178
+ return tidy(() => {
179
+ if (attentions.length === 0) {
180
+ throw new Error('No attentions for rollout');
181
+ }
182
+ const [B, Q, K] = attentions[0].shape as number[];
183
+
184
+ // Validate shapes are consistent
185
+ for (const a of attentions) {
186
+ const [b2, q2, k2] = a.shape as number[];
187
+ if (b2 !== B || q2 !== Q || k2 !== K) {
188
+ throw new Error(
189
+ `Inconsistent attention shapes in rollout: expected [${B},${Q},${K}] got [${b2},${q2},${k2}]`
190
+ );
191
+ }
192
+ }
193
+
194
+ // Always slice to [B, Q, Q] for rollout
195
+ const attentionsSliced = attentions.map((att) => att.slice([0, 0, 0], [B, Q, Q]));
196
+
197
+ const ey = eye(Q, Q).expandDims(0); // (1,Q,Q)
198
+ let rollout = ey.tile([B, 1, 1]); // (B,Q,Q)
199
+ for (const att of attentionsSliced) {
200
+ const a = att.add(ey);
201
+ const aNorm = a.div(a.sum(-1, true)); // (B,Q,Q)
202
+ rollout = aNorm.matMul(rollout); // (B,Q,Q)
203
+ }
204
+ return rollout;
205
+ });
206
+ }*/
207
+ forward(t, e, o) {
208
+ return this.validateInput(e), w(() => {
231
209
  this.startMemory();
232
- const l = s?.[0]?.length ?? 0;
233
- let r = this.inputPhase(t, l, o);
234
- const n = [];
235
- if (s && s.length !== this.blocks.length)
236
- throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
237
- for (let g = 0; g < this.blocks.length; g++) {
238
- const u = r, m = this.blocks[g], {
239
- output: b,
240
- attention: k,
241
- cache: f
242
- } = m.call(r, o, i, s ? s[g] : void 0);
243
- r = b, u.dispose(), i && k && n.push(k), s && f ? (s[g]?.k.dispose(), s[g]?.v.dispose(), s[g] = f) : f && (f.k.dispose(), f.v.dispose());
210
+ const n = t.cache?.[0]?.length ?? 0;
211
+ let s = this.inputPhase(e, n, t.training);
212
+ if (t.cache && t.cache.length !== this.blocks.length)
213
+ throw console.error("Cache", t.cache), new Error(
214
+ `Cache length ${t.cache.length} does not match number of blocks ${this.blocks.length}`
215
+ );
216
+ for (let c = 0; c < this.blocks.length; c++) {
217
+ const l = this.blocks[c], a = Math.random() * 1e9, u = {
218
+ training: t.training,
219
+ seed: a,
220
+ attentionScores: t.attentionScores,
221
+ pastKV: t.cache ? t.cache[c] : void 0
222
+ }, p = this.config.layerConfig.checkpointing && t.training ? l.callCheckpoint(u, s) : l.call(u, s);
223
+ s.dispose(), s = p;
244
224
  }
245
- let h;
246
- i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
247
- const a = this.wte.project(r);
248
- let p;
249
- return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
225
+ s = this.lnF.call(t, s);
226
+ const i = this.wte.project(s);
227
+ s.dispose();
228
+ let r;
229
+ return o && (r = this.calculateLoss(i, o)), this.endMemory("Forward"), r ? [i, r] : [i];
250
230
  });
251
231
  }
252
232
  generate(t, e, o) {
253
- const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
233
+ const n = o?.temperature ?? 1, s = o?.topK, i = o?.usePadding ?? !1;
254
234
  return w(() => {
255
- const n = t, h = n.shape[1], a = h <= this.config.gpt.blockSize ? n : n.slice(
256
- [0, h - this.config.gpt.blockSize],
257
- [n.shape[0], this.config.gpt.blockSize]
258
- ), p = l ? this.config.gpt.blockSize - a.shape[1] : 0, g = p > 0 ? D(a, [
235
+ const r = t, c = r.shape[1], l = c <= this.config.gpt.blockSize ? r : r.slice(
236
+ [0, c - this.config.gpt.blockSize],
237
+ [r.shape[0], this.config.gpt.blockSize]
238
+ ), a = i ? this.config.gpt.blockSize - l.shape[1] : 0, u = a > 0 ? _(l, [
259
239
  [0, 0],
260
- [0, p]
261
- ]) : a, { logits: u, attention: m } = this.forward(g, void 0, !1, r, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, y = k.div(i);
262
- let d;
240
+ [0, a]
241
+ ]) : l, p = {
242
+ training: !1,
243
+ attentionScores: o?.attentionScores ? {
244
+ attentionOut: []
245
+ } : void 0,
246
+ cache: e
247
+ }, [f] = this.forward(p, u), S = f.shape[1] - 1 - a, M = f.slice([0, S, 0], [f.shape[0], 1, f.shape[2]]);
248
+ p.attentionScores?.attentionOut && p.attentionScores.attentionOut.forEach((g, k) => {
249
+ g.shape[1] !== 1 && (p.attentionScores.attentionOut[k] = W(
250
+ g.slice([0, S, 0], [g.shape[0], 1, g.shape[2]])
251
+ ), g.dispose());
252
+ }), f.dispose();
253
+ const b = M.div(n);
254
+ let m;
263
255
  if (s) {
264
- const { values: S, indices: I } = ot(y, s), L = M(S.squeeze([1]), 1);
265
- d = J(I.squeeze([1]), L, 1);
256
+ const { values: g, indices: k } = tt(b, s), x = I(g.squeeze([1]), 1);
257
+ m = J(k.squeeze([1]), x, 1);
266
258
  } else
267
- d = M(y.squeeze([1]), 1);
268
- let z;
269
- return o?.includeProbabilities && (z = U(y.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
259
+ m = I(b.squeeze([1]), 1);
260
+ let $;
261
+ return o?.includeProbabilities && ($ = Q(b.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: $, attention: p.attentionScores?.attentionOut };
270
262
  });
271
263
  }
272
264
  getNumParams() {
273
- return R(this.config.gpt);
265
+ return O(this.config.gpt);
274
266
  }
275
267
  dispose() {
276
268
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
277
269
  }
278
270
  }
279
271
  export {
280
- wt as default
272
+ dt as default
281
273
  };
@@ -1,5 +1,5 @@
1
- import { ac as f, ad as g, n as p, ae as C, j as x } from "./index--6vO-cOz.js";
2
- import { u as I } from "./gpgpu_math-CUzjlO9A.js";
1
+ import { ad as $, ae as g, p, af as C, k as x } from "./index-iNhkcAEQ.js";
2
+ import { u as I } from "./gpgpu_math-C0zyxKFi.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -17,7 +17,7 @@ import { u as I } from "./gpgpu_math-CUzjlO9A.js";
17
17
  * =============================================================================
18
18
  */
19
19
  function R(t, e, o = "index") {
20
- const s = f(e);
20
+ const s = $(e);
21
21
  return s.map((n, r) => {
22
22
  const i = `int ${t[r]} = ${o} / ${n}`, u = r === s.length - 1 ? `int ${t[r + 1]} = ${o} - ${t[r]} * ${n}` : `index -= ${t[r]} * ${n}`;
23
23
  return `${i}; ${u};`;
@@ -38,7 +38,7 @@ function S(t, e, o = "index") {
38
38
  }).join("");
39
39
  }
40
40
  function F(t) {
41
- const e = f(t).map((o) => o.toString());
41
+ const e = $(t).map((o) => o.toString());
42
42
  return `
43
43
  int getFlatIndex(ivec3 coords) {
44
44
  return coords.x * ${e[0]} + coords.y * ${e[1]} + coords.z;
@@ -82,7 +82,7 @@ function m(t) {
82
82
  function d(t) {
83
83
  return t % 2 === 0;
84
84
  }
85
- function $(t, e) {
85
+ function f(t, e) {
86
86
  if (t = t.slice(-2), e = e.slice(-2), g(t, e) || !t.length || !e.length || t[0] === 0 || t[1] === 0 || e[0] === 0 || e[1] === 0)
87
87
  return !0;
88
88
  if (t.length !== e.length) {
@@ -201,12 +201,12 @@ function D(t, e, o) {
201
201
  * limitations under the License.
202
202
  * =============================================================================
203
203
  */
204
- function k(t) {
204
+ function U(t) {
205
205
  const { inputs: e, backend: o, attrs: s } = t, { x: n } = e, { shape: r } = s, i = o, u = p(n.shape), a = C(r, u), c = p(a);
206
206
  x(u === c, () => `The new shape (${a}) has ${c} elements and the old shape (${n.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
207
207
  const l = i.texData.get(n.dataId);
208
- return l.isPacked && !$(n.shape, a) && !(l.texture !== null && $(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
208
+ return l.isPacked && !f(n.shape, a) && !(l.texture !== null && f(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
209
209
  }
210
210
  export {
211
- k as r
211
+ U as r
212
212
  };
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index--6vO-cOz.js";
14
+ import "./index-iNhkcAEQ.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";