@genai-fi/nanogpt 0.5.0 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +95 -46
- package/dist/NanoGPTModel.d.ts +3 -2
- package/dist/NanoGPTModel.js +91 -76
- package/dist/{Reshape-BE5rA4rT.js → Reshape-Bt_t7RNz.js} +4 -4
- package/dist/TeachableLLM.js +1 -1
- package/dist/TiedEmbedding-DORsPlNL.js +44 -0
- package/dist/{axis_util-97KkkyRQ.js → axis_util-CVbf1vmL.js} +3 -3
- package/dist/{broadcast_to-CMlkG8NS.js → broadcast_to-BBoMQXbL.js} +4 -4
- package/dist/{concat-Cxbo2sOz.js → concat-BRRtq4S2.js} +1 -1
- package/dist/dataset-ZHEPJmED.js +1226 -0
- package/dist/{dropout-kbDY39Ci.js → dropout-lQm_YyX3.js} +1 -1
- package/dist/{gather-Bxe1Qip8.js → gather-BWyutxwi.js} +3 -3
- package/dist/{gpgpu_math-C0zyxKFi.js → gpgpu_math-Df7gzJWH.js} +1 -1
- package/dist/{index-iNhkcAEQ.js → index-CnHyhpKc.js} +32 -32
- package/dist/{kernel_funcs_utils-C4eIk4fE.js → kernel_funcs_utils-Dqo82NH4.js} +25 -25
- package/dist/layers/BaseLayer.js +114 -3
- package/dist/layers/CausalSelfAttention.d.ts +2 -3
- package/dist/layers/CausalSelfAttention.js +31 -30
- package/dist/layers/MLP.js +10 -9
- package/dist/layers/RMSNorm.js +12 -11
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +8 -6
- package/dist/layers/TransformerBlock.js +2 -2
- package/dist/{log_sum_exp-CkumwesB.js → log_sum_exp-CRH7Np9v.js} +12 -12
- package/dist/main.js +1 -1
- package/dist/{mat_mul-D0SifYfJ.js → mat_mul-DeGU1U_C.js} +3 -3
- package/dist/{max-CYaAjEEp.js → max-CcnEArWK.js} +3 -3
- package/dist/{moments-B06NlR_V.js → moments-DLTE6-1p.js} +4 -4
- package/dist/{norm-D3676xIo.js → norm-BpWsOapl.js} +5 -5
- package/dist/{ones-BIeFnPHR.js → ones-CDWGzVnm.js} +6 -6
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +27 -27
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +36 -36
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.js +22 -22
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{ops-ObfXLHYQ.js → ops-DzQTmLIl.js} +60 -60
- package/dist/{TiedEmbedding-DsDRvLB0.js → random_width-DI2h9CMs.js} +1215 -1250
- package/dist/{range-BsFU-SNG.js → range-CkOJ7090.js} +1 -1
- package/dist/{reshape-DxTPgnwL.js → reshape-CTIbqjwm.js} +1 -1
- package/dist/{sin-BOX-JVAj.js → sin-HzioENy_.js} +5 -5
- package/dist/{slice_util-D-kaD4ZV.js → slice_util-n4wHKmex.js} +1 -1
- package/dist/{softmax-BjsptB07.js → softmax-DX6qXAbm.js} +2 -2
- package/dist/{split-BCbrzthj.js → split-CVwhL8Oe.js} +3 -3
- package/dist/{stack--cqr9Dgc.js → stack-S2-D2JAQ.js} +1 -1
- package/dist/{sum-B_92TaHD.js → sum-UdfvaNhB.js} +4 -4
- package/dist/{tensor-CfiPXsW4.js → tensor-IZex6Bwp.js} +1 -1
- package/dist/{tensor2d-tSxWdFMH.js → tensor2d-CqtBzOKq.js} +1 -1
- package/dist/{tfjs_backend-NucKez4s.js → tfjs_backend-DX9yVvwk.js} +41 -41
- package/dist/tokeniser/CharTokeniser.js +27 -27
- package/dist/tokeniser/bpe.d.ts +1 -0
- package/dist/tokeniser/bpe.js +38 -35
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +22 -1242
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +5 -5
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/save.js +5 -5
- package/dist/utilities/weights.js +2 -2
- package/dist/variable-BGvK-VN3.js +23 -0
- package/dist/{zeros-NMYTayy7.js → zeros-CYMicyqz.js} +3 -3
- package/package.json +1 -1
- package/dist/BaseLayer-BhrMN8JO.js +0 -135
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { o as l,
|
|
1
|
+
import { o as l, j as h, E as m, ak as p, n as c, al as d, ae as g, l as u, T as V, am as v, a9 as N, b as w } from "./index-CnHyhpKc.js";
|
|
2
2
|
import { s as f } from "./index-C4L8Cm77.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { o as g,
|
|
1
|
+
import { o as g, j as t, E as h, G as p } from "./index-CnHyhpKc.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -16,8 +16,8 @@ import { o as g, i as t, E as h, G as p } from "./index-iNhkcAEQ.js";
|
|
|
16
16
|
* =============================================================================
|
|
17
17
|
*/
|
|
18
18
|
function u(n, s, r = 0, e = 0) {
|
|
19
|
-
const o = t(n, "x", "gather"), a = t(s, "indices", "gather", "int32"),
|
|
20
|
-
return h.runKernel(p,
|
|
19
|
+
const o = t(n, "x", "gather"), a = t(s, "indices", "gather", "int32"), c = { x: o, indices: a }, i = { axis: r, batchDims: e };
|
|
20
|
+
return h.runKernel(p, c, i);
|
|
21
21
|
}
|
|
22
22
|
const d = /* @__PURE__ */ g({ gather_: u });
|
|
23
23
|
export {
|
|
@@ -4005,26 +4005,26 @@ export {
|
|
|
4005
4005
|
Ss as A,
|
|
4006
4006
|
Zs as B,
|
|
4007
4007
|
or as C,
|
|
4008
|
-
|
|
4008
|
+
Ft as D,
|
|
4009
4009
|
g as E,
|
|
4010
|
-
|
|
4010
|
+
Wa as F,
|
|
4011
4011
|
Pr as G,
|
|
4012
|
-
|
|
4013
|
-
|
|
4014
|
-
|
|
4015
|
-
|
|
4016
|
-
|
|
4012
|
+
Bn as H,
|
|
4013
|
+
Fs as I,
|
|
4014
|
+
kn as J,
|
|
4015
|
+
En as K,
|
|
4016
|
+
Qa as L,
|
|
4017
4017
|
ta as M,
|
|
4018
|
-
|
|
4019
|
-
|
|
4018
|
+
k as N,
|
|
4019
|
+
Lr as O,
|
|
4020
4020
|
ba as P,
|
|
4021
|
-
|
|
4021
|
+
rs as Q,
|
|
4022
4022
|
Ia as R,
|
|
4023
4023
|
qa as S,
|
|
4024
|
-
|
|
4025
|
-
|
|
4026
|
-
|
|
4027
|
-
|
|
4024
|
+
D as T,
|
|
4025
|
+
de as U,
|
|
4026
|
+
Ea as V,
|
|
4027
|
+
Zt as W,
|
|
4028
4028
|
De as X,
|
|
4029
4029
|
ar as Y,
|
|
4030
4030
|
ne as Z,
|
|
@@ -4074,13 +4074,13 @@ export {
|
|
|
4074
4074
|
$t as ad,
|
|
4075
4075
|
Rt as ae,
|
|
4076
4076
|
Rs as af,
|
|
4077
|
-
|
|
4078
|
-
|
|
4079
|
-
|
|
4080
|
-
|
|
4081
|
-
|
|
4082
|
-
|
|
4083
|
-
|
|
4077
|
+
F as ag,
|
|
4078
|
+
pe as ah,
|
|
4079
|
+
fo as ai,
|
|
4080
|
+
dt as aj,
|
|
4081
|
+
xr as ak,
|
|
4082
|
+
Wn as al,
|
|
4083
|
+
x as am,
|
|
4084
4084
|
jt as an,
|
|
4085
4085
|
ue as ao,
|
|
4086
4086
|
za as ap,
|
|
@@ -4214,22 +4214,22 @@ export {
|
|
|
4214
4214
|
K as f,
|
|
4215
4215
|
ss as g,
|
|
4216
4216
|
lo as h,
|
|
4217
|
-
|
|
4218
|
-
|
|
4219
|
-
|
|
4220
|
-
|
|
4217
|
+
To as i,
|
|
4218
|
+
T as j,
|
|
4219
|
+
In as k,
|
|
4220
|
+
y as l,
|
|
4221
4221
|
po as m,
|
|
4222
|
-
|
|
4222
|
+
xt as n,
|
|
4223
4223
|
N as o,
|
|
4224
|
-
|
|
4225
|
-
q,
|
|
4224
|
+
Ge as p,
|
|
4225
|
+
z as q,
|
|
4226
4226
|
co as r,
|
|
4227
4227
|
tt as s,
|
|
4228
4228
|
E as t,
|
|
4229
|
-
|
|
4229
|
+
q as u,
|
|
4230
4230
|
ls as v,
|
|
4231
|
-
|
|
4232
|
-
|
|
4233
|
-
|
|
4231
|
+
Ba as w,
|
|
4232
|
+
Ka as x,
|
|
4233
|
+
qn as y,
|
|
4234
4234
|
C as z
|
|
4235
4235
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { an as D, ao as N,
|
|
2
|
-
import { u as g } from "./gpgpu_math-
|
|
1
|
+
import { an as D, ao as N, Q as w, q as R, U as v, N as P } from "./index-CnHyhpKc.js";
|
|
2
|
+
import { u as g } from "./gpgpu_math-Df7gzJWH.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
5
5
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -23,7 +23,7 @@ function B(t) {
|
|
|
23
23
|
throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${e}`);
|
|
24
24
|
}
|
|
25
25
|
}
|
|
26
|
-
function
|
|
26
|
+
function H(t) {
|
|
27
27
|
return t.map((e) => N(e));
|
|
28
28
|
}
|
|
29
29
|
/**
|
|
@@ -127,12 +127,12 @@ class C {
|
|
|
127
127
|
* =============================================================================
|
|
128
128
|
*/
|
|
129
129
|
class _ {
|
|
130
|
-
constructor(e, o, u,
|
|
130
|
+
constructor(e, o, u, d = !1) {
|
|
131
131
|
this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = w(o, u);
|
|
132
132
|
const a = this.outputShape.length;
|
|
133
133
|
this.enableShapeUniforms = g(a);
|
|
134
134
|
let n = "";
|
|
135
|
-
if (
|
|
135
|
+
if (d)
|
|
136
136
|
if (a === 0 || R(this.outputShape) === 1)
|
|
137
137
|
n = `
|
|
138
138
|
result.y = 0.;
|
|
@@ -225,7 +225,7 @@ function A(t) {
|
|
|
225
225
|
* =============================================================================
|
|
226
226
|
*/
|
|
227
227
|
function G(t) {
|
|
228
|
-
const { inputs: e, backend: o } = t, { real: u, imag:
|
|
228
|
+
const { inputs: e, backend: o } = t, { real: u, imag: d } = e, a = o.makeTensorInfo(u.shape, "complex64"), n = o.texData.get(a.dataId), l = A({ inputs: { x: u }, backend: o }), s = A({ inputs: { x: d }, backend: o });
|
|
229
229
|
return n.complexTensorInfos = { real: l, imag: s }, a;
|
|
230
230
|
}
|
|
231
231
|
/**
|
|
@@ -260,7 +260,7 @@ class V {
|
|
|
260
260
|
`;
|
|
261
261
|
}
|
|
262
262
|
}
|
|
263
|
-
const
|
|
263
|
+
const K = "if (isnan(x)) return x;";
|
|
264
264
|
/**
|
|
265
265
|
* @license
|
|
266
266
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -310,8 +310,8 @@ class L {
|
|
|
310
310
|
* =============================================================================
|
|
311
311
|
*/
|
|
312
312
|
function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
|
|
313
|
-
return ({ inputs:
|
|
314
|
-
const { x: n } =
|
|
313
|
+
return ({ inputs: d, backend: a }) => {
|
|
314
|
+
const { x: n } = d, l = a, s = u || n.dtype;
|
|
315
315
|
if (l.shouldExecuteOnCPU([n]) && o != null) {
|
|
316
316
|
const c = l.texData.get(n.dataId), x = o(c.values, s);
|
|
317
317
|
return l.makeTensorInfo(n.shape, s, x);
|
|
@@ -321,37 +321,37 @@ function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
|
|
|
321
321
|
return i ? r = new L(n.shape, e) : r = new V(n.shape, t), l.runWebGLProgram(r, [n], s);
|
|
322
322
|
};
|
|
323
323
|
}
|
|
324
|
-
function
|
|
324
|
+
function q({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, supportsComplex: u = !1, cpuKernelImpl: d, dtype: a }) {
|
|
325
325
|
return ({ inputs: n, backend: l }) => {
|
|
326
326
|
const { a: s, b: i } = n, r = l;
|
|
327
327
|
if (u && s.dtype === "complex64") {
|
|
328
|
-
const h = r.texData.get(s.dataId), f = r.texData.get(i.dataId), [
|
|
328
|
+
const h = r.texData.get(s.dataId), f = r.texData.get(i.dataId), [y, O] = [
|
|
329
329
|
[h.complexTensorInfos.real, f.complexTensorInfos.real],
|
|
330
330
|
[h.complexTensorInfos.imag, f.complexTensorInfos.imag]
|
|
331
331
|
].map((S) => {
|
|
332
|
-
const [
|
|
333
|
-
dataId:
|
|
334
|
-
dtype:
|
|
332
|
+
const [p, m] = S, $ = {
|
|
333
|
+
dataId: p.dataId,
|
|
334
|
+
dtype: p.dtype,
|
|
335
335
|
shape: s.shape
|
|
336
336
|
}, T = {
|
|
337
337
|
dataId: m.dataId,
|
|
338
338
|
dtype: m.dtype,
|
|
339
339
|
shape: i.shape
|
|
340
340
|
}, U = new C(t, s.shape, i.shape);
|
|
341
|
-
return r.runWebGLProgram(U, [$, T], v(
|
|
342
|
-
}), I = G({ inputs: { real:
|
|
343
|
-
return r.disposeIntermediateTensorInfo(
|
|
341
|
+
return r.runWebGLProgram(U, [$, T], v(p.dtype, m.dtype));
|
|
342
|
+
}), I = G({ inputs: { real: y, imag: O }, backend: r });
|
|
343
|
+
return r.disposeIntermediateTensorInfo(y), r.disposeIntermediateTensorInfo(O), I;
|
|
344
344
|
}
|
|
345
345
|
const c = a || v(s.dtype, i.dtype);
|
|
346
|
-
if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) &&
|
|
347
|
-
const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values,
|
|
346
|
+
if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && d != null) {
|
|
347
|
+
const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values, y = s.dtype === "string" ? (
|
|
348
348
|
// tslint:disable-next-line: no-any
|
|
349
349
|
B(h)
|
|
350
|
-
) : h,
|
|
350
|
+
) : h, O = s.dtype === "string" ? (
|
|
351
351
|
// tslint:disable-next-line: no-any
|
|
352
352
|
B(f)
|
|
353
|
-
) : f, [I, S] =
|
|
354
|
-
return m.values = I,
|
|
353
|
+
) : f, [I, S] = d(s.shape, i.shape, y, O, c), p = r.makeTensorInfo(S, c), m = r.texData.get(p.dataId);
|
|
354
|
+
return m.values = I, p;
|
|
355
355
|
}
|
|
356
356
|
const x = P().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
|
|
357
357
|
let b;
|
|
@@ -359,10 +359,10 @@ function j({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, support
|
|
|
359
359
|
};
|
|
360
360
|
}
|
|
361
361
|
export {
|
|
362
|
-
|
|
363
|
-
|
|
362
|
+
K as C,
|
|
363
|
+
H as a,
|
|
364
364
|
E as b,
|
|
365
|
-
|
|
365
|
+
q as c,
|
|
366
366
|
B as f,
|
|
367
367
|
k as g,
|
|
368
368
|
Y as u
|
package/dist/layers/BaseLayer.js
CHANGED
|
@@ -1,5 +1,116 @@
|
|
|
1
|
-
import "../index-
|
|
2
|
-
import {
|
|
1
|
+
import { T as g, c as p, e as o, i as v } from "../index-CnHyhpKc.js";
|
|
2
|
+
import { v as _ } from "../variable-BGvK-VN3.js";
|
|
3
|
+
class M {
|
|
4
|
+
parent;
|
|
5
|
+
config;
|
|
6
|
+
_variables = /* @__PURE__ */ new Map();
|
|
7
|
+
_trainable = !0;
|
|
8
|
+
children = [];
|
|
9
|
+
constructor(t, r) {
|
|
10
|
+
this.config = t, this.parent = r, this.parent && this.parent.children.push(this);
|
|
11
|
+
}
|
|
12
|
+
getProfiler() {
|
|
13
|
+
return this.config.layerConfig.profiler;
|
|
14
|
+
}
|
|
15
|
+
startMemory() {
|
|
16
|
+
this.config.layerConfig.profiler?.startMemory();
|
|
17
|
+
}
|
|
18
|
+
endMemory(t) {
|
|
19
|
+
this.config.layerConfig.profiler?.endMemory(t);
|
|
20
|
+
}
|
|
21
|
+
addVariable(t, r) {
|
|
22
|
+
this._variables.set(t, r || null);
|
|
23
|
+
}
|
|
24
|
+
get variables() {
|
|
25
|
+
const t = Array.from(this._variables.values()).filter((e) => e !== null), r = this.children.flatMap((e) => e.variables);
|
|
26
|
+
return [...t, ...r];
|
|
27
|
+
}
|
|
28
|
+
get trainableVariables() {
|
|
29
|
+
const t = Array.from(this._variables.values()).filter(
|
|
30
|
+
(e) => e !== null && e.trainable
|
|
31
|
+
), r = this.children.flatMap((e) => e.trainableVariables);
|
|
32
|
+
return [...t, ...r];
|
|
33
|
+
}
|
|
34
|
+
get trainable() {
|
|
35
|
+
return this._trainable;
|
|
36
|
+
}
|
|
37
|
+
set trainable(t) {
|
|
38
|
+
this._trainable = t, this._variables.forEach((r) => {
|
|
39
|
+
r && (r.trainable = t);
|
|
40
|
+
}), this.children.forEach((r) => {
|
|
41
|
+
r.trainable = t;
|
|
42
|
+
});
|
|
43
|
+
}
|
|
44
|
+
getVariable(t) {
|
|
45
|
+
const r = this._variables.get(t);
|
|
46
|
+
if (!r)
|
|
47
|
+
throw new Error(`Variable ${t} not found`);
|
|
48
|
+
return r;
|
|
49
|
+
}
|
|
50
|
+
hasVariable(t) {
|
|
51
|
+
return this._variables.get(t) !== null;
|
|
52
|
+
}
|
|
53
|
+
setVariable(t, r) {
|
|
54
|
+
if (!this._variables.has(t))
|
|
55
|
+
throw new Error(`Variable ${t} not found`);
|
|
56
|
+
this._variables.set(t, r);
|
|
57
|
+
}
|
|
58
|
+
saveWeights(t) {
|
|
59
|
+
this._variables.forEach((r, e) => {
|
|
60
|
+
r && t.set(e, [r.clone()]);
|
|
61
|
+
}), this.children.forEach((r) => {
|
|
62
|
+
r.saveWeights(t);
|
|
63
|
+
});
|
|
64
|
+
}
|
|
65
|
+
loadWeights(t) {
|
|
66
|
+
this._variables.forEach((r, e) => {
|
|
67
|
+
const i = t.get(e)?.[0];
|
|
68
|
+
if (!i)
|
|
69
|
+
throw new Error(`Weights for ${e} not found`);
|
|
70
|
+
r ? r.assign(i) : this._variables.set(e, _(i, this._trainable));
|
|
71
|
+
}), this.children.forEach((r) => {
|
|
72
|
+
r.loadWeights(t);
|
|
73
|
+
});
|
|
74
|
+
}
|
|
75
|
+
dispose() {
|
|
76
|
+
this._variables.forEach((t) => {
|
|
77
|
+
t?.dispose();
|
|
78
|
+
}), this._variables.clear();
|
|
79
|
+
}
|
|
80
|
+
build() {
|
|
81
|
+
}
|
|
82
|
+
dropout(t) {
|
|
83
|
+
return t;
|
|
84
|
+
}
|
|
85
|
+
call(t, ...r) {
|
|
86
|
+
this.build();
|
|
87
|
+
const e = this.forward(t, ...r);
|
|
88
|
+
if (t.training && e instanceof g) {
|
|
89
|
+
const i = this.dropout(e);
|
|
90
|
+
return i !== e && e.dispose(), i;
|
|
91
|
+
} else
|
|
92
|
+
return e;
|
|
93
|
+
}
|
|
94
|
+
callCheckpoint(t, ...r) {
|
|
95
|
+
return this.build(), this.checkpointingFn(t, ...r);
|
|
96
|
+
}
|
|
97
|
+
checkpointingFn(t, ...r) {
|
|
98
|
+
const e = this.trainableVariables, s = p((...a) => {
|
|
99
|
+
const l = a[a.length - 1], n = a.slice(0, r.length), h = this.forward(t, ...n);
|
|
100
|
+
return l(n), { value: h, gradFunc: (c, f) => {
|
|
101
|
+
const u = o().state.activeTape;
|
|
102
|
+
o().state.activeTape = [];
|
|
103
|
+
const b = v((...d) => this.forward(t, ...d.slice(0, n.length)))([...f, ...e], c);
|
|
104
|
+
return o().state.activeTape = u, b;
|
|
105
|
+
} };
|
|
106
|
+
})(...r, ...e);
|
|
107
|
+
if (t.training) {
|
|
108
|
+
const a = this.dropout(s);
|
|
109
|
+
return a !== s && s.dispose(), a;
|
|
110
|
+
} else
|
|
111
|
+
return s;
|
|
112
|
+
}
|
|
113
|
+
}
|
|
3
114
|
export {
|
|
4
|
-
|
|
115
|
+
M as default
|
|
5
116
|
};
|
|
@@ -7,9 +7,8 @@ export type KVCache = {
|
|
|
7
7
|
cumulativeLength: number;
|
|
8
8
|
};
|
|
9
9
|
export interface AttentionScores {
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
attentionOut?: Tensor;
|
|
10
|
+
meanOfHeads?: boolean;
|
|
11
|
+
attentionOut?: Tensor[];
|
|
13
12
|
}
|
|
14
13
|
interface AttentionForwardAttributes extends ForwardAttributes {
|
|
15
14
|
attentionScores?: AttentionScores;
|
|
@@ -1,15 +1,16 @@
|
|
|
1
|
-
import { attentionMask as
|
|
2
|
-
import
|
|
1
|
+
import { attentionMask as g } from "../ops/attentionMask.js";
|
|
2
|
+
import O from "./BaseLayer.js";
|
|
3
3
|
import { qkv as P } from "../ops/qkv.js";
|
|
4
|
-
import { rope as
|
|
5
|
-
import { appendCache as
|
|
6
|
-
import {
|
|
4
|
+
import { rope as v } from "../ops/rope.js";
|
|
5
|
+
import { appendCache as V } from "../ops/appendCache.js";
|
|
6
|
+
import { H as c, t as C } from "../index-CnHyhpKc.js";
|
|
7
7
|
import { fusedSoftmax as T } from "../ops/fusedSoftmax.js";
|
|
8
|
-
import { d as y } from "../tfjs_backend-
|
|
9
|
-
import {
|
|
10
|
-
import { r as
|
|
11
|
-
import {
|
|
12
|
-
|
|
8
|
+
import { d as y } from "../tfjs_backend-DX9yVvwk.js";
|
|
9
|
+
import { v as b } from "../variable-BGvK-VN3.js";
|
|
10
|
+
import { r as k, d as L } from "../dropout-lQm_YyX3.js";
|
|
11
|
+
import { r as N } from "../reshape-CTIbqjwm.js";
|
|
12
|
+
import { m as R } from "../mat_mul-DeGU1U_C.js";
|
|
13
|
+
class $ extends O {
|
|
13
14
|
divisor;
|
|
14
15
|
index;
|
|
15
16
|
units;
|
|
@@ -22,14 +23,14 @@ class W extends O {
|
|
|
22
23
|
build() {
|
|
23
24
|
this.hasVariable(this.ATTN) === !1 && this.setVariable(
|
|
24
25
|
this.ATTN,
|
|
25
|
-
|
|
26
|
+
b(
|
|
26
27
|
k([this.config.gpt.nEmbed, this.units], 0, 0.02),
|
|
27
28
|
!0
|
|
28
29
|
//`block_${this.index}_attn_cAttn_kernel`
|
|
29
30
|
)
|
|
30
31
|
), this.hasVariable(this.PROJ) === !1 && this.setVariable(
|
|
31
32
|
this.PROJ,
|
|
32
|
-
|
|
33
|
+
b(
|
|
33
34
|
k([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
|
|
34
35
|
!0
|
|
35
36
|
//`block_${this.index}_attn_cProj_kernel`
|
|
@@ -37,12 +38,12 @@ class W extends O {
|
|
|
37
38
|
);
|
|
38
39
|
}
|
|
39
40
|
getAttentionScores(t, i, s, o) {
|
|
40
|
-
const e =
|
|
41
|
+
const e = g(t, i, this.divisor), n = T(e, s ? this.config.gpt.dropout : 0, o);
|
|
41
42
|
return e.dispose(), n;
|
|
42
43
|
}
|
|
43
44
|
// Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
|
|
44
45
|
getAttentionScoresWithPast(t, i, s) {
|
|
45
|
-
const o =
|
|
46
|
+
const o = g(t, i, this.divisor, s), e = T(o, 0, 0);
|
|
46
47
|
return o.dispose(), e;
|
|
47
48
|
}
|
|
48
49
|
getQKV(t) {
|
|
@@ -53,33 +54,33 @@ class W extends O {
|
|
|
53
54
|
return n.dispose(), e.dispose(), p;
|
|
54
55
|
}
|
|
55
56
|
updateCache(t, i, s) {
|
|
56
|
-
const o = this.config.gpt.blockSize, e = t.shape[2], n = s.length || 0, p =
|
|
57
|
+
const o = this.config.gpt.blockSize, e = t.shape[2], n = s.length || 0, p = V(t, o, n, s.k);
|
|
57
58
|
t.dispose(), s.k && s.k.dispose();
|
|
58
|
-
const
|
|
59
|
+
const a = V(i, o, n, s.v);
|
|
59
60
|
i.dispose(), s.v && s.v.dispose();
|
|
60
61
|
const d = Math.min(n + e, o), h = s.cumulativeLength + e;
|
|
61
|
-
s.length = d, s.cumulativeLength = h, s.k = c(p), s.v = c(
|
|
62
|
+
s.length = d, s.cumulativeLength = h, s.k = c(p), s.v = c(a);
|
|
62
63
|
}
|
|
63
64
|
forward(t, i) {
|
|
64
65
|
return C(() => {
|
|
65
66
|
this.startMemory();
|
|
66
|
-
const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0, p = this.config.layerConfig.ropeCache,
|
|
67
|
+
const [s, o, e] = this.getQKV(i), n = t.pastKV ? t.pastKV.cumulativeLength : 0, p = this.config.layerConfig.ropeCache, a = p ? v(s, p, n) : s, d = p ? v(o, p, n) : o;
|
|
67
68
|
p && (s.dispose(), o.dispose());
|
|
68
69
|
const h = t.pastKV ? t.pastKV.length : 0;
|
|
69
70
|
t.pastKV && !t.training && this.updateCache(d, e, t.pastKV);
|
|
70
|
-
const u = t.pastKV?.k ? t.pastKV.k : d,
|
|
71
|
-
let
|
|
72
|
-
h > 0 ?
|
|
73
|
-
const
|
|
74
|
-
|
|
75
|
-
const
|
|
76
|
-
if (
|
|
77
|
-
const
|
|
78
|
-
t.attentionScores.attentionOut
|
|
79
|
-
|
|
71
|
+
const u = t.pastKV?.k ? t.pastKV.k : d, m = t.pastKV?.v ? t.pastKV.v : e;
|
|
72
|
+
let r;
|
|
73
|
+
h > 0 ? r = this.getAttentionScoresWithPast(a, u, h) : r = this.getAttentionScores(a, u, t.training, t.seed || 0), a.dispose(), t.pastKV || u.dispose();
|
|
74
|
+
const l = R(r, m), f = t.attentionScores !== void 0 && t.attentionScores.attentionOut !== void 0;
|
|
75
|
+
f || r.dispose(), t.pastKV || m.dispose();
|
|
76
|
+
const A = this.getOutputProjection(l);
|
|
77
|
+
if (l.dispose(), f && t.attentionScores && t.attentionScores.attentionOut !== void 0) {
|
|
78
|
+
const K = r.shape[1], S = r.shape[2];
|
|
79
|
+
t.attentionScores.attentionOut?.push(
|
|
80
|
+
c(r.slice([0, 0, 0, 0], [1, -1, -1, -1]).reshape([K, S, -1]))
|
|
80
81
|
);
|
|
81
82
|
}
|
|
82
|
-
return this.endMemory("CausalSelfAttention"),
|
|
83
|
+
return this.endMemory("CausalSelfAttention"), A;
|
|
83
84
|
});
|
|
84
85
|
}
|
|
85
86
|
dropout(t) {
|
|
@@ -91,5 +92,5 @@ class W extends O {
|
|
|
91
92
|
}
|
|
92
93
|
}
|
|
93
94
|
export {
|
|
94
|
-
|
|
95
|
+
$ as default
|
|
95
96
|
};
|
package/dist/layers/MLP.js
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
import { t as l } from "../index-
|
|
2
|
-
import
|
|
1
|
+
import { t as l } from "../index-CnHyhpKc.js";
|
|
2
|
+
import u from "./BaseLayer.js";
|
|
3
3
|
import { matMulGelu as M } from "../ops/matMulGelu.js";
|
|
4
|
-
import {
|
|
5
|
-
import { r as d } from "../
|
|
6
|
-
import {
|
|
7
|
-
|
|
4
|
+
import { v as o } from "../variable-BGvK-VN3.js";
|
|
5
|
+
import { r as h, d as f } from "../dropout-lQm_YyX3.js";
|
|
6
|
+
import { r as d } from "../reshape-CTIbqjwm.js";
|
|
7
|
+
import { m as c } from "../mat_mul-DeGU1U_C.js";
|
|
8
|
+
class V extends u {
|
|
8
9
|
index;
|
|
9
10
|
hiddenUnits;
|
|
10
11
|
MLPHIDDEN;
|
|
@@ -36,7 +37,7 @@ class O extends u {
|
|
|
36
37
|
forward(i, t) {
|
|
37
38
|
return l(() => {
|
|
38
39
|
this.startMemory();
|
|
39
|
-
const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)), p =
|
|
40
|
+
const [s, r, e] = t.shape, n = d(t, [s * r, e]), a = M(n, this.getVariable(this.MLPHIDDEN)), p = c(a, this.getVariable(this.MLPOUT));
|
|
40
41
|
a.dispose();
|
|
41
42
|
const m = d(p, [s, r, e]);
|
|
42
43
|
return this.endMemory("MLP"), m;
|
|
@@ -44,12 +45,12 @@ class O extends u {
|
|
|
44
45
|
}
|
|
45
46
|
dropout(i) {
|
|
46
47
|
if (this.config.gpt.dropout > 0) {
|
|
47
|
-
const t =
|
|
48
|
+
const t = f(i, this.config.gpt.dropout);
|
|
48
49
|
return i.dispose(), t;
|
|
49
50
|
}
|
|
50
51
|
return i;
|
|
51
52
|
}
|
|
52
53
|
}
|
|
53
54
|
export {
|
|
54
|
-
|
|
55
|
+
V as default
|
|
55
56
|
};
|
package/dist/layers/RMSNorm.js
CHANGED
|
@@ -1,20 +1,21 @@
|
|
|
1
|
-
import { t as
|
|
2
|
-
import
|
|
3
|
-
import { normRMS as
|
|
4
|
-
import {
|
|
5
|
-
|
|
1
|
+
import { t as s } from "../index-CnHyhpKc.js";
|
|
2
|
+
import e from "./BaseLayer.js";
|
|
3
|
+
import { normRMS as a } from "../ops/normRMS.js";
|
|
4
|
+
import { v as i } from "../variable-BGvK-VN3.js";
|
|
5
|
+
import { o as m } from "../ones-CDWGzVnm.js";
|
|
6
|
+
class f extends e {
|
|
6
7
|
GAMMA;
|
|
7
|
-
constructor(r, t = "",
|
|
8
|
-
super(r,
|
|
8
|
+
constructor(r, t = "", o) {
|
|
9
|
+
super(r, o), this.GAMMA = t, this.addVariable(this.GAMMA, i(m([r.gpt.nEmbed]), !0, this.GAMMA, "float32"));
|
|
9
10
|
}
|
|
10
11
|
forward(r, t) {
|
|
11
|
-
return
|
|
12
|
+
return s(() => {
|
|
12
13
|
this.startMemory();
|
|
13
|
-
const
|
|
14
|
-
return this.endMemory("RMSNorm"),
|
|
14
|
+
const o = a(t, this.getVariable(this.GAMMA));
|
|
15
|
+
return this.endMemory("RMSNorm"), o;
|
|
15
16
|
});
|
|
16
17
|
}
|
|
17
18
|
}
|
|
18
19
|
export {
|
|
19
|
-
|
|
20
|
+
f as default
|
|
20
21
|
};
|
package/dist/layers/RoPECache.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { o as c,
|
|
2
|
-
import { c as d, s as C } from "../sin-
|
|
3
|
-
import { r as h } from "../range-
|
|
1
|
+
import { o as c, j as f, E as l, V as m, f as n, W as u, t as p, H as a } from "../index-CnHyhpKc.js";
|
|
2
|
+
import { c as d, s as C } from "../sin-HzioENy_.js";
|
|
3
|
+
import { r as h } from "../range-CkOJ7090.js";
|
|
4
4
|
/**
|
|
5
5
|
* @license
|
|
6
6
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -1,8 +1,10 @@
|
|
|
1
|
-
import
|
|
2
|
-
import "../index-
|
|
3
|
-
import "../
|
|
4
|
-
import "../
|
|
5
|
-
import "
|
|
1
|
+
import "../random_width-DI2h9CMs.js";
|
|
2
|
+
import "../index-CnHyhpKc.js";
|
|
3
|
+
import { T as f } from "../TiedEmbedding-DORsPlNL.js";
|
|
4
|
+
import "../tfjs_backend-DX9yVvwk.js";
|
|
5
|
+
import "./BaseLayer.js";
|
|
6
|
+
import "../variable-BGvK-VN3.js";
|
|
7
|
+
import "../gather-BWyutxwi.js";
|
|
6
8
|
export {
|
|
7
|
-
|
|
9
|
+
f as default
|
|
8
10
|
};
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import l from "./CausalSelfAttention.js";
|
|
2
2
|
import r from "./MLP.js";
|
|
3
3
|
import o from "./RMSNorm.js";
|
|
4
|
-
import
|
|
5
|
-
import { t as p } from "../index-
|
|
4
|
+
import d from "./BaseLayer.js";
|
|
5
|
+
import { t as p } from "../index-CnHyhpKc.js";
|
|
6
6
|
class k extends d {
|
|
7
7
|
ln1;
|
|
8
8
|
attn;
|