@genai-fi/nanogpt 0.5.1 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/dist/Generator.js +90 -41
  2. package/dist/NanoGPTModel.d.ts +1 -0
  3. package/dist/NanoGPTModel.js +86 -73
  4. package/dist/{Reshape-BE5rA4rT.js → Reshape-Bt_t7RNz.js} +4 -4
  5. package/dist/TeachableLLM.js +1 -1
  6. package/dist/TiedEmbedding-DORsPlNL.js +44 -0
  7. package/dist/{axis_util-97KkkyRQ.js → axis_util-CVbf1vmL.js} +3 -3
  8. package/dist/{broadcast_to-CMlkG8NS.js → broadcast_to-BBoMQXbL.js} +4 -4
  9. package/dist/{concat-Cxbo2sOz.js → concat-BRRtq4S2.js} +1 -1
  10. package/dist/dataset-ZHEPJmED.js +1226 -0
  11. package/dist/{dropout-kbDY39Ci.js → dropout-lQm_YyX3.js} +1 -1
  12. package/dist/{gather-Bxe1Qip8.js → gather-BWyutxwi.js} +3 -3
  13. package/dist/{gpgpu_math-C0zyxKFi.js → gpgpu_math-Df7gzJWH.js} +1 -1
  14. package/dist/{index-iNhkcAEQ.js → index-CnHyhpKc.js} +32 -32
  15. package/dist/{kernel_funcs_utils-C4eIk4fE.js → kernel_funcs_utils-Dqo82NH4.js} +25 -25
  16. package/dist/layers/BaseLayer.js +114 -3
  17. package/dist/layers/CausalSelfAttention.js +29 -28
  18. package/dist/layers/MLP.js +10 -9
  19. package/dist/layers/RMSNorm.js +12 -11
  20. package/dist/layers/RoPECache.js +3 -3
  21. package/dist/layers/TiedEmbedding.js +8 -6
  22. package/dist/layers/TransformerBlock.js +2 -2
  23. package/dist/{log_sum_exp-CkumwesB.js → log_sum_exp-CRH7Np9v.js} +12 -12
  24. package/dist/main.js +1 -1
  25. package/dist/{mat_mul-D0SifYfJ.js → mat_mul-DeGU1U_C.js} +3 -3
  26. package/dist/{max-CYaAjEEp.js → max-CcnEArWK.js} +3 -3
  27. package/dist/{moments-B06NlR_V.js → moments-DLTE6-1p.js} +4 -4
  28. package/dist/{norm-D3676xIo.js → norm-BpWsOapl.js} +5 -5
  29. package/dist/{ones-BIeFnPHR.js → ones-CDWGzVnm.js} +6 -6
  30. package/dist/ops/appendCache.js +3 -3
  31. package/dist/ops/attentionMask.js +1 -1
  32. package/dist/ops/cpu/appendCache.js +2 -2
  33. package/dist/ops/cpu/attentionMask.js +5 -5
  34. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  35. package/dist/ops/cpu/gatherSub.js +5 -5
  36. package/dist/ops/cpu/gelu.js +1 -1
  37. package/dist/ops/cpu/matMulGelu.js +1 -1
  38. package/dist/ops/cpu/matMulMul.js +1 -1
  39. package/dist/ops/cpu/mulDropout.js +1 -1
  40. package/dist/ops/cpu/normRMS.js +1 -1
  41. package/dist/ops/cpu/qkv.js +3 -3
  42. package/dist/ops/cpu/rope.js +5 -5
  43. package/dist/ops/cpu/scatterSub.js +27 -27
  44. package/dist/ops/fusedSoftmax.js +1 -1
  45. package/dist/ops/gatherSub.js +1 -1
  46. package/dist/ops/gelu.js +1 -1
  47. package/dist/ops/grads/attentionMask.js +1 -1
  48. package/dist/ops/grads/fusedSoftmax.js +2 -2
  49. package/dist/ops/grads/gelu.js +1 -1
  50. package/dist/ops/grads/matMulGelu.js +1 -1
  51. package/dist/ops/grads/normRMS.js +1 -1
  52. package/dist/ops/grads/qkv.js +1 -1
  53. package/dist/ops/grads/rope.js +1 -1
  54. package/dist/ops/matMulGelu.js +1 -1
  55. package/dist/ops/matMulMul.js +1 -1
  56. package/dist/ops/mulDrop.js +1 -1
  57. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  58. package/dist/ops/normRMS.js +1 -1
  59. package/dist/ops/qkv.js +1 -1
  60. package/dist/ops/scatterSub.js +1 -1
  61. package/dist/ops/webgl/appendCache.js +1 -1
  62. package/dist/ops/webgl/attentionMask.js +1 -1
  63. package/dist/ops/webgl/fusedSoftmax.js +36 -36
  64. package/dist/ops/webgl/gatherSub.js +1 -1
  65. package/dist/ops/webgl/gelu.js +2 -2
  66. package/dist/ops/webgl/matMulGelu.js +22 -22
  67. package/dist/ops/webgl/matMulMul.js +1 -1
  68. package/dist/ops/webgl/mulDropout.js +1 -1
  69. package/dist/ops/webgl/normRMS.js +2 -2
  70. package/dist/ops/webgl/qkv.js +1 -1
  71. package/dist/ops/webgl/rope.js +1 -1
  72. package/dist/ops/webgl/scatterSub.js +1 -1
  73. package/dist/{ops-ObfXLHYQ.js → ops-DzQTmLIl.js} +60 -60
  74. package/dist/{TiedEmbedding-DsDRvLB0.js → random_width-DI2h9CMs.js} +1215 -1250
  75. package/dist/{range-BsFU-SNG.js → range-CkOJ7090.js} +1 -1
  76. package/dist/{reshape-DxTPgnwL.js → reshape-CTIbqjwm.js} +1 -1
  77. package/dist/{sin-BOX-JVAj.js → sin-HzioENy_.js} +5 -5
  78. package/dist/{slice_util-D-kaD4ZV.js → slice_util-n4wHKmex.js} +1 -1
  79. package/dist/{softmax-BjsptB07.js → softmax-DX6qXAbm.js} +2 -2
  80. package/dist/{split-BCbrzthj.js → split-CVwhL8Oe.js} +3 -3
  81. package/dist/{stack--cqr9Dgc.js → stack-S2-D2JAQ.js} +1 -1
  82. package/dist/{sum-B_92TaHD.js → sum-UdfvaNhB.js} +4 -4
  83. package/dist/{tensor-CfiPXsW4.js → tensor-IZex6Bwp.js} +1 -1
  84. package/dist/{tensor2d-tSxWdFMH.js → tensor2d-CqtBzOKq.js} +1 -1
  85. package/dist/{tfjs_backend-NucKez4s.js → tfjs_backend-DX9yVvwk.js} +41 -41
  86. package/dist/tokeniser/CharTokeniser.js +27 -27
  87. package/dist/tokeniser/bpe.d.ts +1 -0
  88. package/dist/tokeniser/bpe.js +38 -35
  89. package/dist/training/AdamExt.js +1 -1
  90. package/dist/training/DatasetBuilder.js +22 -1242
  91. package/dist/training/FullTrainer.js +1 -1
  92. package/dist/training/Trainer.js +5 -5
  93. package/dist/training/sparseCrossEntropy.js +4 -4
  94. package/dist/utilities/dummy.js +2 -2
  95. package/dist/utilities/generate.js +3 -3
  96. package/dist/utilities/load.js +1 -1
  97. package/dist/utilities/profile.js +1 -1
  98. package/dist/utilities/weights.js +2 -2
  99. package/dist/variable-BGvK-VN3.js +23 -0
  100. package/dist/{zeros-NMYTayy7.js → zeros-CYMicyqz.js} +3 -3
  101. package/package.json +1 -1
  102. package/dist/BaseLayer-BhrMN8JO.js +0 -135
package/dist/Generator.js CHANGED
@@ -1,72 +1,121 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-iNhkcAEQ.js";
3
- import { t as d } from "./tensor2d-tSxWdFMH.js";
4
- import { c as p } from "./concat-Cxbo2sOz.js";
5
- class w extends u {
6
- constructor(i, t) {
7
- super(), this.model = i, this.tokeniser = t;
2
+ import "./index-CnHyhpKc.js";
3
+ import "./ops/cpu/attentionMask.js";
4
+ import "./ops/webgl/attentionMask.js";
5
+ import "./ops/grads/attentionMask.js";
6
+ import "./ops/cpu/qkv.js";
7
+ import "./ops/webgl/qkv.js";
8
+ import "./ops/grads/qkv.js";
9
+ import "@tensorflow/tfjs";
10
+ import "./ops/cpu/rope.js";
11
+ import "./ops/webgl/rope.js";
12
+ import "./ops/grads/rope.js";
13
+ import "./ops/cpu/appendCache.js";
14
+ import "./ops/webgl/appendCache.js";
15
+ import "./ops/cpu/fusedSoftmax.js";
16
+ import "./ops/webgl/fusedSoftmax.js";
17
+ import "./ops/grads/fusedSoftmax.js";
18
+ import "./ops/cpu/matMulGelu.js";
19
+ import "./ops/webgl/matMulGelu.js";
20
+ import "./ops/grads/matMulGelu.js";
21
+ import "./ops/cpu/normRMS.js";
22
+ import "./ops/webgl/normRMS.js";
23
+ import "./ops/grads/normRMS.js";
24
+ import "./random_width-DI2h9CMs.js";
25
+ import "./ops/cpu/gatherSub.js";
26
+ import "./ops/webgl/gatherSub.js";
27
+ import "./ops/cpu/scatterSub.js";
28
+ import "./ops/webgl/scatterSub.js";
29
+ import "./jszip.min-CjP2V1VV.js";
30
+ import f from "./tokeniser/CharTokeniser.js";
31
+ import "./dataset-ZHEPJmED.js";
32
+ import "./index-Tf7vU29b.js";
33
+ import "./papaparse.min-C8l2Kvo1.js";
34
+ import "./ops/cpu/gelu.js";
35
+ import "./ops/webgl/gelu.js";
36
+ import "./ops/grads/gelu.js";
37
+ import { t as d } from "./tensor2d-CqtBzOKq.js";
38
+ import { c as g } from "./concat-BRRtq4S2.js";
39
+ const k = [
40
+ ...Array.from({ length: 95 }, (a, t) => String.fromCharCode(t + 32)),
41
+ // ASCII
42
+ // Spanish accented letters and punctuation
43
+ ..."áéíóúüñ¿¡",
44
+ // Finnish accented letters
45
+ ..."äöÄÖÅå",
46
+ // Greek letters
47
+ ..."αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
48
+ // Cyrillic letters
49
+ ..."абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ"
50
+ ];
51
+ function T(a, t) {
52
+ return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
53
+ }
54
+ class ot extends u {
55
+ constructor(t, o) {
56
+ super(), this.model = t, this.tokeniser = o;
8
57
  }
9
58
  active = !1;
10
- async tokenisePrompt(i) {
11
- const t = i ? await this.tokeniser.tokenise([i], !0) : [[this.tokeniser.eosToken]];
12
- return d(t, [1, t[0].length], "int32");
59
+ async tokenisePrompt(t, o) {
60
+ const r = o ? await t.tokenise([o], !0) : [[t.eosToken]];
61
+ return d(r, [1, r[0].length], "int32");
13
62
  }
14
- async generateNoCache(i, t) {
15
- let s = await this.tokenisePrompt(i), o = i || "";
16
- const n = t?.maxLength ?? 1e3;
17
- for (let r = 0; r < n && this.active; r++) {
63
+ async generateNoCache(t, o, r) {
64
+ let i = await this.tokenisePrompt(t, o), s = o || "";
65
+ const n = r?.maxLength ?? 1e3;
66
+ for (let m = 0; m < n && this.active; m++) {
18
67
  const {
19
68
  output: e,
20
- attention: a,
69
+ attention: p,
21
70
  probabilities: c
22
- } = this.model.generate(s, void 0, t), h = s;
23
- s = p([s, e], 1), h.dispose();
24
- const l = await this.processResponse(e, a, c);
71
+ } = this.model.generate(i, void 0, r), h = i;
72
+ i = g([i, e], 1), h.dispose();
73
+ const l = await this.processResponse(t, e, p, c);
25
74
  if (e.dispose(), l === null)
26
75
  break;
27
- o += l;
76
+ s += l;
28
77
  }
29
- return s.dispose(), o;
78
+ return i.dispose(), s;
30
79
  }
31
- async processResponse(i, t, s) {
32
- const o = (await i.array())[0][0];
33
- if (o === this.tokeniser.eosToken)
80
+ async processResponse(t, o, r, i) {
81
+ const s = (await o.array())[0][0];
82
+ if (s === this.tokeniser.eosToken)
34
83
  return null;
35
- const n = await this.tokeniser.decode([o]);
36
- let r;
37
- t && (r = await Promise.all(t.map((a) => a.array().then((c) => c))), t.forEach((a) => a.dispose()));
84
+ const n = await t.decode([s]);
85
+ let m;
86
+ r && (m = await Promise.all(r.map((p) => p.array().then((c) => c))), r.forEach((p) => p.dispose()));
38
87
  let e;
39
- return s && (e = await s.array(), s.dispose()), this.emit("tokens", [o], n, r, e), n;
88
+ return i && (e = await i.array(), i.dispose()), this.emit("tokens", [s], n, m, e), n;
40
89
  }
41
- async generateCache(i, t) {
42
- let s = await this.tokenisePrompt(i), o = i || "";
90
+ async generateCache(t, o, r) {
91
+ let i = await this.tokenisePrompt(t, o), s = o || "";
43
92
  const n = new Array(this.model.config.gpt.nLayer);
44
93
  for (let e = 0; e < this.model.config.gpt.nLayer; e++)
45
94
  n[e] = { k: void 0, v: void 0, length: 0, cumulativeLength: 0 };
46
- const r = t?.maxLength ?? 1e3;
47
- for (let e = 0; e < r && this.active; e++) {
95
+ const m = r?.maxLength ?? 1e3;
96
+ for (let e = 0; e < m && this.active; e++) {
48
97
  const {
49
- output: a,
98
+ output: p,
50
99
  probabilities: c,
51
100
  attention: h
52
- } = this.model.generate(s, n, {
53
- ...t,
101
+ } = this.model.generate(i, n, {
102
+ ...r,
54
103
  usePadding: !1
55
104
  });
56
- s.dispose(), s = a;
57
- const l = await this.processResponse(a, h, c);
105
+ i.dispose(), i = p;
106
+ const l = await this.processResponse(t, p, h, c);
58
107
  if (l === null)
59
108
  break;
60
- o += l;
109
+ s += l;
61
110
  }
62
111
  return n.forEach((e) => {
63
112
  e && (e.k && e.k.dispose(), e.v && e.v.dispose());
64
- }), s.dispose(), o;
113
+ }), i.dispose(), s;
65
114
  }
66
- async generate(i, t) {
67
- const s = i && i.length > this.model.config.gpt.blockSize ? i.slice(-this.model.config.gpt.blockSize) : i;
115
+ async generate(t, o) {
116
+ const r = t && t.length > this.model.config.gpt.blockSize ? t.slice(-this.model.config.gpt.blockSize) : t;
68
117
  this.active = !0, this.emit("start");
69
- const n = await (this.model.config.gpt.useRope && !t?.noCache ? this.generateCache(s, t) : this.generateNoCache(s, t));
118
+ const i = this.tokeniser.trained ? this.tokeniser : new f(T(k, this.tokeniser.vocabSize)), n = await (this.model.config.gpt.useRope && !o?.noCache ? this.generateCache(i, r, o) : this.generateNoCache(i, r, o));
70
119
  return this.active = !1, this.emit("stop"), n;
71
120
  }
72
121
  stop() {
@@ -74,5 +123,5 @@ class w extends u {
74
123
  }
75
124
  }
76
125
  export {
77
- w as default
126
+ ot as default
78
127
  };
@@ -13,6 +13,7 @@ export interface TrainingLogEntry {
13
13
  export interface GenerateOptions {
14
14
  temperature?: number;
15
15
  topK?: number;
16
+ topP?: number;
16
17
  usePadding?: boolean;
17
18
  attentionScores?: boolean;
18
19
  includeProbabilities?: boolean;
@@ -1,16 +1,18 @@
1
- import { defaultConfig as L } from "./config.js";
2
- import v from "./layers/TransformerBlock.js";
3
- import { E as T, D as q, T as K, r as P, p as _ } from "./TiedEmbedding-DsDRvLB0.js";
4
- import F from "./layers/RoPECache.js";
5
- import D from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as O } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
8
- import { B as R } from "./BaseLayer-BhrMN8JO.js";
9
- import { o as E, i as d, q as B, E as y, aa as G, ab as V, ac as j, t as w, a9 as A, f as z, F as W } from "./index-iNhkcAEQ.js";
10
- import { r as C } from "./reshape-DxTPgnwL.js";
11
- import { r as H } from "./range-BsFU-SNG.js";
12
- import { g as J } from "./gather-Bxe1Qip8.js";
13
- import { s as Q } from "./softmax-BjsptB07.js";
1
+ import { defaultConfig as F } from "./config.js";
2
+ import O from "./layers/TransformerBlock.js";
3
+ import { T as N, r as R } from "./TiedEmbedding-DORsPlNL.js";
4
+ import A from "./layers/RoPECache.js";
5
+ import G from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as j } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as B } from "./training/sparseCrossEntropy.js";
8
+ import V from "./layers/BaseLayer.js";
9
+ import { E as H, D as W, p as J } from "./random_width-DI2h9CMs.js";
10
+ import { o as x, j as y, u as Q, E as I, aa as U, ab as X, ac as Y, t as z, a9 as Z, f as L, H as tt } from "./index-CnHyhpKc.js";
11
+ import { r as T } from "./reshape-CTIbqjwm.js";
12
+ import { r as et } from "./range-CkOJ7090.js";
13
+ import { s as q } from "./softmax-DX6qXAbm.js";
14
+ import { t as ot } from "./ops-DzQTmLIl.js";
15
+ import { g as st } from "./gather-BWyutxwi.js";
14
16
  /**
15
17
  * @license
16
18
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -27,13 +29,13 @@ import { s as Q } from "./softmax-BjsptB07.js";
27
29
  * limitations under the License.
28
30
  * =============================================================================
29
31
  */
30
- function U(h, t) {
31
- let e = d(h, "a", "mod"), o = d(t, "b", "mod");
32
- [e, o] = B(e, o);
32
+ function nt(l, t) {
33
+ let e = y(l, "a", "mod"), o = y(t, "b", "mod");
34
+ [e, o] = Q(e, o);
33
35
  const n = { a: e, b: o };
34
- return y.runKernel(G, n);
36
+ return I.runKernel(U, n);
35
37
  }
36
- const X = /* @__PURE__ */ E({ mod_: U });
38
+ const it = /* @__PURE__ */ x({ mod_: nt });
37
39
  /**
38
40
  * @license
39
41
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -50,17 +52,17 @@ const X = /* @__PURE__ */ E({ mod_: U });
50
52
  * limitations under the License.
51
53
  * =============================================================================
52
54
  */
53
- function Y(h, t, e, o = !1) {
54
- const n = d(h, "logits", "multinomial"), s = n.size, i = n.rank;
55
+ function rt(l, t, e, o = !1) {
56
+ const n = y(l, "logits", "multinomial"), s = n.size, i = n.rank;
55
57
  if (s < 2)
56
58
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
57
59
  if (i > 2)
58
60
  throw new Error(`Rank of probabilities must be 1 or 2, but is ${i}`);
59
61
  e = e || Math.random();
60
- const c = { logits: i === 1 ? C(n, [1, -1]) : n }, l = { numSamples: t, seed: e, normalized: o }, a = y.runKernel(V, c, l);
61
- return i === 1 ? C(a, [a.size]) : a;
62
+ const r = { logits: i === 1 ? T(n, [1, -1]) : n }, u = { numSamples: t, seed: e, normalized: o }, c = I.runKernel(X, r, u);
63
+ return i === 1 ? T(c, [c.size]) : c;
62
64
  }
63
- const I = /* @__PURE__ */ E({ multinomial_: Y });
65
+ const C = /* @__PURE__ */ x({ multinomial_: rt });
64
66
  /**
65
67
  * @license
66
68
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -77,8 +79,8 @@ const I = /* @__PURE__ */ E({ multinomial_: Y });
77
79
  * limitations under the License.
78
80
  * =============================================================================
79
81
  */
80
- function Z(h, t = 1, e = !0) {
81
- const o = d(h, "x", "topk");
82
+ function ct(l, t = 1, e = !0) {
83
+ const o = y(l, "x", "topk");
82
84
  if (o.rank === 0)
83
85
  throw new Error("topk() expects the input to be of rank 1 or higher");
84
86
  const n = o.shape[o.shape.length - 1];
@@ -86,10 +88,10 @@ function Z(h, t = 1, e = !0) {
86
88
  throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
87
89
  if (t > n)
88
90
  throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
89
- const s = { x: o }, i = { k: t, sorted: e }, [r, c] = y.runKernel(j, s, i);
90
- return { values: r, indices: c };
91
+ const s = { x: o }, i = { k: t, sorted: e }, [p, r] = I.runKernel(Y, s, i);
92
+ return { values: p, indices: r };
91
93
  }
92
- const tt = /* @__PURE__ */ E({ topk_: Z });
94
+ const at = /* @__PURE__ */ x({ topk_: ct });
93
95
  /**
94
96
  * @license
95
97
  * Copyright 2018 Google LLC
@@ -99,13 +101,13 @@ const tt = /* @__PURE__ */ E({ topk_: Z });
99
101
  * https://opensource.org/licenses/MIT.
100
102
  * =============================================================================
101
103
  */
102
- function et(h) {
103
- return new q(h);
104
+ function lt(l) {
105
+ return new W(l);
104
106
  }
105
- function ot(h) {
106
- return new T(h);
107
+ function pt(l) {
108
+ return new H(l);
107
109
  }
108
- class dt extends R {
110
+ class xt extends V {
109
111
  wte;
110
112
  // Token embeddings
111
113
  wpe;
@@ -119,15 +121,15 @@ class dt extends R {
119
121
  log = [];
120
122
  // Training log
121
123
  constructor(t = {}) {
122
- super({ gpt: { ...L, ...t }, layerConfig: {} }), this.wte = new K(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = ot({
124
+ super({ gpt: { ...F, ...t }, layerConfig: {} }), this.wte = new N(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = pt({
123
125
  inputDim: this.config.gpt.blockSize,
124
126
  outputDim: this.config.gpt.nEmbed,
125
127
  name: "positional_embedding",
126
- embeddingsInitializer: P({ mean: 0, stddev: 0.02 })
127
- }) : (this.ropeCache = new F(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = et({ rate: this.config.gpt.dropout }), this.blocks = [];
128
+ embeddingsInitializer: R({ mean: 0, stddev: 0.02 })
129
+ }) : (this.ropeCache = new A(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = lt({ rate: this.config.gpt.dropout }), this.blocks = [];
128
130
  for (let e = 0; e < this.config.gpt.nLayer; e++)
129
- this.blocks.push(new v(e, this.config, this));
130
- this.lnF = new D(this.config, "final_rms_norm", this);
131
+ this.blocks.push(new O(e, this.config, this));
132
+ this.lnF = new G(this.config, "final_rms_norm", this);
131
133
  }
132
134
  get checkpointing() {
133
135
  return this.config.layerConfig.checkpointing === !0;
@@ -136,11 +138,11 @@ class dt extends R {
136
138
  this.config.layerConfig.checkpointing = t;
137
139
  }
138
140
  inputPhase(t, e, o = !1) {
139
- return w(() => {
141
+ return z(() => {
140
142
  const n = this.wte.embed(t);
141
143
  if (this.config.gpt.useRope === !1) {
142
- const [, s] = t.shape, i = this.config.gpt.blockSize, r = H(0, s, 1, "int32"), c = X(A(r, z(e, "int32")), z(i, "int32")), l = this.wpe.apply(c), a = n.add(l);
143
- return this.drop.apply(a, { training: o });
144
+ const [, s] = t.shape, i = this.config.gpt.blockSize, p = et(0, s, 1, "int32"), r = it(Z(p, L(e, "int32")), L(i, "int32")), u = this.wpe.apply(r), c = n.add(u);
145
+ return this.drop.apply(c, { training: o });
144
146
  } else
145
147
  return this.drop.apply(n, { training: o });
146
148
  });
@@ -167,7 +169,7 @@ class dt extends R {
167
169
  }
168
170
  calculateLoss(t, e) {
169
171
  try {
170
- return N()(t, e).mean();
172
+ return B()(t, e).mean();
171
173
  } catch (o) {
172
174
  throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
173
175
  }
@@ -205,7 +207,7 @@ class dt extends R {
205
207
  });
206
208
  }*/
207
209
  forward(t, e, o) {
208
- return this.validateInput(e), w(() => {
210
+ return this.validateInput(e), z(() => {
209
211
  this.startMemory();
210
212
  const n = t.cache?.[0]?.length ?? 0;
211
213
  let s = this.inputPhase(e, n, t.training);
@@ -213,61 +215,72 @@ class dt extends R {
213
215
  throw console.error("Cache", t.cache), new Error(
214
216
  `Cache length ${t.cache.length} does not match number of blocks ${this.blocks.length}`
215
217
  );
216
- for (let c = 0; c < this.blocks.length; c++) {
217
- const l = this.blocks[c], a = Math.random() * 1e9, u = {
218
+ for (let r = 0; r < this.blocks.length; r++) {
219
+ const u = this.blocks[r], c = Math.random() * 1e9, g = {
218
220
  training: t.training,
219
- seed: a,
221
+ seed: c,
220
222
  attentionScores: t.attentionScores,
221
- pastKV: t.cache ? t.cache[c] : void 0
222
- }, p = this.config.layerConfig.checkpointing && t.training ? l.callCheckpoint(u, s) : l.call(u, s);
223
- s.dispose(), s = p;
223
+ pastKV: t.cache ? t.cache[r] : void 0
224
+ }, S = this.config.layerConfig.checkpointing && t.training ? u.callCheckpoint(g, s) : u.call(g, s);
225
+ s.dispose(), s = S;
224
226
  }
225
227
  s = this.lnF.call(t, s);
226
228
  const i = this.wte.project(s);
227
229
  s.dispose();
228
- let r;
229
- return o && (r = this.calculateLoss(i, o)), this.endMemory("Forward"), r ? [i, r] : [i];
230
+ let p;
231
+ return o && (p = this.calculateLoss(i, o)), this.endMemory("Forward"), p ? [i, p] : [i];
230
232
  });
231
233
  }
232
234
  generate(t, e, o) {
233
- const n = o?.temperature ?? 1, s = o?.topK, i = o?.usePadding ?? !1;
234
- return w(() => {
235
- const r = t, c = r.shape[1], l = c <= this.config.gpt.blockSize ? r : r.slice(
236
- [0, c - this.config.gpt.blockSize],
235
+ const n = o?.temperature ?? 1, s = o?.topK, i = o?.topP, p = o?.usePadding ?? !1;
236
+ return z(() => {
237
+ const r = t, u = r.shape[1], c = u <= this.config.gpt.blockSize ? r : r.slice(
238
+ [0, u - this.config.gpt.blockSize],
237
239
  [r.shape[0], this.config.gpt.blockSize]
238
- ), a = i ? this.config.gpt.blockSize - l.shape[1] : 0, u = a > 0 ? _(l, [
240
+ ), g = p ? this.config.gpt.blockSize - c.shape[1] : 0, S = g > 0 ? J(c, [
239
241
  [0, 0],
240
- [0, a]
241
- ]) : l, p = {
242
+ [0, g]
243
+ ]) : c, f = {
242
244
  training: !1,
243
245
  attentionScores: o?.attentionScores ? {
244
246
  attentionOut: []
245
247
  } : void 0,
246
248
  cache: e
247
- }, [f] = this.forward(p, u), S = f.shape[1] - 1 - a, M = f.slice([0, S, 0], [f.shape[0], 1, f.shape[2]]);
248
- p.attentionScores?.attentionOut && p.attentionScores.attentionOut.forEach((g, k) => {
249
- g.shape[1] !== 1 && (p.attentionScores.attentionOut[k] = W(
250
- g.slice([0, S, 0], [g.shape[0], 1, g.shape[2]])
251
- ), g.dispose());
252
- }), f.dispose();
253
- const b = M.div(n);
249
+ }, [d] = this.forward(f, S), M = d.shape[1] - 1 - g, K = d.slice([0, M, 0], [d.shape[0], 1, d.shape[2]]);
250
+ f.attentionScores?.attentionOut && f.attentionScores.attentionOut.forEach((h, b) => {
251
+ h.shape[1] !== 1 && (f.attentionScores.attentionOut[b] = tt(
252
+ h.slice([0, M, 0], [h.shape[0], 1, h.shape[2]])
253
+ ), h.dispose());
254
+ }), d.dispose();
255
+ const w = K.div(n);
254
256
  let m;
255
- if (s) {
256
- const { values: g, indices: k } = tt(b, s), x = I(g.squeeze([1]), 1);
257
- m = J(k.squeeze([1]), x, 1);
257
+ if (i) {
258
+ const h = q(w.squeeze([1])), b = h.arraySync()[0];
259
+ h.dispose();
260
+ const E = b.map((a, k) => ({ prob: a, index: k })).sort((a, k) => k.prob - a.prob);
261
+ let v = 0;
262
+ const $ = new Array(E.length).fill(0);
263
+ for (const a of E)
264
+ if (v += a.prob, $[a.index] = a.prob, v >= i)
265
+ break;
266
+ const _ = $.reduce((a, k) => a + k, 0), D = $.map((a) => a / _);
267
+ m = C(ot(D), 1, void 0, !0);
268
+ } else if (s) {
269
+ const { values: h, indices: b } = at(w, s), E = C(h.squeeze([1]), 1);
270
+ m = st(b.squeeze([1]), E, 1);
258
271
  } else
259
- m = I(b.squeeze([1]), 1);
260
- let $;
261
- return o?.includeProbabilities && ($ = Q(b.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: $, attention: p.attentionScores?.attentionOut };
272
+ m = C(w.squeeze([1]), 1);
273
+ let P;
274
+ return o?.includeProbabilities && (P = q(w.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: P, attention: f.attentionScores?.attentionOut };
262
275
  });
263
276
  }
264
277
  getNumParams() {
265
- return O(this.config.gpt);
278
+ return j(this.config.gpt);
266
279
  }
267
280
  dispose() {
268
281
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
269
282
  }
270
283
  }
271
284
  export {
272
- dt as default
285
+ xt as default
273
286
  };
@@ -1,5 +1,5 @@
1
- import { ad as $, ae as g, p, af as C, k as x } from "./index-iNhkcAEQ.js";
2
- import { u as I } from "./gpgpu_math-C0zyxKFi.js";
1
+ import { ad as $, ae as g, q as p, af as C, l as x } from "./index-CnHyhpKc.js";
2
+ import { u as I } from "./gpgpu_math-Df7gzJWH.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -201,12 +201,12 @@ function D(t, e, o) {
201
201
  * limitations under the License.
202
202
  * =============================================================================
203
203
  */
204
- function U(t) {
204
+ function k(t) {
205
205
  const { inputs: e, backend: o, attrs: s } = t, { x: n } = e, { shape: r } = s, i = o, u = p(n.shape), a = C(r, u), c = p(a);
206
206
  x(u === c, () => `The new shape (${a}) has ${c} elements and the old shape (${n.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
207
207
  const l = i.texData.get(n.dataId);
208
208
  return l.isPacked && !f(n.shape, a) && !(l.texture !== null && f(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
209
209
  }
210
210
  export {
211
- U as r
211
+ k as r
212
212
  };
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
11
11
  import "./papaparse.min-C8l2Kvo1.js";
12
12
  import "./index-Tf7vU29b.js";
13
13
  import "./jszip.min-CjP2V1VV.js";
14
- import "./index-iNhkcAEQ.js";
14
+ import "./index-CnHyhpKc.js";
15
15
  import "./ops/cpu/scatterSub.js";
16
16
  import "./ops/webgl/scatterSub.js";
17
17
  import "./ops/cpu/gatherSub.js";
@@ -0,0 +1,44 @@
1
+ import { R as a } from "./random_width-DI2h9CMs.js";
2
+ import "./index-CnHyhpKc.js";
3
+ import { d as s } from "./tfjs_backend-DX9yVvwk.js";
4
+ import o from "./layers/BaseLayer.js";
5
+ import { v as m } from "./variable-BGvK-VN3.js";
6
+ import { g as d } from "./gather-BWyutxwi.js";
7
+ /**
8
+ * @license
9
+ * Copyright 2018 Google LLC
10
+ *
11
+ * Use of this source code is governed by an MIT-style
12
+ * license that can be found in the LICENSE file or at
13
+ * https://opensource.org/licenses/MIT.
14
+ * =============================================================================
15
+ */
16
+ function n(e) {
17
+ return new a(e);
18
+ }
19
+ class c extends o {
20
+ vocabSize;
21
+ embedDim;
22
+ initializer;
23
+ WEIGHTS;
24
+ constructor(t, i, r) {
25
+ super(t, r), this.WEIGHTS = i, this.vocabSize = t.gpt.vocabSize, this.embedDim = t.gpt.nEmbed, this.initializer = n({
26
+ mean: 0,
27
+ stddev: 0.02
28
+ }), this.addVariable(this.WEIGHTS, m(this.initializer.apply([this.vocabSize, this.embedDim]), !0));
29
+ }
30
+ embed(t) {
31
+ return d(this.getVariable(this.WEIGHTS), t, 0);
32
+ }
33
+ project(t) {
34
+ return s(t, this.getVariable(this.WEIGHTS).transpose());
35
+ }
36
+ // Dummy, should not be used.
37
+ forward(t, i) {
38
+ return this.project(i);
39
+ }
40
+ }
41
+ export {
42
+ c as T,
43
+ n as r
44
+ };
@@ -1,4 +1,4 @@
1
- import { k as c } from "./index-iNhkcAEQ.js";
1
+ import { l as c } from "./index-CnHyhpKc.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2017 Google LLC. All Rights Reserved.
@@ -21,7 +21,7 @@ function i(e, n) {
21
21
  return !1;
22
22
  return !0;
23
23
  }
24
- function p(e, n, t) {
24
+ function l(e, n, t) {
25
25
  const r = e.length + n.length, s = [];
26
26
  let o = 0, f = 0;
27
27
  for (let u = 0; u < r; u++)
@@ -37,7 +37,7 @@ function a(e, n) {
37
37
  }
38
38
  function m(e, n) {
39
39
  const t = n.map((r) => 1);
40
- return p(e, t, n);
40
+ return l(e, t, n);
41
41
  }
42
42
  function d(e, n, t) {
43
43
  c(i(n, t), () => `${e} supports only inner-most axes for now. Got axes ${n} and rank-${t} input.`);
@@ -1,5 +1,5 @@
1
- import { o as h, i as f, l as p, x as g, E as u, T } from "./index-iNhkcAEQ.js";
2
- import { r as b } from "./reshape-DxTPgnwL.js";
1
+ import { o as h, j as f, n as p, y as g, E as u, L as b } from "./index-CnHyhpKc.js";
2
+ import { r as T } from "./reshape-CTIbqjwm.js";
3
3
  /**
4
4
  * @license
5
5
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -25,7 +25,7 @@ function m(e, r) {
25
25
  const t = n.shape.slice();
26
26
  for (; t.length < r.length; )
27
27
  t.unshift(1);
28
- n = b(n, t);
28
+ n = T(n, t);
29
29
  }
30
30
  const s = n.shape, o = Array.from(r);
31
31
  for (let t = r.length - 1; t >= 0; t--)
@@ -36,7 +36,7 @@ function m(e, r) {
36
36
  if (o.map((t, l) => t > 1 ? l : -1).filter((t) => t >= 0).length === 0)
37
37
  return g(n);
38
38
  const i = { x: n }, c = { reps: o };
39
- return u.runKernel(T, i, c);
39
+ return u.runKernel(b, i, c);
40
40
  }
41
41
  const E = /* @__PURE__ */ h({ broadcastTo_: m });
42
42
  export {
@@ -1,4 +1,4 @@
1
- import { o as s, k as a, j as p, x as i, E as l, C as f } from "./index-iNhkcAEQ.js";
1
+ import { o as s, l as a, k as p, y as i, E as l, C as f } from "./index-CnHyhpKc.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.