@genai-fi/nanogpt 0.4.1 → 0.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. package/dist/Generator.js +3 -3
  2. package/dist/NanoGPTModel.js +83 -70
  3. package/dist/TeachableLLM.js +1 -1
  4. package/dist/{random_width-CMHmdbSu.js → TiedEmbedding-CnJ1bx4q.js} +760 -719
  5. package/dist/{axis_util-DeydwOoC.js → axis_util-BgTGy5w8.js} +1 -1
  6. package/dist/{concat-DS_qH7MI.js → concat-CuRsVY-K.js} +1 -1
  7. package/dist/dropout-DfDdklfL.js +193 -0
  8. package/dist/{gather-BUmJIS8n.js → gather-ZYRWhmXR.js} +1 -1
  9. package/dist/gelu-CnCt17Lk.js +26 -0
  10. package/dist/{index-XjBAhiFO.js → index-C4JCoBvj.js} +61 -61
  11. package/dist/kernel_funcs_utils-CAd1h9X1.js +388 -0
  12. package/dist/layers/CausalSelfAttention.js +71 -70
  13. package/dist/layers/MLP.d.ts +3 -1
  14. package/dist/layers/MLP.js +93 -5
  15. package/dist/layers/RMSNorm.js +3 -3
  16. package/dist/layers/RoPECache.js +3 -3
  17. package/dist/layers/TiedEmbedding.js +6 -46
  18. package/dist/layers/TransformerBlock.js +2 -2
  19. package/dist/{log_sum_exp-DJPkVZZn.js → log_sum_exp-BswFnwOb.js} +5 -5
  20. package/dist/main.js +1 -1
  21. package/dist/{mat_mul-CKwFEV1Q.js → mat_mul-415y5Qn2.js} +1 -1
  22. package/dist/{max-DJvEiCAJ.js → max-CP_9O2Yd.js} +1 -1
  23. package/dist/{moments-CrWRPcR3.js → moments-CjeIaVdp.js} +3 -3
  24. package/dist/{norm-BzY929B_.js → norm-CZM380I3.js} +5 -5
  25. package/dist/{ones-BO01zpJG.js → ones-Bf3YR48P.js} +2 -2
  26. package/dist/ops/appendCache.js +1 -1
  27. package/dist/ops/attentionMask.d.ts +1 -1
  28. package/dist/ops/attentionMask.js +4 -4
  29. package/dist/ops/cpu/appendCache.js +2 -2
  30. package/dist/ops/cpu/attentionMask.js +13 -9
  31. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  32. package/dist/ops/cpu/gatherSub.js +3 -3
  33. package/dist/ops/cpu/gelu.d.ts +1 -0
  34. package/dist/ops/cpu/gelu.js +40 -0
  35. package/dist/ops/cpu/mulDropout.js +1 -1
  36. package/dist/ops/cpu/qkv.js +3 -3
  37. package/dist/ops/cpu/rope.js +5 -5
  38. package/dist/ops/cpu/scatterSub.js +4 -4
  39. package/dist/ops/fusedSoftmax.js +1 -1
  40. package/dist/ops/gatherSub.js +1 -1
  41. package/dist/ops/gelu.d.ts +3 -0
  42. package/dist/ops/gelu.js +8 -0
  43. package/dist/ops/grads/attentionMask.js +1 -1
  44. package/dist/ops/grads/fusedSoftmax.js +2 -2
  45. package/dist/ops/grads/gelu.d.ts +2 -0
  46. package/dist/ops/grads/gelu.js +5 -0
  47. package/dist/ops/grads/qkv.js +1 -1
  48. package/dist/ops/grads/rope.js +1 -1
  49. package/dist/ops/mulDrop.js +1 -1
  50. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  51. package/dist/ops/qkv.js +1 -1
  52. package/dist/ops/scatterSub.js +1 -1
  53. package/dist/ops/webgl/appendCache.js +1 -1
  54. package/dist/ops/webgl/attentionMask.js +19 -18
  55. package/dist/ops/webgl/fusedSoftmax.js +483 -782
  56. package/dist/ops/webgl/gatherSub.js +1 -1
  57. package/dist/ops/webgl/gelu.d.ts +2 -0
  58. package/dist/ops/webgl/gelu.js +50 -0
  59. package/dist/ops/webgl/mulDropout.js +1 -1
  60. package/dist/ops/webgl/qkv.js +1 -1
  61. package/dist/ops/webgl/rope.js +1 -1
  62. package/dist/ops/webgl/scatterSub.js +1 -1
  63. package/dist/{range-DQMNzBWs.js → range-9AzeApCc.js} +1 -1
  64. package/dist/{reshape-DFzh97Sc.js → reshape-Boe4DuIO.js} +1 -1
  65. package/dist/{sin-BYM-U4Ut.js → sin-KmhiDuMa.js} +1 -1
  66. package/dist/{slice_util-CnVNPQI-.js → slice_util-19zDNNSn.js} +2 -2
  67. package/dist/{softmax-4DOn6cPq.js → softmax-Cujsg4ay.js} +1 -1
  68. package/dist/{split-CkbeVdF8.js → split-DbcNm1-i.js} +1 -1
  69. package/dist/{stack-DaIMO5iX.js → stack-D1YjmgKN.js} +1 -1
  70. package/dist/{sum-C6u3xMi3.js → sum-R28pucR5.js} +1 -1
  71. package/dist/{tensor-Cu1fU7H7.js → tensor-BVeHdl7V.js} +1 -1
  72. package/dist/{tensor2d-D0CKdG6B.js → tensor2d-DqFGNs_K.js} +1 -1
  73. package/dist/{tfjs_backend-Bzl2SrRo.js → tfjs_backend-Cug-PH75.js} +826 -1015
  74. package/dist/training/AdamExt.js +1 -1
  75. package/dist/training/DatasetBuilder.js +3 -3
  76. package/dist/training/FullTrainer.js +1 -1
  77. package/dist/training/Trainer.js +5 -5
  78. package/dist/training/sparseCrossEntropy.js +4 -4
  79. package/dist/utilities/dummy.js +2 -2
  80. package/dist/utilities/generate.js +3 -3
  81. package/dist/utilities/load.js +1 -1
  82. package/dist/utilities/profile.js +1 -1
  83. package/dist/utilities/weights.js +2 -2
  84. package/dist/{variable-BS4AKqNU.js → variable-LJT9Ld63.js} +1 -1
  85. package/dist/{zeros-CmJFiC84.js → zeros-dnQxFgAD.js} +1 -1
  86. package/package.json +1 -1
  87. package/dist/MLP-KHhikThU.js +0 -83
package/dist/Generator.js CHANGED
@@ -1,7 +1,7 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-XjBAhiFO.js";
3
- import { t as d } from "./tensor2d-D0CKdG6B.js";
4
- import { c as p } from "./concat-DS_qH7MI.js";
2
+ import "./index-C4JCoBvj.js";
3
+ import { t as d } from "./tensor2d-DqFGNs_K.js";
4
+ import { c as p } from "./concat-CuRsVY-K.js";
5
5
  class w extends u {
6
6
  constructor(s, e) {
7
7
  super(), this.model = s, this.tokeniser = e;
@@ -1,19 +1,17 @@
1
1
  import { defaultConfig as x } from "./config.js";
2
2
  import W from "./layers/TransformerBlock.js";
3
- import F from "./layers/TiedEmbedding.js";
4
- import P from "./layers/RoPECache.js";
5
- import q from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as K } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
8
- import T from "./layers/BaseLayer.js";
9
- import { r as R, p as D } from "./random_width-CMHmdbSu.js";
10
- import { o as y, h as $, p as A, E as v, a6 as B, a7 as G, a8 as O, t as w, a5 as Q, f as C } from "./index-XjBAhiFO.js";
11
- import { e as j, d as U } from "./MLP-KHhikThU.js";
12
- import { r as _ } from "./reshape-DFzh97Sc.js";
13
- import { r as V } from "./range-DQMNzBWs.js";
14
- import { e as X } from "./tfjs_backend-Bzl2SrRo.js";
15
- import { g as H } from "./gather-BUmJIS8n.js";
16
- import { s as J } from "./softmax-4DOn6cPq.js";
3
+ import { E as F, D as P, T as q, r as K, p as T } from "./TiedEmbedding-CnJ1bx4q.js";
4
+ import D from "./layers/RoPECache.js";
5
+ import N from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as R } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
8
+ import B from "./layers/BaseLayer.js";
9
+ import { o as y, h as E, p as G, E as v, a6 as O, a7 as Q, a8 as j, t as w, a5 as U, f as C } from "./index-C4JCoBvj.js";
10
+ import { r as _ } from "./reshape-Boe4DuIO.js";
11
+ import { r as V } from "./range-9AzeApCc.js";
12
+ import { e as X } from "./tfjs_backend-Cug-PH75.js";
13
+ import { g as H } from "./gather-ZYRWhmXR.js";
14
+ import { s as J } from "./softmax-Cujsg4ay.js";
17
15
  /**
18
16
  * @license
19
17
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -30,11 +28,11 @@ import { s as J } from "./softmax-4DOn6cPq.js";
30
28
  * limitations under the License.
31
29
  * =============================================================================
32
30
  */
33
- function Y(f, t) {
34
- let e = $(f, "a", "mod"), o = $(t, "b", "mod");
35
- [e, o] = A(e, o);
36
- const i = { a: e, b: o };
37
- return v.runKernel(B, i);
31
+ function Y(h, t) {
32
+ let e = E(h, "a", "mod"), o = E(t, "b", "mod");
33
+ [e, o] = G(e, o);
34
+ const n = { a: e, b: o };
35
+ return v.runKernel(O, n);
38
36
  }
39
37
  const Z = /* @__PURE__ */ y({ mod_: Y });
40
38
  /**
@@ -53,14 +51,14 @@ const Z = /* @__PURE__ */ y({ mod_: Y });
53
51
  * limitations under the License.
54
52
  * =============================================================================
55
53
  */
56
- function tt(f, t, e, o = !1) {
57
- const i = $(f, "logits", "multinomial"), s = i.size, r = i.rank;
54
+ function tt(h, t, e, o = !1) {
55
+ const n = E(h, "logits", "multinomial"), s = n.size, r = n.rank;
58
56
  if (s < 2)
59
57
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
60
58
  if (r > 2)
61
59
  throw new Error(`Rank of probabilities must be 1 or 2, but is ${r}`);
62
60
  e = e || Math.random();
63
- const n = { logits: r === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, l = v.runKernel(G, n, h);
61
+ const i = { logits: r === 1 ? _(n, [1, -1]) : n }, p = { numSamples: t, seed: e, normalized: o }, l = v.runKernel(Q, i, p);
64
62
  return r === 1 ? _(l, [l.size]) : l;
65
63
  }
66
64
  const M = /* @__PURE__ */ y({ multinomial_: tt });
@@ -80,20 +78,35 @@ const M = /* @__PURE__ */ y({ multinomial_: tt });
80
78
  * limitations under the License.
81
79
  * =============================================================================
82
80
  */
83
- function et(f, t = 1, e = !0) {
84
- const o = $(f, "x", "topk");
81
+ function et(h, t = 1, e = !0) {
82
+ const o = E(h, "x", "topk");
85
83
  if (o.rank === 0)
86
84
  throw new Error("topk() expects the input to be of rank 1 or higher");
87
- const i = o.shape[o.shape.length - 1];
85
+ const n = o.shape[o.shape.length - 1];
88
86
  if (t < 0)
89
87
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
90
- if (t > i)
91
- throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
92
- const s = { x: o }, r = { k: t, sorted: e }, [a, n] = v.runKernel(O, s, r);
93
- return { values: a, indices: n };
88
+ if (t > n)
89
+ throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
90
+ const s = { x: o }, r = { k: t, sorted: e }, [a, i] = v.runKernel(j, s, r);
91
+ return { values: a, indices: i };
94
92
  }
95
93
  const ot = /* @__PURE__ */ y({ topk_: et });
96
- class wt extends T {
94
+ /**
95
+ * @license
96
+ * Copyright 2018 Google LLC
97
+ *
98
+ * Use of this source code is governed by an MIT-style
99
+ * license that can be found in the LICENSE file or at
100
+ * https://opensource.org/licenses/MIT.
101
+ * =============================================================================
102
+ */
103
+ function st(h) {
104
+ return new P(h);
105
+ }
106
+ function nt(h) {
107
+ return new F(h);
108
+ }
109
+ class wt extends B {
97
110
  wte;
98
111
  // Token embeddings
99
112
  wpe;
@@ -107,19 +120,19 @@ class wt extends T {
107
120
  log = [];
108
121
  // Training log
109
122
  constructor(t = {}) {
110
- super({ gpt: { ...x, ...t }, layerConfig: {} }), this.wte = new F({
123
+ super({ gpt: { ...x, ...t }, layerConfig: {} }), this.wte = new q({
111
124
  vocabSize: this.config.gpt.vocabSize,
112
125
  embedDim: this.config.gpt.nEmbed,
113
126
  name: "token_embedding"
114
- }), this.config.gpt.useRope === !1 ? this.wpe = j({
127
+ }), this.config.gpt.useRope === !1 ? this.wpe = nt({
115
128
  inputDim: this.config.gpt.blockSize,
116
129
  outputDim: this.config.gpt.nEmbed,
117
130
  name: "positional_embedding",
118
- embeddingsInitializer: R({ mean: 0, stddev: 0.02 })
119
- }) : (this.ropeCache = new P(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = U({ rate: this.config.gpt.dropout }), this.blocks = [];
131
+ embeddingsInitializer: K({ mean: 0, stddev: 0.02 })
132
+ }) : (this.ropeCache = new D(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
120
133
  for (let e = 0; e < this.config.gpt.nLayer; e++)
121
134
  this.blocks.push(new W(e, this.config));
122
- this.lnF = new q(this.config, 1e-8, "final_rms_norm");
135
+ this.lnF = new N(this.config, 1e-8, "final_rms_norm");
123
136
  }
124
137
  get checkpointing() {
125
138
  return this.config.layerConfig.checkpointAttention === !0 || this.config.layerConfig.checkpointMLP === !0;
@@ -150,12 +163,12 @@ class wt extends T {
150
163
  }
151
164
  inputPhase(t, e, o = !1) {
152
165
  return w(() => {
153
- const i = this.wte.embed(t);
166
+ const n = this.wte.embed(t);
154
167
  if (this.config.gpt.useRope === !1) {
155
- const [, s] = t.shape, r = this.config.gpt.blockSize, a = V(0, s, 1, "int32"), n = Z(Q(a, C(e, "int32")), C(r, "int32")), h = this.wpe.apply(n), l = i.add(h);
168
+ const [, s] = t.shape, r = this.config.gpt.blockSize, a = V(0, s, 1, "int32"), i = Z(U(a, C(e, "int32")), C(r, "int32")), p = this.wpe.apply(i), l = n.add(p);
156
169
  return this.drop.apply(l, { training: o });
157
170
  } else
158
- return this.drop.apply(i, { training: o });
171
+ return this.drop.apply(n, { training: o });
159
172
  });
160
173
  }
161
174
  setSkipMask(t) {
@@ -185,7 +198,7 @@ class wt extends T {
185
198
  }
186
199
  calculateLoss(t, e) {
187
200
  try {
188
- return N()(t, e).mean();
201
+ return A()(t, e).mean();
189
202
  } catch (o) {
190
203
  throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
191
204
  }
@@ -196,71 +209,71 @@ class wt extends T {
196
209
  return w(() => {
197
210
  if (t.length === 0)
198
211
  throw new Error("No attentions for rollout");
199
- const [e, o, i] = t[0].shape;
212
+ const [e, o, n] = t[0].shape;
200
213
  for (const s of t) {
201
- const [r, a, n] = s.shape;
202
- if (r !== e || a !== o || n !== i)
214
+ const [r, a, i] = s.shape;
215
+ if (r !== e || a !== o || i !== n)
203
216
  throw new Error(
204
- `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${r},${a},${n}]`
217
+ `Inconsistent attention shapes in rollout: expected [${e},${o},${n}] got [${r},${a},${i}]`
205
218
  );
206
219
  }
207
- if (o === i) {
208
- const s = X(i, i).expandDims(0);
220
+ if (o === n) {
221
+ const s = X(n, n).expandDims(0);
209
222
  let r = s.tile([e, 1, 1]);
210
223
  for (const a of t) {
211
- const n = a.add(s);
212
- r = n.div(n.sum(-1, !0)).matMul(r);
224
+ const i = a.add(s);
225
+ r = i.div(i.sum(-1, !0)).matMul(r);
213
226
  }
214
227
  return r;
215
228
  }
216
- throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${o}, K=${i}]`);
229
+ throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${o}, K=${n}]`);
217
230
  });
218
231
  }
219
- forward(t, e, o = !1, i = !1, s) {
232
+ forward(t, e, o = !1, n = !1, s) {
220
233
  return this.validateInput(t), w(() => {
221
234
  this.startMemory();
222
235
  const r = s?.[0]?.length ?? 0;
223
236
  let a = this.inputPhase(t, r, o);
224
- const n = [];
237
+ const i = [];
225
238
  if (s && s.length !== this.blocks.length)
226
239
  throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
227
- for (let p = 0; p < this.blocks.length; p++) {
228
- const m = a, u = this.blocks[p], {
240
+ for (let c = 0; c < this.blocks.length; c++) {
241
+ const u = a, m = this.blocks[c], {
229
242
  output: b,
230
243
  attention: k,
231
- cache: g
232
- } = u.call(a, o, i, s ? s[p] : void 0);
233
- a = b, m.dispose(), i && k && n.push(k), s && g ? (s[p]?.k.dispose(), s[p]?.v.dispose(), s[p] = g) : g && (g.k.dispose(), g.v.dispose());
244
+ cache: f
245
+ } = m.call(a, o, n, s ? s[c] : void 0);
246
+ a = b, u.dispose(), n && k && i.push(k), s && f ? (s[c]?.k.dispose(), s[c]?.v.dispose(), s[c] = f) : f && (f.k.dispose(), f.v.dispose());
234
247
  }
235
- let h;
236
- i && n.length > 0 && (h = this.computeAttentionRollout(n)), a = this.lnF.apply(a);
248
+ let p;
249
+ n && i.length > 0 && (p = this.computeAttentionRollout(i)), a = this.lnF.apply(a);
237
250
  const l = this.wte.project(a);
238
- let c;
239
- return e && (c = this.calculateLoss(l, e)), this.endMemory("Forward"), { logits: l, loss: c, attention: i ? h : void 0 };
251
+ let g;
252
+ return e && (g = this.calculateLoss(l, e)), this.endMemory("Forward"), { logits: l, loss: g, attention: n ? p : void 0 };
240
253
  });
241
254
  }
242
255
  generate(t, e, o) {
243
- const i = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, a = o?.includeAttention ?? !1;
256
+ const n = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, a = o?.includeAttention ?? !1;
244
257
  return w(() => {
245
- const n = t, h = n.shape[1], l = h <= this.config.gpt.blockSize ? n : n.slice(
246
- [0, h - this.config.gpt.blockSize],
247
- [n.shape[0], this.config.gpt.blockSize]
248
- ), c = r ? this.config.gpt.blockSize - l.shape[1] : 0, p = c > 0 ? D(l, [
258
+ const i = t, p = i.shape[1], l = p <= this.config.gpt.blockSize ? i : i.slice(
259
+ [0, p - this.config.gpt.blockSize],
260
+ [i.shape[0], this.config.gpt.blockSize]
261
+ ), g = r ? this.config.gpt.blockSize - l.shape[1] : 0, c = g > 0 ? T(l, [
249
262
  [0, 0],
250
- [0, c]
251
- ]) : l, { logits: m, attention: u } = this.forward(p, void 0, !1, a, e), b = m.shape[1] - 1 - c, k = m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]), g = u ? u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]) : void 0, E = k.div(i);
263
+ [0, g]
264
+ ]) : l, { logits: u, attention: m } = this.forward(c, void 0, !1, a, e), b = u.shape[1] - 1 - g, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, $ = k.div(n);
252
265
  let d;
253
266
  if (s) {
254
- const { values: S, indices: I } = ot(E, s), L = M(S.squeeze([1]), 1);
267
+ const { values: S, indices: I } = ot($, s), L = M(S.squeeze([1]), 1);
255
268
  d = H(I.squeeze([1]), L, 1);
256
269
  } else
257
- d = M(E.squeeze([1]), 1);
270
+ d = M($.squeeze([1]), 1);
258
271
  let z;
259
- return o?.includeProbabilities && (z = J(E.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: g?.squeeze([1]), probabilities: z };
272
+ return o?.includeProbabilities && (z = J($.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
260
273
  });
261
274
  }
262
275
  getNumParams() {
263
- return K(this.config.gpt);
276
+ return R(this.config.gpt);
264
277
  }
265
278
  dispose() {
266
279
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index-XjBAhiFO.js";
14
+ import "./index-C4JCoBvj.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";