@genai-fi/nanogpt 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/dist/Generator.js +95 -46
  2. package/dist/NanoGPTModel.d.ts +3 -2
  3. package/dist/NanoGPTModel.js +91 -76
  4. package/dist/{Reshape-BE5rA4rT.js → Reshape-Bt_t7RNz.js} +4 -4
  5. package/dist/TeachableLLM.js +1 -1
  6. package/dist/TiedEmbedding-DORsPlNL.js +44 -0
  7. package/dist/{axis_util-97KkkyRQ.js → axis_util-CVbf1vmL.js} +3 -3
  8. package/dist/{broadcast_to-CMlkG8NS.js → broadcast_to-BBoMQXbL.js} +4 -4
  9. package/dist/{concat-Cxbo2sOz.js → concat-BRRtq4S2.js} +1 -1
  10. package/dist/dataset-ZHEPJmED.js +1226 -0
  11. package/dist/{dropout-kbDY39Ci.js → dropout-lQm_YyX3.js} +1 -1
  12. package/dist/{gather-Bxe1Qip8.js → gather-BWyutxwi.js} +3 -3
  13. package/dist/{gpgpu_math-C0zyxKFi.js → gpgpu_math-Df7gzJWH.js} +1 -1
  14. package/dist/{index-iNhkcAEQ.js → index-CnHyhpKc.js} +32 -32
  15. package/dist/{kernel_funcs_utils-C4eIk4fE.js → kernel_funcs_utils-Dqo82NH4.js} +25 -25
  16. package/dist/layers/BaseLayer.js +114 -3
  17. package/dist/layers/CausalSelfAttention.d.ts +2 -3
  18. package/dist/layers/CausalSelfAttention.js +31 -30
  19. package/dist/layers/MLP.js +10 -9
  20. package/dist/layers/RMSNorm.js +12 -11
  21. package/dist/layers/RoPECache.js +3 -3
  22. package/dist/layers/TiedEmbedding.js +8 -6
  23. package/dist/layers/TransformerBlock.js +2 -2
  24. package/dist/{log_sum_exp-CkumwesB.js → log_sum_exp-CRH7Np9v.js} +12 -12
  25. package/dist/main.js +1 -1
  26. package/dist/{mat_mul-D0SifYfJ.js → mat_mul-DeGU1U_C.js} +3 -3
  27. package/dist/{max-CYaAjEEp.js → max-CcnEArWK.js} +3 -3
  28. package/dist/{moments-B06NlR_V.js → moments-DLTE6-1p.js} +4 -4
  29. package/dist/{norm-D3676xIo.js → norm-BpWsOapl.js} +5 -5
  30. package/dist/{ones-BIeFnPHR.js → ones-CDWGzVnm.js} +6 -6
  31. package/dist/ops/appendCache.js +3 -3
  32. package/dist/ops/attentionMask.js +1 -1
  33. package/dist/ops/cpu/appendCache.js +2 -2
  34. package/dist/ops/cpu/attentionMask.js +5 -5
  35. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  36. package/dist/ops/cpu/gatherSub.js +5 -5
  37. package/dist/ops/cpu/gelu.js +1 -1
  38. package/dist/ops/cpu/matMulGelu.js +1 -1
  39. package/dist/ops/cpu/matMulMul.js +1 -1
  40. package/dist/ops/cpu/mulDropout.js +1 -1
  41. package/dist/ops/cpu/normRMS.js +1 -1
  42. package/dist/ops/cpu/qkv.js +3 -3
  43. package/dist/ops/cpu/rope.js +5 -5
  44. package/dist/ops/cpu/scatterSub.js +27 -27
  45. package/dist/ops/fusedSoftmax.js +1 -1
  46. package/dist/ops/gatherSub.js +1 -1
  47. package/dist/ops/gelu.js +1 -1
  48. package/dist/ops/grads/attentionMask.js +1 -1
  49. package/dist/ops/grads/fusedSoftmax.js +2 -2
  50. package/dist/ops/grads/gelu.js +1 -1
  51. package/dist/ops/grads/matMulGelu.js +1 -1
  52. package/dist/ops/grads/normRMS.js +1 -1
  53. package/dist/ops/grads/qkv.js +1 -1
  54. package/dist/ops/grads/rope.js +1 -1
  55. package/dist/ops/matMulGelu.js +1 -1
  56. package/dist/ops/matMulMul.js +1 -1
  57. package/dist/ops/mulDrop.js +1 -1
  58. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  59. package/dist/ops/normRMS.js +1 -1
  60. package/dist/ops/qkv.js +1 -1
  61. package/dist/ops/scatterSub.js +1 -1
  62. package/dist/ops/webgl/appendCache.js +1 -1
  63. package/dist/ops/webgl/attentionMask.js +1 -1
  64. package/dist/ops/webgl/fusedSoftmax.js +36 -36
  65. package/dist/ops/webgl/gatherSub.js +1 -1
  66. package/dist/ops/webgl/gelu.js +2 -2
  67. package/dist/ops/webgl/matMulGelu.js +22 -22
  68. package/dist/ops/webgl/matMulMul.js +1 -1
  69. package/dist/ops/webgl/mulDropout.js +1 -1
  70. package/dist/ops/webgl/normRMS.js +2 -2
  71. package/dist/ops/webgl/qkv.js +1 -1
  72. package/dist/ops/webgl/rope.js +1 -1
  73. package/dist/ops/webgl/scatterSub.js +1 -1
  74. package/dist/{ops-ObfXLHYQ.js → ops-DzQTmLIl.js} +60 -60
  75. package/dist/{TiedEmbedding-DsDRvLB0.js → random_width-DI2h9CMs.js} +1215 -1250
  76. package/dist/{range-BsFU-SNG.js → range-CkOJ7090.js} +1 -1
  77. package/dist/{reshape-DxTPgnwL.js → reshape-CTIbqjwm.js} +1 -1
  78. package/dist/{sin-BOX-JVAj.js → sin-HzioENy_.js} +5 -5
  79. package/dist/{slice_util-D-kaD4ZV.js → slice_util-n4wHKmex.js} +1 -1
  80. package/dist/{softmax-BjsptB07.js → softmax-DX6qXAbm.js} +2 -2
  81. package/dist/{split-BCbrzthj.js → split-CVwhL8Oe.js} +3 -3
  82. package/dist/{stack--cqr9Dgc.js → stack-S2-D2JAQ.js} +1 -1
  83. package/dist/{sum-B_92TaHD.js → sum-UdfvaNhB.js} +4 -4
  84. package/dist/{tensor-CfiPXsW4.js → tensor-IZex6Bwp.js} +1 -1
  85. package/dist/{tensor2d-tSxWdFMH.js → tensor2d-CqtBzOKq.js} +1 -1
  86. package/dist/{tfjs_backend-NucKez4s.js → tfjs_backend-DX9yVvwk.js} +41 -41
  87. package/dist/tokeniser/CharTokeniser.js +27 -27
  88. package/dist/tokeniser/bpe.d.ts +1 -0
  89. package/dist/tokeniser/bpe.js +38 -35
  90. package/dist/training/AdamExt.js +1 -1
  91. package/dist/training/DatasetBuilder.js +22 -1242
  92. package/dist/training/FullTrainer.js +1 -1
  93. package/dist/training/Trainer.js +5 -5
  94. package/dist/training/sparseCrossEntropy.js +4 -4
  95. package/dist/utilities/dummy.js +2 -2
  96. package/dist/utilities/generate.js +3 -3
  97. package/dist/utilities/load.js +1 -1
  98. package/dist/utilities/profile.js +1 -1
  99. package/dist/utilities/save.js +5 -5
  100. package/dist/utilities/weights.js +2 -2
  101. package/dist/variable-BGvK-VN3.js +23 -0
  102. package/dist/{zeros-NMYTayy7.js → zeros-CYMicyqz.js} +3 -3
  103. package/package.json +1 -1
  104. package/dist/BaseLayer-BhrMN8JO.js +0 -135
package/dist/Generator.js CHANGED
@@ -1,72 +1,121 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-iNhkcAEQ.js";
3
- import { t as d } from "./tensor2d-tSxWdFMH.js";
4
- import { c as p } from "./concat-Cxbo2sOz.js";
5
- class w extends u {
6
- constructor(i, t) {
7
- super(), this.model = i, this.tokeniser = t;
2
+ import "./index-CnHyhpKc.js";
3
+ import "./ops/cpu/attentionMask.js";
4
+ import "./ops/webgl/attentionMask.js";
5
+ import "./ops/grads/attentionMask.js";
6
+ import "./ops/cpu/qkv.js";
7
+ import "./ops/webgl/qkv.js";
8
+ import "./ops/grads/qkv.js";
9
+ import "@tensorflow/tfjs";
10
+ import "./ops/cpu/rope.js";
11
+ import "./ops/webgl/rope.js";
12
+ import "./ops/grads/rope.js";
13
+ import "./ops/cpu/appendCache.js";
14
+ import "./ops/webgl/appendCache.js";
15
+ import "./ops/cpu/fusedSoftmax.js";
16
+ import "./ops/webgl/fusedSoftmax.js";
17
+ import "./ops/grads/fusedSoftmax.js";
18
+ import "./ops/cpu/matMulGelu.js";
19
+ import "./ops/webgl/matMulGelu.js";
20
+ import "./ops/grads/matMulGelu.js";
21
+ import "./ops/cpu/normRMS.js";
22
+ import "./ops/webgl/normRMS.js";
23
+ import "./ops/grads/normRMS.js";
24
+ import "./random_width-DI2h9CMs.js";
25
+ import "./ops/cpu/gatherSub.js";
26
+ import "./ops/webgl/gatherSub.js";
27
+ import "./ops/cpu/scatterSub.js";
28
+ import "./ops/webgl/scatterSub.js";
29
+ import "./jszip.min-CjP2V1VV.js";
30
+ import f from "./tokeniser/CharTokeniser.js";
31
+ import "./dataset-ZHEPJmED.js";
32
+ import "./index-Tf7vU29b.js";
33
+ import "./papaparse.min-C8l2Kvo1.js";
34
+ import "./ops/cpu/gelu.js";
35
+ import "./ops/webgl/gelu.js";
36
+ import "./ops/grads/gelu.js";
37
+ import { t as d } from "./tensor2d-CqtBzOKq.js";
38
+ import { c as g } from "./concat-BRRtq4S2.js";
39
+ const k = [
40
+ ...Array.from({ length: 95 }, (a, t) => String.fromCharCode(t + 32)),
41
+ // ASCII
42
+ // Spanish accented letters and punctuation
43
+ ..."áéíóúüñ¿¡",
44
+ // Finnish accented letters
45
+ ..."äöÄÖÅå",
46
+ // Greek letters
47
+ ..."αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
48
+ // Cyrillic letters
49
+ ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
50
+ ];
51
+ function T(a, t) {
52
+ return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
53
+ }
54
+ class ot extends u {
55
+ constructor(t, o) {
56
+ super(), this.model = t, this.tokeniser = o;
8
57
  }
9
58
  active = !1;
10
- async tokenisePrompt(i) {
11
- const t = i ? await this.tokeniser.tokenise([i], !0) : [[this.tokeniser.eosToken]];
12
- return d(t, [1, t[0].length], "int32");
59
+ async tokenisePrompt(t, o) {
60
+ const r = o ? await t.tokenise([o], !0) : [[t.eosToken]];
61
+ return d(r, [1, r[0].length], "int32");
13
62
  }
14
- async generateNoCache(i, t) {
15
- let s = await this.tokenisePrompt(i), o = i || "";
16
- const n = t?.maxLength ?? 1e3;
17
- for (let a = 0; a < n && this.active; a++) {
63
+ async generateNoCache(t, o, r) {
64
+ let i = await this.tokenisePrompt(t, o), s = o || "";
65
+ const n = r?.maxLength ?? 1e3;
66
+ for (let m = 0; m < n && this.active; m++) {
18
67
  const {
19
68
  output: e,
20
- attention: c,
21
- probabilities: l
22
- } = this.model.generate(s, void 0, t), h = s;
23
- s = p([s, e], 1), h.dispose();
24
- const r = await this.processResponse(e, c, l);
25
- if (e.dispose(), r === null)
69
+ attention: p,
70
+ probabilities: c
71
+ } = this.model.generate(i, void 0, r), h = i;
72
+ i = g([i, e], 1), h.dispose();
73
+ const l = await this.processResponse(t, e, p, c);
74
+ if (e.dispose(), l === null)
26
75
  break;
27
- o += r;
76
+ s += l;
28
77
  }
29
- return s.dispose(), o;
78
+ return i.dispose(), s;
30
79
  }
31
- async processResponse(i, t, s) {
32
- const o = (await i.array())[0][0];
33
- if (o === this.tokeniser.eosToken)
80
+ async processResponse(t, o, r, i) {
81
+ const s = (await o.array())[0][0];
82
+ if (s === this.tokeniser.eosToken)
34
83
  return null;
35
- const n = await this.tokeniser.decode([o]);
36
- let a;
37
- t && (a = await t.array(), t.dispose());
84
+ const n = await t.decode([s]);
85
+ let m;
86
+ r && (m = await Promise.all(r.map((p) => p.array().then((c) => c))), r.forEach((p) => p.dispose()));
38
87
  let e;
39
- return s && (e = await s.array(), s.dispose()), this.emit("tokens", [o], n, a, e), n;
88
+ return i && (e = await i.array(), i.dispose()), this.emit("tokens", [s], n, m, e), n;
40
89
  }
41
- async generateCache(i, t) {
42
- let s = await this.tokenisePrompt(i), o = i || "";
90
+ async generateCache(t, o, r) {
91
+ let i = await this.tokenisePrompt(t, o), s = o || "";
43
92
  const n = new Array(this.model.config.gpt.nLayer);
44
93
  for (let e = 0; e < this.model.config.gpt.nLayer; e++)
45
94
  n[e] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
46
- const a = t?.maxLength ?? 1e3;
47
- for (let e = 0; e < a && this.active; e++) {
95
+ const m = r?.maxLength ?? 1e3;
96
+ for (let e = 0; e < m && this.active; e++) {
48
97
  const {
49
- output: c,
50
- attention: l,
51
- probabilities: h
52
- } = this.model.generate(s, n, {
53
- ...t,
98
+ output: p,
99
+ probabilities: c,
100
+ attention: h
101
+ } = this.model.generate(i, n, {
102
+ ...r,
54
103
  usePadding: !1
55
104
  });
56
- s.dispose(), s = c;
57
- const r = await this.processResponse(c, l, h);
58
- if (r === null)
105
+ i.dispose(), i = p;
106
+ const l = await this.processResponse(t, p, h, c);
107
+ if (l === null)
59
108
  break;
60
- o += r;
109
+ s += l;
61
110
  }
62
111
  return n.forEach((e) => {
63
112
  e && (e.k && e.k.dispose(), e.v && e.v.dispose());
64
- }), s.dispose(), o;
113
+ }), i.dispose(), s;
65
114
  }
66
- async generate(i, t) {
67
- const s = i && i.length > this.model.config.gpt.blockSize ? i.slice(-this.model.config.gpt.blockSize) : i;
115
+ async generate(t, o) {
116
+ const r = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t;
68
117
  this.active = !0, this.emit("start");
69
- const n = await (this.model.config.gpt.useRope && !t?.noCache ? this.generateCache(s, t) : this.generateNoCache(s, t));
118
+ const i = this.tokeniser.trained ? this.tokeniser : new f(T(k, this.tokeniser.vocabSize)), n = await (this.model.config.gpt.useRope && !o?.noCache ? this.generateCache(i, r, o) : this.generateNoCache(i, r, o));
70
119
  return this.active = !1, this.emit("stop"), n;
71
120
  }
72
121
  stop() {
@@ -74,5 +123,5 @@ class w extends u {
74
123
  }
75
124
  }
76
125
  export {
77
- w as default
126
+ ot as default
78
127
  };
@@ -13,8 +13,9 @@ export interface TrainingLogEntry {
13
13
  export interface GenerateOptions {
14
14
  temperature?: number;
15
15
  topK?: number;
16
+ topP?: number;
16
17
  usePadding?: boolean;
17
- attentionScores?: AttentionScores;
18
+ attentionScores?: boolean;
18
19
  includeProbabilities?: boolean;
19
20
  }
20
21
  export interface ModelForwardAttributes extends ForwardAttributes {
@@ -41,8 +42,8 @@ export default class NanoGPT extends BaseLayer<ModelForwardAttributes> {
41
42
  forward(attrs: ModelForwardAttributes, idx: Tensor, targets?: Tensor): Tensor[];
42
43
  generate(idx: Tensor, cache?: KVCache[], options?: GenerateOptions): {
43
44
  output: Tensor;
44
- attention?: Tensor;
45
45
  probabilities?: Tensor;
46
+ attention?: Tensor[];
46
47
  };
47
48
  getNumParams(): number;
48
49
  dispose(): void;
@@ -1,16 +1,18 @@
1
- import { defaultConfig as L } from "./config.js";
2
- import q from "./layers/TransformerBlock.js";
3
- import { E as O, D as T, T as K, r as P, p as _ } from "./TiedEmbedding-DsDRvLB0.js";
4
- import F from "./layers/RoPECache.js";
5
- import D from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as N } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as R } from "./training/sparseCrossEntropy.js";
8
- import { B } from "./BaseLayer-BhrMN8JO.js";
9
- import { o as k, i as m, q as G, E as w, aa as A, ab as V, ac as j, t as b, a9 as W, f as y, F as H } from "./index-iNhkcAEQ.js";
10
- import { r as $ } from "./reshape-DxTPgnwL.js";
11
- import { r as J } from "./range-BsFU-SNG.js";
12
- import { g as Q } from "./gather-Bxe1Qip8.js";
13
- import { s as U } from "./softmax-BjsptB07.js";
1
+ import { defaultConfig as F } from "./config.js";
2
+ import O from "./layers/TransformerBlock.js";
3
+ import { T as N, r as R } from "./TiedEmbedding-DORsPlNL.js";
4
+ import A from "./layers/RoPECache.js";
5
+ import G from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as j } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as B } from "./training/sparseCrossEntropy.js";
8
+ import V from "./layers/BaseLayer.js";
9
+ import { E as H, D as W, p as J } from "./random_width-DI2h9CMs.js";
10
+ import { o as x, j as y, u as Q, E as I, aa as U, ab as X, ac as Y, t as z, a9 as Z, f as L, H as tt } from "./index-CnHyhpKc.js";
11
+ import { r as T } from "./reshape-CTIbqjwm.js";
12
+ import { r as et } from "./range-CkOJ7090.js";
13
+ import { s as q } from "./softmax-DX6qXAbm.js";
14
+ import { t as ot } from "./ops-DzQTmLIl.js";
15
+ import { g as st } from "./gather-BWyutxwi.js";
14
16
  /**
15
17
  * @license
16
18
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -27,13 +29,13 @@ import { s as U } from "./softmax-BjsptB07.js";
27
29
  * limitations under the License.
28
30
  * =============================================================================
29
31
  */
30
- function X(h, t) {
31
- let e = m(h, "a", "mod"), o = m(t, "b", "mod");
32
- [e, o] = G(e, o);
32
+ function nt(l, t) {
33
+ let e = y(l, "a", "mod"), o = y(t, "b", "mod");
34
+ [e, o] = Q(e, o);
33
35
  const n = { a: e, b: o };
34
- return w.runKernel(A, n);
36
+ return I.runKernel(U, n);
35
37
  }
36
- const Y = /* @__PURE__ */ k({ mod_: X });
38
+ const it = /* @__PURE__ */ x({ mod_: nt });
37
39
  /**
38
40
  * @license
39
41
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -50,17 +52,17 @@ const Y = /* @__PURE__ */ k({ mod_: X });
50
52
  * limitations under the License.
51
53
  * =============================================================================
52
54
  */
53
- function Z(h, t, e, o = !1) {
54
- const n = m(h, "logits", "multinomial"), s = n.size, i = n.rank;
55
+ function rt(l, t, e, o = !1) {
56
+ const n = y(l, "logits", "multinomial"), s = n.size, i = n.rank;
55
57
  if (s < 2)
56
58
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
57
59
  if (i > 2)
58
60
  throw new Error(`Rank of probabilities must be 1 or 2, but is ${i}`);
59
61
  e = e || Math.random();
60
- const a = { logits: i === 1 ? $(n, [1, -1]) : n }, c = { numSamples: t, seed: e, normalized: o }, l = w.runKernel(V, a, c);
61
- return i === 1 ? $(l, [l.size]) : l;
62
+ const r = { logits: i === 1 ? T(n, [1, -1]) : n }, u = { numSamples: t, seed: e, normalized: o }, c = I.runKernel(X, r, u);
63
+ return i === 1 ? T(c, [c.size]) : c;
62
64
  }
63
- const z = /* @__PURE__ */ k({ multinomial_: Z });
65
+ const C = /* @__PURE__ */ x({ multinomial_: rt });
64
66
  /**
65
67
  * @license
66
68
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -77,8 +79,8 @@ const z = /* @__PURE__ */ k({ multinomial_: Z });
77
79
  * limitations under the License.
78
80
  * =============================================================================
79
81
  */
80
- function tt(h, t = 1, e = !0) {
81
- const o = m(h, "x", "topk");
82
+ function ct(l, t = 1, e = !0) {
83
+ const o = y(l, "x", "topk");
82
84
  if (o.rank === 0)
83
85
  throw new Error("topk() expects the input to be of rank 1 or higher");
84
86
  const n = o.shape[o.shape.length - 1];
@@ -86,10 +88,10 @@ function tt(h, t = 1, e = !0) {
86
88
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
87
89
  if (t > n)
88
90
  throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
89
- const s = { x: o }, i = { k: t, sorted: e }, [r, a] = w.runKernel(j, s, i);
90
- return { values: r, indices: a };
91
+ const s = { x: o }, i = { k: t, sorted: e }, [p, r] = I.runKernel(Y, s, i);
92
+ return { values: p, indices: r };
91
93
  }
92
- const et = /* @__PURE__ */ k({ topk_: tt });
94
+ const at = /* @__PURE__ */ x({ topk_: ct });
93
95
  /**
94
96
  * @license
95
97
  * Copyright 2018 Google LLC
@@ -99,13 +101,13 @@ const et = /* @__PURE__ */ k({ topk_: tt });
99
101
  * https://opensource.org/licenses/MIT.
100
102
  * =============================================================================
101
103
  */
102
- function ot(h) {
103
- return new T(h);
104
+ function lt(l) {
105
+ return new W(l);
104
106
  }
105
- function st(h) {
106
- return new O(h);
107
+ function pt(l) {
108
+ return new H(l);
107
109
  }
108
- class bt extends B {
110
+ class xt extends V {
109
111
  wte;
110
112
  // Token embeddings
111
113
  wpe;
@@ -119,15 +121,15 @@ class bt extends B {
119
121
  log = [];
120
122
  // Training log
121
123
  constructor(t = {}) {
122
- super({ gpt: { ...L, ...t }, layerConfig: {} }), this.wte = new K(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = st({
124
+ super({ gpt: { ...F, ...t }, layerConfig: {} }), this.wte = new N(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = pt({
123
125
  inputDim: this.config.gpt.blockSize,
124
126
  outputDim: this.config.gpt.nEmbed,
125
127
  name: "positional_embedding",
126
- embeddingsInitializer: P({ mean: 0, stddev: 0.02 })
127
- }) : (this.ropeCache = new F(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = ot({ rate: this.config.gpt.dropout }), this.blocks = [];
128
+ embeddingsInitializer: R({ mean: 0, stddev: 0.02 })
129
+ }) : (this.ropeCache = new A(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = lt({ rate: this.config.gpt.dropout }), this.blocks = [];
128
130
  for (let e = 0; e < this.config.gpt.nLayer; e++)
129
- this.blocks.push(new q(e, this.config, this));
130
- this.lnF = new D(this.config, "final_rms_norm", this);
131
+ this.blocks.push(new O(e, this.config, this));
132
+ this.lnF = new G(this.config, "final_rms_norm", this);
131
133
  }
132
134
  get checkpointing() {
133
135
  return this.config.layerConfig.checkpointing === !0;
@@ -136,11 +138,11 @@ class bt extends B {
136
138
  this.config.layerConfig.checkpointing = t;
137
139
  }
138
140
  inputPhase(t, e, o = !1) {
139
- return b(() => {
141
+ return z(() => {
140
142
  const n = this.wte.embed(t);
141
143
  if (this.config.gpt.useRope === !1) {
142
- const [, s] = t.shape, i = this.config.gpt.blockSize, r = J(0, s, 1, "int32"), a = Y(W(r, y(e, "int32")), y(i, "int32")), c = this.wpe.apply(a), l = n.add(c);
143
- return this.drop.apply(l, { training: o });
144
+ const [, s] = t.shape, i = this.config.gpt.blockSize, p = et(0, s, 1, "int32"), r = it(Z(p, L(e, "int32")), L(i, "int32")), u = this.wpe.apply(r), c = n.add(u);
145
+ return this.drop.apply(c, { training: o });
144
146
  } else
145
147
  return this.drop.apply(n, { training: o });
146
148
  });
@@ -167,7 +169,7 @@ class bt extends B {
167
169
  }
168
170
  calculateLoss(t, e) {
169
171
  try {
170
- return R()(t, e).mean();
172
+ return B()(t, e).mean();
171
173
  } catch (o) {
172
174
  throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
173
175
  }
@@ -205,7 +207,7 @@ class bt extends B {
205
207
  });
206
208
  }*/
207
209
  forward(t, e, o) {
208
- return this.validateInput(e), b(() => {
210
+ return this.validateInput(e), z(() => {
209
211
  this.startMemory();
210
212
  const n = t.cache?.[0]?.length ?? 0;
211
213
  let s = this.inputPhase(e, n, t.training);
@@ -213,59 +215,72 @@ class bt extends B {
213
215
  throw console.error("Cache", t.cache), new Error(
214
216
  `Cache length ${t.cache.length} does not match number of blocks ${this.blocks.length}`
215
217
  );
216
- let i;
217
- for (let c = 0; c < this.blocks.length; c++) {
218
- const l = this.blocks[c], f = Math.random() * 1e9, p = {
218
+ for (let r = 0; r < this.blocks.length; r++) {
219
+ const u = this.blocks[r], c = Math.random() * 1e9, g = {
219
220
  training: t.training,
220
- seed: f,
221
+ seed: c,
221
222
  attentionScores: t.attentionScores,
222
- pastKV: t.cache ? t.cache[c] : void 0
223
- }, u = this.config.layerConfig.checkpointing && t.training ? l.callCheckpoint(p, s) : l.call(p, s);
224
- s.dispose(), s = u, p.attentionScores?.attentionOut && (i = p.attentionScores.attentionOut);
223
+ pastKV: t.cache ? t.cache[r] : void 0
224
+ }, S = this.config.layerConfig.checkpointing && t.training ? u.callCheckpoint(g, s) : u.call(g, s);
225
+ s.dispose(), s = S;
225
226
  }
226
227
  s = this.lnF.call(t, s);
227
- const r = this.wte.project(s);
228
+ const i = this.wte.project(s);
228
229
  s.dispose();
229
- let a;
230
- return o && (a = this.calculateLoss(r, o)), this.endMemory("Forward"), t.attentionScores && (t.attentionScores.attentionOut = i ? H(i) : void 0), a ? [r, a] : [r];
230
+ let p;
231
+ return o && (p = this.calculateLoss(i, o)), this.endMemory("Forward"), p ? [i, p] : [i];
231
232
  });
232
233
  }
233
234
  generate(t, e, o) {
234
- const n = o?.temperature ?? 1, s = o?.topK, i = o?.usePadding ?? !1;
235
- return b(() => {
236
- const r = t, a = r.shape[1], c = a <= this.config.gpt.blockSize ? r : r.slice(
237
- [0, a - this.config.gpt.blockSize],
235
+ const n = o?.temperature ?? 1, s = o?.topK, i = o?.topP, p = o?.usePadding ?? !1;
236
+ return z(() => {
237
+ const r = t, u = r.shape[1], c = u <= this.config.gpt.blockSize ? r : r.slice(
238
+ [0, u - this.config.gpt.blockSize],
238
239
  [r.shape[0], this.config.gpt.blockSize]
239
- ), l = i ? this.config.gpt.blockSize - c.shape[1] : 0, f = l > 0 ? _(c, [
240
+ ), g = p ? this.config.gpt.blockSize - c.shape[1] : 0, S = g > 0 ? J(c, [
240
241
  [0, 0],
241
- [0, l]
242
- ]) : c, p = {
242
+ [0, g]
243
+ ]) : c, f = {
243
244
  training: !1,
244
- attentionScores: o?.attentionScores,
245
+ attentionScores: o?.attentionScores ? {
246
+ attentionOut: []
247
+ } : void 0,
245
248
  cache: e
246
- }, [u] = this.forward(p, f), E = u.shape[1] - 1 - l, C = u.slice([0, E, 0], [u.shape[0], 1, u.shape[2]]), I = p.attentionScores?.attentionOut ? p.attentionScores.attentionOut.slice(
247
- [0, E, 0],
248
- [p.attentionScores.attentionOut.shape[0], 1, p.attentionScores.attentionOut.shape[2]]
249
- ) : void 0;
250
- u.dispose();
251
- const d = C.div(n);
252
- let g;
253
- if (s) {
254
- const { values: v, indices: M } = et(d, s), x = z(v.squeeze([1]), 1);
255
- g = Q(M.squeeze([1]), x, 1);
249
+ }, [d] = this.forward(f, S), M = d.shape[1] - 1 - g, K = d.slice([0, M, 0], [d.shape[0], 1, d.shape[2]]);
250
+ f.attentionScores?.attentionOut && f.attentionScores.attentionOut.forEach((h, b) => {
251
+ h.shape[1] !== 1 && (f.attentionScores.attentionOut[b] = tt(
252
+ h.slice([0, M, 0], [h.shape[0], 1, h.shape[2]])
253
+ ), h.dispose());
254
+ }), d.dispose();
255
+ const w = K.div(n);
256
+ let m;
257
+ if (i) {
258
+ const h = q(w.squeeze([1])), b = h.arraySync()[0];
259
+ h.dispose();
260
+ const E = b.map((a, k) => ({ prob: a, index: k })).sort((a, k) => k.prob - a.prob);
261
+ let v = 0;
262
+ const $ = new Array(E.length).fill(0);
263
+ for (const a of E)
264
+ if (v += a.prob, $[a.index] = a.prob, v >= i)
265
+ break;
266
+ const _ = $.reduce((a, k) => a + k, 0), D = $.map((a) => a / _);
267
+ m = C(ot(D), 1, void 0, !0);
268
+ } else if (s) {
269
+ const { values: h, indices: b } = at(w, s), E = C(h.squeeze([1]), 1);
270
+ m = st(b.squeeze([1]), E, 1);
256
271
  } else
257
- g = z(d.squeeze([1]), 1);
258
- let S;
259
- return o?.includeProbabilities && (S = U(d.squeeze([1]))), g = g.reshape([1, 1]), { output: g, attention: I?.squeeze([1]), probabilities: S };
272
+ m = C(w.squeeze([1]), 1);
273
+ let P;
274
+ return o?.includeProbabilities && (P = q(w.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: P, attention: f.attentionScores?.attentionOut };
260
275
  });
261
276
  }
262
277
  getNumParams() {
263
- return N(this.config.gpt);
278
+ return j(this.config.gpt);
264
279
  }
265
280
  dispose() {
266
281
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
267
282
  }
268
283
  }
269
284
  export {
270
- bt as default
285
+ xt as default
271
286
  };
@@ -1,5 +1,5 @@
1
- import { ad as $, ae as g, p, af as C, k as x } from "./index-iNhkcAEQ.js";
2
- import { u as I } from "./gpgpu_math-C0zyxKFi.js";
1
+ import { ad as $, ae as g, q as p, af as C, l as x } from "./index-CnHyhpKc.js";
2
+ import { u as I } from "./gpgpu_math-Df7gzJWH.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -201,12 +201,12 @@ function D(t, e, o) {
201
201
  * limitations under the License.
202
202
  * =============================================================================
203
203
  */
204
- function U(t) {
204
+ function k(t) {
205
205
  const { inputs: e, backend: o, attrs: s } = t, { x: n } = e, { shape: r } = s, i = o, u = p(n.shape), a = C(r, u), c = p(a);
206
206
  x(u === c, () => `The new shape (${a}) has ${c} elements and the old shape (${n.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
207
207
  const l = i.texData.get(n.dataId);
208
208
  return l.isPacked && !f(n.shape, a) && !(l.texture !== null && f(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
209
209
  }
210
210
  export {
211
- U as r
211
+ k as r
212
212
  };
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index-iNhkcAEQ.js";
14
+ import "./index-CnHyhpKc.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";
@@ -0,0 +1,44 @@
1
+ import { R as a } from "./random_width-DI2h9CMs.js";
2
+ import "./index-CnHyhpKc.js";
3
+ import { d as s } from "./tfjs_backend-DX9yVvwk.js";
4
+ import o from "./layers/BaseLayer.js";
5
+ import { v as m } from "./variable-BGvK-VN3.js";
6
+ import { g as d } from "./gather-BWyutxwi.js";
7
+ /**
8
+ * @license
9
+ * Copyright 2018 Google LLC
10
+ *
11
+ * Use of this source code is governed by an MIT-style
12
+ * license that can be found in the LICENSE file or at
13
+ * https://opensource.org/licenses/MIT.
14
+ * =============================================================================
15
+ */
16
+ function n(e) {
17
+ return new a(e);
18
+ }
19
+ class c extends o {
20
+ vocabSize;
21
+ embedDim;
22
+ initializer;
23
+ WEIGHTS;
24
+ constructor(t, i, r) {
25
+ super(t, r), this.WEIGHTS = i, this.vocabSize = t.gpt.vocabSize, this.embedDim = t.gpt.nEmbed, this.initializer = n({
26
+ mean: 0,
27
+ stddev: 0.02
28
+ }), this.addVariable(this.WEIGHTS, m(this.initializer.apply([this.vocabSize, this.embedDim]), !0));
29
+ }
30
+ embed(t) {
31
+ return d(this.getVariable(this.WEIGHTS), t, 0);
32
+ }
33
+ project(t) {
34
+ return s(t, this.getVariable(this.WEIGHTS).transpose());
35
+ }
36
+ // Dummy, should not be used.
37
+ forward(t, i) {
38
+ return this.project(i);
39
+ }
40
+ }
41
+ export {
42
+ c as T,
43
+ n as r
44
+ };
@@ -1,4 +1,4 @@
1
- import { k as c } from "./index-iNhkcAEQ.js";
1
+ import { l as c } from "./index-CnHyhpKc.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2017 Google LLC. All Rights Reserved.
@@ -21,7 +21,7 @@ function i(e, n) {
21
21
  return !1;
22
22
  return !0;
23
23
  }
24
- function p(e, n, t) {
24
+ function l(e, n, t) {
25
25
  const r = e.length + n.length, s = [];
26
26
  let o = 0, f = 0;
27
27
  for (let u = 0; u < r; u++)
@@ -37,7 +37,7 @@ function a(e, n) {
37
37
  }
38
38
  function m(e, n) {
39
39
  const t = n.map((r) => 1);
40
- return p(e, t, n);
40
+ return l(e, t, n);
41
41
  }
42
42
  function d(e, n, t) {
43
43
  c(i(n, t), () => `${e} supports only inner-most axes for now. Got axes ${n} and rank-${t} input.`);
@@ -1,5 +1,5 @@
1
- import { o as h, i as f, l as p, x as g, E as u, T } from "./index-iNhkcAEQ.js";
2
- import { r as b } from "./reshape-DxTPgnwL.js";
1
+ import { o as h, j as f, n as p, y as g, E as u, L as b } from "./index-CnHyhpKc.js";
2
+ import { r as T } from "./reshape-CTIbqjwm.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -25,7 +25,7 @@ function m(e, r) {
25
25
  const t = n.shape.slice();
26
26
  for (; t.length < r.length; )
27
27
  t.unshift(1);
28
- n = b(n, t);
28
+ n = T(n, t);
29
29
  }
30
30
  const s = n.shape, o = Array.from(r);
31
31
  for (let t = r.length - 1; t >= 0; t--)
@@ -36,7 +36,7 @@ function m(e, r) {
36
36
  if (o.map((t, l) => t > 1 ? l : -1).filter((t) => t >= 0).length === 0)
37
37
  return g(n);
38
38
  const i = { x: n }, c = { reps: o };
39
- return u.runKernel(T, i, c);
39
+ return u.runKernel(b, i, c);
40
40
  }
41
41
  const E = /* @__PURE__ */ h({ broadcastTo_: m });
42
42
  export {
@@ -1,4 +1,4 @@
1
- import { o as s, k as a, j as p, x as i, E as l, C as f } from "./index-iNhkcAEQ.js";
1
+ import { o as s, l as a, k as p, y as i, E as l, C as f } from "./index-CnHyhpKc.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.