@genai-fi/nanogpt 0.4.5 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseLayer-BhrMN8JO.js +135 -0
- package/dist/Generator.js +52 -49
- package/dist/NanoGPTModel.d.ts +13 -17
- package/dist/NanoGPTModel.js +128 -136
- package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
- package/dist/TeachableLLM.js +1 -1
- package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
- package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
- package/dist/broadcast_to-CMlkG8NS.js +44 -0
- package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
- package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
- package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
- package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
- package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
- package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
- package/dist/layers/BaseLayer.d.ts +28 -4
- package/dist/layers/BaseLayer.js +3 -16
- package/dist/layers/CausalSelfAttention.d.ts +21 -24
- package/dist/layers/CausalSelfAttention.js +73 -128
- package/dist/layers/MLP.d.ts +8 -15
- package/dist/layers/MLP.js +43 -81
- package/dist/layers/RMSNorm.d.ts +5 -10
- package/dist/layers/RMSNorm.js +13 -29
- package/dist/layers/RoPECache.js +14 -12
- package/dist/layers/TiedEmbedding.d.ts +6 -16
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.d.ts +12 -16
- package/dist/layers/TransformerBlock.js +20 -41
- package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
- package/dist/main.js +1 -1
- package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
- package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
- package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
- package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
- package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
- package/dist/ops/appendCache.js +4 -4
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +14 -15
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.d.ts +1 -0
- package/dist/ops/cpu/matMulMul.js +17 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +8 -8
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +13 -9
- package/dist/ops/grads/fusedSoftmax.js +12 -9
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +19 -9
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.d.ts +2 -0
- package/dist/ops/matMulMul.js +9 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +13 -12
- package/dist/ops/webgl/fusedSoftmax.js +43 -40
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.js +17 -17
- package/dist/ops/webgl/matMulMul.d.ts +14 -0
- package/dist/ops/webgl/matMulMul.js +28 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +29 -21
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops-ObfXLHYQ.js +1269 -0
- package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
- package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
- package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
- package/dist/slice_util-D-kaD4ZV.js +49 -0
- package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
- package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
- package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
- package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
- package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
- package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
- package/dist/tfjs_backend-NucKez4s.js +1010 -0
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +44 -44
- package/dist/training/Evaluator.js +6 -6
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +7 -7
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +10 -10
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/save.js +13 -11
- package/dist/utilities/weights.js +2 -2
- package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
- package/package.json +1 -1
- package/dist/slice_util-BdhYwFY_.js +0 -90
- package/dist/tfjs_backend-DuKis_xG.js +0 -2271
- package/dist/variable-BJTZ3jOy.js +0 -23
package/dist/NanoGPTModel.js
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
|
-
import { defaultConfig as
|
|
2
|
-
import
|
|
3
|
-
import { E as
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
import { estimateParameterCount as
|
|
7
|
-
import { createSoftmaxCrossEntropyWithGrad as
|
|
8
|
-
import B from "./
|
|
9
|
-
import { o as
|
|
10
|
-
import { r as
|
|
11
|
-
import { r as
|
|
12
|
-
import {
|
|
13
|
-
import {
|
|
14
|
-
import { s as U } from "./softmax-Dsxflvdl.js";
|
|
1
|
+
import { defaultConfig as L } from "./config.js";
|
|
2
|
+
import v from "./layers/TransformerBlock.js";
|
|
3
|
+
import { E as T, D as q, T as K, r as P, p as _ } from "./TiedEmbedding-DsDRvLB0.js";
|
|
4
|
+
import F from "./layers/RoPECache.js";
|
|
5
|
+
import D from "./layers/RMSNorm.js";
|
|
6
|
+
import { estimateParameterCount as O } from "./utilities/parameters.js";
|
|
7
|
+
import { createSoftmaxCrossEntropyWithGrad as N } from "./training/sparseCrossEntropy.js";
|
|
8
|
+
import { B as R } from "./BaseLayer-BhrMN8JO.js";
|
|
9
|
+
import { o as E, i as d, q as B, E as y, aa as G, ab as V, ac as j, t as w, a9 as A, f as z, F as W } from "./index-iNhkcAEQ.js";
|
|
10
|
+
import { r as C } from "./reshape-DxTPgnwL.js";
|
|
11
|
+
import { r as H } from "./range-BsFU-SNG.js";
|
|
12
|
+
import { g as J } from "./gather-Bxe1Qip8.js";
|
|
13
|
+
import { s as Q } from "./softmax-BjsptB07.js";
|
|
15
14
|
/**
|
|
16
15
|
* @license
|
|
17
16
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -28,13 +27,13 @@ import { s as U } from "./softmax-Dsxflvdl.js";
|
|
|
28
27
|
* limitations under the License.
|
|
29
28
|
* =============================================================================
|
|
30
29
|
*/
|
|
31
|
-
function
|
|
32
|
-
let e =
|
|
33
|
-
[e, o] =
|
|
34
|
-
const
|
|
35
|
-
return
|
|
30
|
+
function U(h, t) {
|
|
31
|
+
let e = d(h, "a", "mod"), o = d(t, "b", "mod");
|
|
32
|
+
[e, o] = B(e, o);
|
|
33
|
+
const n = { a: e, b: o };
|
|
34
|
+
return y.runKernel(G, n);
|
|
36
35
|
}
|
|
37
|
-
const
|
|
36
|
+
const X = /* @__PURE__ */ E({ mod_: U });
|
|
38
37
|
/**
|
|
39
38
|
* @license
|
|
40
39
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -51,17 +50,17 @@ const Z = /* @__PURE__ */ $({ mod_: Y });
|
|
|
51
50
|
* limitations under the License.
|
|
52
51
|
* =============================================================================
|
|
53
52
|
*/
|
|
54
|
-
function
|
|
55
|
-
const
|
|
53
|
+
function Y(h, t, e, o = !1) {
|
|
54
|
+
const n = d(h, "logits", "multinomial"), s = n.size, i = n.rank;
|
|
56
55
|
if (s < 2)
|
|
57
56
|
throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
|
|
58
|
-
if (
|
|
59
|
-
throw new Error(`Rank of probabilities must be 1 or 2, but is ${
|
|
57
|
+
if (i > 2)
|
|
58
|
+
throw new Error(`Rank of probabilities must be 1 or 2, but is ${i}`);
|
|
60
59
|
e = e || Math.random();
|
|
61
|
-
const
|
|
62
|
-
return
|
|
60
|
+
const c = { logits: i === 1 ? C(n, [1, -1]) : n }, l = { numSamples: t, seed: e, normalized: o }, a = y.runKernel(V, c, l);
|
|
61
|
+
return i === 1 ? C(a, [a.size]) : a;
|
|
63
62
|
}
|
|
64
|
-
const
|
|
63
|
+
const I = /* @__PURE__ */ E({ multinomial_: Y });
|
|
65
64
|
/**
|
|
66
65
|
* @license
|
|
67
66
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -78,19 +77,19 @@ const M = /* @__PURE__ */ $({ multinomial_: tt });
|
|
|
78
77
|
* limitations under the License.
|
|
79
78
|
* =============================================================================
|
|
80
79
|
*/
|
|
81
|
-
function
|
|
82
|
-
const o =
|
|
80
|
+
function Z(h, t = 1, e = !0) {
|
|
81
|
+
const o = d(h, "x", "topk");
|
|
83
82
|
if (o.rank === 0)
|
|
84
83
|
throw new Error("topk() expects the input to be of rank 1 or higher");
|
|
85
|
-
const
|
|
84
|
+
const n = o.shape[o.shape.length - 1];
|
|
86
85
|
if (t < 0)
|
|
87
86
|
throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
|
|
88
|
-
if (t >
|
|
89
|
-
throw new Error(`'k' passed to topk() must be <= the last dimension (${
|
|
90
|
-
const s = { x: o },
|
|
91
|
-
return { values: r, indices:
|
|
87
|
+
if (t > n)
|
|
88
|
+
throw new Error(`'k' passed to topk() must be <= the last dimension (${n}) but got ${t}`);
|
|
89
|
+
const s = { x: o }, i = { k: t, sorted: e }, [r, c] = y.runKernel(j, s, i);
|
|
90
|
+
return { values: r, indices: c };
|
|
92
91
|
}
|
|
93
|
-
const
|
|
92
|
+
const tt = /* @__PURE__ */ E({ topk_: Z });
|
|
94
93
|
/**
|
|
95
94
|
* @license
|
|
96
95
|
* Copyright 2018 Google LLC
|
|
@@ -100,13 +99,13 @@ const ot = /* @__PURE__ */ $({ topk_: et });
|
|
|
100
99
|
* https://opensource.org/licenses/MIT.
|
|
101
100
|
* =============================================================================
|
|
102
101
|
*/
|
|
103
|
-
function
|
|
104
|
-
return new
|
|
102
|
+
function et(h) {
|
|
103
|
+
return new q(h);
|
|
105
104
|
}
|
|
106
|
-
function
|
|
107
|
-
return new
|
|
105
|
+
function ot(h) {
|
|
106
|
+
return new T(h);
|
|
108
107
|
}
|
|
109
|
-
class
|
|
108
|
+
class dt extends R {
|
|
110
109
|
wte;
|
|
111
110
|
// Token embeddings
|
|
112
111
|
wpe;
|
|
@@ -120,55 +119,30 @@ class wt extends B {
|
|
|
120
119
|
log = [];
|
|
121
120
|
// Training log
|
|
122
121
|
constructor(t = {}) {
|
|
123
|
-
super({ gpt: { ...
|
|
124
|
-
vocabSize: this.config.gpt.vocabSize,
|
|
125
|
-
embedDim: this.config.gpt.nEmbed,
|
|
126
|
-
name: "token_embedding"
|
|
127
|
-
}), this.config.gpt.useRope === !1 ? this.wpe = it({
|
|
122
|
+
super({ gpt: { ...L, ...t }, layerConfig: {} }), this.wte = new K(this.config, "token_embedding", this), this.config.gpt.useRope === !1 ? this.wpe = ot({
|
|
128
123
|
inputDim: this.config.gpt.blockSize,
|
|
129
124
|
outputDim: this.config.gpt.nEmbed,
|
|
130
125
|
name: "positional_embedding",
|
|
131
|
-
embeddingsInitializer:
|
|
132
|
-
}) : (this.ropeCache = new
|
|
126
|
+
embeddingsInitializer: P({ mean: 0, stddev: 0.02 })
|
|
127
|
+
}) : (this.ropeCache = new F(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = et({ rate: this.config.gpt.dropout }), this.blocks = [];
|
|
133
128
|
for (let e = 0; e < this.config.gpt.nLayer; e++)
|
|
134
|
-
this.blocks.push(new
|
|
135
|
-
this.lnF = new
|
|
129
|
+
this.blocks.push(new v(e, this.config, this));
|
|
130
|
+
this.lnF = new D(this.config, "final_rms_norm", this);
|
|
136
131
|
}
|
|
137
132
|
get checkpointing() {
|
|
138
|
-
return this.config.layerConfig.
|
|
133
|
+
return this.config.layerConfig.checkpointing === !0;
|
|
139
134
|
}
|
|
140
135
|
set checkpointing(t) {
|
|
141
|
-
this.config.layerConfig.
|
|
142
|
-
}
|
|
143
|
-
get variables() {
|
|
144
|
-
return [
|
|
145
|
-
//...this.wpe.trainableWeights.map((v) => v.read() as TF.Variable),
|
|
146
|
-
...this.blocks.flatMap((t) => t.variables),
|
|
147
|
-
...this.lnF.trainableWeights.map((t) => t),
|
|
148
|
-
...this.wte.variables
|
|
149
|
-
];
|
|
150
|
-
}
|
|
151
|
-
saveWeights() {
|
|
152
|
-
const t = /* @__PURE__ */ new Map();
|
|
153
|
-
t.set("token_embedding", this.wte.getWeights()), this.wpe && t.set("positional_embedding", this.wpe.getWeights());
|
|
154
|
-
for (let e = 0; e < this.blocks.length; e++)
|
|
155
|
-
this.blocks[e].saveWeights(t);
|
|
156
|
-
return t.set("final_rms_norm", this.lnF.getWeights()), t;
|
|
157
|
-
}
|
|
158
|
-
loadWeights(t) {
|
|
159
|
-
this.wte.setWeights(t.get("token_embedding") || []), this.wpe && this.wpe.setWeights(t.get("positional_embedding") || []);
|
|
160
|
-
for (let e = 0; e < this.blocks.length; e++)
|
|
161
|
-
this.blocks[e].loadWeights(t);
|
|
162
|
-
this.lnF.setWeights(t.get("final_rms_norm") || []);
|
|
136
|
+
this.config.layerConfig.checkpointing = t;
|
|
163
137
|
}
|
|
164
138
|
inputPhase(t, e, o = !1) {
|
|
165
139
|
return w(() => {
|
|
166
|
-
const
|
|
140
|
+
const n = this.wte.embed(t);
|
|
167
141
|
if (this.config.gpt.useRope === !1) {
|
|
168
|
-
const [, s] = t.shape,
|
|
142
|
+
const [, s] = t.shape, i = this.config.gpt.blockSize, r = H(0, s, 1, "int32"), c = X(A(r, z(e, "int32")), z(i, "int32")), l = this.wpe.apply(c), a = n.add(l);
|
|
169
143
|
return this.drop.apply(a, { training: o });
|
|
170
144
|
} else
|
|
171
|
-
return this.drop.apply(
|
|
145
|
+
return this.drop.apply(n, { training: o });
|
|
172
146
|
});
|
|
173
147
|
}
|
|
174
148
|
setSkipMask(t) {
|
|
@@ -183,11 +157,6 @@ class wt extends B {
|
|
|
183
157
|
for (let e = 0; e < this.blocks.length; e++)
|
|
184
158
|
this.blocks[e].trainable = t[e];
|
|
185
159
|
}
|
|
186
|
-
set trainable(t) {
|
|
187
|
-
for (const e of this.blocks)
|
|
188
|
-
e.trainable = t;
|
|
189
|
-
this.lnF.trainable = t;
|
|
190
|
-
}
|
|
191
160
|
validateInput(t) {
|
|
192
161
|
if (t.shape.length !== 2)
|
|
193
162
|
throw new Error(`Invalid input shape: expected [batch_size, sequence_length], got ${t.shape}`);
|
|
@@ -198,84 +167,107 @@ class wt extends B {
|
|
|
198
167
|
}
|
|
199
168
|
calculateLoss(t, e) {
|
|
200
169
|
try {
|
|
201
|
-
return
|
|
170
|
+
return N()(t, e).mean();
|
|
202
171
|
} catch (o) {
|
|
203
172
|
throw console.error("Error computing loss:", o), new Error(`Loss computation failed: ${o}`);
|
|
204
173
|
}
|
|
205
174
|
}
|
|
206
175
|
// Attention rollout per Abnar & Zuidema (2020)
|
|
207
176
|
// Expects list of (B, T, T) attention matrices already averaged over heads.
|
|
208
|
-
computeAttentionRollout(
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
177
|
+
/*private computeAttentionRollout(attentions: Tensor[]): Tensor {
|
|
178
|
+
return tidy(() => {
|
|
179
|
+
if (attentions.length === 0) {
|
|
180
|
+
throw new Error('No attentions for rollout');
|
|
181
|
+
}
|
|
182
|
+
const [B, Q, K] = attentions[0].shape as number[];
|
|
183
|
+
|
|
184
|
+
// Validate shapes are consistent
|
|
185
|
+
for (const a of attentions) {
|
|
186
|
+
const [b2, q2, k2] = a.shape as number[];
|
|
187
|
+
if (b2 !== B || q2 !== Q || k2 !== K) {
|
|
188
|
+
throw new Error(
|
|
189
|
+
`Inconsistent attention shapes in rollout: expected [${B},${Q},${K}] got [${b2},${q2},${k2}]`
|
|
190
|
+
);
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
// Always slice to [B, Q, Q] for rollout
|
|
195
|
+
const attentionsSliced = attentions.map((att) => att.slice([0, 0, 0], [B, Q, Q]));
|
|
196
|
+
|
|
197
|
+
const ey = eye(Q, Q).expandDims(0); // (1,Q,Q)
|
|
198
|
+
let rollout = ey.tile([B, 1, 1]); // (B,Q,Q)
|
|
199
|
+
for (const att of attentionsSliced) {
|
|
200
|
+
const a = att.add(ey);
|
|
201
|
+
const aNorm = a.div(a.sum(-1, true)); // (B,Q,Q)
|
|
202
|
+
rollout = aNorm.matMul(rollout); // (B,Q,Q)
|
|
203
|
+
}
|
|
204
|
+
return rollout;
|
|
205
|
+
});
|
|
206
|
+
}*/
|
|
207
|
+
forward(t, e, o) {
|
|
208
|
+
return this.validateInput(e), w(() => {
|
|
231
209
|
this.startMemory();
|
|
232
|
-
const
|
|
233
|
-
let
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
210
|
+
const n = t.cache?.[0]?.length ?? 0;
|
|
211
|
+
let s = this.inputPhase(e, n, t.training);
|
|
212
|
+
if (t.cache && t.cache.length !== this.blocks.length)
|
|
213
|
+
throw console.error("Cache", t.cache), new Error(
|
|
214
|
+
`Cache length ${t.cache.length} does not match number of blocks ${this.blocks.length}`
|
|
215
|
+
);
|
|
216
|
+
for (let c = 0; c < this.blocks.length; c++) {
|
|
217
|
+
const l = this.blocks[c], a = Math.random() * 1e9, u = {
|
|
218
|
+
training: t.training,
|
|
219
|
+
seed: a,
|
|
220
|
+
attentionScores: t.attentionScores,
|
|
221
|
+
pastKV: t.cache ? t.cache[c] : void 0
|
|
222
|
+
}, p = this.config.layerConfig.checkpointing && t.training ? l.callCheckpoint(u, s) : l.call(u, s);
|
|
223
|
+
s.dispose(), s = p;
|
|
244
224
|
}
|
|
245
|
-
|
|
246
|
-
i
|
|
247
|
-
|
|
248
|
-
let
|
|
249
|
-
return
|
|
225
|
+
s = this.lnF.call(t, s);
|
|
226
|
+
const i = this.wte.project(s);
|
|
227
|
+
s.dispose();
|
|
228
|
+
let r;
|
|
229
|
+
return o && (r = this.calculateLoss(i, o)), this.endMemory("Forward"), r ? [i, r] : [i];
|
|
250
230
|
});
|
|
251
231
|
}
|
|
252
232
|
generate(t, e, o) {
|
|
253
|
-
const
|
|
233
|
+
const n = o?.temperature ?? 1, s = o?.topK, i = o?.usePadding ?? !1;
|
|
254
234
|
return w(() => {
|
|
255
|
-
const
|
|
256
|
-
[0,
|
|
257
|
-
[
|
|
258
|
-
),
|
|
235
|
+
const r = t, c = r.shape[1], l = c <= this.config.gpt.blockSize ? r : r.slice(
|
|
236
|
+
[0, c - this.config.gpt.blockSize],
|
|
237
|
+
[r.shape[0], this.config.gpt.blockSize]
|
|
238
|
+
), a = i ? this.config.gpt.blockSize - l.shape[1] : 0, u = a > 0 ? _(l, [
|
|
259
239
|
[0, 0],
|
|
260
|
-
[0,
|
|
261
|
-
]) :
|
|
262
|
-
|
|
240
|
+
[0, a]
|
|
241
|
+
]) : l, p = {
|
|
242
|
+
training: !1,
|
|
243
|
+
attentionScores: o?.attentionScores ? {
|
|
244
|
+
attentionOut: []
|
|
245
|
+
} : void 0,
|
|
246
|
+
cache: e
|
|
247
|
+
}, [f] = this.forward(p, u), S = f.shape[1] - 1 - a, M = f.slice([0, S, 0], [f.shape[0], 1, f.shape[2]]);
|
|
248
|
+
p.attentionScores?.attentionOut && p.attentionScores.attentionOut.forEach((g, k) => {
|
|
249
|
+
g.shape[1] !== 1 && (p.attentionScores.attentionOut[k] = W(
|
|
250
|
+
g.slice([0, S, 0], [g.shape[0], 1, g.shape[2]])
|
|
251
|
+
), g.dispose());
|
|
252
|
+
}), f.dispose();
|
|
253
|
+
const b = M.div(n);
|
|
254
|
+
let m;
|
|
263
255
|
if (s) {
|
|
264
|
-
const { values:
|
|
265
|
-
|
|
256
|
+
const { values: g, indices: k } = tt(b, s), x = I(g.squeeze([1]), 1);
|
|
257
|
+
m = J(k.squeeze([1]), x, 1);
|
|
266
258
|
} else
|
|
267
|
-
|
|
268
|
-
let
|
|
269
|
-
return o?.includeProbabilities && (
|
|
259
|
+
m = I(b.squeeze([1]), 1);
|
|
260
|
+
let $;
|
|
261
|
+
return o?.includeProbabilities && ($ = Q(b.squeeze([1]))), m = m.reshape([1, 1]), { output: m, probabilities: $, attention: p.attentionScores?.attentionOut };
|
|
270
262
|
});
|
|
271
263
|
}
|
|
272
264
|
getNumParams() {
|
|
273
|
-
return
|
|
265
|
+
return O(this.config.gpt);
|
|
274
266
|
}
|
|
275
267
|
dispose() {
|
|
276
268
|
this.wte.dispose(), this.wpe && this.wpe.dispose(), this.drop.dispose(), this.blocks.forEach((t) => t.dispose()), this.lnF.dispose();
|
|
277
269
|
}
|
|
278
270
|
}
|
|
279
271
|
export {
|
|
280
|
-
|
|
272
|
+
dt as default
|
|
281
273
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { u as I } from "./gpgpu_math-
|
|
1
|
+
import { ad as $, ae as g, p, af as C, k as x } from "./index-iNhkcAEQ.js";
|
|
2
|
+
import { u as I } from "./gpgpu_math-C0zyxKFi.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
5
5
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -17,7 +17,7 @@ import { u as I } from "./gpgpu_math-CUzjlO9A.js";
|
|
|
17
17
|
* =============================================================================
|
|
18
18
|
*/
|
|
19
19
|
function R(t, e, o = "index") {
|
|
20
|
-
const s =
|
|
20
|
+
const s = $(e);
|
|
21
21
|
return s.map((n, r) => {
|
|
22
22
|
const i = `int ${t[r]} = ${o} / ${n}`, u = r === s.length - 1 ? `int ${t[r + 1]} = ${o} - ${t[r]} * ${n}` : `index -= ${t[r]} * ${n}`;
|
|
23
23
|
return `${i}; ${u};`;
|
|
@@ -38,7 +38,7 @@ function S(t, e, o = "index") {
|
|
|
38
38
|
}).join("");
|
|
39
39
|
}
|
|
40
40
|
function F(t) {
|
|
41
|
-
const e =
|
|
41
|
+
const e = $(t).map((o) => o.toString());
|
|
42
42
|
return `
|
|
43
43
|
int getFlatIndex(ivec3 coords) {
|
|
44
44
|
return coords.x * ${e[0]} + coords.y * ${e[1]} + coords.z;
|
|
@@ -82,7 +82,7 @@ function m(t) {
|
|
|
82
82
|
function d(t) {
|
|
83
83
|
return t % 2 === 0;
|
|
84
84
|
}
|
|
85
|
-
function
|
|
85
|
+
function f(t, e) {
|
|
86
86
|
if (t = t.slice(-2), e = e.slice(-2), g(t, e) || !t.length || !e.length || t[0] === 0 || t[1] === 0 || e[0] === 0 || e[1] === 0)
|
|
87
87
|
return !0;
|
|
88
88
|
if (t.length !== e.length) {
|
|
@@ -201,12 +201,12 @@ function D(t, e, o) {
|
|
|
201
201
|
* limitations under the License.
|
|
202
202
|
* =============================================================================
|
|
203
203
|
*/
|
|
204
|
-
function
|
|
204
|
+
function U(t) {
|
|
205
205
|
const { inputs: e, backend: o, attrs: s } = t, { x: n } = e, { shape: r } = s, i = o, u = p(n.shape), a = C(r, u), c = p(a);
|
|
206
206
|
x(u === c, () => `The new shape (${a}) has ${c} elements and the old shape (${n.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
|
|
207
207
|
const l = i.texData.get(n.dataId);
|
|
208
|
-
return l.isPacked &&
|
|
208
|
+
return l.isPacked && !f(n.shape, a) && !(l.texture !== null && f(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
|
|
209
209
|
}
|
|
210
210
|
export {
|
|
211
|
-
|
|
211
|
+
U as r
|
|
212
212
|
};
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
|
|
|
11
11
|
import "./papaparse.min-C8l2Kvo1.js";
|
|
12
12
|
import "./index-Tf7vU29b.js";
|
|
13
13
|
import "./jszip.min-CjP2V1VV.js";
|
|
14
|
-
import "./index
|
|
14
|
+
import "./index-iNhkcAEQ.js";
|
|
15
15
|
import "./ops/cpu/scatterSub.js";
|
|
16
16
|
import "./ops/webgl/scatterSub.js";
|
|
17
17
|
import "./ops/cpu/gatherSub.js";
|