@genai-fi/nanogpt 0.6.0 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. package/dist/Generator.js +7 -7
  2. package/dist/NanoGPTModel.js +70 -121
  3. package/dist/RealDiv-BYViZwhN.js +540 -0
  4. package/dist/Reshape-t7Kcikjk.js +127 -0
  5. package/dist/TeachableLLM.d.ts +2 -0
  6. package/dist/TeachableLLM.js +34 -27
  7. package/dist/{TiedEmbedding-BhxWO8QR.js → TiedEmbedding-9WeDwvjO.js} +12 -13
  8. package/dist/{axis_util-D17qZRQm.js → axis_util-Bu4h7XWV.js} +14 -12
  9. package/dist/{broadcast_to-BMQLjvt_.js → broadcast_to-DARN-DBD.js} +2 -2
  10. package/dist/{concat-DhZfF1GY.js → concat-5aPGqw3Z.js} +3 -3
  11. package/dist/{dataset-oilnemHf.js → dataset-pgqp-YfL.js} +3 -3
  12. package/dist/{dropout-CrMQPCeG.js → dropout-Bciw46HT.js} +7 -7
  13. package/dist/{gather-DZCMHZuN.js → gather-DjyCjmOD.js} +1 -1
  14. package/dist/gpgpu_math-CNslybmD.js +3115 -0
  15. package/dist/{index-bMBtI-WR.js → index-BAzbokzv.js} +846 -649
  16. package/dist/{kernel_funcs_utils-CNmjLWnB.js → kernel_funcs_utils-CUxJCg0g.js} +232 -138
  17. package/dist/layers/BaseLayer.js +2 -2
  18. package/dist/layers/CausalSelfAttention.js +6 -6
  19. package/dist/layers/MLP.js +5 -5
  20. package/dist/layers/RMSNorm.js +3 -3
  21. package/dist/layers/RoPECache.js +13 -33
  22. package/dist/layers/TiedEmbedding.js +6 -7
  23. package/dist/layers/TransformerBlock.js +1 -1
  24. package/dist/loader/load.d.ts +13 -0
  25. package/dist/loader/load.js +27 -0
  26. package/dist/loader/loadHF.d.ts +7 -0
  27. package/dist/loader/loadHF.js +22 -0
  28. package/dist/{utilities/load.d.ts → loader/loadTransformers.d.ts} +11 -11
  29. package/dist/loader/loadTransformers.js +28 -0
  30. package/dist/loader/newZipLoad.d.ts +8 -0
  31. package/dist/loader/newZipLoad.js +21 -0
  32. package/dist/loader/oldZipLoad.d.ts +7 -0
  33. package/dist/loader/oldZipLoad.js +76 -0
  34. package/dist/{log_sum_exp-BHdkCb4s.js → log_sum_exp-YEo2h3gb.js} +14 -14
  35. package/dist/main.js +23 -20
  36. package/dist/{mat_mul-BsrLfy81.js → mat_mul-7121rsJk.js} +1 -1
  37. package/dist/{max-DechV4Bc.js → max-DtlIuVeW.js} +1 -1
  38. package/dist/mulmat_packed_gpu-D4nKF7Je.js +71 -0
  39. package/dist/{norm-B9hWHZH1.js → norm-CzltS9Fz.js} +16 -16
  40. package/dist/{ones-g0K8jVwm.js → ones-BBlSRqn1.js} +2 -2
  41. package/dist/ops/appendCache.js +3 -3
  42. package/dist/ops/attentionMask.js +1 -1
  43. package/dist/ops/cpu/appendCache.js +2 -2
  44. package/dist/ops/cpu/attentionMask.js +6 -6
  45. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  46. package/dist/ops/cpu/gatherSub.js +9 -9
  47. package/dist/ops/cpu/gelu.js +1 -1
  48. package/dist/ops/cpu/matMulGelu.js +1 -1
  49. package/dist/ops/cpu/matMulMul.js +1 -1
  50. package/dist/ops/cpu/mulDropout.js +1 -1
  51. package/dist/ops/cpu/normRMS.js +1 -1
  52. package/dist/ops/cpu/qkv.js +3 -3
  53. package/dist/ops/cpu/rope.js +5 -5
  54. package/dist/ops/cpu/scatterSub.js +17 -48
  55. package/dist/ops/fusedSoftmax.js +1 -1
  56. package/dist/ops/gatherSub.js +1 -1
  57. package/dist/ops/gelu.js +1 -1
  58. package/dist/ops/grads/attentionMask.js +1 -1
  59. package/dist/ops/grads/fusedSoftmax.js +4 -4
  60. package/dist/ops/grads/gelu.js +1 -1
  61. package/dist/ops/grads/matMulGelu.js +1 -1
  62. package/dist/ops/grads/normRMS.js +1 -1
  63. package/dist/ops/grads/qkv.js +1 -1
  64. package/dist/ops/grads/rope.js +1 -1
  65. package/dist/ops/matMulGelu.js +1 -1
  66. package/dist/ops/matMulMul.js +1 -1
  67. package/dist/ops/mulDrop.js +1 -1
  68. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  69. package/dist/ops/normRMS.js +1 -1
  70. package/dist/ops/qkv.js +1 -1
  71. package/dist/ops/rope.js +8 -4
  72. package/dist/ops/scatterSub.js +1 -1
  73. package/dist/ops/webgl/appendCache.js +1 -1
  74. package/dist/ops/webgl/attentionMask.js +1 -1
  75. package/dist/ops/webgl/fusedSoftmax.js +29 -560
  76. package/dist/ops/webgl/gatherSub.js +1 -1
  77. package/dist/ops/webgl/gelu.js +2 -2
  78. package/dist/ops/webgl/log.js +3 -3
  79. package/dist/ops/webgl/matMulGelu.js +46 -113
  80. package/dist/ops/webgl/matMulMul.js +1 -1
  81. package/dist/ops/webgl/mulDropout.js +1 -1
  82. package/dist/ops/webgl/normRMS.js +2 -2
  83. package/dist/ops/webgl/qkv.js +1 -1
  84. package/dist/ops/webgl/rope.js +1 -1
  85. package/dist/ops/webgl/scatterSub.js +1 -1
  86. package/dist/{ops-Mv7Ta72x.js → ops-C0sQEcPw.js} +117 -109
  87. package/dist/{random_width-BBAWzDym.js → random_width-DWzaOgrn.js} +6925 -6291
  88. package/dist/{range-DMaG9A3G.js → range-DYsrnfiy.js} +1 -1
  89. package/dist/{gpgpu_math-Ctc31slO.js → reciprocal-CJQeasVa.js} +7 -5
  90. package/dist/register_all_kernels-BfFCQAqs.js +21397 -0
  91. package/dist/{reshape-T4yDEqoF.js → reshape-krWGKraP.js} +1 -1
  92. package/dist/scatter_nd_util-93ln7Hut.js +46 -0
  93. package/dist/selu_util-sntGesxr.js +740 -0
  94. package/dist/{shared-XNAoXhOa.js → shared-Ca6iDobD.js} +1462 -1089
  95. package/dist/{sin-EEhbrRO_.js → sin-D_h-qCSx.js} +1 -1
  96. package/dist/{softmax-B2_IKPDR.js → softmax-fsdtf6JC.js} +1 -1
  97. package/dist/{split-dcks18H1.js → split-eiktj-6L.js} +1 -1
  98. package/dist/{stack-lpJ5kYvE.js → stack-dfEEz2OY.js} +2 -2
  99. package/dist/{sum-CutF5lj2.js → sum-BE_Irnim.js} +1 -1
  100. package/dist/{tensor-C15NA2LA.js → tensor-Xyi595sG.js} +1 -1
  101. package/dist/{tensor2d-DZ_e5eKM.js → tensor2d-CPEkynbH.js} +1 -1
  102. package/dist/training/AdamExt.js +1 -1
  103. package/dist/training/DatasetBuilder.js +2 -2
  104. package/dist/training/FullTrainer.js +1 -1
  105. package/dist/training/Trainer.js +3 -3
  106. package/dist/training/sparseCrossEntropy.js +5 -5
  107. package/dist/utilities/dummy.d.ts +6 -0
  108. package/dist/utilities/dummy.js +31 -10
  109. package/dist/utilities/generate.js +3 -3
  110. package/dist/utilities/profile.d.ts +5 -0
  111. package/dist/utilities/profile.js +10 -7
  112. package/dist/utilities/safetensors.js +2 -2
  113. package/dist/utilities/save.js +1 -1
  114. package/dist/utilities/weights.js +2 -2
  115. package/dist/{variable-CdRKKp8x.js → variable-wSS22xj5.js} +1 -1
  116. package/dist/{zeros-CAbHfODe.js → zeros-YJDE7oRb.js} +4 -4
  117. package/package.json +2 -8
  118. package/dist/Reshape-CLOrdpve.js +0 -212
  119. package/dist/slice_util-Ddk0uxGJ.js +0 -49
  120. package/dist/tfjs_backend-BDb8r9qx.js +0 -1010
  121. package/dist/utilities/load.js +0 -99
package/dist/Generator.js CHANGED
@@ -1,12 +1,15 @@
1
1
  import { E as u } from "./index-Dwqa6Zy2.js";
2
- import "./index-bMBtI-WR.js";
2
+ import "./index-BAzbokzv.js";
3
3
  import "./ops/cpu/attentionMask.js";
4
4
  import "./ops/webgl/attentionMask.js";
5
5
  import "./ops/grads/attentionMask.js";
6
6
  import "./ops/cpu/qkv.js";
7
7
  import "./ops/webgl/qkv.js";
8
8
  import "./ops/grads/qkv.js";
9
- import "@tensorflow/tfjs";
9
+ import "./random_width-DWzaOgrn.js";
10
+ import "./register_all_kernels-BfFCQAqs.js";
11
+ import "./index-Tf7vU29b.js";
12
+ import "./dataset-pgqp-YfL.js";
10
13
  import "./ops/cpu/rope.js";
11
14
  import "./ops/webgl/rope.js";
12
15
  import "./ops/grads/rope.js";
@@ -21,22 +24,19 @@ import "./ops/grads/matMulGelu.js";
21
24
  import "./ops/cpu/normRMS.js";
22
25
  import "./ops/webgl/normRMS.js";
23
26
  import "./ops/grads/normRMS.js";
24
- import "./random_width-BBAWzDym.js";
25
27
  import "./ops/cpu/gatherSub.js";
26
28
  import "./ops/webgl/gatherSub.js";
27
29
  import "./ops/cpu/scatterSub.js";
28
30
  import "./ops/webgl/scatterSub.js";
29
31
  import "./jszip.min-CjP2V1VV.js";
30
32
  import f from "./tokeniser/CharTokeniser.js";
31
- import "./dataset-oilnemHf.js";
32
- import "./index-Tf7vU29b.js";
33
33
  import "./papaparse.min-C8l2Kvo1.js";
34
34
  import "./ops/cpu/gelu.js";
35
35
  import "./ops/webgl/gelu.js";
36
36
  import "./ops/grads/gelu.js";
37
37
  import "./ops/webgl/log.js";
38
- import { t as d } from "./tensor2d-DZ_e5eKM.js";
39
- import { c as g } from "./concat-DhZfF1GY.js";
38
+ import { t as d } from "./tensor2d-CPEkynbH.js";
39
+ import { c as g } from "./concat-5aPGqw3Z.js";
40
40
  const k = [
41
41
  ...Array.from({ length: 95 }, (a, t) => String.fromCharCode(t + 32)),
42
42
  // ASCII
@@ -1,18 +1,19 @@
1
1
  import { defaultConfig as F } from "./config.js";
2
2
  import O from "./layers/TransformerBlock.js";
3
- import { T as N, r as R } from "./TiedEmbedding-BhxWO8QR.js";
4
- import A from "./layers/RoPECache.js";
5
- import G from "./layers/RMSNorm.js";
6
- import { estimateParameterCount as j } from "./utilities/parameters.js";
7
- import { createSoftmaxCrossEntropyWithGrad as B } from "./training/sparseCrossEntropy.js";
8
- import V from "./layers/BaseLayer.js";
9
- import { E as H, D as W, p as J } from "./random_width-BBAWzDym.js";
10
- import { o as x, j as y, u as Q, E as I, a9 as U, aa as X, ab as Y, t as z, a8 as Z, f as L, H as tt } from "./index-bMBtI-WR.js";
11
- import { r as T } from "./reshape-T4yDEqoF.js";
12
- import { r as et } from "./range-DMaG9A3G.js";
13
- import { s as q } from "./softmax-B2_IKPDR.js";
14
- import { t as ot } from "./ops-Mv7Ta72x.js";
15
- import { g as st } from "./gather-DZCMHZuN.js";
3
+ import { T as _, r as D } from "./TiedEmbedding-9WeDwvjO.js";
4
+ import K from "./layers/RoPECache.js";
5
+ import N from "./layers/RMSNorm.js";
6
+ import { estimateParameterCount as R } from "./utilities/parameters.js";
7
+ import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
8
+ import G from "./layers/BaseLayer.js";
9
+ import { E as B, D as V, p as j } from "./random_width-DWzaOgrn.js";
10
+ import { o as W, q as H, E as J, a6 as Q, t as z, a7 as U, s as v, k as X } from "./index-BAzbokzv.js";
11
+ import { m as Y, t as Z } from "./register_all_kernels-BfFCQAqs.js";
12
+ import { r as L } from "./reshape-krWGKraP.js";
13
+ import { r as tt } from "./range-DYsrnfiy.js";
14
+ import { s as M } from "./softmax-fsdtf6JC.js";
15
+ import { t as et } from "./ops-C0sQEcPw.js";
16
+ import { g as ot } from "./gather-DjyCjmOD.js";
16
17
  /**
17
18
  * @license
18
19
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -29,69 +30,17 @@ import { g as st } from "./gather-DZCMHZuN.js";
29
30
  * limitations under the License.
30
31
  * =============================================================================
31
32
  */
32
- function nt(l, t) {
33
- let e = y(l, "a", "mod"), o = y(t, "b", "mod");
34
- [e, o] = Q(e, o);
35
- const n = { a: e, b: o };
36
- return I.runKernel(U, n);
37
- }
38
- const it = /* @__PURE__ */ x({ mod_: nt });
39
- /**
40
- * @license
41
- * Copyright 2020 Google LLC. All Rights Reserved.
42
- * Licensed under the Apache License, Version 2.0 (the "License");
43
- * you may not use this file except in compliance with the License.
44
- * You may obtain a copy of the License at
45
- *
46
- * http://www.apache.org/licenses/LICENSE-2.0
47
- *
48
- * Unless required by applicable law or agreed to in writing, software
49
- * distributed under the License is distributed on an "AS IS" BASIS,
50
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
51
- * See the License for the specific language governing permissions and
52
- * limitations under the License.
53
- * =============================================================================
54
- */
55
- function rt(l, t, e, o = !1) {
56
- const n = y(l, "logits", "multinomial"), s = n.size, i = n.rank;
33
+ function st(u, t, e, o = !1) {
34
+ const r = H(u, "logits", "multinomial"), s = r.size, n = r.rank;
57
35
  if (s < 2)
58
36
  throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
59
- if (i > 2)
60
- throw new Error(`Rank of probabilities must be 1 or 2, but is ${i}`);
37
+ if (n > 2)
38
+ throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
61
39
  e = e || Math.random();
62
- const r = { logits: i === 1 ? T(n, [1, -1]) : n }, u = { numSamples: t, seed: e, normalized: o }, c = I.runKernel(X, r, u);
63
- return i === 1 ? T(c, [c.size]) : c;
64
- }
65
- const C = /* @__PURE__ */ x({ multinomial_: rt });
66
- /**
67
- * @license
68
- * Copyright 2018 Google LLC. All Rights Reserved.
69
- * Licensed under the Apache License, Version 2.0 (the "License");
70
- * you may not use this file except in compliance with the License.
71
- * You may obtain a copy of the License at
72
- *
73
- * http://www.apache.org/licenses/LICENSE-2.0
74
- *
75
- * Unless required by applicable law or agreed to in writing, software
76
- * distributed under the License is distributed on an "AS IS" BASIS,
77
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78
- * See the License for the specific language governing permissions and
79
- * limitations under the License.
80
- * =============================================================================
81
- */
82
- function ct(l, t = 1, e = !0) {
83
- const o = y(l, "x", "topk");
84
- if (o.rank === 0)
85
- throw new Error("topk() expects the input to be of rank 1 or higher");
86
- const n = o.shape[o.shape.length - 1];
87
- if (t < 0)
88
- throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
89
- if (t > n)
90
- throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
91
- const s = { x: o }, i = { k: t, sorted: e }, [p, r] = I.runKernel(Y, s, i);
92
- return { values: p, indices: r };
40
+ const i = { logits: n === 1 ? L(r, [1, -1]) : r }, h = { numSamples: t, seed: e, normalized: o }, c = J.runKernel(Q, i, h);
41
+ return n === 1 ? L(c, [c.size]) : c;
93
42
  }
94
- const at = /* @__PURE__ */ x({ topk_: ct });
43
+ const C = /* @__PURE__ */ W({ multinomial_: st });
95
44
  /**
96
45
  * @license
97
46
  * Copyright 2018 Google LLC
@@ -101,13 +50,13 @@ const at = /* @__PURE__ */ x({ topk_: ct });
101
50
  * https://opensource.org/licenses/MIT.
102
51
  * =============================================================================
103
52
  */
104
- function lt(l) {
105
- return new W(l);
53
+ function nt(u) {
54
+ return new V(u);
106
55
  }
107
- function pt(l) {
108
- return new H(l);
56
+ function it(u) {
57
+ return new B(u);
109
58
  }
110
- class xt extends V {
59
+ class St extends G {
111
60
  wte;
112
61
  // Token embeddings
113
62
  wpe;
@@ -121,15 +70,15 @@ class xt extends V {
121
70
  log = [];
122
71
  // Training log
123
72
  constructor(t = {}) {
124
- super({ gpt: { ...F, ...t }, layerConfig: {} }), this.wte = new N(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = pt({
73
+ super({ gpt: { ...F, ...t }, layerConfig: {} }), this.wte = new _(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = it({
125
74
  inputDim: this.config.gpt.blockSize,
126
75
  outputDim: this.config.gpt.nEmbed,
127
76
  name: "positional_embedding",
128
- embeddingsInitializer: R({ mean: 0, stddev: 0.02 })
129
- }) : (this.ropeCache = new A(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = lt({ rate: this.config.gpt.dropout }), this.blocks = [];
77
+ embeddingsInitializer: D({ mean: 0, stddev: 0.02 })
78
+ }) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = nt({ rate: this.config.gpt.dropout }), this.blocks = [];
130
79
  for (let e = 0; e < this.config.gpt.nLayer; e++)
131
80
  this.blocks.push(new O(e, this.config, this));
132
- this.lnF = new G(this.config, "final_rms_norm", this);
81
+ this.lnF = new N(this.config, "final_rms_norm", this);
133
82
  }
134
83
  get checkpointing() {
135
84
  return this.config.layerConfig.checkpointing === !0;
@@ -139,12 +88,12 @@ class xt extends V {
139
88
  }
140
89
  inputPhase(t, e, o = !1) {
141
90
  return z(() => {
142
- const n = this.wte.embed(t);
91
+ const r = this.wte.embed(t);
143
92
  if (this.config.gpt.useRope === !1) {
144
- const [, s] = t.shape, i = this.config.gpt.blockSize, p = et(0, s, 1, "int32"), r = it(Z(p, L(e, "int32")), L(i, "int32")), u = this.wpe.apply(r), c = n.add(u);
93
+ const [, s] = t.shape, n = this.config.gpt.blockSize, p = tt(0, s, 1, "int32"), i = Y(U(p, v(e, "int32")), v(n, "int32")), h = this.wpe.apply(i), c = r.add(h);
145
94
  return this.drop.apply(c, { training: o });
146
95
  } else
147
- return this.drop.apply(n, { training: o });
96
+ return this.drop.apply(r, { training: o });
148
97
  });
149
98
  }
150
99
  setSkipMask(t) {
@@ -169,7 +118,7 @@ class xt extends V {
169
118
  }
170
119
  calculateLoss(t, e) {
171
120
  try {
172
- return B()(t, e).mean();
121
+ return A()(t, e).mean();
173
122
  } catch (o) {
174
123
  throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
175
124
  }
@@ -209,35 +158,35 @@ class xt extends V {
209
158
  forward(t, e, o) {
210
159
  return this.validateInput(e), z(() => {
211
160
  this.startMemory();
212
- const n = t.cache?.[0]?.length ?? 0;
213
- let s = this.inputPhase(e, n, t.training);
161
+ const r = t.cache?.[0]?.length ?? 0;
162
+ let s = this.inputPhase(e, r, t.training);
214
163
  if (t.cache && t.cache.length !== this.blocks.length)
215
164
  throw console.error("Cache", t.cache), new Error(
216
165
  `Cache length ${t.cache.length} does not match number of blocks ${this.blocks.length}`
217
166
  );
218
- for (let r = 0; r < this.blocks.length; r++) {
219
- const u = this.blocks[r], c = Math.random() * 1e9, g = {
167
+ for (let i = 0; i < this.blocks.length; i++) {
168
+ const h = this.blocks[i], c = Math.random() * 1e9, g = {
220
169
  training: t.training,
221
170
  seed: c,
222
171
  attentionScores: t.attentionScores,
223
- pastKV: t.cache ? t.cache[r] : void 0
224
- }, S = this.config.layerConfig.checkpointing && t.training ? u.callCheckpoint(g, s) : u.call(g, s);
225
- s.dispose(), s = S;
172
+ pastKV: t.cache ? t.cache[i] : void 0
173
+ }, E = this.config.layerConfig.checkpointing && t.training ? h.callCheckpoint(g, s) : h.call(g, s);
174
+ s.dispose(), s = E;
226
175
  }
227
176
  s = this.lnF.call(t, s);
228
- const i = this.wte.project(s);
177
+ const n = this.wte.project(s);
229
178
  s.dispose();
230
179
  let p;
231
- return o && (p = this.calculateLoss(i, o)), this.endMemory("Forward"), p ? [i, p] : [i];
180
+ return o && (p = this.calculateLoss(n, o)), this.endMemory("Forward"), p ? [n, p] : [n];
232
181
  });
233
182
  }
234
183
  generate(t, e, o) {
235
- const n = o?.temperature ?? 1, s = o?.topK, i = o?.topP, p = o?.usePadding ?? !1;
184
+ const r = o?.temperature ?? 1, s = o?.topK, n = o?.topP, p = o?.usePadding ?? !1;
236
185
  return z(() => {
237
- const r = t, u = r.shape[1], c = u <= this.config.gpt.blockSize ? r : r.slice(
238
- [0, u - this.config.gpt.blockSize],
239
- [r.shape[0], this.config.gpt.blockSize]
240
- ), g = p ? this.config.gpt.blockSize - c.shape[1] : 0, S = g > 0 ? J(c, [
186
+ const i = t, h = i.shape[1], c = h <= this.config.gpt.blockSize ? i : i.slice(
187
+ [0, h - this.config.gpt.blockSize],
188
+ [i.shape[0], this.config.gpt.blockSize]
189
+ ), g = p ? this.config.gpt.blockSize - c.shape[1] : 0, E = g > 0 ? j(c, [
241
190
  [0, 0],
242
191
  [0, g]
243
192
  ]) : c, f = {
@@ -246,41 +195,41 @@ class xt extends V {
246
195
  attentionOut: []
247
196
  } : void 0,
248
197
  cache: e
249
- }, [d] = this.forward(f, S), M = d.shape[1] - 1 - g, K = d.slice([0, M, 0], [d.shape[0], 1, d.shape[2]]);
250
- f.attentionScores?.attentionOut && f.attentionScores.attentionOut.forEach((h, b) => {
251
- h.shape[1] !== 1 && (f.attentionScores.attentionOut[b] = tt(
252
- h.slice([0, M, 0], [h.shape[0], 1, h.shape[2]])
253
- ), h.dispose());
198
+ }, [d] = this.forward(f, E), $ = d.shape[1] - 1 - g, q = d.slice([0, $, 0], [d.shape[0], 1, d.shape[2]]);
199
+ f.attentionScores?.attentionOut && f.attentionScores.attentionOut.forEach((l, b) => {
200
+ l.shape[1] !== 1 && (f.attentionScores.attentionOut[b] = X(
201
+ l.slice([0, $, 0], [l.shape[0], 1, l.shape[2]])
202
+ ), l.dispose());
254
203
  }), d.dispose();
255
- const w = K.div(n);
204
+ const w = q.div(r);
256
205
  let m;
257
- if (i) {
258
- const h = q(w.squeeze([1])), b = h.arraySync()[0];
259
- h.dispose();
260
- const E = b.map((a, k) => ({ prob: a, index: k })).sort((a, k) => k.prob - a.prob);
261
- let v = 0;
262
- const $ = new Array(E.length).fill(0);
263
- for (const a of E)
264
- if (v += a.prob, $[a.index] = a.prob, v >= i)
206
+ if (n) {
207
+ const l = M(w.squeeze([1])), b = l.arraySync()[0];
208
+ l.dispose();
209
+ const y = b.map((a, k) => ({ prob: a, index: k })).sort((a, k) => k.prob - a.prob);
210
+ let P = 0;
211
+ const S = new Array(y.length).fill(0);
212
+ for (const a of y)
213
+ if (P += a.prob, S[a.index] = a.prob, P >= n)
265
214
  break;
266
- const _ = $.reduce((a, k) => a + k, 0), D = $.map((a) => a / _);
267
- m = C(ot(D), 1, void 0, !0);
215
+ const x = S.reduce((a, k) => a + k, 0), T = S.map((a) => a / x);
216
+ m = C(et(T), 1, void 0, !0);
268
217
  } else if (s) {
269
- const { values: h, indices: b } = at(w, s), E = C(h.squeeze([1]), 1);
270
- m = st(b.squeeze([1]), E, 1);
218
+ const { values: l, indices: b } = Z(w, s), y = C(l.squeeze([1]), 1);
219
+ m = ot(b.squeeze([1]), y, 1);
271
220
  } else
272
221
  m = C(w.squeeze([1]), 1);
273
- let P;
274
- return o?.includeProbabilities && (P = q(w.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: P, attention: f.attentionScores?.attentionOut };
222
+ let I;
223
+ return o?.includeProbabilities && (I = M(w.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: I, attention: f.attentionScores?.attentionOut };
275
224
  });
276
225
  }
277
226
  getNumParams() {
278
- return j(this.config.gpt);
227
+ return R(this.config.gpt);
279
228
  }
280
229
  dispose() {
281
230
  this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
282
231
  }
283
232
  }
284
233
  export {
285
- xt as default
234
+ St as default
286
235
  };