@genai-fi/nanogpt 0.2.10 → 0.2.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/TeachableLLM.js +8 -6
- package/dist/{complex-x7w5HPOS.js → complex-CJ-qCcLB.js} +1 -1
- package/dist/{index-CWQLouWz.js → index-YPKosni4.js} +52 -48
- package/dist/layers/CausalSelfAttention.d.ts +2 -0
- package/dist/layers/CausalSelfAttention.js +46 -56
- package/dist/layers/RoPECache.d.ts +4 -3
- package/dist/layers/RoPECache.js +17 -22
- package/dist/layers/TiedEmbedding.js +33 -32
- package/dist/main.js +18 -16
- package/dist/{mat_mul-4v7St11W.js → mat_mul-Bu7bhLms.js} +1 -1
- package/dist/ops/attentionMask.js +2 -2
- package/dist/ops/gatherSub.js +2 -2
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/qkv.d.ts +7 -0
- package/dist/ops/qkv.js +127 -0
- package/dist/ops/rope.d.ts +8 -0
- package/dist/ops/rope.js +154 -0
- package/dist/ops/scatterSub.js +10 -10
- package/dist/reshape-DmnmKT6r.js +25 -0
- package/dist/{stack-CTdK-itU.js → stack-BtKpB0Ry.js} +7 -7
- package/dist/sum-D7fu15XL.js +27 -0
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/sparseCrossEntropy.js +31 -30
- package/dist/utilities/profile.js +1 -1
- package/package.json +1 -1
- package/dist/sum-CnIf1YOh.js +0 -49
package/dist/TeachableLLM.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { defaultConfig as h } from "./config.js";
|
|
2
|
-
import
|
|
3
|
-
import { saveModel as
|
|
2
|
+
import m from "./NanoGPTModel.js";
|
|
3
|
+
import { saveModel as d } from "./utilities/save.js";
|
|
4
4
|
import { loadModel as f } from "./utilities/load.js";
|
|
5
5
|
import u from "./Generator.js";
|
|
6
6
|
import _ from "./Trainer.js";
|
|
@@ -13,7 +13,9 @@ import "./jszip.min-CjP2V1VV.js";
|
|
|
13
13
|
import "./ops/scatterSub.js";
|
|
14
14
|
import "./ops/gatherSub.js";
|
|
15
15
|
import "./ops/attentionMask.js";
|
|
16
|
-
import
|
|
16
|
+
import "./ops/qkv.js";
|
|
17
|
+
import "./ops/rope.js";
|
|
18
|
+
import p from "./utilities/profile.js";
|
|
17
19
|
class a extends c {
|
|
18
20
|
_config;
|
|
19
21
|
_model;
|
|
@@ -50,7 +52,7 @@ class a extends c {
|
|
|
50
52
|
saveModel(t) {
|
|
51
53
|
if (!this._model || !this._tokeniser)
|
|
52
54
|
throw new Error("Model or tokeniser is not initialized.");
|
|
53
|
-
return
|
|
55
|
+
return d(this._model, this._tokeniser, t);
|
|
54
56
|
}
|
|
55
57
|
static loadModel(t, r) {
|
|
56
58
|
const e = new a(t);
|
|
@@ -65,7 +67,7 @@ class a extends c {
|
|
|
65
67
|
}), e;
|
|
66
68
|
}
|
|
67
69
|
static create(t, r = {}) {
|
|
68
|
-
const e = { ...h, ...r }, o = new g(e.vocabSize), s = new
|
|
70
|
+
const e = { ...h, ...r }, o = new g(e.vocabSize), s = new m(t, e), i = new a(t, o, s);
|
|
69
71
|
return i.setStatus("warmup"), l(s).then(() => {
|
|
70
72
|
i.tokeniser.trained ? i.setStatus("ready") : (i.setStatus("awaitingTokens"), i.tokeniser.once("trainStatus", (n) => {
|
|
71
73
|
n === "trained" && i.setStatus("ready");
|
|
@@ -84,7 +86,7 @@ class a extends c {
|
|
|
84
86
|
if (t) {
|
|
85
87
|
if (!this._model)
|
|
86
88
|
throw new Error("Model is not initialized.");
|
|
87
|
-
this._model.getProfiler() || this._model.setProfiler(new
|
|
89
|
+
this._model.getProfiler() || this._model.setProfiler(new p());
|
|
88
90
|
} else
|
|
89
91
|
this._model && this._model.setProfiler(void 0);
|
|
90
92
|
}
|
|
@@ -383,7 +383,7 @@ function _t(n, t) {
|
|
|
383
383
|
return e.set(n, s), e.get(n);
|
|
384
384
|
}
|
|
385
385
|
}
|
|
386
|
-
const Ge = "Abs", ne = "Add", Es = "BatchMatMul", se = "Cast", As = "Complex", ze = "ComplexAbs", We = "RealDiv",
|
|
386
|
+
const Ge = "Abs", ne = "Add", Es = "BatchMatMul", se = "Cast", As = "Complex", ze = "ComplexAbs", Bs = "Concat", We = "RealDiv", vs = "Elu", Ms = "Exp", je = "Fill", Ke = "FloorDiv", Fs = "GatherV2", $s = "GatherNd", re = "Identity", Rs = "Imag", xs = "LeakyRelu", Ns = "Log", Ds = "Max", Ve = "Maximum", qe = "Multiply", Cs = "Neg", _s = "Pack", He = "Pow", Ps = "Prelu", Os = "Range", Ls = "Real", Us = "Relu", Gs = "Reshape", zs = "Relu6", Ws = "ScatterNd", js = "Sigmoid", Je = "Sqrt", Ks = "Sum", Vs = "SplitV", qs = "Softmax", Xe = "Sub", Hs = "Transpose", Ye = "ZerosLike", Js = "Step", Xs = "_FusedMatMul";
|
|
387
387
|
/**
|
|
388
388
|
* @license
|
|
389
389
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -438,11 +438,11 @@ function Wt(n) {
|
|
|
438
438
|
}
|
|
439
439
|
return e;
|
|
440
440
|
}
|
|
441
|
-
function
|
|
441
|
+
function Ys(n) {
|
|
442
442
|
const { kernelName: t, backendName: e } = n, s = ie(t, e);
|
|
443
443
|
ht.has(s) && O(`The kernel '${t}' for backend '${e}' is already registered`), ht.set(s, n);
|
|
444
444
|
}
|
|
445
|
-
function
|
|
445
|
+
function Qs(n) {
|
|
446
446
|
const { kernelName: t } = n;
|
|
447
447
|
It.has(t) && S().getBool("DEBUG") && O(`Overriding the gradient for '${t}'`), It.set(t, n);
|
|
448
448
|
}
|
|
@@ -1902,7 +1902,7 @@ function I(n, t, e, s = "numeric") {
|
|
|
1902
1902
|
const a = r !== "string" ? ae(n, r) : at(n, [], !0);
|
|
1903
1903
|
return g.makeTensor(a, i, r);
|
|
1904
1904
|
}
|
|
1905
|
-
function
|
|
1905
|
+
function Zs(n, t, e, s = "numeric") {
|
|
1906
1906
|
if (!Array.isArray(n))
|
|
1907
1907
|
throw new Error(`Argument ${t} passed to ${e} must be a \`Tensor[]\` or \`TensorLike[]\``);
|
|
1908
1908
|
return n.map((i, o) => I(i, `${t}[${o}]`, e, s));
|
|
@@ -2065,10 +2065,10 @@ function Sn(n, t) {
|
|
|
2065
2065
|
* limitations under the License.
|
|
2066
2066
|
* =============================================================================
|
|
2067
2067
|
*/
|
|
2068
|
-
function
|
|
2068
|
+
function tr() {
|
|
2069
2069
|
return g;
|
|
2070
2070
|
}
|
|
2071
|
-
function
|
|
2071
|
+
function er() {
|
|
2072
2072
|
return g.memory();
|
|
2073
2073
|
}
|
|
2074
2074
|
function E(n, t) {
|
|
@@ -2893,7 +2893,7 @@ function Yn(n, t, e) {
|
|
|
2893
2893
|
* limitations under the License.
|
|
2894
2894
|
* =============================================================================
|
|
2895
2895
|
*/
|
|
2896
|
-
function
|
|
2896
|
+
function nr(n, t) {
|
|
2897
2897
|
const e = [];
|
|
2898
2898
|
for (let s = 0; s < t.length; s++) {
|
|
2899
2899
|
const r = n[n.length - s - 1], i = t.length - s - 1, o = t[i];
|
|
@@ -3061,7 +3061,7 @@ function ss(n, t) {
|
|
|
3061
3061
|
a[u] != null && (c[l.name] = a[u]);
|
|
3062
3062
|
}), s?.forEach((l) => c[l.name] = null), { value: o, grads: c };
|
|
3063
3063
|
}
|
|
3064
|
-
function
|
|
3064
|
+
function sr(n) {
|
|
3065
3065
|
return g.customGrad(n);
|
|
3066
3066
|
}
|
|
3067
3067
|
/**
|
|
@@ -3841,55 +3841,59 @@ function bs() {
|
|
|
3841
3841
|
*/
|
|
3842
3842
|
bs();
|
|
3843
3843
|
export {
|
|
3844
|
+
Qn as $,
|
|
3844
3845
|
ds as A,
|
|
3845
3846
|
Es as B,
|
|
3846
3847
|
As as C,
|
|
3847
|
-
|
|
3848
|
+
w as D,
|
|
3848
3849
|
g as E,
|
|
3849
|
-
|
|
3850
|
-
|
|
3851
|
-
|
|
3852
|
-
|
|
3853
|
-
|
|
3854
|
-
|
|
3855
|
-
|
|
3856
|
-
|
|
3857
|
-
|
|
3858
|
-
|
|
3859
|
-
|
|
3860
|
-
|
|
3861
|
-
|
|
3862
|
-
|
|
3863
|
-
|
|
3864
|
-
|
|
3865
|
-
|
|
3866
|
-
|
|
3867
|
-
|
|
3868
|
-
|
|
3850
|
+
qs as F,
|
|
3851
|
+
$s as G,
|
|
3852
|
+
sr as H,
|
|
3853
|
+
E as I,
|
|
3854
|
+
C as J,
|
|
3855
|
+
js as K,
|
|
3856
|
+
Ns as L,
|
|
3857
|
+
Ds as M,
|
|
3858
|
+
vs as N,
|
|
3859
|
+
Rs as O,
|
|
3860
|
+
_s as P,
|
|
3861
|
+
xs as Q,
|
|
3862
|
+
Gs as R,
|
|
3863
|
+
Ks as S,
|
|
3864
|
+
Cs as T,
|
|
3865
|
+
Ps as U,
|
|
3866
|
+
Ls as V,
|
|
3867
|
+
Us as W,
|
|
3868
|
+
zs as X,
|
|
3869
|
+
Js as Y,
|
|
3870
|
+
Hs as Z,
|
|
3871
|
+
nr as _,
|
|
3869
3872
|
p as a,
|
|
3873
|
+
Xs as a0,
|
|
3870
3874
|
Z as b,
|
|
3871
|
-
|
|
3875
|
+
Qs as c,
|
|
3872
3876
|
I as d,
|
|
3873
|
-
|
|
3877
|
+
tr as e,
|
|
3874
3878
|
V as f,
|
|
3875
3879
|
Is as g,
|
|
3876
|
-
|
|
3877
|
-
|
|
3878
|
-
|
|
3879
|
-
|
|
3880
|
-
|
|
3881
|
-
|
|
3882
|
-
|
|
3880
|
+
$t as h,
|
|
3881
|
+
Vs as i,
|
|
3882
|
+
Os as j,
|
|
3883
|
+
Zs as k,
|
|
3884
|
+
y as l,
|
|
3885
|
+
er as m,
|
|
3886
|
+
Gn as n,
|
|
3883
3887
|
F as o,
|
|
3884
|
-
|
|
3885
|
-
|
|
3886
|
-
|
|
3888
|
+
Bs as p,
|
|
3889
|
+
Fs as q,
|
|
3890
|
+
Ys as r,
|
|
3887
3891
|
K as s,
|
|
3888
|
-
|
|
3889
|
-
|
|
3890
|
-
|
|
3891
|
-
w,
|
|
3892
|
-
|
|
3893
|
-
|
|
3894
|
-
|
|
3892
|
+
Dt as t,
|
|
3893
|
+
Zt as u,
|
|
3894
|
+
G as v,
|
|
3895
|
+
De as w,
|
|
3896
|
+
Ws as x,
|
|
3897
|
+
Ms as y,
|
|
3898
|
+
Ts as z
|
|
3895
3899
|
};
|
|
@@ -21,7 +21,9 @@ export default class CausalSelfAttention extends BaseLayer {
|
|
|
21
21
|
private divisor;
|
|
22
22
|
private index;
|
|
23
23
|
private _trainable;
|
|
24
|
+
private units;
|
|
24
25
|
constructor(tf: typeof TF, index: number, config: GPTConfig, ropeCache?: RoPECache | undefined);
|
|
26
|
+
private build;
|
|
25
27
|
get variables(): TF.Variable[];
|
|
26
28
|
get trainable(): boolean;
|
|
27
29
|
set trainable(value: boolean);
|
|
@@ -1,17 +1,10 @@
|
|
|
1
|
-
import { attentionMask as
|
|
2
|
-
import
|
|
3
|
-
|
|
1
|
+
import { attentionMask as x } from "../ops/attentionMask.js";
|
|
2
|
+
import j from "./BaseLayer.js";
|
|
3
|
+
import { qkv as w } from "../ops/qkv.js";
|
|
4
|
+
import { rope as y } from "../ops/rope.js";
|
|
5
|
+
class N extends j {
|
|
4
6
|
constructor(t, i, s, e) {
|
|
5
|
-
super(), this.ropeCache = e, this.config = s, this.tf = t, this.index = i, this.
|
|
6
|
-
units: 3 * s.nEmbed,
|
|
7
|
-
useBias: s.biasInLinear,
|
|
8
|
-
name: `block_${i}_attn_cAttn`,
|
|
9
|
-
kernelInitializer: this.tf.initializers.randomNormal({
|
|
10
|
-
mean: 0,
|
|
11
|
-
stddev: 0.02
|
|
12
|
-
}),
|
|
13
|
-
biasInitializer: "zeros"
|
|
14
|
-
}), this.cProj = this.tf.layers.dense({
|
|
7
|
+
super(), this.ropeCache = e, this.config = s, this.tf = t, this.index = i, this.units = s.nEmbed * 3, this.cProj = this.tf.layers.dense({
|
|
15
8
|
units: s.nEmbed,
|
|
16
9
|
useBias: s.biasInLinear,
|
|
17
10
|
name: `block_${i}_attn_cProj`,
|
|
@@ -21,11 +14,11 @@ class C extends S {
|
|
|
21
14
|
}),
|
|
22
15
|
biasInitializer: "zeros"
|
|
23
16
|
}), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.nEmbed / s.nHead);
|
|
24
|
-
const o = this.tf.zeros([s.blockSize, s.blockSize]),
|
|
25
|
-
this.maskInf = this.tf.where(this.bias, o,
|
|
17
|
+
const o = this.tf.zeros([s.blockSize, s.blockSize]), a = this.tf.fill([s.blockSize, s.blockSize], Number.NEGATIVE_INFINITY);
|
|
18
|
+
this.maskInf = this.tf.where(this.bias, o, a);
|
|
26
19
|
}
|
|
27
20
|
config;
|
|
28
|
-
cAttn;
|
|
21
|
+
cAttn = null;
|
|
29
22
|
cProj;
|
|
30
23
|
attnDropout;
|
|
31
24
|
residDropout;
|
|
@@ -35,26 +28,35 @@ class C extends S {
|
|
|
35
28
|
divisor;
|
|
36
29
|
index;
|
|
37
30
|
_trainable = !0;
|
|
31
|
+
units;
|
|
32
|
+
build() {
|
|
33
|
+
this.cAttn === null && (this.cAttn = this.tf.variable(
|
|
34
|
+
this.tf.randomNormal([this.config.nEmbed, this.units], 0, 0.02),
|
|
35
|
+
!0
|
|
36
|
+
//`block_${this.index}_attn_cAttn_kernel`
|
|
37
|
+
));
|
|
38
|
+
}
|
|
38
39
|
get variables() {
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
];
|
|
40
|
+
if (this.cAttn === null)
|
|
41
|
+
throw new Error("Layer not built yet");
|
|
42
|
+
return [this.cAttn, ...this.cProj.trainableWeights.map((t) => t.read())];
|
|
43
43
|
}
|
|
44
44
|
get trainable() {
|
|
45
45
|
return this._trainable;
|
|
46
46
|
}
|
|
47
47
|
set trainable(t) {
|
|
48
|
-
this._trainable = t, this.cAttn.trainable = t, this.cProj.trainable = t;
|
|
48
|
+
this._trainable = t, this.cAttn && (this.cAttn.trainable = t), this.cProj.trainable = t;
|
|
49
49
|
}
|
|
50
50
|
saveWeights(t) {
|
|
51
|
-
t.set(`block_${this.index}_cAttn`, this.cAttn.
|
|
51
|
+
t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj.getWeights());
|
|
52
52
|
}
|
|
53
53
|
loadWeights(t) {
|
|
54
|
-
|
|
54
|
+
const i = t.get(`block_${this.index}_cAttn`)?.[0];
|
|
55
|
+
if (!i) throw new Error(`Weights for block_${this.index}_cAttn not found`);
|
|
56
|
+
this.cAttn ? this.cAttn.assign(i) : this.cAttn = this.tf.variable(i, !0), this.cProj.setWeights(t.get(`block_${this.index}_cProj`) || []);
|
|
55
57
|
}
|
|
56
58
|
getAttentionScores(t, i, s) {
|
|
57
|
-
const e =
|
|
59
|
+
const e = x(t, i, this.maskInf, this.divisor), o = this.tf.softmax(e, -1);
|
|
58
60
|
return this.attnDropout.apply(o, { training: s });
|
|
59
61
|
}
|
|
60
62
|
// Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
|
|
@@ -64,61 +66,49 @@ class C extends S {
|
|
|
64
66
|
if (o > 1 && e > 0)
|
|
65
67
|
throw new Error("Cannot use past with T_cur > 1");
|
|
66
68
|
if (o > 1) {
|
|
67
|
-
const
|
|
68
|
-
r = r.add(
|
|
69
|
+
const c = this.maskInf.slice([0, 0], [o, o]).expandDims(0).expandDims(0);
|
|
70
|
+
r = r.add(c);
|
|
69
71
|
}
|
|
70
72
|
const h = this.tf.softmax(r, -1);
|
|
71
73
|
return this.attnDropout.apply(h, { training: s });
|
|
72
74
|
}
|
|
73
75
|
getQKV(t) {
|
|
74
|
-
|
|
75
|
-
o.dispose();
|
|
76
|
-
const a = e / this.config.nHead, u = this.tf.reshape(c, [i, s, this.config.nHead, a]);
|
|
77
|
-
c.dispose();
|
|
78
|
-
const f = u.transpose([0, 2, 1, 3]);
|
|
79
|
-
u.dispose();
|
|
80
|
-
const d = this.tf.reshape(r, [i, s, this.config.nHead, a]);
|
|
81
|
-
r.dispose();
|
|
82
|
-
const n = d.transpose([0, 2, 1, 3]);
|
|
83
|
-
d.dispose();
|
|
84
|
-
const l = this.tf.reshape(h, [i, s, this.config.nHead, a]);
|
|
85
|
-
h.dispose();
|
|
86
|
-
const p = l.transpose([0, 2, 1, 3]);
|
|
87
|
-
return l.dispose(), [f, n, p];
|
|
76
|
+
return w(t, this.cAttn, this.config.nHead);
|
|
88
77
|
}
|
|
89
78
|
getOutputProjection(t, i) {
|
|
90
|
-
const s = t.shape[0], e = t.shape[2], o = this.config.nEmbed,
|
|
79
|
+
const s = t.shape[0], e = t.shape[2], o = this.config.nEmbed, a = t.transpose([0, 2, 1, 3]), r = this.tf.reshape(a, [s, e, o]), h = this.cProj.apply(r);
|
|
91
80
|
return this.residDropout.apply(h, { training: i });
|
|
92
81
|
}
|
|
93
82
|
// Added optional KV cache support (pastKV). Returns presentKV for chaining.
|
|
94
83
|
call(t, i = !1, s = !1, e) {
|
|
95
84
|
if (e && !this.config.useRope)
|
|
96
85
|
throw new Error("Cannot use pastKV without RoPE enabled");
|
|
97
|
-
return this.tf.tidy(() => {
|
|
86
|
+
return this.build(), this.tf.tidy(() => {
|
|
98
87
|
this.startMemory();
|
|
99
|
-
const [o,
|
|
100
|
-
|
|
101
|
-
|
|
88
|
+
const [o, a, r] = this.getQKV(t), h = o.shape[2], c = this.config.blockSize, d = e ? e.cumulativeLength : 0, f = this.ropeCache ? y(o, this.ropeCache, d) : o, m = this.ropeCache ? y(a, this.ropeCache, d) : a;
|
|
89
|
+
this.ropeCache && (o.dispose(), a.dispose());
|
|
90
|
+
let n = m, l = r, u = 0;
|
|
91
|
+
e && (u = e.length, n = this.tf.concat([e.k, m], 2), l = this.tf.concat([e.v, r], 2));
|
|
102
92
|
const b = n.shape[2];
|
|
103
|
-
if (b >
|
|
104
|
-
const k = b -
|
|
105
|
-
n = n.slice([0, 0, k, 0], [
|
|
93
|
+
if (b > c) {
|
|
94
|
+
const k = b - c, A = n.shape[0], g = n.shape[1], _ = n.shape[3];
|
|
95
|
+
n = n.slice([0, 0, k, 0], [A, g, c, _]), l = l.slice([0, 0, k, 0], [A, g, c, _]), u = c - h;
|
|
106
96
|
}
|
|
107
|
-
let
|
|
108
|
-
|
|
109
|
-
const
|
|
97
|
+
let p;
|
|
98
|
+
u > 0 ? p = this.getAttentionScoresWithPast(f, n, i, u) : p = this.getAttentionScores(f, n, i);
|
|
99
|
+
const P = this.tf.matMul(p, l), S = this.getOutputProjection(P, i), v = {
|
|
110
100
|
k: this.tf.keep(n),
|
|
111
101
|
v: this.tf.keep(l),
|
|
112
|
-
length:
|
|
102
|
+
length: u + h,
|
|
113
103
|
cumulativeLength: e ? e.cumulativeLength + h : h
|
|
114
|
-
},
|
|
115
|
-
return this.endMemory("CausalSelfAttention"), { output:
|
|
104
|
+
}, I = s ? p.mean(1) : void 0;
|
|
105
|
+
return this.endMemory("CausalSelfAttention"), { output: S, attention: I, presentKV: v };
|
|
116
106
|
});
|
|
117
107
|
}
|
|
118
108
|
dispose() {
|
|
119
|
-
this.cAttn
|
|
109
|
+
this.cAttn?.dispose(), this.cProj.dispose(), this.attnDropout.dispose(), this.residDropout.dispose(), this.bias.dispose(), this.maskInf.dispose();
|
|
120
110
|
}
|
|
121
111
|
}
|
|
122
112
|
export {
|
|
123
|
-
|
|
113
|
+
N as default
|
|
124
114
|
};
|
|
@@ -3,14 +3,15 @@ import { GPTConfig } from '../config';
|
|
|
3
3
|
export default class RoPECache {
|
|
4
4
|
private readonly tf;
|
|
5
5
|
private readonly config;
|
|
6
|
-
|
|
6
|
+
readonly rotaryDim: number;
|
|
7
7
|
private ropeBase;
|
|
8
8
|
private ropeInvFreq;
|
|
9
9
|
private ropeCos;
|
|
10
10
|
private ropeSin;
|
|
11
11
|
private ropeCacheLen;
|
|
12
12
|
constructor(tf: typeof TF, config: GPTConfig);
|
|
13
|
-
|
|
14
|
-
|
|
13
|
+
ensureRopeCache(needed: number): void;
|
|
14
|
+
getCos(): TF.Tensor | null;
|
|
15
|
+
getSin(): TF.Tensor | null;
|
|
15
16
|
dispose(): void;
|
|
16
17
|
}
|
package/dist/layers/RoPECache.js
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
class
|
|
2
|
-
constructor(
|
|
3
|
-
this.tf =
|
|
4
|
-
const
|
|
5
|
-
if (this.rotaryDim =
|
|
1
|
+
class n {
|
|
2
|
+
constructor(i, e) {
|
|
3
|
+
this.tf = i, this.config = e;
|
|
4
|
+
const t = this.config.nEmbed / this.config.nHead;
|
|
5
|
+
if (this.rotaryDim = t, this.rotaryDim % 2 !== 0)
|
|
6
6
|
throw new Error("rotaryDim must be even");
|
|
7
7
|
this.ropeBase = 1e4;
|
|
8
|
-
const
|
|
9
|
-
this.ropeInvFreq = this.tf.reciprocal(
|
|
8
|
+
const s = this.tf.range(0, this.rotaryDim, 2, "float32"), o = s.div(this.tf.scalar(this.rotaryDim, "float32")), r = this.tf.pow(this.tf.scalar(this.ropeBase, "float32"), o);
|
|
9
|
+
this.ropeInvFreq = this.tf.reciprocal(r), o.dispose(), r.dispose(), s.dispose(), this.config.useRope === !1 ? (this.ropeCos = null, this.ropeSin = null, this.ropeCacheLen = 0) : this.tf.tidy(() => {
|
|
10
10
|
this.ensureRopeCache(this.config.blockSize * 4);
|
|
11
11
|
});
|
|
12
12
|
}
|
|
@@ -18,27 +18,22 @@ class b {
|
|
|
18
18
|
ropeSin = null;
|
|
19
19
|
// [cacheLen, rotaryDim/2]
|
|
20
20
|
ropeCacheLen = 0;
|
|
21
|
-
ensureRopeCache(
|
|
22
|
-
if (
|
|
21
|
+
ensureRopeCache(i) {
|
|
22
|
+
if (i <= this.ropeCacheLen) return;
|
|
23
23
|
this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose();
|
|
24
|
-
const
|
|
25
|
-
this.ropeCos = this.tf.keep(this.tf.cos(
|
|
24
|
+
const e = Math.max(i, this.ropeCacheLen + this.config.blockSize * 4), s = this.tf.range(0, e, 1, "float32").expandDims(1).mul(this.ropeInvFreq.expandDims(0));
|
|
25
|
+
this.ropeCos = this.tf.keep(this.tf.cos(s).expandDims(-1)), this.ropeSin = this.tf.keep(this.tf.sin(s).expandDims(-1)), this.ropeCacheLen = e;
|
|
26
26
|
}
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
this.
|
|
32
|
-
const n = t / 2, p = this.ropeCos.slice([o, 0, 0], [e, n, 1]).reshape([1, 1, e, n]), a = this.ropeSin.slice([o, 0, 0], [e, n, 1]).reshape([1, 1, e, n]), h = s.shape[0], c = s.shape[1], f = this.tf.range(0, t, 2, "int32"), l = this.tf.range(1, t, 2, "int32"), d = (u) => {
|
|
33
|
-
const m = u.slice([0, 0, 0, 0], [h, c, e, t]), C = t < i ? u.slice([0, 0, 0, t], [h, c, e, i - t]) : null, D = this.tf.gather(m, f, 3), g = this.tf.gather(m, l, 3), x = D.mul(p).sub(g.mul(a)), k = g.mul(p).add(D.mul(a)), R = this.tf.stack([x, k], -1).reshape([h, c, e, t]);
|
|
34
|
-
return C ? this.tf.concat([R, C], 3) : R;
|
|
35
|
-
}, y = d(s), S = d(r);
|
|
36
|
-
return f.dispose(), l.dispose(), [y, S];
|
|
27
|
+
getCos() {
|
|
28
|
+
return this.ropeCos;
|
|
29
|
+
}
|
|
30
|
+
getSin() {
|
|
31
|
+
return this.ropeSin;
|
|
37
32
|
}
|
|
38
33
|
dispose() {
|
|
39
34
|
this.ropeCos && this.ropeCos.dispose(), this.ropeSin && this.ropeSin.dispose(), this.ropeInvFreq.dispose();
|
|
40
35
|
}
|
|
41
36
|
}
|
|
42
37
|
export {
|
|
43
|
-
|
|
38
|
+
n as default
|
|
44
39
|
};
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
import { o as h, d as i, E as o,
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
1
|
+
import { o as h, d as i, E as o, K as X, N as Y, O as Z, Q as J, T as ee, U as te, V as se, W as ne, X as re, Y as ue, l as L, I as ae, Z as A, a as ie, _ as oe, D as le, f as q, v as C, $ as P, H as U, a0 as H } from "../index-YPKosni4.js";
|
|
2
|
+
import { r as f } from "../reshape-DmnmKT6r.js";
|
|
3
|
+
import { s as ce } from "../sum-D7fu15XL.js";
|
|
4
|
+
import { m } from "../mat_mul-Bu7bhLms.js";
|
|
5
|
+
import { c as pe } from "../complex-CJ-qCcLB.js";
|
|
5
6
|
/**
|
|
6
7
|
* @license
|
|
7
8
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -20,7 +21,7 @@ import { c as pe } from "../complex-x7w5HPOS.js";
|
|
|
20
21
|
*/
|
|
21
22
|
function he(t) {
|
|
22
23
|
const s = { x: i(t, "x", "sigmoid", "float32") };
|
|
23
|
-
return o.runKernel(
|
|
24
|
+
return o.runKernel(X, s);
|
|
24
25
|
}
|
|
25
26
|
const fe = /* @__PURE__ */ h({ sigmoid_: he });
|
|
26
27
|
/**
|
|
@@ -41,7 +42,7 @@ const fe = /* @__PURE__ */ h({ sigmoid_: he });
|
|
|
41
42
|
*/
|
|
42
43
|
function de(t) {
|
|
43
44
|
const s = { x: i(t, "x", "elu", "float32") };
|
|
44
|
-
return o.runKernel(
|
|
45
|
+
return o.runKernel(Y, s);
|
|
45
46
|
}
|
|
46
47
|
const me = /* @__PURE__ */ h({ elu_: de });
|
|
47
48
|
/**
|
|
@@ -62,7 +63,7 @@ const me = /* @__PURE__ */ h({ elu_: de });
|
|
|
62
63
|
*/
|
|
63
64
|
function ge(t) {
|
|
64
65
|
const s = { input: i(t, "input", "imag") };
|
|
65
|
-
return o.runKernel(
|
|
66
|
+
return o.runKernel(Z, s);
|
|
66
67
|
}
|
|
67
68
|
const $e = /* @__PURE__ */ h({ imag_: ge });
|
|
68
69
|
/**
|
|
@@ -83,7 +84,7 @@ const $e = /* @__PURE__ */ h({ imag_: ge });
|
|
|
83
84
|
*/
|
|
84
85
|
function xe(t, e = 0.2) {
|
|
85
86
|
const n = { x: i(t, "x", "leakyRelu") }, r = { alpha: e };
|
|
86
|
-
return o.runKernel(
|
|
87
|
+
return o.runKernel(J, n, r);
|
|
87
88
|
}
|
|
88
89
|
const ke = /* @__PURE__ */ h({ leakyRelu_: xe });
|
|
89
90
|
/**
|
|
@@ -169,7 +170,7 @@ function Me(t) {
|
|
|
169
170
|
const s = { x: i(t, "x", "relu") };
|
|
170
171
|
return o.runKernel(ne, s);
|
|
171
172
|
}
|
|
172
|
-
const
|
|
173
|
+
const We = /* @__PURE__ */ h({ relu_: Me });
|
|
173
174
|
/**
|
|
174
175
|
* @license
|
|
175
176
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -186,11 +187,11 @@ const we = /* @__PURE__ */ h({ relu_: Me });
|
|
|
186
187
|
* limitations under the License.
|
|
187
188
|
* =============================================================================
|
|
188
189
|
*/
|
|
189
|
-
function
|
|
190
|
+
function we(t) {
|
|
190
191
|
const s = { x: i(t, "x", "relu6") };
|
|
191
192
|
return o.runKernel(re, s);
|
|
192
193
|
}
|
|
193
|
-
const ze = /* @__PURE__ */ h({ relu6_:
|
|
194
|
+
const ze = /* @__PURE__ */ h({ relu6_: we });
|
|
194
195
|
/**
|
|
195
196
|
* @license
|
|
196
197
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -273,7 +274,7 @@ function Te(t, e, s, n) {
|
|
|
273
274
|
if (e === "linear")
|
|
274
275
|
return t;
|
|
275
276
|
if (e === "relu")
|
|
276
|
-
return
|
|
277
|
+
return We(t);
|
|
277
278
|
if (e === "elu")
|
|
278
279
|
return me(t);
|
|
279
280
|
if (e === "relu6")
|
|
@@ -310,42 +311,42 @@ function Ne({ a: t, b: e, transposeA: s = !1, transposeB: n = !1, bias: r, activ
|
|
|
310
311
|
}
|
|
311
312
|
let u = i(t, "a", "fused matMul"), a = i(e, "b", "fused matMul");
|
|
312
313
|
[u, a] = q(u, a);
|
|
313
|
-
const D = s ? u.shape[u.rank - 2] : u.shape[u.rank - 1], b = n ? a.shape[a.rank - 1] : a.shape[a.rank - 2],
|
|
314
|
+
const D = s ? u.shape[u.rank - 2] : u.shape[u.rank - 1], b = n ? a.shape[a.rank - 1] : a.shape[a.rank - 2], W = s ? u.shape[u.rank - 1] : u.shape[u.rank - 2], w = n ? a.shape[a.rank - 2] : a.shape[a.rank - 1], T = u.shape.slice(0, -2), y = a.shape.slice(0, -2), B = C(T), N = C(y);
|
|
314
315
|
L(D === b, () => `Error in fused matMul: inner shapes (${D}) and (${b}) of Tensors with shapes ${u.shape} and ${a.shape} and transposeA=${s} and transposeB=${n} must match.`);
|
|
315
|
-
const O = P(u.shape.slice(0, -2), a.shape.slice(0, -2)).concat([
|
|
316
|
+
const O = P(u.shape.slice(0, -2), a.shape.slice(0, -2)).concat([W, w]), F = s ? f(u, [B, D, W]) : f(u, [B, W, D]), R = n ? f(a, [N, w, b]) : f(a, [N, b, w]);
|
|
316
317
|
let S;
|
|
317
318
|
r != null && (S = i(r, "bias", "fused matMul"), [S] = q(S, u), P(O, S.shape));
|
|
318
|
-
let
|
|
319
|
-
l != null && (
|
|
320
|
-
const
|
|
319
|
+
let v;
|
|
320
|
+
l != null && (v = i(l, "prelu weights", "fused matMul"));
|
|
321
|
+
const G = (x, M) => {
|
|
321
322
|
const [g, $, k, z] = M, d = Ae(f(x, k.shape), k, c);
|
|
322
323
|
let K, _;
|
|
323
324
|
if (!s && !n ? (K = m(d, $, !1, !0), _ = m(g, d, !0, !1)) : !s && n ? (K = m(d, $, !1, !1), _ = m(d, g, !0, !1)) : s && !n ? (K = m($, d, !1, !0), _ = m(g, d, !1, !1)) : (K = m($, d, !0, !0), _ = m(d, g, !0, !0)), r != null) {
|
|
324
|
-
const
|
|
325
|
-
return [K, _,
|
|
325
|
+
const V = Le(z, d);
|
|
326
|
+
return [K, _, V];
|
|
326
327
|
} else
|
|
327
328
|
return [K, _];
|
|
328
|
-
},
|
|
329
|
+
}, I = {
|
|
329
330
|
a: F,
|
|
330
331
|
b: R,
|
|
331
332
|
bias: S,
|
|
332
|
-
preluActivationWeights:
|
|
333
|
+
preluActivationWeights: v
|
|
333
334
|
}, j = { transposeA: s, transposeB: n, activation: c, leakyreluAlpha: p };
|
|
334
335
|
return r == null ? U((M, g, $) => {
|
|
335
336
|
const k = (
|
|
336
337
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
337
|
-
o.runKernel(H,
|
|
338
|
+
o.runKernel(H, I, j)
|
|
338
339
|
);
|
|
339
|
-
return $([M, g, k]), { value: f(k, O), gradFunc:
|
|
340
|
+
return $([M, g, k]), { value: f(k, O), gradFunc: G };
|
|
340
341
|
})(F, R) : U((M, g, $, k) => {
|
|
341
342
|
const z = (
|
|
342
343
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
343
|
-
o.runKernel(H,
|
|
344
|
+
o.runKernel(H, I, j)
|
|
344
345
|
);
|
|
345
|
-
return k([M, g, z, $]), { value: f(z, O), gradFunc:
|
|
346
|
+
return k([M, g, z, $]), { value: f(z, O), gradFunc: G };
|
|
346
347
|
})(F, R, S);
|
|
347
348
|
}
|
|
348
|
-
const
|
|
349
|
+
const Q = /* @__PURE__ */ h({ fusedMatMul_: Ne });
|
|
349
350
|
/**
|
|
350
351
|
* @license
|
|
351
352
|
* Copyright 2018 Google LLC
|
|
@@ -369,7 +370,7 @@ class E extends Error {
|
|
|
369
370
|
* https://opensource.org/licenses/MIT.
|
|
370
371
|
* =============================================================================
|
|
371
372
|
*/
|
|
372
|
-
function
|
|
373
|
+
function ve(t, e, s, n) {
|
|
373
374
|
if (t.rank < 2 || e.rank < 2)
|
|
374
375
|
throw new E(`dot requires both inputs to be rank >= 2 but got x shape = ${t.shape} and y shape = ${e.shape}`);
|
|
375
376
|
if (e.rank >= 3) {
|
|
@@ -378,7 +379,7 @@ function Ge(t, e, s, n) {
|
|
|
378
379
|
throw new E(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${t.shape} and y shape = ${e.shape}`);
|
|
379
380
|
}
|
|
380
381
|
if (t.rank === 2 && e.rank === 2)
|
|
381
|
-
return
|
|
382
|
+
return Q({
|
|
382
383
|
a: t,
|
|
383
384
|
b: e,
|
|
384
385
|
transposeA: !1,
|
|
@@ -392,7 +393,7 @@ function Ge(t, e, s, n) {
|
|
|
392
393
|
const l = e.shape.slice(), p = l.pop(), u = l.pop(), a = [...l, p], D = Array.from({ length: e.rank }, (T, y) => y === 0 ? e.rank - 2 : y <= e.rank - 2 ? y - 1 : y);
|
|
393
394
|
e = f(Re(e, D), [u, -1]);
|
|
394
395
|
const b = [...r, ...a];
|
|
395
|
-
return f(
|
|
396
|
+
return f(Q({
|
|
396
397
|
a: t,
|
|
397
398
|
b: e,
|
|
398
399
|
transposeA: !1,
|
|
@@ -402,7 +403,7 @@ function Ge(t, e, s, n) {
|
|
|
402
403
|
}), b);
|
|
403
404
|
}
|
|
404
405
|
}
|
|
405
|
-
class
|
|
406
|
+
class Ue {
|
|
406
407
|
vocabSize;
|
|
407
408
|
embedDim;
|
|
408
409
|
tf;
|
|
@@ -425,7 +426,7 @@ class Pe {
|
|
|
425
426
|
return this.tf.gather(this.tiedWeights, e, 0);
|
|
426
427
|
}
|
|
427
428
|
project(e) {
|
|
428
|
-
return
|
|
429
|
+
return ve(e, this.tiedWeights.transpose());
|
|
429
430
|
}
|
|
430
431
|
getWeights() {
|
|
431
432
|
return [this.tiedWeights];
|
|
@@ -444,5 +445,5 @@ class Pe {
|
|
|
444
445
|
}
|
|
445
446
|
}
|
|
446
447
|
export {
|
|
447
|
-
|
|
448
|
+
Ue as default
|
|
448
449
|
};
|