@genai-fi/nanogpt 0.4.2 → 0.4.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +3 -3
- package/dist/NanoGPTModel.js +73 -76
- package/dist/Reshape-CiAY8ltP.js +212 -0
- package/dist/TeachableLLM.js +7 -1
- package/dist/{TiedEmbedding-CnJ1bx4q.js → TiedEmbedding-DznFwzcB.js} +244 -244
- package/dist/{axis_util-BgTGy5w8.js → axis_util-QP0LdI1v.js} +1 -1
- package/dist/{concat-CuRsVY-K.js → concat-DvWM7HGZ.js} +1 -1
- package/dist/data/parquet.js +9 -6
- package/dist/data/textLoader.js +6 -5
- package/dist/{dropout-DfDdklfL.js → dropout-DFEXTPV0.js} +4 -4
- package/dist/{gather-ZYRWhmXR.js → gather-C5D8PxwA.js} +1 -1
- package/dist/gpgpu_math-CUzjlO9A.js +23 -0
- package/dist/{index-C4JCoBvj.js → index--6vO-cOz.js} +87 -87
- package/dist/{kernel_funcs_utils-CAd1h9X1.js → kernel_funcs_utils-C6YBCuOt.js} +72 -91
- package/dist/layers/CausalSelfAttention.js +44 -44
- package/dist/layers/MLP.js +31 -33
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/{log_sum_exp-BswFnwOb.js → log_sum_exp-CiEy1aUe.js} +7 -7
- package/dist/main.js +25 -19
- package/dist/{mat_mul-415y5Qn2.js → mat_mul-BEHRPMh0.js} +1 -1
- package/dist/{max-CP_9O2Yd.js → max-BUShNgfh.js} +1 -1
- package/dist/{moments-CjeIaVdp.js → moments-DYOHXoRV.js} +5 -5
- package/dist/{norm-CZM380I3.js → norm-DSva3hI3.js} +13 -13
- package/dist/{ones-Bf3YR48P.js → ones-D6kB8bdY.js} +2 -2
- package/dist/ops/appendCache.d.ts +1 -1
- package/dist/ops/appendCache.js +10 -4
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.d.ts +1 -2
- package/dist/ops/cpu/appendCache.js +15 -20
- package/dist/ops/cpu/attentionMask.js +10 -10
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +4 -4
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.d.ts +1 -0
- package/dist/ops/cpu/matMulGelu.js +40 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +4 -4
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +24 -3
- package/dist/ops/grads/matMulGelu.d.ts +1 -0
- package/dist/ops/grads/matMulGelu.js +17 -0
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.d.ts +3 -0
- package/dist/ops/matMulGelu.js +14 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +14 -13
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +689 -895
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.d.ts +20 -0
- package/dist/ops/webgl/matMulGelu.js +166 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{range-9AzeApCc.js → range-C_vpUjBu.js} +1 -1
- package/dist/{reshape-Boe4DuIO.js → reshape-z51Eu-re.js} +1 -1
- package/dist/{sin-KmhiDuMa.js → sin-H567uayl.js} +1 -1
- package/dist/{slice_util-19zDNNSn.js → slice_util-BdhYwFY_.js} +2 -2
- package/dist/{softmax-Cujsg4ay.js → softmax-Dsxflvdl.js} +1 -1
- package/dist/{split-DbcNm1-i.js → split-B_k_jwud.js} +1 -1
- package/dist/{stack-D1YjmgKN.js → stack-CmqSdsfs.js} +1 -1
- package/dist/{sum-R28pucR5.js → sum-DdkDf2MG.js} +1 -1
- package/dist/{tensor-BVeHdl7V.js → tensor-BGYi41cj.js} +1 -1
- package/dist/{tensor2d-DqFGNs_K.js → tensor2d-DUr_htjt.js} +1 -1
- package/dist/{tfjs_backend-Cug-PH75.js → tfjs_backend-DuKis_xG.js} +46 -46
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +18 -18
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +5 -5
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-LJT9Ld63.js → variable-BJTZ3jOy.js} +1 -1
- package/dist/{zeros-dnQxFgAD.js → zeros-8xl-W2DC.js} +1 -1
- package/package.json +1 -1
- package/dist/gelu-CnCt17Lk.js +0 -26
package/dist/Generator.js
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { E as u } from "./index-Dwqa6Zy2.js";
|
|
2
|
-
import "./index-
|
|
3
|
-
import { t as d } from "./tensor2d-
|
|
4
|
-
import { c as p } from "./concat-
|
|
2
|
+
import "./index--6vO-cOz.js";
|
|
3
|
+
import { t as d } from "./tensor2d-DUr_htjt.js";
|
|
4
|
+
import { c as p } from "./concat-DvWM7HGZ.js";
|
|
5
5
|
class w extends u {
|
|
6
6
|
constructor(s, e) {
|
|
7
7
|
super(), this.model = s, this.tokeniser = e;
|
package/dist/NanoGPTModel.js
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
import { defaultConfig as x } from "./config.js";
|
|
2
2
|
import W from "./layers/TransformerBlock.js";
|
|
3
|
-
import { E as F, D as P, T as q, r as
|
|
4
|
-
import
|
|
3
|
+
import { E as F, D as P, T as q, r as T, p as D } from "./TiedEmbedding-DznFwzcB.js";
|
|
4
|
+
import K from "./layers/RoPECache.js";
|
|
5
5
|
import N from "./layers/RMSNorm.js";
|
|
6
6
|
import { estimateParameterCount as R } from "./utilities/parameters.js";
|
|
7
7
|
import { createSoftmaxCrossEntropyWithGrad as A } from "./training/sparseCrossEntropy.js";
|
|
8
8
|
import B from "./layers/BaseLayer.js";
|
|
9
|
-
import { o as
|
|
10
|
-
import { r as _ } from "./reshape-
|
|
11
|
-
import { r as
|
|
12
|
-
import { e as
|
|
13
|
-
import { g as
|
|
14
|
-
import { s as
|
|
9
|
+
import { o as $, h as E, p as G, E as v, a9 as O, aa as j, ab as Q, t as w, a8 as V, f as C } from "./index--6vO-cOz.js";
|
|
10
|
+
import { r as _ } from "./reshape-z51Eu-re.js";
|
|
11
|
+
import { r as X } from "./range-C_vpUjBu.js";
|
|
12
|
+
import { e as H } from "./tfjs_backend-DuKis_xG.js";
|
|
13
|
+
import { g as J } from "./gather-C5D8PxwA.js";
|
|
14
|
+
import { s as U } from "./softmax-Dsxflvdl.js";
|
|
15
15
|
/**
|
|
16
16
|
* @license
|
|
17
17
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -28,13 +28,13 @@ import { s as J } from "./softmax-Cujsg4ay.js";
|
|
|
28
28
|
* limitations under the License.
|
|
29
29
|
* =============================================================================
|
|
30
30
|
*/
|
|
31
|
-
function Y(
|
|
32
|
-
let e = E(
|
|
31
|
+
function Y(c, t) {
|
|
32
|
+
let e = E(c, "a", "mod"), o = E(t, "b", "mod");
|
|
33
33
|
[e, o] = G(e, o);
|
|
34
|
-
const
|
|
35
|
-
return v.runKernel(O,
|
|
34
|
+
const i = { a: e, b: o };
|
|
35
|
+
return v.runKernel(O, i);
|
|
36
36
|
}
|
|
37
|
-
const Z = /* @__PURE__ */
|
|
37
|
+
const Z = /* @__PURE__ */ $({ mod_: Y });
|
|
38
38
|
/**
|
|
39
39
|
* @license
|
|
40
40
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -51,17 +51,17 @@ const Z = /* @__PURE__ */ y({ mod_: Y });
|
|
|
51
51
|
* limitations under the License.
|
|
52
52
|
* =============================================================================
|
|
53
53
|
*/
|
|
54
|
-
function tt(
|
|
55
|
-
const
|
|
54
|
+
function tt(c, t, e, o = !1) {
|
|
55
|
+
const i = E(c, "logits", "multinomial"), s = i.size, l = i.rank;
|
|
56
56
|
if (s < 2)
|
|
57
57
|
throw new Error(`Error in multinomial: you need at least 2 outcomes, but got ${s}.`);
|
|
58
|
-
if (
|
|
59
|
-
throw new Error(`Rank of probabilities must be 1 or 2, but is ${
|
|
58
|
+
if (l > 2)
|
|
59
|
+
throw new Error(`Rank of probabilities must be 1 or 2, but is ${l}`);
|
|
60
60
|
e = e || Math.random();
|
|
61
|
-
const
|
|
62
|
-
return
|
|
61
|
+
const n = { logits: l === 1 ? _(i, [1, -1]) : i }, h = { numSamples: t, seed: e, normalized: o }, a = v.runKernel(j, n, h);
|
|
62
|
+
return l === 1 ? _(a, [a.size]) : a;
|
|
63
63
|
}
|
|
64
|
-
const M = /* @__PURE__ */
|
|
64
|
+
const M = /* @__PURE__ */ $({ multinomial_: tt });
|
|
65
65
|
/**
|
|
66
66
|
* @license
|
|
67
67
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -78,19 +78,19 @@ const M = /* @__PURE__ */ y({ multinomial_: tt });
|
|
|
78
78
|
* limitations under the License.
|
|
79
79
|
* =============================================================================
|
|
80
80
|
*/
|
|
81
|
-
function et(
|
|
82
|
-
const o = E(
|
|
81
|
+
function et(c, t = 1, e = !0) {
|
|
82
|
+
const o = E(c, "x", "topk");
|
|
83
83
|
if (o.rank === 0)
|
|
84
84
|
throw new Error("topk() expects the input to be of rank 1 or higher");
|
|
85
|
-
const
|
|
85
|
+
const i = o.shape[o.shape.length - 1];
|
|
86
86
|
if (t < 0)
|
|
87
87
|
throw new Error(`'k' passed to topk() must be >= 0 but got ${t}`);
|
|
88
|
-
if (t >
|
|
89
|
-
throw new Error(`'k' passed to topk() must be <= the last dimension (${
|
|
90
|
-
const s = { x: o },
|
|
91
|
-
return { values:
|
|
88
|
+
if (t > i)
|
|
89
|
+
throw new Error(`'k' passed to topk() must be <= the last dimension (${i}) but got ${t}`);
|
|
90
|
+
const s = { x: o }, l = { k: t, sorted: e }, [r, n] = v.runKernel(Q, s, l);
|
|
91
|
+
return { values: r, indices: n };
|
|
92
92
|
}
|
|
93
|
-
const ot = /* @__PURE__ */
|
|
93
|
+
const ot = /* @__PURE__ */ $({ topk_: et });
|
|
94
94
|
/**
|
|
95
95
|
* @license
|
|
96
96
|
* Copyright 2018 Google LLC
|
|
@@ -100,11 +100,11 @@ const ot = /* @__PURE__ */ y({ topk_: et });
|
|
|
100
100
|
* https://opensource.org/licenses/MIT.
|
|
101
101
|
* =============================================================================
|
|
102
102
|
*/
|
|
103
|
-
function st(
|
|
104
|
-
return new P(
|
|
103
|
+
function st(c) {
|
|
104
|
+
return new P(c);
|
|
105
105
|
}
|
|
106
|
-
function
|
|
107
|
-
return new F(
|
|
106
|
+
function it(c) {
|
|
107
|
+
return new F(c);
|
|
108
108
|
}
|
|
109
109
|
class wt extends B {
|
|
110
110
|
wte;
|
|
@@ -124,12 +124,12 @@ class wt extends B {
|
|
|
124
124
|
vocabSize: this.config.gpt.vocabSize,
|
|
125
125
|
embedDim: this.config.gpt.nEmbed,
|
|
126
126
|
name: "token_embedding"
|
|
127
|
-
}), this.config.gpt.useRope === !1 ? this.wpe =
|
|
127
|
+
}), this.config.gpt.useRope === !1 ? this.wpe = it({
|
|
128
128
|
inputDim: this.config.gpt.blockSize,
|
|
129
129
|
outputDim: this.config.gpt.nEmbed,
|
|
130
130
|
name: "positional_embedding",
|
|
131
|
-
embeddingsInitializer:
|
|
132
|
-
}) : (this.ropeCache = new
|
|
131
|
+
embeddingsInitializer: T({ mean: 0, stddev: 0.02 })
|
|
132
|
+
}) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
|
|
133
133
|
for (let e = 0; e < this.config.gpt.nLayer; e++)
|
|
134
134
|
this.blocks.push(new W(e, this.config));
|
|
135
135
|
this.lnF = new N(this.config, 1e-8, "final_rms_norm");
|
|
@@ -163,12 +163,12 @@ class wt extends B {
|
|
|
163
163
|
}
|
|
164
164
|
inputPhase(t, e, o = !1) {
|
|
165
165
|
return w(() => {
|
|
166
|
-
const
|
|
166
|
+
const i = this.wte.embed(t);
|
|
167
167
|
if (this.config.gpt.useRope === !1) {
|
|
168
|
-
const [, s] = t.shape,
|
|
169
|
-
return this.drop.apply(
|
|
168
|
+
const [, s] = t.shape, l = this.config.gpt.blockSize, r = X(0, s, 1, "int32"), n = Z(V(r, C(e, "int32")), C(l, "int32")), h = this.wpe.apply(n), a = i.add(h);
|
|
169
|
+
return this.drop.apply(a, { training: o });
|
|
170
170
|
} else
|
|
171
|
-
return this.drop.apply(
|
|
171
|
+
return this.drop.apply(i, { training: o });
|
|
172
172
|
});
|
|
173
173
|
}
|
|
174
174
|
setSkipMask(t) {
|
|
@@ -209,67 +209,64 @@ class wt extends B {
|
|
|
209
209
|
return w(() => {
|
|
210
210
|
if (t.length === 0)
|
|
211
211
|
throw new Error("No attentions for rollout");
|
|
212
|
-
const [e, o,
|
|
213
|
-
for (const
|
|
214
|
-
const [
|
|
215
|
-
if (
|
|
212
|
+
const [e, o, i] = t[0].shape;
|
|
213
|
+
for (const n of t) {
|
|
214
|
+
const [h, a, p] = n.shape;
|
|
215
|
+
if (h !== e || a !== o || p !== i)
|
|
216
216
|
throw new Error(
|
|
217
|
-
`Inconsistent attention shapes in rollout: expected [${e},${o},${
|
|
217
|
+
`Inconsistent attention shapes in rollout: expected [${e},${o},${i}] got [${h},${a},${p}]`
|
|
218
218
|
);
|
|
219
219
|
}
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
r = i.div(i.sum(-1, !0)).matMul(r);
|
|
226
|
-
}
|
|
227
|
-
return r;
|
|
220
|
+
const s = t.map((n) => n.slice([0, 0, 0], [e, o, o])), l = H(o, o).expandDims(0);
|
|
221
|
+
let r = l.tile([e, 1, 1]);
|
|
222
|
+
for (const n of s) {
|
|
223
|
+
const h = n.add(l);
|
|
224
|
+
r = h.div(h.sum(-1, !0)).matMul(r);
|
|
228
225
|
}
|
|
229
|
-
|
|
226
|
+
return r;
|
|
230
227
|
});
|
|
231
228
|
}
|
|
232
|
-
forward(t, e, o = !1,
|
|
229
|
+
forward(t, e, o = !1, i = !1, s) {
|
|
233
230
|
return this.validateInput(t), w(() => {
|
|
234
231
|
this.startMemory();
|
|
235
|
-
const
|
|
236
|
-
let
|
|
237
|
-
const
|
|
232
|
+
const l = s?.[0]?.length ?? 0;
|
|
233
|
+
let r = this.inputPhase(t, l, o);
|
|
234
|
+
const n = [];
|
|
238
235
|
if (s && s.length !== this.blocks.length)
|
|
239
236
|
throw console.error("Cache", s), new Error(`Cache length ${s.length} does not match number of blocks ${this.blocks.length}`);
|
|
240
|
-
for (let
|
|
241
|
-
const u =
|
|
237
|
+
for (let g = 0; g < this.blocks.length; g++) {
|
|
238
|
+
const u = r, m = this.blocks[g], {
|
|
242
239
|
output: b,
|
|
243
240
|
attention: k,
|
|
244
241
|
cache: f
|
|
245
|
-
} = m.call(
|
|
246
|
-
|
|
242
|
+
} = m.call(r, o, i, s ? s[g] : void 0);
|
|
243
|
+
r = b, u.dispose(), i && k && n.push(k), s && f ? (s[g]?.k.dispose(), s[g]?.v.dispose(), s[g] = f) : f && (f.k.dispose(), f.v.dispose());
|
|
247
244
|
}
|
|
245
|
+
let h;
|
|
246
|
+
i && n.length > 0 && (h = this.computeAttentionRollout(n)), r = this.lnF.apply(r);
|
|
247
|
+
const a = this.wte.project(r);
|
|
248
248
|
let p;
|
|
249
|
-
|
|
250
|
-
const l = this.wte.project(a);
|
|
251
|
-
let g;
|
|
252
|
-
return e && (g = this.calculateLoss(l, e)), this.endMemory("Forward"), { logits: l, loss: g, attention: n ? p : void 0 };
|
|
249
|
+
return e && (p = this.calculateLoss(a, e)), this.endMemory("Forward"), { logits: a, loss: p, attention: i ? h : void 0 };
|
|
253
250
|
});
|
|
254
251
|
}
|
|
255
252
|
generate(t, e, o) {
|
|
256
|
-
const
|
|
253
|
+
const i = o?.temperature ?? 1, s = o?.topK, l = o?.usePadding ?? !1, r = o?.includeAttention ?? !1;
|
|
257
254
|
return w(() => {
|
|
258
|
-
const
|
|
259
|
-
[0,
|
|
260
|
-
[
|
|
261
|
-
),
|
|
255
|
+
const n = t, h = n.shape[1], a = h <= this.config.gpt.blockSize ? n : n.slice(
|
|
256
|
+
[0, h - this.config.gpt.blockSize],
|
|
257
|
+
[n.shape[0], this.config.gpt.blockSize]
|
|
258
|
+
), p = l ? this.config.gpt.blockSize - a.shape[1] : 0, g = p > 0 ? D(a, [
|
|
262
259
|
[0, 0],
|
|
263
|
-
[0,
|
|
264
|
-
]) :
|
|
260
|
+
[0, p]
|
|
261
|
+
]) : a, { logits: u, attention: m } = this.forward(g, void 0, !1, r, e), b = u.shape[1] - 1 - p, k = u.slice([0, b, 0], [u.shape[0], 1, u.shape[2]]), f = m ? m.slice([0, b, 0], [m.shape[0], 1, m.shape[2]]) : void 0, y = k.div(i);
|
|
265
262
|
let d;
|
|
266
263
|
if (s) {
|
|
267
|
-
const { values: S, indices: I } = ot(
|
|
268
|
-
d =
|
|
264
|
+
const { values: S, indices: I } = ot(y, s), L = M(S.squeeze([1]), 1);
|
|
265
|
+
d = J(I.squeeze([1]), L, 1);
|
|
269
266
|
} else
|
|
270
|
-
d = M(
|
|
267
|
+
d = M(y.squeeze([1]), 1);
|
|
271
268
|
let z;
|
|
272
|
-
return o?.includeProbabilities && (z =
|
|
269
|
+
return o?.includeProbabilities && (z = U(y.squeeze([1]))), d = d.reshape([1, 1]), { output: d, attention: f?.squeeze([1]), probabilities: z };
|
|
273
270
|
});
|
|
274
271
|
}
|
|
275
272
|
getNumParams() {
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import { ac as f, ad as g, n as p, ae as C, j as x } from "./index--6vO-cOz.js";
|
|
2
|
+
import { u as I } from "./gpgpu_math-CUzjlO9A.js";
|
|
3
|
+
/**
|
|
4
|
+
* @license
|
|
5
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
6
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
* you may not use this file except in compliance with the License.
|
|
8
|
+
* You may obtain a copy of the License at
|
|
9
|
+
*
|
|
10
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
*
|
|
12
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
* See the License for the specific language governing permissions and
|
|
16
|
+
* limitations under the License.
|
|
17
|
+
* =============================================================================
|
|
18
|
+
*/
|
|
19
|
+
function R(t, e, o = "index") {
|
|
20
|
+
const s = f(e);
|
|
21
|
+
return s.map((n, r) => {
|
|
22
|
+
const i = `int ${t[r]} = ${o} / ${n}`, u = r === s.length - 1 ? `int ${t[r + 1]} = ${o} - ${t[r]} * ${n}` : `index -= ${t[r]} * ${n}`;
|
|
23
|
+
return `${i}; ${u};`;
|
|
24
|
+
}).join("");
|
|
25
|
+
}
|
|
26
|
+
function y(t, e) {
|
|
27
|
+
const o = t.length, s = t.map((r) => `${e}[${r}]`), n = new Array(o - 1);
|
|
28
|
+
n[o - 2] = s[o - 1];
|
|
29
|
+
for (let r = o - 3; r >= 0; --r)
|
|
30
|
+
n[r] = `(${n[r + 1]} * ${s[r + 1]})`;
|
|
31
|
+
return n;
|
|
32
|
+
}
|
|
33
|
+
function S(t, e, o = "index") {
|
|
34
|
+
const s = t.map((r, i) => i), n = y(s, e);
|
|
35
|
+
return n.map((r, i) => {
|
|
36
|
+
const u = `int ${t[i]} = ${o} / ${n[i]}`, a = i === n.length - 1 ? `int ${t[i + 1]} = ${o} - ${t[i]} * ${n[i]}` : `index -= ${t[i]} * ${n[i]}`;
|
|
37
|
+
return `${u}; ${a};`;
|
|
38
|
+
}).join("");
|
|
39
|
+
}
|
|
40
|
+
function F(t) {
|
|
41
|
+
const e = f(t).map((o) => o.toString());
|
|
42
|
+
return `
|
|
43
|
+
int getFlatIndex(ivec3 coords) {
|
|
44
|
+
return coords.x * ${e[0]} + coords.y * ${e[1]} + coords.z;
|
|
45
|
+
}
|
|
46
|
+
`;
|
|
47
|
+
}
|
|
48
|
+
function v() {
|
|
49
|
+
return `
|
|
50
|
+
int getFlatIndex(ivec3 coords) {
|
|
51
|
+
return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;
|
|
52
|
+
}
|
|
53
|
+
`;
|
|
54
|
+
}
|
|
55
|
+
/**
|
|
56
|
+
* @license
|
|
57
|
+
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
58
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
59
|
+
* you may not use this file except in compliance with the License.
|
|
60
|
+
* You may obtain a copy of the License at
|
|
61
|
+
*
|
|
62
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
63
|
+
*
|
|
64
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
65
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
66
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
67
|
+
* See the License for the specific language governing permissions and
|
|
68
|
+
* limitations under the License.
|
|
69
|
+
* =============================================================================
|
|
70
|
+
*/
|
|
71
|
+
function h(t, e = 2) {
|
|
72
|
+
return p(t.slice(0, t.length - e));
|
|
73
|
+
}
|
|
74
|
+
function m(t) {
|
|
75
|
+
if (t.length === 0)
|
|
76
|
+
throw Error("Cannot get rows and columns of an empty shape array.");
|
|
77
|
+
return [
|
|
78
|
+
t.length > 1 ? t[t.length - 2] : 1,
|
|
79
|
+
t[t.length - 1]
|
|
80
|
+
];
|
|
81
|
+
}
|
|
82
|
+
function d(t) {
|
|
83
|
+
return t % 2 === 0;
|
|
84
|
+
}
|
|
85
|
+
function $(t, e) {
|
|
86
|
+
if (t = t.slice(-2), e = e.slice(-2), g(t, e) || !t.length || !e.length || t[0] === 0 || t[1] === 0 || e[0] === 0 || e[1] === 0)
|
|
87
|
+
return !0;
|
|
88
|
+
if (t.length !== e.length) {
|
|
89
|
+
const o = t[t.length - 1], s = e[e.length - 1];
|
|
90
|
+
if (o === s || d(o) && d(s) && (t[0] === 1 || e[0] === 1))
|
|
91
|
+
return !0;
|
|
92
|
+
}
|
|
93
|
+
return t[1] === e[1] && d(t[0]) && d(e[0]);
|
|
94
|
+
}
|
|
95
|
+
/**
|
|
96
|
+
* @license
|
|
97
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
98
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
99
|
+
* you may not use this file except in compliance with the License.
|
|
100
|
+
* You may obtain a copy of the License at
|
|
101
|
+
*
|
|
102
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
103
|
+
*
|
|
104
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
105
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
106
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
107
|
+
* See the License for the specific language governing permissions and
|
|
108
|
+
* limitations under the License.
|
|
109
|
+
* =============================================================================
|
|
110
|
+
*/
|
|
111
|
+
class b {
|
|
112
|
+
constructor(e, o) {
|
|
113
|
+
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{ name: "inputShape", type: "ivec3" }], this.outputShape = e, this.enableShapeUniforms = I(this.outputShape.length);
|
|
114
|
+
let s = "";
|
|
115
|
+
for (let n = 0; n < 4; n++) {
|
|
116
|
+
let r = "thisRC = rc;";
|
|
117
|
+
n % 2 === 1 && (r += "thisRC.z += 1;"), n > 1 && (r += "thisRC.y += 1;"), s += `
|
|
118
|
+
${r}
|
|
119
|
+
${n > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : ""}
|
|
120
|
+
int flatIndex = getFlatIndex(thisRC);
|
|
121
|
+
|
|
122
|
+
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
|
|
123
|
+
vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
|
|
124
|
+
|
|
125
|
+
result[${n}] =
|
|
126
|
+
getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
|
|
127
|
+
${n > 0 ? "}" : ""}
|
|
128
|
+
`;
|
|
129
|
+
}
|
|
130
|
+
this.userCode = `
|
|
131
|
+
${w(o, this.enableShapeUniforms)}
|
|
132
|
+
${this.enableShapeUniforms ? v() : F(e)}
|
|
133
|
+
|
|
134
|
+
void main() {
|
|
135
|
+
ivec3 rc = getOutputCoords();
|
|
136
|
+
|
|
137
|
+
vec4 result = vec4(0.);
|
|
138
|
+
|
|
139
|
+
ivec3 thisRC;
|
|
140
|
+
int rows = ${this.enableShapeUniforms ? "outShape[1]" : e[1]};
|
|
141
|
+
int cols = ${this.enableShapeUniforms ? "outShape[2]" : e[2]};
|
|
142
|
+
|
|
143
|
+
${s}
|
|
144
|
+
|
|
145
|
+
setOutput(result);
|
|
146
|
+
}
|
|
147
|
+
`;
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
function w(t, e) {
|
|
151
|
+
return `
|
|
152
|
+
ivec3 inputCoordsFromReshapedOutCoords(int index) {
|
|
153
|
+
${e ? S(["r", "c", "d"], "inputShape") : R(["r", "c", "d"], t)}
|
|
154
|
+
return ivec3(r, c, d);
|
|
155
|
+
}
|
|
156
|
+
`;
|
|
157
|
+
}
|
|
158
|
+
/**
|
|
159
|
+
* @license
|
|
160
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
161
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
162
|
+
* you may not use this file except in compliance with the License.
|
|
163
|
+
* You may obtain a copy of the License at
|
|
164
|
+
*
|
|
165
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
166
|
+
*
|
|
167
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
168
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
169
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
170
|
+
* See the License for the specific language governing permissions and
|
|
171
|
+
* limitations under the License.
|
|
172
|
+
* =============================================================================
|
|
173
|
+
*/
|
|
174
|
+
function D(t, e, o) {
|
|
175
|
+
const s = [
|
|
176
|
+
h(t.shape),
|
|
177
|
+
...m(t.shape)
|
|
178
|
+
], n = {
|
|
179
|
+
dtype: t.dtype,
|
|
180
|
+
shape: s,
|
|
181
|
+
dataId: t.dataId
|
|
182
|
+
}, r = [
|
|
183
|
+
h(e),
|
|
184
|
+
...m(e)
|
|
185
|
+
], i = new b(r, s), u = !0, a = [s], c = o.runWebGLProgram(i, [n], t.dtype, a, u);
|
|
186
|
+
return { dataId: c.dataId, shape: e, dtype: c.dtype };
|
|
187
|
+
}
|
|
188
|
+
/**
|
|
189
|
+
* @license
|
|
190
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
191
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
192
|
+
* you may not use this file except in compliance with the License.
|
|
193
|
+
* You may obtain a copy of the License at
|
|
194
|
+
*
|
|
195
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
196
|
+
*
|
|
197
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
198
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
199
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
200
|
+
* See the License for the specific language governing permissions and
|
|
201
|
+
* limitations under the License.
|
|
202
|
+
* =============================================================================
|
|
203
|
+
*/
|
|
204
|
+
function k(t) {
|
|
205
|
+
const { inputs: e, backend: o, attrs: s } = t, { x: n } = e, { shape: r } = s, i = o, u = p(n.shape), a = C(r, u), c = p(a);
|
|
206
|
+
x(u === c, () => `The new shape (${a}) has ${c} elements and the old shape (${n.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
|
|
207
|
+
const l = i.texData.get(n.dataId);
|
|
208
|
+
return l.isPacked && !$(n.shape, a) && !(l.texture !== null && $(l.shape, a)) ? D(n, a, i) : (i.incRef(n.dataId), { dataId: n.dataId, shape: a, dtype: n.dtype });
|
|
209
|
+
}
|
|
210
|
+
export {
|
|
211
|
+
k as r
|
|
212
|
+
};
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -11,7 +11,7 @@ import g from "./tokeniser/bpe.js";
|
|
|
11
11
|
import "./papaparse.min-C8l2Kvo1.js";
|
|
12
12
|
import "./index-Tf7vU29b.js";
|
|
13
13
|
import "./jszip.min-CjP2V1VV.js";
|
|
14
|
-
import "./index-
|
|
14
|
+
import "./index--6vO-cOz.js";
|
|
15
15
|
import "./ops/cpu/scatterSub.js";
|
|
16
16
|
import "./ops/webgl/scatterSub.js";
|
|
17
17
|
import "./ops/cpu/gatherSub.js";
|
|
@@ -31,6 +31,12 @@ import "./ops/webgl/appendCache.js";
|
|
|
31
31
|
import "./ops/cpu/fusedSoftmax.js";
|
|
32
32
|
import "./ops/webgl/fusedSoftmax.js";
|
|
33
33
|
import "./ops/grads/fusedSoftmax.js";
|
|
34
|
+
import "./ops/cpu/matMulGelu.js";
|
|
35
|
+
import "./ops/webgl/matMulGelu.js";
|
|
36
|
+
import "./ops/grads/matMulGelu.js";
|
|
37
|
+
import "./ops/cpu/gelu.js";
|
|
38
|
+
import "./ops/webgl/gelu.js";
|
|
39
|
+
import "./ops/grads/gelu.js";
|
|
34
40
|
import w from "./utilities/profile.js";
|
|
35
41
|
class a {
|
|
36
42
|
ee = new p();
|