@genai-fi/nanogpt 0.5.6 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +10 -9
- package/dist/NanoGPTModel.js +70 -121
- package/dist/RealDiv-7xu-pkZN.js +540 -0
- package/dist/Reshape-BYC1oUku.js +127 -0
- package/dist/TeachableLLM.d.ts +2 -0
- package/dist/TeachableLLM.js +42 -34
- package/dist/{TiedEmbedding-8S8xn8e6.js → TiedEmbedding-C1HBot-5.js} +12 -13
- package/dist/{axis_util-BczFISHz.js → axis_util-CCNL7jea.js} +14 -12
- package/dist/{broadcast_to-B7NGsBSh.js → broadcast_to-CddAF879.js} +2 -2
- package/dist/{concat-DdKPyAtw.js → concat-XOK9ANZu.js} +7 -7
- package/dist/{dataset-iqT4Otvb.js → dataset-BFFipD1c.js} +5 -5
- package/dist/{dropout-B09InSJS.js → dropout-xlKRoJyU.js} +9 -9
- package/dist/{gather-D6MsdXqc.js → gather-DKtUaTtA.js} +1 -1
- package/dist/gpgpu_math-B_ycgZ4W.js +3115 -0
- package/dist/{index-Du-bmOP8.js → index-CamYe_M8.js} +844 -647
- package/dist/{kernel_funcs_utils-DShm7-0k.js → kernel_funcs_utils-D5MS0JFg.js} +232 -136
- package/dist/layers/BaseLayer.js +2 -2
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +13 -33
- package/dist/layers/TiedEmbedding.js +6 -7
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/{log_sum_exp-CxfBtUaG.js → log_sum_exp-CV_5-TTu.js} +15 -15
- package/dist/main.js +24 -20
- package/dist/{mat_mul-CbiqIe2d.js → mat_mul-CAbRFWUj.js} +4 -4
- package/dist/{max-0Xnlpv8k.js → max-JBBv7aUf.js} +3 -3
- package/dist/mulmat_packed_gpu-DW4doKL_.js +71 -0
- package/dist/{norm-01kY9I2B.js → norm-B9dQTFYn.js} +12 -12
- package/dist/{ones-CrutWGas.js → ones-CMHNqMr6.js} +2 -2
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +18 -49
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +15 -11
- package/dist/ops/grads/fusedSoftmax.js +12 -10
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/log.d.ts +0 -0
- package/dist/ops/log.js +1 -0
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +8 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +31 -3379
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/{gpgpu_math-BFbOyvk4.js → ops/webgl/log.d.ts} +2 -8
- package/dist/ops/webgl/log.js +39 -0
- package/dist/ops/webgl/matMulGelu.js +48 -115
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{ops-CJNniCAV.js → ops-DqtYemmV.js} +143 -135
- package/dist/{random_width-C-v-35bY.js → random_width-CLMQG5Jn.js} +6925 -6291
- package/dist/{range-Bvs1hidm.js → range-DqYjKnuG.js} +1 -1
- package/dist/reciprocal-z49filta.js +25 -0
- package/dist/register_all_kernels-COt6wLD0.js +21397 -0
- package/dist/{reshape-BH7eBpwq.js → reshape-C45vIIRU.js} +1 -1
- package/dist/scatter_nd_util-qgtnviTE.js +46 -0
- package/dist/selu_util-4QV_GXTB.js +740 -0
- package/dist/shared-ByfrGA97.js +3199 -0
- package/dist/{sin-CPAZXNjH.js → sin-9JBrfVaB.js} +1 -1
- package/dist/{softmax-DhWoBa7r.js → softmax-DvMvui-_.js} +1 -1
- package/dist/{split-BCUhuU7B.js → split-DxrHrPFK.js} +4 -4
- package/dist/{stack-BV1v7l3S.js → stack-DgaoDmnF.js} +1 -1
- package/dist/{sum-Cvq06317.js → sum-BpcpxNEh.js} +3 -3
- package/dist/{tensor-DgTOPY6h.js → tensor-CDz5x1mP.js} +1 -1
- package/dist/{tensor2d-CRWjDyUe.js → tensor2d-jO8JY5Jd.js} +1 -1
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +3 -3
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.d.ts +6 -0
- package/dist/utilities/dummy.js +31 -10
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.d.ts +25 -0
- package/dist/utilities/load.js +89 -37
- package/dist/utilities/profile.d.ts +5 -0
- package/dist/utilities/profile.js +12 -9
- package/dist/utilities/safetensors.d.ts +3 -0
- package/dist/utilities/safetensors.js +83 -0
- package/dist/utilities/save.js +47 -29
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-DZ3fF0R2.js → variable-CLVXjN7F.js} +1 -1
- package/dist/{zeros-BaHhQTWf.js → zeros-DUkkVccu.js} +8 -8
- package/package.json +3 -9
- package/dist/Reshape-Biok_3X1.js +0 -212
- package/dist/slice_util-DskXqRZa.js +0 -49
- package/dist/tfjs_backend-D9Ytje0G.js +0 -1010
package/dist/Generator.js
CHANGED
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
import { E as u } from "./index-Dwqa6Zy2.js";
|
|
2
|
-
import "./index-
|
|
2
|
+
import "./index-CamYe_M8.js";
|
|
3
3
|
import "./ops/cpu/attentionMask.js";
|
|
4
4
|
import "./ops/webgl/attentionMask.js";
|
|
5
5
|
import "./ops/grads/attentionMask.js";
|
|
6
6
|
import "./ops/cpu/qkv.js";
|
|
7
7
|
import "./ops/webgl/qkv.js";
|
|
8
8
|
import "./ops/grads/qkv.js";
|
|
9
|
-
import "
|
|
9
|
+
import "./random_width-CLMQG5Jn.js";
|
|
10
|
+
import "./register_all_kernels-COt6wLD0.js";
|
|
11
|
+
import "./index-Tf7vU29b.js";
|
|
12
|
+
import "./dataset-BFFipD1c.js";
|
|
10
13
|
import "./ops/cpu/rope.js";
|
|
11
14
|
import "./ops/webgl/rope.js";
|
|
12
15
|
import "./ops/grads/rope.js";
|
|
@@ -21,21 +24,19 @@ import "./ops/grads/matMulGelu.js";
|
|
|
21
24
|
import "./ops/cpu/normRMS.js";
|
|
22
25
|
import "./ops/webgl/normRMS.js";
|
|
23
26
|
import "./ops/grads/normRMS.js";
|
|
24
|
-
import "./random_width-C-v-35bY.js";
|
|
25
27
|
import "./ops/cpu/gatherSub.js";
|
|
26
28
|
import "./ops/webgl/gatherSub.js";
|
|
27
29
|
import "./ops/cpu/scatterSub.js";
|
|
28
30
|
import "./ops/webgl/scatterSub.js";
|
|
29
31
|
import "./jszip.min-CjP2V1VV.js";
|
|
30
32
|
import f from "./tokeniser/CharTokeniser.js";
|
|
31
|
-
import "./dataset-iqT4Otvb.js";
|
|
32
|
-
import "./index-Tf7vU29b.js";
|
|
33
33
|
import "./papaparse.min-C8l2Kvo1.js";
|
|
34
34
|
import "./ops/cpu/gelu.js";
|
|
35
35
|
import "./ops/webgl/gelu.js";
|
|
36
36
|
import "./ops/grads/gelu.js";
|
|
37
|
-
import
|
|
38
|
-
import {
|
|
37
|
+
import "./ops/webgl/log.js";
|
|
38
|
+
import { t as d } from "./tensor2d-jO8JY5Jd.js";
|
|
39
|
+
import { c as g } from "./concat-XOK9ANZu.js";
|
|
39
40
|
const k = [
|
|
40
41
|
...Array.from({ length: 95 }, (a, t) => String.fromCharCode(t + 32)),
|
|
41
42
|
// ASCII
|
|
@@ -51,7 +52,7 @@ const k = [
|
|
|
51
52
|
function T(a, t) {
|
|
52
53
|
return a.length === t ? a : a.length > t ? a.slice(0, t) : a.concat(Array(t - a.length).fill(""));
|
|
53
54
|
}
|
|
54
|
-
class
|
|
55
|
+
class rt extends u {
|
|
55
56
|
constructor(t, o) {
|
|
56
57
|
super(), this.model = t, this.tokeniser = o;
|
|
57
58
|
}
|
|
@@ -123,5 +124,5 @@ class ot extends u {
|
|
|
123
124
|
}
|
|
124
125
|
}
|
|
125
126
|
export {
|
|
126
|
-
|
|
127
|
+
rt as default
|
|
127
128
|
};
|
package/dist/NanoGPTModel.js
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import { defaultConfig as F } from "./config.js";
|
|
2
2
|
import O from "./layers/TransformerBlock.js";
|
|
3
|
-
import { T as
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
import { estimateParameterCount as
|
|
7
|
-
import { createSoftmaxCrossEntropyWithGrad as
|
|
8
|
-
import
|
|
9
|
-
import { E as
|
|
10
|
-
import {
|
|
11
|
-
import {
|
|
12
|
-
import { r as
|
|
13
|
-
import {
|
|
14
|
-
import {
|
|
15
|
-
import {
|
|
3
|
+
import { T as _, r as D } from "./TiedEmbedding-C1HBot-5.js";
|
|
4
|
+
import K from "./layers/RoPECache.js";
|
|
5
|
+
import N from "./layers/RMSNorm.js";
|
|
6
|
+
import { estimateParameterCount as R } from "./utilities/parameters.js";
|
|
7
|
+
import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
|
|
8
|
+
import G from "./layers/BaseLayer.js";
|
|
9
|
+
import { E as B, D as V, p as j } from "./random_width-CLMQG5Jn.js";
|
|
10
|
+
import { q as W, w as H, E as J, a6 as Q, t as z, a7 as U, f as v, n as X } from "./index-CamYe_M8.js";
|
|
11
|
+
import { m as Y, t as Z } from "./register_all_kernels-COt6wLD0.js";
|
|
12
|
+
import { r as L } from "./reshape-C45vIIRU.js";
|
|
13
|
+
import { r as tt } from "./range-DqYjKnuG.js";
|
|
14
|
+
import { s as M } from "./softmax-DvMvui-_.js";
|
|
15
|
+
import { t as et } from "./ops-DqtYemmV.js";
|
|
16
|
+
import { g as ot } from "./gather-DKtUaTtA.js";
|
|
16
17
|
/**
|
|
17
18
|
* @license
|
|
18
19
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -29,69 +30,17 @@ import { g as st } from "./gather-D6MsdXqc.js";
|
|
|
29
30
|
* limitations under the License.
|
|
30
31
|
* =============================================================================
|
|
31
32
|
*/
|
|
32
|
-
function
|
|
33
|
-
|
|
34
|
-
[e, o] = Q(e, o);
|
|
35
|
-
const n = { a: e, b: o };
|
|
36
|
-
return I.runKernel(U, n);
|
|
37
|
-
}
|
|
38
|
-
const it = /* @__PURE__ */ x({ mod_: nt });
|
|
39
|
-
/**
|
|
40
|
-
* @license
|
|
41
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
42
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
43
|
-
* you may not use this file except in compliance with the License.
|
|
44
|
-
* You may obtain a copy of the License at
|
|
45
|
-
*
|
|
46
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
47
|
-
*
|
|
48
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
49
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
50
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
51
|
-
* See the License for the specific language governing permissions and
|
|
52
|
-
* limitations under the License.
|
|
53
|
-
* =============================================================================
|
|
54
|
-
*/
|
|
55
|
-
function rt(l, t, e, o = !1) {
|
|
56
|
-
const n = y(l, "logits", "multinomial"), s = n.size, i = n.rank;
|
|
33
|
+
function st(u, t, e, o = !1) {
|
|
34
|
+
const r = H(u, "logits", "multinomial"), s = r.size, n = r.rank;
|
|
57
35
|
if (s < 2)
|
|
58
36
|
throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
|
|
59
|
-
if (
|
|
60
|
-
throw new Error(`Rank of probabilities must be 1 or 2, but is ${
|
|
37
|
+
if (n > 2)
|
|
38
|
+
throw new Error(`Rank of probabilities must be 1 or 2, but is ${n}`);
|
|
61
39
|
e = e || Math.random();
|
|
62
|
-
const
|
|
63
|
-
return
|
|
64
|
-
}
|
|
65
|
-
const C = /* @__PURE__ */ x({ multinomial_: rt });
|
|
66
|
-
/**
|
|
67
|
-
* @license
|
|
68
|
-
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
69
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
70
|
-
* you may not use this file except in compliance with the License.
|
|
71
|
-
* You may obtain a copy of the License at
|
|
72
|
-
*
|
|
73
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
74
|
-
*
|
|
75
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
76
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
77
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
78
|
-
* See the License for the specific language governing permissions and
|
|
79
|
-
* limitations under the License.
|
|
80
|
-
* =============================================================================
|
|
81
|
-
*/
|
|
82
|
-
function ct(l, t = 1, e = !0) {
|
|
83
|
-
const o = y(l, "x", "topk");
|
|
84
|
-
if (o.rank === 0)
|
|
85
|
-
throw new Error("topk() expects the input to be of rank 1 or higher");
|
|
86
|
-
const n = o.shape[o.shape.length - 1];
|
|
87
|
-
if (t < 0)
|
|
88
|
-
throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
|
|
89
|
-
if (t > n)
|
|
90
|
-
throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
|
|
91
|
-
const s = { x: o }, i = { k: t, sorted: e }, [p, r] = I.runKernel(Y, s, i);
|
|
92
|
-
return { values: p, indices: r };
|
|
40
|
+
const i = { logits: n === 1 ? L(r, [1, -1]) : r }, h = { numSamples: t, seed: e, normalized: o }, c = J.runKernel(Q, i, h);
|
|
41
|
+
return n === 1 ? L(c, [c.size]) : c;
|
|
93
42
|
}
|
|
94
|
-
const
|
|
43
|
+
const C = /* @__PURE__ */ W({ multinomial_: st });
|
|
95
44
|
/**
|
|
96
45
|
* @license
|
|
97
46
|
* Copyright 2018 Google LLC
|
|
@@ -101,13 +50,13 @@ const at = /* @__PURE__ */ x({ topk_: ct });
|
|
|
101
50
|
* https://opensource.org/licenses/MIT.
|
|
102
51
|
* =============================================================================
|
|
103
52
|
*/
|
|
104
|
-
function
|
|
105
|
-
return new
|
|
53
|
+
function nt(u) {
|
|
54
|
+
return new V(u);
|
|
106
55
|
}
|
|
107
|
-
function
|
|
108
|
-
return new
|
|
56
|
+
function it(u) {
|
|
57
|
+
return new B(u);
|
|
109
58
|
}
|
|
110
|
-
class
|
|
59
|
+
class St extends G {
|
|
111
60
|
wte;
|
|
112
61
|
// Token embeddings
|
|
113
62
|
wpe;
|
|
@@ -121,15 +70,15 @@ class xt extends V {
|
|
|
121
70
|
log = [];
|
|
122
71
|
// Training log
|
|
123
72
|
constructor(t = {}) {
|
|
124
|
-
super({ gpt: { ...F, ...t }, layerConfig: {} }), this.wte = new
|
|
73
|
+
super({ gpt: { ...F, ...t }, layerConfig: {} }), this.wte = new _(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = it({
|
|
125
74
|
inputDim: this.config.gpt.blockSize,
|
|
126
75
|
outputDim: this.config.gpt.nEmbed,
|
|
127
76
|
name: "positional_embedding",
|
|
128
|
-
embeddingsInitializer:
|
|
129
|
-
}) : (this.ropeCache = new
|
|
77
|
+
embeddingsInitializer: D({ mean: 0, stddev: 0.02 })
|
|
78
|
+
}) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = nt({ rate: this.config.gpt.dropout }), this.blocks = [];
|
|
130
79
|
for (let e = 0; e < this.config.gpt.nLayer; e++)
|
|
131
80
|
this.blocks.push(new O(e, this.config, this));
|
|
132
|
-
this.lnF = new
|
|
81
|
+
this.lnF = new N(this.config, "final_rms_norm", this);
|
|
133
82
|
}
|
|
134
83
|
get checkpointing() {
|
|
135
84
|
return this.config.layerConfig.checkpointing === !0;
|
|
@@ -139,12 +88,12 @@ class xt extends V {
|
|
|
139
88
|
}
|
|
140
89
|
inputPhase(t, e, o = !1) {
|
|
141
90
|
return z(() => {
|
|
142
|
-
const
|
|
91
|
+
const r = this.wte.embed(t);
|
|
143
92
|
if (this.config.gpt.useRope === !1) {
|
|
144
|
-
const [, s] = t.shape,
|
|
93
|
+
const [, s] = t.shape, n = this.config.gpt.blockSize, p = tt(0, s, 1, "int32"), i = Y(U(p, v(e, "int32")), v(n, "int32")), h = this.wpe.apply(i), c = r.add(h);
|
|
145
94
|
return this.drop.apply(c, { training: o });
|
|
146
95
|
} else
|
|
147
|
-
return this.drop.apply(
|
|
96
|
+
return this.drop.apply(r, { training: o });
|
|
148
97
|
});
|
|
149
98
|
}
|
|
150
99
|
setSkipMask(t) {
|
|
@@ -169,7 +118,7 @@ class xt extends V {
|
|
|
169
118
|
}
|
|
170
119
|
calculateLoss(t, e) {
|
|
171
120
|
try {
|
|
172
|
-
return
|
|
121
|
+
return A()(t, e).mean();
|
|
173
122
|
} catch (o) {
|
|
174
123
|
throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
|
|
175
124
|
}
|
|
@@ -209,35 +158,35 @@ class xt extends V {
|
|
|
209
158
|
forward(t, e, o) {
|
|
210
159
|
return this.validateInput(e), z(() => {
|
|
211
160
|
this.startMemory();
|
|
212
|
-
const
|
|
213
|
-
let s = this.inputPhase(e,
|
|
161
|
+
const r = t.cache?.[0]?.length ?? 0;
|
|
162
|
+
let s = this.inputPhase(e, r, t.training);
|
|
214
163
|
if (t.cache && t.cache.length !== this.blocks.length)
|
|
215
164
|
throw console.error("Cache", t.cache), new Error(
|
|
216
165
|
`Cache length ${t.cache.length} does not match number of blocks ${this.blocks.length}`
|
|
217
166
|
);
|
|
218
|
-
for (let
|
|
219
|
-
const
|
|
167
|
+
for (let i = 0; i < this.blocks.length; i++) {
|
|
168
|
+
const h = this.blocks[i], c = Math.random() * 1e9, g = {
|
|
220
169
|
training: t.training,
|
|
221
170
|
seed: c,
|
|
222
171
|
attentionScores: t.attentionScores,
|
|
223
|
-
pastKV: t.cache ? t.cache[
|
|
224
|
-
},
|
|
225
|
-
s.dispose(), s =
|
|
172
|
+
pastKV: t.cache ? t.cache[i] : void 0
|
|
173
|
+
}, E = this.config.layerConfig.checkpointing && t.training ? h.callCheckpoint(g, s) : h.call(g, s);
|
|
174
|
+
s.dispose(), s = E;
|
|
226
175
|
}
|
|
227
176
|
s = this.lnF.call(t, s);
|
|
228
|
-
const
|
|
177
|
+
const n = this.wte.project(s);
|
|
229
178
|
s.dispose();
|
|
230
179
|
let p;
|
|
231
|
-
return o && (p = this.calculateLoss(
|
|
180
|
+
return o && (p = this.calculateLoss(n, o)), this.endMemory("Forward"), p ? [n, p] : [n];
|
|
232
181
|
});
|
|
233
182
|
}
|
|
234
183
|
generate(t, e, o) {
|
|
235
|
-
const
|
|
184
|
+
const r = o?.temperature ?? 1, s = o?.topK, n = o?.topP, p = o?.usePadding ?? !1;
|
|
236
185
|
return z(() => {
|
|
237
|
-
const
|
|
238
|
-
[0,
|
|
239
|
-
[
|
|
240
|
-
), g = p ? this.config.gpt.blockSize - c.shape[1] : 0,
|
|
186
|
+
const i = t, h = i.shape[1], c = h <= this.config.gpt.blockSize ? i : i.slice(
|
|
187
|
+
[0, h - this.config.gpt.blockSize],
|
|
188
|
+
[i.shape[0], this.config.gpt.blockSize]
|
|
189
|
+
), g = p ? this.config.gpt.blockSize - c.shape[1] : 0, E = g > 0 ? j(c, [
|
|
241
190
|
[0, 0],
|
|
242
191
|
[0, g]
|
|
243
192
|
]) : c, f = {
|
|
@@ -246,41 +195,41 @@ class xt extends V {
|
|
|
246
195
|
attentionOut: []
|
|
247
196
|
} : void 0,
|
|
248
197
|
cache: e
|
|
249
|
-
}, [d] = this.forward(f,
|
|
250
|
-
f.attentionScores?.attentionOut && f.attentionScores.attentionOut.forEach((
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
),
|
|
198
|
+
}, [d] = this.forward(f, E), $ = d.shape[1] - 1 - g, q = d.slice([0, $, 0], [d.shape[0], 1, d.shape[2]]);
|
|
199
|
+
f.attentionScores?.attentionOut && f.attentionScores.attentionOut.forEach((l, b) => {
|
|
200
|
+
l.shape[1] !== 1 && (f.attentionScores.attentionOut[b] = X(
|
|
201
|
+
l.slice([0, $, 0], [l.shape[0], 1, l.shape[2]])
|
|
202
|
+
), l.dispose());
|
|
254
203
|
}), d.dispose();
|
|
255
|
-
const w =
|
|
204
|
+
const w = q.div(r);
|
|
256
205
|
let m;
|
|
257
|
-
if (
|
|
258
|
-
const
|
|
259
|
-
|
|
260
|
-
const
|
|
261
|
-
let
|
|
262
|
-
const
|
|
263
|
-
for (const a of
|
|
264
|
-
if (
|
|
206
|
+
if (n) {
|
|
207
|
+
const l = M(w.squeeze([1])), b = l.arraySync()[0];
|
|
208
|
+
l.dispose();
|
|
209
|
+
const y = b.map((a, k) => ({ prob: a, index: k })).sort((a, k) => k.prob - a.prob);
|
|
210
|
+
let P = 0;
|
|
211
|
+
const S = new Array(y.length).fill(0);
|
|
212
|
+
for (const a of y)
|
|
213
|
+
if (P += a.prob, S[a.index] = a.prob, P >= n)
|
|
265
214
|
break;
|
|
266
|
-
const
|
|
267
|
-
m = C(
|
|
215
|
+
const x = S.reduce((a, k) => a + k, 0), T = S.map((a) => a / x);
|
|
216
|
+
m = C(et(T), 1, void 0, !0);
|
|
268
217
|
} else if (s) {
|
|
269
|
-
const { values:
|
|
270
|
-
m =
|
|
218
|
+
const { values: l, indices: b } = Z(w, s), y = C(l.squeeze([1]), 1);
|
|
219
|
+
m = ot(b.squeeze([1]), y, 1);
|
|
271
220
|
} else
|
|
272
221
|
m = C(w.squeeze([1]), 1);
|
|
273
|
-
let
|
|
274
|
-
return o?.includeProbabilities && (
|
|
222
|
+
let I;
|
|
223
|
+
return o?.includeProbabilities && (I = M(w.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: I, attention: f.attentionScores?.attentionOut };
|
|
275
224
|
});
|
|
276
225
|
}
|
|
277
226
|
getNumParams() {
|
|
278
|
-
return
|
|
227
|
+
return R(this.config.gpt);
|
|
279
228
|
}
|
|
280
229
|
dispose() {
|
|
281
230
|
this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
|
|
282
231
|
}
|
|
283
232
|
}
|
|
284
233
|
export {
|
|
285
|
-
|
|
234
|
+
St as default
|
|
286
235
|
};
|