@genai-fi/nanogpt 0.4.1 → 0.4.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. package/dist/Generator.js +3 -3
  2. package/dist/NanoGPTModel.js +84 -74
  3. package/dist/TeachableLLM.js +1 -1
  4. package/dist/{random_width-CMHmdbSu.js → TiedEmbedding-CnJ1bx4q.js} +760 -719
  5. package/dist/{axis_util-DeydwOoC.js → axis_util-BgTGy5w8.js} +1 -1
  6. package/dist/{concat-DS_qH7MI.js → concat-CuRsVY-K.js} +1 -1
  7. package/dist/dropout-DfDdklfL.js +193 -0
  8. package/dist/{gather-BUmJIS8n.js → gather-ZYRWhmXR.js} +1 -1
  9. package/dist/gelu-CnCt17Lk.js +26 -0
  10. package/dist/{index-XjBAhiFO.js → index-C4JCoBvj.js} +61 -61
  11. package/dist/kernel_funcs_utils-CAd1h9X1.js +388 -0
  12. package/dist/layers/CausalSelfAttention.js +74 -73
  13. package/dist/layers/MLP.d.ts +3 -1
  14. package/dist/layers/MLP.js +93 -5
  15. package/dist/layers/RMSNorm.js +3 -3
  16. package/dist/layers/RoPECache.js +3 -3
  17. package/dist/layers/TiedEmbedding.js +6 -46
  18. package/dist/layers/TransformerBlock.js +2 -2
  19. package/dist/{log_sum_exp-DJPkVZZn.js → log_sum_exp-BswFnwOb.js} +5 -5
  20. package/dist/main.js +1 -1
  21. package/dist/{mat_mul-CKwFEV1Q.js → mat_mul-415y5Qn2.js} +1 -1
  22. package/dist/{max-DJvEiCAJ.js → max-CP_9O2Yd.js} +1 -1
  23. package/dist/{moments-CrWRPcR3.js → moments-CjeIaVdp.js} +3 -3
  24. package/dist/{norm-BzY929B_.js → norm-CZM380I3.js} +5 -5
  25. package/dist/{ones-BO01zpJG.js → ones-Bf3YR48P.js} +2 -2
  26. package/dist/ops/appendCache.d.ts +1 -1
  27. package/dist/ops/appendCache.js +10 -4
  28. package/dist/ops/attentionMask.d.ts +1 -1
  29. package/dist/ops/attentionMask.js +4 -4
  30. package/dist/ops/cpu/appendCache.d.ts +1 -2
  31. package/dist/ops/cpu/appendCache.js +15 -20
  32. package/dist/ops/cpu/attentionMask.js +15 -11
  33. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  34. package/dist/ops/cpu/gatherSub.js +3 -3
  35. package/dist/ops/cpu/gelu.d.ts +1 -0
  36. package/dist/ops/cpu/gelu.js +40 -0
  37. package/dist/ops/cpu/mulDropout.js +1 -1
  38. package/dist/ops/cpu/qkv.js +3 -3
  39. package/dist/ops/cpu/rope.js +5 -5
  40. package/dist/ops/cpu/scatterSub.js +4 -4
  41. package/dist/ops/fusedSoftmax.js +1 -1
  42. package/dist/ops/gatherSub.js +1 -1
  43. package/dist/ops/gelu.d.ts +3 -0
  44. package/dist/ops/gelu.js +8 -0
  45. package/dist/ops/grads/attentionMask.js +1 -1
  46. package/dist/ops/grads/fusedSoftmax.js +2 -2
  47. package/dist/ops/grads/gelu.d.ts +2 -0
  48. package/dist/ops/grads/gelu.js +5 -0
  49. package/dist/ops/grads/qkv.js +1 -1
  50. package/dist/ops/grads/rope.js +1 -1
  51. package/dist/ops/mulDrop.js +1 -1
  52. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  53. package/dist/ops/qkv.js +1 -1
  54. package/dist/ops/scatterSub.js +1 -1
  55. package/dist/ops/webgl/appendCache.js +14 -13
  56. package/dist/ops/webgl/attentionMask.js +19 -18
  57. package/dist/ops/webgl/fusedSoftmax.js +483 -782
  58. package/dist/ops/webgl/gatherSub.js +1 -1
  59. package/dist/ops/webgl/gelu.d.ts +2 -0
  60. package/dist/ops/webgl/gelu.js +50 -0
  61. package/dist/ops/webgl/mulDropout.js +1 -1
  62. package/dist/ops/webgl/qkv.js +1 -1
  63. package/dist/ops/webgl/rope.js +1 -1
  64. package/dist/ops/webgl/scatterSub.js +1 -1
  65. package/dist/{range-DQMNzBWs.js → range-9AzeApCc.js} +1 -1
  66. package/dist/{reshape-DFzh97Sc.js → reshape-Boe4DuIO.js} +1 -1
  67. package/dist/{sin-BYM-U4Ut.js → sin-KmhiDuMa.js} +1 -1
  68. package/dist/{slice_util-CnVNPQI-.js → slice_util-19zDNNSn.js} +2 -2
  69. package/dist/{softmax-4DOn6cPq.js → softmax-Cujsg4ay.js} +1 -1
  70. package/dist/{split-CkbeVdF8.js → split-DbcNm1-i.js} +1 -1
  71. package/dist/{stack-DaIMO5iX.js → stack-D1YjmgKN.js} +1 -1
  72. package/dist/{sum-C6u3xMi3.js → sum-R28pucR5.js} +1 -1
  73. package/dist/{tensor-Cu1fU7H7.js → tensor-BVeHdl7V.js} +1 -1
  74. package/dist/{tensor2d-D0CKdG6B.js → tensor2d-DqFGNs_K.js} +1 -1
  75. package/dist/{tfjs_backend-Bzl2SrRo.js → tfjs_backend-Cug-PH75.js} +826 -1015
  76. package/dist/training/AdamExt.js +1 -1
  77. package/dist/training/DatasetBuilder.js +3 -3
  78. package/dist/training/FullTrainer.js +1 -1
  79. package/dist/training/Trainer.js +5 -5
  80. package/dist/training/sparseCrossEntropy.js +4 -4
  81. package/dist/utilities/dummy.js +2 -2
  82. package/dist/utilities/generate.js +3 -3
  83. package/dist/utilities/load.js +1 -1
  84. package/dist/utilities/profile.js +1 -1
  85. package/dist/utilities/weights.js +2 -2
  86. package/dist/{variable-BS4AKqNU.js → variable-LJT9Ld63.js} +1 -1
  87. package/dist/{zeros-CmJFiC84.js → zeros-dnQxFgAD.js} +1 -1
  88. package/package.json +1 -1
  89. package/dist/MLP-KHhikThU.js +0 -83
package/dist/Generator.js CHANGED
@@ -1,7 +1,7 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-XjBAhiFO.js";
3
- import { t as d } from "./tensor2d-D0CKdG6B.js";
4
- import { c as p } from "./concat-DS_qH7MI.js";
2
+ import "./index-C4JCoBvj.js";
3
+ import { t as d } from "./tensor2d-DqFGNs_K.js";
4
+ import { c as p } from "./concat-CuRsVY-K.js";
5
5
  class w extends u {
6
6
  constructor(s, e) {
7
7
  super(), this.model = s, this.tokeniser = e;
@@ -1,19 +1,17 @@
1
1
  import { defaultConfig as x } from "./config.js";
2
2
  import W from "./layers/TransformerBlock.js";
3
- import F from "./layers/TiedEmbedding.js";
4
- import P from "./layers/RoPECache.js";
5
- import q from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as K } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
8
- import T from "./layers/BaseLayer.js";
9
- import { r as R, p as D } from "./random_width-CMHmdbSu.js";
10
- import { o as y, h as $, p as A, E as v, a6 as B, a7 as G, a8 as O, t as w, a5 as Q, f as C } from "./index-XjBAhiFO.js";
11
- import { e as j, d as U } from "./MLP-KHhikThU.js";
12
- import { r as _ } from "./reshape-DFzh97Sc.js";
13
- import { r as V } from "./range-DQMNzBWs.js";
14
- import { e as X } from "./tfjs_backend-Bzl2SrRo.js";
15
- import { g as H } from "./gather-BUmJIS8n.js";
16
- import { s as J } from "./softmax-4DOn6cPq.js";
3
+ import { E as F, D as P, T as q, r as T, p as D } from "./TiedEmbedding-CnJ1bx4q.js";
4
+ import K from "./layers/RoPECache.js";
5
+ import N from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as R } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
8
+ import B from "./layers/BaseLayer.js";
9
+ import { o as $, h as E, p as G, E as v, a6 as O, a7 as j, a8 as Q, t as w, a5 as V, f as C } from "./index-C4JCoBvj.js";
10
+ import { r as _ } from "./reshape-Boe4DuIO.js";
11
+ import { r as X } from "./range-9AzeApCc.js";
12
+ import { e as H } from "./tfjs_backend-Cug-PH75.js";
13
+ import { g as J } from "./gather-ZYRWhmXR.js";
14
+ import { s as U } from "./softmax-Cujsg4ay.js";
17
15
  /**
18
16
  * @license
19
17
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -30,13 +28,13 @@ import { s as J } from "./softmax-4DOn6cPq.js";
30
28
  * limitations under the License.
31
29
  * =============================================================================
32
30
  */
33
- function Y(f, t) {
34
- let e = $(f, "a", "mod"), o = $(t, "b", "mod");
35
- [e, o] = A(e, o);
31
+ function Y(c, t) {
32
+ let e = E(c, "a", "mod"), o = E(t, "b", "mod");
33
+ [e, o] = G(e, o);
36
34
  const i = { a: e, b: o };
37
- return v.runKernel(B, i);
35
+ return v.runKernel(O, i);
38
36
  }
39
- const Z = /* @__PURE__ */ y({ mod_: Y });
37
+ const Z = /* @__PURE__ */ $({ mod_: Y });
40
38
  /**
41
39
  * @license
42
40
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -53,17 +51,17 @@ const Z = /* @__PURE__ */ y({ mod_: Y });
53
51
  * limitations under the License.
54
52
  * =============================================================================
55
53
  */
56
- function tt(f, t, e, o = !1) {
57
- const i = $(f, "logits", "multinomial"), s = i.size, r = i.rank;
54
+ function tt(c, t, e, o = !1) {
55
+ const i = E(c, "logits", "multinomial"), s = i.size, l = i.rank;
58
56
  if (s < 2)
59
57
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
60
- if (r > 2)
61
- throw new Error(`Rank of probabilities must be 1 or 2, but is ${r}`);
58
+ if (l > 2)
59
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${l}`);
62
60
  e = e || Math.random();
63
- const n = { logits: r === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, l = v.runKernel(G, n, h);
64
- return r === 1 ? _(l, [l.size]) : l;
61
+ const n = { logits: l === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = v.runKernel(j, n, h);
62
+ return l === 1 ? _(a, [a.size]) : a;
65
63
  }
66
- const M = /* @__PURE__ */ y({ multinomial_: tt });
64
+ const M = /* @__PURE__ */ $({ multinomial_: tt });
67
65
  /**
68
66
  * @license
69
67
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -80,8 +78,8 @@ const M = /* @__PURE__ */ y({ multinomial_: tt });
80
78
  * limitations under the License.
81
79
  * =============================================================================
82
80
  */
83
- function et(f, t = 1, e = !0) {
84
- const o = $(f, "x", "topk");
81
+ function et(c, t = 1, e = !0) {
82
+ const o = E(c, "x", "topk");
85
83
  if (o.rank === 0)
86
84
  throw new Error("topk() expects the input to be of rank 1 or higher");
87
85
  const i = o.shape[o.shape.length - 1];
@@ -89,11 +87,26 @@ function et(f, t = 1, e = !0) {
89
87
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
90
88
  if (t > i)
91
89
  throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
92
- const s = { x: o }, r = { k: t, sorted: e }, [a, n] = v.runKernel(O, s, r);
93
- return { values: a, indices: n };
90
+ const s = { x: o }, l = { k: t, sorted: e }, [r, n] = v.runKernel(Q, s, l);
91
+ return { values: r, indices: n };
94
92
  }
95
- const ot = /* @__PURE__ */ y({ topk_: et });
96
- class wt extends T {
93
+ const ot = /* @__PURE__ */ $({ topk_: et });
94
+ /**
95
+ * @license
96
+ * Copyright 2018 Google LLC
97
+ *
98
+ * Use of this source code is governed by an MIT-style
99
+ * license that can be found in the LICENSE file or at
100
+ * https://opensource.org/licenses/MIT.
101
+ * =============================================================================
102
+ */
103
+ function st(c) {
104
+ return new P(c);
105
+ }
106
+ function it(c) {
107
+ return new F(c);
108
+ }
109
+ class wt extends B {
97
110
  wte;
98
111
  // Token embeddings
99
112
  wpe;
@@ -107,19 +120,19 @@ class wt extends T {
107
120
  log = [];
108
121
  // Training log
109
122
  constructor(t = {}) {
110
- super({ gpt: { ...x, ...t }, layerConfig: {} }), this.wte = new F({
123
+ super({ gpt: { ...x, ...t }, layerConfig: {} }), this.wte = new q({
111
124
  vocabSize: this.config.gpt.vocabSize,
112
125
  embedDim: this.config.gpt.nEmbed,
113
126
  name: "token_embedding"
114
- }), this.config.gpt.useRope === !1 ? this.wpe = j({
127
+ }), this.config.gpt.useRope === !1 ? this.wpe = it({
115
128
  inputDim: this.config.gpt.blockSize,
116
129
  outputDim: this.config.gpt.nEmbed,
117
130
  name: "positional_embedding",
118
- embeddingsInitializer: R({ mean: 0, stddev: 0.02 })
119
- }) : (this.ropeCache = new P(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = U({ rate: this.config.gpt.dropout }), this.blocks = [];
131
+ embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
132
+ }) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
120
133
  for (let e = 0; e < this.config.gpt.nLayer; e++)
121
134
  this.blocks.push(new W(e, this.config));
122
- this.lnF = new q(this.config, 1e-8, "final_rms_norm");
135
+ this.lnF = new N(this.config, 1e-8, "final_rms_norm");
123
136
  }
124
137
  get checkpointing() {
125
138
  return this.config.layerConfig.checkpointAttention === !0 || this.config.layerConfig.checkpointMLP === !0;
@@ -152,8 +165,8 @@ class wt extends T {
152
165
  return w(() => {
153
166
  const i = this.wte.embed(t);
154
167
  if (this.config.gpt.useRope === !1) {
155
- const [, s] = t.shape, r = this.config.gpt.blockSize, a = V(0, s, 1, "int32"), n = Z(Q(a, C(e, "int32")), C(r, "int32")), h = this.wpe.apply(n), l = i.add(h);
156
- return this.drop.apply(l, { training: o });
168
+ const [, s] = t.shape, l = this.config.gpt.blockSize, r = X(0, s, 1, "int32"), n = Z(V(r, C(e, "int32")), C(l, "int32")), h = this.wpe.apply(n), a = i.add(h);
169
+ return this.drop.apply(a, { training: o });
157
170
  } else
158
171
  return this.drop.apply(i, { training: o });
159
172
  });
@@ -185,7 +198,7 @@ class wt extends T {
185
198
  }
186
199
  calculateLoss(t, e) {
187
200
  try {
188
- return N()(t, e).mean();
201
+ return A()(t, e).mean();
189
202
  } catch (o) {
190
203
  throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
191
204
  }
@@ -197,70 +210,67 @@ class wt extends T {
197
210
  if (t.length === 0)
198
211
  throw new Error("No attentions for rollout");
199
212
  const [e, o, i] = t[0].shape;
200
- for (const s of t) {
201
- const [r, a, n] = s.shape;
202
- if (r !== e || a !== o || n !== i)
213
+ for (const n of t) {
214
+ const [h, a, p] = n.shape;
215
+ if (h !== e || a !== o || p !== i)
203
216
  throw new Error(
204
- `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${r},${a},${n}]`
217
+ `Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${h},${a},${p}]`
205
218
  );
206
219
  }
207
- if (o === i) {
208
- const s = X(i, i).expandDims(0);
209
- let r = s.tile([e, 1, 1]);
210
- for (const a of t) {
211
- const n = a.add(s);
212
- r = n.div(n.sum(-1, !0)).matMul(r);
213
- }
214
- return r;
220
+ const s = t.map((n) => n.slice([0, 0, 0], [e, o, o])), l = H(o, o).expandDims(0);
221
+ let r = l.tile([e, 1, 1]);
222
+ for (const n of s) {
223
+ const h = n.add(l);
224
+ r = h.div(h.sum(-1, !0)).matMul(r);
215
225
  }
216
- throw new Error(`Unsupported attention shapes for rollout: [B=${e}, Q=${o}, K=${i}]`);
226
+ return r;
217
227
  });
218
228
  }
219
229
  forward(t, e, o = !1, i = !1, s) {
220
230
  return this.validateInput(t), w(() => {
221
231
  this.startMemory();
222
- const r = s?.[0]?.length ?? 0;
223
- let a = this.inputPhase(t, r, o);
232
+ const l = s?.[0]?.length ?? 0;
233
+ let r = this.inputPhase(t, l, o);
224
234
  const n = [];
225
235
  if (s && s.length !== this.blocks.length)
226
236
  throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
227
- for (let p = 0; p < this.blocks.length; p++) {
228
- const m = a, u = this.blocks[p], {
237
+ for (let g = 0; g < this.blocks.length; g++) {
238
+ const u = r, m = this.blocks[g], {
229
239
  output: b,
230
240
  attention: k,
231
- cache: g
232
- } = u.call(a, o, i, s ? s[p] : void 0);
233
- a = b, m.dispose(), i && k && n.push(k), s && g ? (s[p]?.k.dispose(), s[p]?.v.dispose(), s[p] = g) : g && (g.k.dispose(), g.v.dispose());
241
+ cache: f
242
+ } = m.call(r, o, i, s ? s[g] : void 0);
243
+ r = b, u.dispose(), i && k && n.push(k), s && f ? (s[g]?.k.dispose(), s[g]?.v.dispose(), s[g] = f) : f && (f.k.dispose(), f.v.dispose());
234
244
  }
235
245
  let h;
236
- i && n.length > 0 && (h = this.computeAttentionRollout(n)), a = this.lnF.apply(a);
237
- const l = this.wte.project(a);
238
- let c;
239
- return e && (c = this.calculateLoss(l, e)), this.endMemory("Forward"), { logits: l, loss: c, attention: i ? h : void 0 };
246
+ i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
247
+ const a = this.wte.project(r);
248
+ let p;
249
+ return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
240
250
  });
241
251
  }
242
252
  generate(t, e, o) {
243
- const i = o?.temperature ?? 1, s = o?.topK, r = o?.usePadding ?? !1, a = o?.includeAttention ?? !1;
253
+ const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
244
254
  return w(() => {
245
- const n = t, h = n.shape[1], l = h <= this.config.gpt.blockSize ? n : n.slice(
255
+ const n = t, h = n.shape[1], a = h <= this.config.gpt.blockSize ? n : n.slice(
246
256
  [0, h - this.config.gpt.blockSize],
247
257
  [n.shape[0], this.config.gpt.blockSize]
248
- ), c = r ? this.config.gpt.blockSize - l.shape[1] : 0, p = c > 0 ? D(l, [
258
+ ), p = l ? this.config.gpt.blockSize - a.shape[1] : 0, g = p > 0 ? D(a, [
249
259
  [0, 0],
250
- [0, c]
251
- ]) : l, { logits: m, attention: u } = this.forward(p, void 0, !1, a, e), b = m.shape[1] - 1 - c, k = m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]), g = u ? u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]) : void 0, E = k.div(i);
260
+ [0, p]
261
+ ]) : a, { logits: u, attention: m } = this.forward(g, void 0, !1, r, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, y = k.div(i);
252
262
  let d;
253
263
  if (s) {
254
- const { values: S, indices: I } = ot(E, s), L = M(S.squeeze([1]), 1);
255
- d = H(I.squeeze([1]), L, 1);
264
+ const { values: S, indices: I } = ot(y, s), L = M(S.squeeze([1]), 1);
265
+ d = J(I.squeeze([1]), L, 1);
256
266
  } else
257
- d = M(E.squeeze([1]), 1);
267
+ d = M(y.squeeze([1]), 1);
258
268
  let z;
259
- return o?.includeProbabilities && (z = J(E.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: g?.squeeze([1]), probabilities: z };
269
+ return o?.includeProbabilities && (z = U(y.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
260
270
  });
261
271
  }
262
272
  getNumParams() {
263
- return K(this.config.gpt);
273
+ return R(this.config.gpt);
264
274
  }
265
275
  dispose() {
266
276
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index-XjBAhiFO.js";
14
+ import "./index-C4JCoBvj.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";