@genai-fi/nanogpt 0.4.4 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseLayer-BhrMN8JO.js +135 -0
- package/dist/Generator.js +44 -41
- package/dist/NanoGPTModel.d.ts +12 -16
- package/dist/NanoGPTModel.js +128 -138
- package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
- package/dist/TeachableLLM.js +8 -5
- package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
- package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
- package/dist/broadcast_to-CMlkG8NS.js +44 -0
- package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
- package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
- package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
- package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
- package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
- package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
- package/dist/layers/BaseLayer.d.ts +28 -4
- package/dist/layers/BaseLayer.js +3 -16
- package/dist/layers/CausalSelfAttention.d.ts +22 -24
- package/dist/layers/CausalSelfAttention.js +73 -127
- package/dist/layers/MLP.d.ts +8 -15
- package/dist/layers/MLP.js +43 -81
- package/dist/layers/RMSNorm.d.ts +5 -11
- package/dist/layers/RMSNorm.js +13 -29
- package/dist/layers/RoPECache.js +14 -12
- package/dist/layers/TiedEmbedding.d.ts +6 -16
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.d.ts +12 -16
- package/dist/layers/TransformerBlock.js +20 -41
- package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
- package/dist/main.js +22 -19
- package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
- package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
- package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
- package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
- package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
- package/dist/ops/appendCache.js +4 -4
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +14 -15
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.d.ts +1 -0
- package/dist/ops/cpu/matMulMul.js +17 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.d.ts +1 -0
- package/dist/ops/cpu/normRMS.js +39 -0
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +8 -8
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +13 -9
- package/dist/ops/grads/fusedSoftmax.js +12 -9
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.d.ts +2 -0
- package/dist/ops/grads/normRMS.js +20 -0
- package/dist/ops/grads/qkv.js +19 -9
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.d.ts +2 -0
- package/dist/ops/matMulMul.js +9 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.d.ts +2 -0
- package/dist/ops/normRMS.js +10 -0
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +13 -12
- package/dist/ops/webgl/fusedSoftmax.js +43 -40
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.d.ts +3 -2
- package/dist/ops/webgl/matMulGelu.js +77 -75
- package/dist/ops/webgl/matMulMul.d.ts +14 -0
- package/dist/ops/webgl/matMulMul.js +28 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.d.ts +1 -0
- package/dist/ops/webgl/normRMS.js +86 -0
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops-ObfXLHYQ.js +1269 -0
- package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
- package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
- package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
- package/dist/slice_util-D-kaD4ZV.js +49 -0
- package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
- package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
- package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
- package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
- package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
- package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
- package/dist/tfjs_backend-NucKez4s.js +1010 -0
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +44 -44
- package/dist/training/Evaluator.js +6 -6
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +7 -7
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +10 -10
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/save.js +10 -8
- package/dist/utilities/weights.js +2 -2
- package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
- package/package.json +1 -1
- package/dist/slice_util-BdhYwFY_.js +0 -90
- package/dist/tfjs_backend-DuKis_xG.js +0 -2271
- package/dist/variable-BJTZ3jOy.js +0 -23
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import { o as h, i as f, l as p, x as g, E as u, T } from "./index-iNhkcAEQ.js";
|
|
2
|
+
import { r as b } from "./reshape-DxTPgnwL.js";
|
|
3
|
+
/**
|
|
4
|
+
* @license
|
|
5
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
6
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
* you may not use this file except in compliance with the License.
|
|
8
|
+
* You may obtain a copy of the License at
|
|
9
|
+
*
|
|
10
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
*
|
|
12
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
* See the License for the specific language governing permissions and
|
|
16
|
+
* limitations under the License.
|
|
17
|
+
* =============================================================================
|
|
18
|
+
*/
|
|
19
|
+
function m(e, r) {
|
|
20
|
+
let n = f(e, "broadcastTo", "x");
|
|
21
|
+
const a = n.shape;
|
|
22
|
+
if (p(r), r.length < n.rank)
|
|
23
|
+
throw new Error(`broadcastTo(): shape.length=${r.length} < input.rank=${n.rank}.`);
|
|
24
|
+
if (r.length > n.rank) {
|
|
25
|
+
const t = n.shape.slice();
|
|
26
|
+
for (; t.length < r.length; )
|
|
27
|
+
t.unshift(1);
|
|
28
|
+
n = b(n, t);
|
|
29
|
+
}
|
|
30
|
+
const s = n.shape, o = Array.from(r);
|
|
31
|
+
for (let t = r.length - 1; t >= 0; t--)
|
|
32
|
+
if (s[t] === r[t])
|
|
33
|
+
o[t] = 1;
|
|
34
|
+
else if (n.shape[t] !== 1)
|
|
35
|
+
throw new Error(`broadcastTo(): [${a}] cannot be broadcast to [${r}].`);
|
|
36
|
+
if (o.map((t, l) => t > 1 ? l : -1).filter((t) => t >= 0).length === 0)
|
|
37
|
+
return g(n);
|
|
38
|
+
const i = { x: n }, c = { reps: o };
|
|
39
|
+
return u.runKernel(T, i, c);
|
|
40
|
+
}
|
|
41
|
+
const E = /* @__PURE__ */ h({ broadcastTo_: m });
|
|
42
|
+
export {
|
|
43
|
+
E as b
|
|
44
|
+
};
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { o as s,
|
|
1
|
+
import { o as s, k as a, j as p, x as i, E as l, C as f } from "./index-iNhkcAEQ.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -17,13 +17,13 @@ import { o as s, j as a, i, w as p, E as l, C as f } from "./index--6vO-cOz.js";
|
|
|
17
17
|
*/
|
|
18
18
|
function h(o, e = 0) {
|
|
19
19
|
a(o.length >= 1, () => "Pass at least one tensor to concat");
|
|
20
|
-
const t =
|
|
20
|
+
const t = p(o, "tensors", "concat", "string_or_numeric");
|
|
21
21
|
if (t[0].dtype === "complex64" && t.forEach((n) => {
|
|
22
22
|
if (n.dtype !== "complex64")
|
|
23
23
|
throw new Error(`Cannot concatenate complex64 tensors with a tensor
|
|
24
24
|
with dtype ${n.dtype}. `);
|
|
25
25
|
}), t.length === 1)
|
|
26
|
-
return
|
|
26
|
+
return i(t[0]);
|
|
27
27
|
const r = t, c = { axis: e };
|
|
28
28
|
return l.runKernel(f, r, c);
|
|
29
29
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { o as l, h, E as m,
|
|
1
|
+
import { o as l, i as h, E as m, ag as p, l as c, ah as d, ae as g, k as u, V, ai as v, a9 as N, b as w } from "./index-iNhkcAEQ.js";
|
|
2
2
|
import { s as f } from "./index-C4L8Cm77.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { o as
|
|
1
|
+
import { o as g, i as t, E as h, G as p } from "./index-iNhkcAEQ.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -16,10 +16,10 @@ import { o as h, h as t, E as g, G as p } from "./index--6vO-cOz.js";
|
|
|
16
16
|
* =============================================================================
|
|
17
17
|
*/
|
|
18
18
|
function u(n, s, r = 0, e = 0) {
|
|
19
|
-
const o = t(n, "x", "gather"), a = t(s, "indices", "gather", "int32"),
|
|
20
|
-
return
|
|
19
|
+
const o = t(n, "x", "gather"), a = t(s, "indices", "gather", "int32"), i = { x: o, indices: a }, c = { axis: r, batchDims: e };
|
|
20
|
+
return h.runKernel(p, i, c);
|
|
21
21
|
}
|
|
22
|
-
const d = /* @__PURE__ */
|
|
22
|
+
const d = /* @__PURE__ */ g({ gather_: u });
|
|
23
23
|
export {
|
|
24
24
|
d as g
|
|
25
25
|
};
|
|
@@ -4001,81 +4001,81 @@ function As() {
|
|
|
4001
4001
|
*/
|
|
4002
4002
|
As();
|
|
4003
4003
|
export {
|
|
4004
|
-
|
|
4004
|
+
Oa as $,
|
|
4005
4005
|
Ss as A,
|
|
4006
4006
|
Zs as B,
|
|
4007
4007
|
or as C,
|
|
4008
|
-
|
|
4008
|
+
Wa as D,
|
|
4009
4009
|
g as E,
|
|
4010
4010
|
Bn as F,
|
|
4011
4011
|
Pr as G,
|
|
4012
|
-
|
|
4013
|
-
|
|
4014
|
-
|
|
4015
|
-
|
|
4016
|
-
|
|
4012
|
+
Fs as H,
|
|
4013
|
+
kn as I,
|
|
4014
|
+
En as J,
|
|
4015
|
+
k as K,
|
|
4016
|
+
Lr as L,
|
|
4017
4017
|
ta as M,
|
|
4018
|
-
|
|
4019
|
-
|
|
4018
|
+
rs as N,
|
|
4019
|
+
de as O,
|
|
4020
4020
|
ba as P,
|
|
4021
|
-
|
|
4021
|
+
Ea as Q,
|
|
4022
4022
|
Ia as R,
|
|
4023
4023
|
qa as S,
|
|
4024
|
-
|
|
4024
|
+
Qa as T,
|
|
4025
4025
|
Zt as U,
|
|
4026
|
-
|
|
4027
|
-
|
|
4026
|
+
D as V,
|
|
4027
|
+
To as W,
|
|
4028
4028
|
De as X,
|
|
4029
4029
|
ar as Y,
|
|
4030
4030
|
ne as Z,
|
|
4031
|
-
|
|
4031
|
+
dr as _,
|
|
4032
4032
|
M as a,
|
|
4033
4033
|
xs as a$,
|
|
4034
|
-
|
|
4035
|
-
|
|
4036
|
-
|
|
4037
|
-
|
|
4038
|
-
|
|
4039
|
-
|
|
4040
|
-
|
|
4041
|
-
|
|
4042
|
-
|
|
4043
|
-
|
|
4044
|
-
|
|
4045
|
-
|
|
4046
|
-
|
|
4047
|
-
|
|
4048
|
-
|
|
4049
|
-
|
|
4050
|
-
|
|
4051
|
-
|
|
4052
|
-
|
|
4053
|
-
|
|
4054
|
-
|
|
4055
|
-
|
|
4056
|
-
|
|
4057
|
-
|
|
4058
|
-
|
|
4059
|
-
|
|
4060
|
-
|
|
4061
|
-
|
|
4062
|
-
|
|
4063
|
-
|
|
4064
|
-
|
|
4065
|
-
|
|
4066
|
-
|
|
4067
|
-
|
|
4068
|
-
|
|
4069
|
-
|
|
4070
|
-
|
|
4071
|
-
|
|
4072
|
-
|
|
4073
|
-
|
|
4074
|
-
|
|
4075
|
-
|
|
4076
|
-
|
|
4077
|
-
|
|
4078
|
-
|
|
4034
|
+
aa as a0,
|
|
4035
|
+
xe as a1,
|
|
4036
|
+
V as a2,
|
|
4037
|
+
oa as a3,
|
|
4038
|
+
ns as a4,
|
|
4039
|
+
nt as a5,
|
|
4040
|
+
Ca as a6,
|
|
4041
|
+
Fr as a7,
|
|
4042
|
+
qr as a8,
|
|
4043
|
+
S as a9,
|
|
4044
|
+
_a as aA,
|
|
4045
|
+
er as aB,
|
|
4046
|
+
Pa as aC,
|
|
4047
|
+
Ar as aD,
|
|
4048
|
+
Rr as aE,
|
|
4049
|
+
_r as aF,
|
|
4050
|
+
Or as aG,
|
|
4051
|
+
Gr as aH,
|
|
4052
|
+
jr as aI,
|
|
4053
|
+
Kr as aJ,
|
|
4054
|
+
ha as aK,
|
|
4055
|
+
Jr as aL,
|
|
4056
|
+
ia as aM,
|
|
4057
|
+
Ta as aN,
|
|
4058
|
+
$a as aO,
|
|
4059
|
+
Ds as aP,
|
|
4060
|
+
no as aQ,
|
|
4061
|
+
eo as aR,
|
|
4062
|
+
yr as aS,
|
|
4063
|
+
$r as aT,
|
|
4064
|
+
ao as aU,
|
|
4065
|
+
da as aV,
|
|
4066
|
+
ma as aW,
|
|
4067
|
+
ga as aX,
|
|
4068
|
+
Na as aY,
|
|
4069
|
+
va as aZ,
|
|
4070
|
+
to as a_,
|
|
4071
|
+
la as aa,
|
|
4072
|
+
ua as ab,
|
|
4073
|
+
Za as ac,
|
|
4074
|
+
$t as ad,
|
|
4075
|
+
Rt as ae,
|
|
4076
|
+
Rs as af,
|
|
4077
|
+
xr as ag,
|
|
4078
|
+
Wn as ah,
|
|
4079
4079
|
x as ai,
|
|
4080
4080
|
F as aj,
|
|
4081
4081
|
pe as ak,
|
|
@@ -4084,16 +4084,16 @@ export {
|
|
|
4084
4084
|
jt as an,
|
|
4085
4085
|
ue as ao,
|
|
4086
4086
|
za as ap,
|
|
4087
|
-
|
|
4088
|
-
|
|
4089
|
-
|
|
4090
|
-
|
|
4091
|
-
|
|
4092
|
-
|
|
4093
|
-
|
|
4094
|
-
|
|
4095
|
-
|
|
4096
|
-
|
|
4087
|
+
rr as aq,
|
|
4088
|
+
Br as ar,
|
|
4089
|
+
Wr as as,
|
|
4090
|
+
Sa as at,
|
|
4091
|
+
Aa as au,
|
|
4092
|
+
Ra as av,
|
|
4093
|
+
ro as aw,
|
|
4094
|
+
Io as ax,
|
|
4095
|
+
oo as ay,
|
|
4096
|
+
yo as az,
|
|
4097
4097
|
b,
|
|
4098
4098
|
Vs as b$,
|
|
4099
4099
|
$s as b0,
|
|
@@ -4212,24 +4212,24 @@ export {
|
|
|
4212
4212
|
go as d,
|
|
4213
4213
|
mo as e,
|
|
4214
4214
|
K as f,
|
|
4215
|
-
|
|
4216
|
-
|
|
4217
|
-
|
|
4218
|
-
|
|
4219
|
-
|
|
4220
|
-
|
|
4215
|
+
ss as g,
|
|
4216
|
+
lo as h,
|
|
4217
|
+
T as i,
|
|
4218
|
+
In as j,
|
|
4219
|
+
y as k,
|
|
4220
|
+
xt as l,
|
|
4221
4221
|
po as m,
|
|
4222
|
-
|
|
4222
|
+
Ge as n,
|
|
4223
4223
|
N as o,
|
|
4224
|
-
|
|
4225
|
-
|
|
4224
|
+
z as p,
|
|
4225
|
+
q,
|
|
4226
4226
|
co as r,
|
|
4227
4227
|
tt as s,
|
|
4228
4228
|
E as t,
|
|
4229
|
-
|
|
4229
|
+
Ba as u,
|
|
4230
4230
|
ls as v,
|
|
4231
|
-
|
|
4232
|
-
|
|
4233
|
-
|
|
4231
|
+
Ka as w,
|
|
4232
|
+
qn as x,
|
|
4233
|
+
Ft as y,
|
|
4234
4234
|
C as z
|
|
4235
4235
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { an as D, ao as N,
|
|
2
|
-
import { u as g } from "./gpgpu_math-
|
|
1
|
+
import { an as D, ao as N, N as w, p as R, O as v, K as P } from "./index-iNhkcAEQ.js";
|
|
2
|
+
import { u as g } from "./gpgpu_math-C0zyxKFi.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
5
5
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -23,7 +23,7 @@ function B(t) {
|
|
|
23
23
|
throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${e}`);
|
|
24
24
|
}
|
|
25
25
|
}
|
|
26
|
-
function
|
|
26
|
+
function K(t) {
|
|
27
27
|
return t.map((e) => N(e));
|
|
28
28
|
}
|
|
29
29
|
/**
|
|
@@ -127,12 +127,12 @@ class C {
|
|
|
127
127
|
* =============================================================================
|
|
128
128
|
*/
|
|
129
129
|
class _ {
|
|
130
|
-
constructor(e, o, u,
|
|
130
|
+
constructor(e, o, u, p = !1) {
|
|
131
131
|
this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = w(o, u);
|
|
132
132
|
const a = this.outputShape.length;
|
|
133
133
|
this.enableShapeUniforms = g(a);
|
|
134
134
|
let n = "";
|
|
135
|
-
if (
|
|
135
|
+
if (p)
|
|
136
136
|
if (a === 0 || R(this.outputShape) === 1)
|
|
137
137
|
n = `
|
|
138
138
|
result.y = 0.;
|
|
@@ -225,7 +225,7 @@ function A(t) {
|
|
|
225
225
|
* =============================================================================
|
|
226
226
|
*/
|
|
227
227
|
function G(t) {
|
|
228
|
-
const { inputs: e, backend: o } = t, { real: u, imag:
|
|
228
|
+
const { inputs: e, backend: o } = t, { real: u, imag: p } = e, a = o.makeTensorInfo(u.shape, "complex64"), n = o.texData.get(a.dataId), l = A({ inputs: { x: u }, backend: o }), s = A({ inputs: { x: p }, backend: o });
|
|
229
229
|
return n.complexTensorInfos = { real: l, imag: s }, a;
|
|
230
230
|
}
|
|
231
231
|
/**
|
|
@@ -260,7 +260,7 @@ class V {
|
|
|
260
260
|
`;
|
|
261
261
|
}
|
|
262
262
|
}
|
|
263
|
-
const
|
|
263
|
+
const H = "if (isnan(x)) return x;";
|
|
264
264
|
/**
|
|
265
265
|
* @license
|
|
266
266
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -310,8 +310,8 @@ class L {
|
|
|
310
310
|
* =============================================================================
|
|
311
311
|
*/
|
|
312
312
|
function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
|
|
313
|
-
return ({ inputs:
|
|
314
|
-
const { x: n } =
|
|
313
|
+
return ({ inputs: p, backend: a }) => {
|
|
314
|
+
const { x: n } = p, l = a, s = u || n.dtype;
|
|
315
315
|
if (l.shouldExecuteOnCPU([n]) && o != null) {
|
|
316
316
|
const c = l.texData.get(n.dataId), x = o(c.values, s);
|
|
317
317
|
return l.makeTensorInfo(n.shape, s, x);
|
|
@@ -321,7 +321,7 @@ function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
|
|
|
321
321
|
return i ? r = new L(n.shape, e) : r = new V(n.shape, t), l.runWebGLProgram(r, [n], s);
|
|
322
322
|
};
|
|
323
323
|
}
|
|
324
|
-
function
|
|
324
|
+
function j({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, supportsComplex: u = !1, cpuKernelImpl: p, dtype: a }) {
|
|
325
325
|
return ({ inputs: n, backend: l }) => {
|
|
326
326
|
const { a: s, b: i } = n, r = l;
|
|
327
327
|
if (u && s.dtype === "complex64") {
|
|
@@ -329,29 +329,29 @@ function Q({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, support
|
|
|
329
329
|
[h.complexTensorInfos.real, f.complexTensorInfos.real],
|
|
330
330
|
[h.complexTensorInfos.imag, f.complexTensorInfos.imag]
|
|
331
331
|
].map((S) => {
|
|
332
|
-
const [
|
|
333
|
-
dataId:
|
|
334
|
-
dtype:
|
|
332
|
+
const [d, m] = S, $ = {
|
|
333
|
+
dataId: d.dataId,
|
|
334
|
+
dtype: d.dtype,
|
|
335
335
|
shape: s.shape
|
|
336
336
|
}, T = {
|
|
337
337
|
dataId: m.dataId,
|
|
338
338
|
dtype: m.dtype,
|
|
339
339
|
shape: i.shape
|
|
340
340
|
}, U = new C(t, s.shape, i.shape);
|
|
341
|
-
return r.runWebGLProgram(U, [$, T], v(
|
|
341
|
+
return r.runWebGLProgram(U, [$, T], v(d.dtype, m.dtype));
|
|
342
342
|
}), I = G({ inputs: { real: O, imag: y }, backend: r });
|
|
343
343
|
return r.disposeIntermediateTensorInfo(O), r.disposeIntermediateTensorInfo(y), I;
|
|
344
344
|
}
|
|
345
345
|
const c = a || v(s.dtype, i.dtype);
|
|
346
|
-
if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) &&
|
|
346
|
+
if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && p != null) {
|
|
347
347
|
const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values, O = s.dtype === "string" ? (
|
|
348
348
|
// tslint:disable-next-line: no-any
|
|
349
349
|
B(h)
|
|
350
350
|
) : h, y = s.dtype === "string" ? (
|
|
351
351
|
// tslint:disable-next-line: no-any
|
|
352
352
|
B(f)
|
|
353
|
-
) : f, [I, S] =
|
|
354
|
-
return m.values = I,
|
|
353
|
+
) : f, [I, S] = p(s.shape, i.shape, O, y, c), d = r.makeTensorInfo(S, c), m = r.texData.get(d.dataId);
|
|
354
|
+
return m.values = I, d;
|
|
355
355
|
}
|
|
356
356
|
const x = P().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
|
|
357
357
|
let b;
|
|
@@ -359,10 +359,10 @@ function Q({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, support
|
|
|
359
359
|
};
|
|
360
360
|
}
|
|
361
361
|
export {
|
|
362
|
-
|
|
363
|
-
|
|
362
|
+
H as C,
|
|
363
|
+
K as a,
|
|
364
364
|
E as b,
|
|
365
|
-
|
|
365
|
+
j as c,
|
|
366
366
|
B as f,
|
|
367
367
|
k as g,
|
|
368
368
|
Y as u
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import { GPTConfig } from '../config';
|
|
2
2
|
import { default as MemoryProfiler } from '../utilities/profile';
|
|
3
3
|
import { default as RoPECache } from './RoPECache';
|
|
4
|
+
import { Tensor, Variable } from '@tensorflow/tfjs-core';
|
|
4
5
|
export interface LayerConfig {
|
|
5
|
-
|
|
6
|
-
checkpointMLP?: boolean;
|
|
6
|
+
checkpointing?: boolean;
|
|
7
7
|
profiler?: MemoryProfiler;
|
|
8
8
|
ropeCache?: RoPECache;
|
|
9
9
|
}
|
|
@@ -11,10 +11,34 @@ export interface GPTLayerConfig {
|
|
|
11
11
|
gpt: GPTConfig;
|
|
12
12
|
layerConfig: LayerConfig;
|
|
13
13
|
}
|
|
14
|
-
export
|
|
14
|
+
export interface ForwardAttributes {
|
|
15
|
+
training: boolean;
|
|
16
|
+
}
|
|
17
|
+
export default abstract class BaseLayer<ATTR extends ForwardAttributes = ForwardAttributes> {
|
|
18
|
+
readonly parent?: BaseLayer;
|
|
15
19
|
readonly config: GPTLayerConfig;
|
|
16
|
-
|
|
20
|
+
private _variables;
|
|
21
|
+
private _trainable;
|
|
22
|
+
readonly children: BaseLayer[];
|
|
23
|
+
constructor(config: GPTLayerConfig, parent?: BaseLayer);
|
|
17
24
|
getProfiler(): MemoryProfiler | undefined;
|
|
18
25
|
startMemory(): void;
|
|
19
26
|
endMemory(label: string): void;
|
|
27
|
+
addVariable(name: string, variable?: Variable): void;
|
|
28
|
+
get variables(): Variable[];
|
|
29
|
+
get trainableVariables(): Variable[];
|
|
30
|
+
get trainable(): boolean;
|
|
31
|
+
set trainable(value: boolean);
|
|
32
|
+
getVariable(name: string): Variable;
|
|
33
|
+
hasVariable(name: string): boolean;
|
|
34
|
+
setVariable(name: string, variable: Variable): void;
|
|
35
|
+
saveWeights(map: Map<string, Tensor[]>): void;
|
|
36
|
+
loadWeights(weights: Map<string, Tensor[]>): void;
|
|
37
|
+
dispose(): void;
|
|
38
|
+
protected build(): void;
|
|
39
|
+
protected dropout(x: Tensor): Tensor;
|
|
40
|
+
abstract forward(attrs: ATTR, ...x: Tensor[]): Tensor | Tensor[];
|
|
41
|
+
call(attrs: ATTR, ...x: Tensor[]): Tensor | Tensor[];
|
|
42
|
+
callCheckpoint(attrs: ATTR, ...x: Tensor[]): Tensor;
|
|
43
|
+
private checkpointingFn;
|
|
20
44
|
}
|
package/dist/layers/BaseLayer.js
CHANGED
|
@@ -1,18 +1,5 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
constructor(r) {
|
|
4
|
-
this.config = r;
|
|
5
|
-
}
|
|
6
|
-
getProfiler() {
|
|
7
|
-
return this.config.layerConfig.profiler;
|
|
8
|
-
}
|
|
9
|
-
startMemory() {
|
|
10
|
-
this.config.layerConfig.profiler?.startMemory();
|
|
11
|
-
}
|
|
12
|
-
endMemory(r) {
|
|
13
|
-
this.config.layerConfig.profiler?.endMemory(r);
|
|
14
|
-
}
|
|
15
|
-
}
|
|
1
|
+
import "../index-iNhkcAEQ.js";
|
|
2
|
+
import { B as a } from "../BaseLayer-BhrMN8JO.js";
|
|
16
3
|
export {
|
|
17
|
-
|
|
4
|
+
a as default
|
|
18
5
|
};
|
|
@@ -1,38 +1,36 @@
|
|
|
1
|
-
import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
|
|
2
|
-
import { Tensor
|
|
1
|
+
import { default as BaseLayer, ForwardAttributes, GPTLayerConfig } from './BaseLayer';
|
|
2
|
+
import { Tensor } from '@tensorflow/tfjs-core';
|
|
3
3
|
export type KVCache = {
|
|
4
|
-
k
|
|
5
|
-
v
|
|
4
|
+
k?: Tensor;
|
|
5
|
+
v?: Tensor;
|
|
6
6
|
length: number;
|
|
7
7
|
cumulativeLength: number;
|
|
8
8
|
};
|
|
9
|
-
export
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
9
|
+
export interface AttentionScores {
|
|
10
|
+
head: number;
|
|
11
|
+
block: number;
|
|
12
|
+
attentionOut?: Tensor;
|
|
13
|
+
}
|
|
14
|
+
interface AttentionForwardAttributes extends ForwardAttributes {
|
|
15
|
+
attentionScores?: AttentionScores;
|
|
16
|
+
pastKV?: KVCache;
|
|
17
|
+
seed?: number;
|
|
18
|
+
}
|
|
19
|
+
export default class CausalSelfAttention extends BaseLayer<AttentionForwardAttributes> {
|
|
14
20
|
private divisor;
|
|
15
21
|
private index;
|
|
16
|
-
private _trainable;
|
|
17
22
|
private units;
|
|
18
23
|
private projUnits;
|
|
19
|
-
|
|
20
|
-
private
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
set trainable(value: boolean);
|
|
24
|
-
saveWeights(map: Map<string, Tensor[]>): void;
|
|
25
|
-
loadWeights(weights: Map<string, Tensor[]>): void;
|
|
24
|
+
private ATTN;
|
|
25
|
+
private PROJ;
|
|
26
|
+
constructor(index: number, config: GPTLayerConfig, parent?: BaseLayer);
|
|
27
|
+
protected build(): void;
|
|
26
28
|
private getAttentionScores;
|
|
27
29
|
private getAttentionScoresWithPast;
|
|
28
30
|
private getQKV;
|
|
29
31
|
private getOutputProjection;
|
|
30
32
|
private updateCache;
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
output: Tensor;
|
|
34
|
-
attention?: Tensor;
|
|
35
|
-
presentKV?: KVCache;
|
|
36
|
-
};
|
|
37
|
-
dispose(): void;
|
|
33
|
+
forward(attr: AttentionForwardAttributes, x: Tensor): Tensor;
|
|
34
|
+
protected dropout(x: Tensor): Tensor;
|
|
38
35
|
}
|
|
36
|
+
export {};
|