@genai-fi/nanogpt 0.4.4 → 0.4.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/NanoGPTModel.js +1 -1
- package/dist/TeachableLLM.js +7 -4
- package/dist/layers/CausalSelfAttention.js +44 -43
- package/dist/layers/RMSNorm.d.ts +1 -2
- package/dist/layers/RMSNorm.js +9 -9
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/main.js +21 -18
- package/dist/ops/cpu/normRMS.d.ts +1 -0
- package/dist/ops/cpu/normRMS.js +39 -0
- package/dist/ops/grads/normRMS.d.ts +2 -0
- package/dist/ops/grads/normRMS.js +20 -0
- package/dist/ops/normRMS.d.ts +2 -0
- package/dist/ops/normRMS.js +10 -0
- package/dist/ops/webgl/matMulGelu.d.ts +3 -2
- package/dist/ops/webgl/matMulGelu.js +72 -70
- package/dist/ops/webgl/normRMS.d.ts +1 -0
- package/dist/ops/webgl/normRMS.js +78 -0
- package/package.json +1 -1
package/dist/NanoGPTModel.js
CHANGED
|
@@ -132,7 +132,7 @@ class wt extends B {
|
|
|
132
132
|
}) : (this.ropeCache = new K(this.config.gpt), this.config.layerConfig.ropeCache = this.ropeCache), this.drop = st({ rate: this.config.gpt.dropout }), this.blocks = [];
|
|
133
133
|
for (let e = 0; e < this.config.gpt.nLayer; e++)
|
|
134
134
|
this.blocks.push(new W(e, this.config));
|
|
135
|
-
this.lnF = new N(this.config,
|
|
135
|
+
this.lnF = new N(this.config, "final_rms_norm");
|
|
136
136
|
}
|
|
137
137
|
get checkpointing() {
|
|
138
138
|
return this.config.layerConfig.checkpointAttention === !0 || this.config.layerConfig.checkpointMLP === !0;
|
package/dist/TeachableLLM.js
CHANGED
|
@@ -3,8 +3,8 @@ import l from "./NanoGPTModel.js";
|
|
|
3
3
|
import { saveModel as d } from "./utilities/save.js";
|
|
4
4
|
import { loadModel as f } from "./utilities/load.js";
|
|
5
5
|
import u from "./Generator.js";
|
|
6
|
-
import
|
|
7
|
-
import { E as
|
|
6
|
+
import p from "./Trainer.js";
|
|
7
|
+
import { E as _ } from "./index-Dwqa6Zy2.js";
|
|
8
8
|
import { dummyPassAsync as m } from "./utilities/dummy.js";
|
|
9
9
|
import c from "./tokeniser/CharTokeniser.js";
|
|
10
10
|
import g from "./tokeniser/bpe.js";
|
|
@@ -37,9 +37,12 @@ import "./ops/grads/matMulGelu.js";
|
|
|
37
37
|
import "./ops/cpu/gelu.js";
|
|
38
38
|
import "./ops/webgl/gelu.js";
|
|
39
39
|
import "./ops/grads/gelu.js";
|
|
40
|
+
import "./ops/cpu/normRMS.js";
|
|
41
|
+
import "./ops/webgl/normRMS.js";
|
|
42
|
+
import "./ops/grads/normRMS.js";
|
|
40
43
|
import w from "./utilities/profile.js";
|
|
41
44
|
class a {
|
|
42
|
-
ee = new
|
|
45
|
+
ee = new _();
|
|
43
46
|
_config;
|
|
44
47
|
_model;
|
|
45
48
|
_tokeniser;
|
|
@@ -126,7 +129,7 @@ class a {
|
|
|
126
129
|
trainer() {
|
|
127
130
|
if (!this._model || !this._tokeniser)
|
|
128
131
|
throw new Error("Model or tokeniser is not initialized.");
|
|
129
|
-
const t = new
|
|
132
|
+
const t = new p(this._model, this._tokeniser);
|
|
130
133
|
return t.on("start", () => this.setStatus("training")), t.on("stop", () => this.setStatus("ready")), t.on("log", async (e) => {
|
|
131
134
|
const i = this.ee.listeners("trainStep");
|
|
132
135
|
for (const o of i)
|
|
@@ -6,12 +6,12 @@ import { appendCache as E } from "../ops/appendCache.js";
|
|
|
6
6
|
import { D as z, F as S, t as $, c as L, e as j, H as O } from "../index--6vO-cOz.js";
|
|
7
7
|
import { fusedSoftmax as _ } from "../ops/fusedSoftmax.js";
|
|
8
8
|
import { l as W, w as M, d as x } from "../tfjs_backend-DuKis_xG.js";
|
|
9
|
-
import { o as
|
|
10
|
-
import { v as
|
|
11
|
-
import { z as
|
|
9
|
+
import { o as q } from "../ones-D6kB8bdY.js";
|
|
10
|
+
import { v as b } from "../variable-BJTZ3jOy.js";
|
|
11
|
+
import { z as B } from "../zeros-8xl-W2DC.js";
|
|
12
12
|
import { r as C, d as I } from "../dropout-DFEXTPV0.js";
|
|
13
|
-
import { r as
|
|
14
|
-
import { m as
|
|
13
|
+
import { r as F } from "../reshape-z51Eu-re.js";
|
|
14
|
+
import { m as H } from "../mat_mul-BEHRPMh0.js";
|
|
15
15
|
class nt extends T {
|
|
16
16
|
cAttn = null;
|
|
17
17
|
cProj = null;
|
|
@@ -23,16 +23,16 @@ class nt extends T {
|
|
|
23
23
|
units;
|
|
24
24
|
projUnits;
|
|
25
25
|
constructor(t, s) {
|
|
26
|
-
super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(
|
|
27
|
-
const
|
|
28
|
-
this.maskInf = M(this.bias,
|
|
26
|
+
super(s), this.index = t, this.units = s.gpt.nEmbed * 3, this.projUnits = s.gpt.nEmbed, this.bias = W.bandPart(q([s.gpt.blockSize, s.gpt.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.gpt.nEmbed / s.gpt.nHead);
|
|
27
|
+
const e = B([s.gpt.blockSize, s.gpt.blockSize]), o = z([s.gpt.blockSize, s.gpt.blockSize], Number.NEGATIVE_INFINITY);
|
|
28
|
+
this.maskInf = M(this.bias, e, o);
|
|
29
29
|
}
|
|
30
30
|
build() {
|
|
31
|
-
this.cAttn === null && (this.cAttn =
|
|
31
|
+
this.cAttn === null && (this.cAttn = b(
|
|
32
32
|
C([this.config.gpt.nEmbed, this.units], 0, 0.02),
|
|
33
33
|
!0
|
|
34
34
|
//`block_${this.index}_attn_cAttn_kernel`
|
|
35
|
-
)), this.cProj === null && (this.cProj =
|
|
35
|
+
)), this.cProj === null && (this.cProj = b(
|
|
36
36
|
C([this.projUnits, this.config.gpt.nEmbed], 0, 0.02),
|
|
37
37
|
!0
|
|
38
38
|
//`block_${this.index}_attn_cProj_kernel`
|
|
@@ -53,57 +53,58 @@ class nt extends T {
|
|
|
53
53
|
t.set(`block_${this.index}_cAttn`, this.cAttn ? [this.cAttn.clone()] : []), t.set(`block_${this.index}_cProj`, this.cProj ? [this.cProj.clone()] : []);
|
|
54
54
|
}
|
|
55
55
|
loadWeights(t) {
|
|
56
|
-
const s = t.get(`block_${this.index}_cAttn`)?.[0],
|
|
56
|
+
const s = t.get(`block_${this.index}_cAttn`)?.[0], e = t.get(`block_${this.index}_cProj`)?.[0];
|
|
57
57
|
if (!s) throw new Error(`Weights for block_${this.index}_cAttn not found`);
|
|
58
|
-
if (!
|
|
59
|
-
this.cAttn ? this.cAttn.assign(s) : this.cAttn =
|
|
58
|
+
if (!e) throw new Error(`Weights for block_${this.index}_cProj not found`);
|
|
59
|
+
this.cAttn ? this.cAttn.assign(s) : this.cAttn = b(s, !0), this.cProj ? this.cProj.assign(e) : this.cProj = b(e, !0);
|
|
60
60
|
}
|
|
61
|
-
getAttentionScores(t, s,
|
|
61
|
+
getAttentionScores(t, s, e, o) {
|
|
62
62
|
const i = P(t, s, this.divisor, this.maskInf);
|
|
63
|
-
return _(i,
|
|
63
|
+
return _(i, e ? this.config.gpt.dropout : 0, o);
|
|
64
64
|
}
|
|
65
65
|
// Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
|
|
66
|
-
getAttentionScoresWithPast(t, s,
|
|
67
|
-
const
|
|
68
|
-
return _(
|
|
66
|
+
getAttentionScoresWithPast(t, s, e) {
|
|
67
|
+
const o = P(t, s, this.divisor, void 0, e);
|
|
68
|
+
return _(o, 0, 0);
|
|
69
69
|
}
|
|
70
70
|
getQKV(t) {
|
|
71
71
|
return y(t, this.cAttn, this.config.gpt.nHead);
|
|
72
72
|
}
|
|
73
73
|
getOutputProjection(t) {
|
|
74
|
-
const s = t.shape[0],
|
|
74
|
+
const s = t.shape[0], e = t.shape[2], o = this.config.gpt.nEmbed, i = t.transpose([0, 2, 1, 3]), n = F(i, [s, e, o]);
|
|
75
75
|
return x(n, this.cProj);
|
|
76
76
|
}
|
|
77
|
-
updateCache(t, s,
|
|
78
|
-
const i = this.config.gpt.blockSize, n = t.shape[2], r =
|
|
79
|
-
|
|
77
|
+
updateCache(t, s, e, o) {
|
|
78
|
+
const i = this.config.gpt.blockSize, n = t.shape[2], r = o?.length || 0, a = e ? t : E(t, i, r, o?.k);
|
|
79
|
+
e || (t.dispose(), o?.k.dispose());
|
|
80
|
+
const p = e ? s : E(s, i, r, o?.v);
|
|
81
|
+
return e || (s.dispose(), o?.v.dispose()), {
|
|
80
82
|
k: S(a),
|
|
81
83
|
v: S(p),
|
|
82
84
|
length: Math.min(r + n, i),
|
|
83
|
-
cumulativeLength:
|
|
85
|
+
cumulativeLength: o ? o.cumulativeLength + n : n
|
|
84
86
|
};
|
|
85
87
|
}
|
|
86
|
-
forward(t, s = !1,
|
|
88
|
+
forward(t, s = !1, e, o = !1, i) {
|
|
87
89
|
return $(() => {
|
|
88
90
|
this.startMemory();
|
|
89
|
-
const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n,
|
|
91
|
+
const [n, r, a] = this.getQKV(t), p = i ? i.cumulativeLength : 0, c = this.config.layerConfig.ropeCache, u = c ? w(n, c, p) : n, A = c ? w(r, c, p) : r;
|
|
90
92
|
c && (n.dispose(), r.dispose());
|
|
91
|
-
const
|
|
92
|
-
i && (f.dispose(), a.dispose());
|
|
93
|
+
const f = i ? i.length : 0, d = this.updateCache(A, a, s, i), l = d.k, g = d.v;
|
|
93
94
|
let h;
|
|
94
|
-
|
|
95
|
-
const
|
|
96
|
-
|
|
97
|
-
const k = this.getOutputProjection(
|
|
98
|
-
|
|
99
|
-
const v =
|
|
95
|
+
f > 0 ? h = this.getAttentionScoresWithPast(u, l, f) : h = this.getAttentionScores(u, l, s, e), u.dispose(), s && l.dispose();
|
|
96
|
+
const m = H(h, g);
|
|
97
|
+
o || h.dispose(), s && g.dispose();
|
|
98
|
+
const k = this.getOutputProjection(m);
|
|
99
|
+
m.dispose();
|
|
100
|
+
const v = o ? h.mean(1) : void 0;
|
|
100
101
|
return this.endMemory("CausalSelfAttention"), { output: k, attention: v, presentKV: s ? void 0 : d };
|
|
101
102
|
});
|
|
102
103
|
}
|
|
103
|
-
call(t, s = !1,
|
|
104
|
-
if (
|
|
104
|
+
call(t, s = !1, e = !1, o) {
|
|
105
|
+
if (o && !this.config.gpt.useRope)
|
|
105
106
|
throw new Error("Cannot use pastKV without RoPE enabled");
|
|
106
|
-
if (s &&
|
|
107
|
+
if (s && o)
|
|
107
108
|
throw new Error("Cannot use pastKV during training");
|
|
108
109
|
if (t.shape.length !== 3)
|
|
109
110
|
throw new Error(`Input tensor must be rank 3 [B, T, C], got shape ${t.shape}`);
|
|
@@ -115,15 +116,15 @@ class nt extends T {
|
|
|
115
116
|
const r = L(
|
|
116
117
|
// @ts-expect-error Invalid params
|
|
117
118
|
(a, p, c, u) => {
|
|
118
|
-
const
|
|
119
|
+
const A = this.forward(a, !0, i);
|
|
119
120
|
u([a]);
|
|
120
|
-
const
|
|
121
|
-
const [
|
|
121
|
+
const f = (d, l) => {
|
|
122
|
+
const [g] = l, h = j().state.activeTape;
|
|
122
123
|
j().state.activeTape = [];
|
|
123
|
-
const
|
|
124
|
-
return j().state.activeTape = h,
|
|
124
|
+
const m = O((k, v, R) => this.forward(k, !0, i).output)([g, p, c], d);
|
|
125
|
+
return j().state.activeTape = h, m;
|
|
125
126
|
};
|
|
126
|
-
return { value:
|
|
127
|
+
return { value: A.output, gradFunc: f };
|
|
127
128
|
}
|
|
128
129
|
)(t, this.cAttn, this.cProj);
|
|
129
130
|
if (this.config.gpt.dropout > 0) {
|
|
@@ -132,7 +133,7 @@ class nt extends T {
|
|
|
132
133
|
} else
|
|
133
134
|
return { output: r };
|
|
134
135
|
} else {
|
|
135
|
-
const n = this.forward(t, s, i,
|
|
136
|
+
const n = this.forward(t, s, i, e, o);
|
|
136
137
|
if (this.config.gpt.dropout > 0) {
|
|
137
138
|
const r = I(n.output, this.config.gpt.dropout);
|
|
138
139
|
return n.output.dispose(), { output: r, attention: n.attention, presentKV: n.presentKV };
|
package/dist/layers/RMSNorm.d.ts
CHANGED
|
@@ -2,8 +2,7 @@ import { Tensor, Variable } from '@tensorflow/tfjs-core';
|
|
|
2
2
|
import { default as BaseLayer, GPTLayerConfig } from './BaseLayer';
|
|
3
3
|
export default class RMSNorm extends BaseLayer {
|
|
4
4
|
private gamma;
|
|
5
|
-
|
|
6
|
-
constructor(config: GPTLayerConfig, epsilon?: number, name?: string);
|
|
5
|
+
constructor(config: GPTLayerConfig, name?: string);
|
|
7
6
|
get trainableWeights(): Variable[];
|
|
8
7
|
set trainable(value: boolean);
|
|
9
8
|
getWeights(): Tensor[];
|
package/dist/layers/RMSNorm.js
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import { t as r } from "../index--6vO-cOz.js";
|
|
2
2
|
import m from "./BaseLayer.js";
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
|
|
3
|
+
import { normRMS as s } from "../ops/normRMS.js";
|
|
4
|
+
import { v as e } from "../variable-BJTZ3jOy.js";
|
|
5
|
+
import { o as i } from "../ones-D6kB8bdY.js";
|
|
6
|
+
class u extends m {
|
|
6
7
|
gamma;
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
super(t), this.epsilon = s, this.gamma = i(o([t.gpt.nEmbed]), !0, `${a}_gamma`, "float32");
|
|
8
|
+
constructor(t, a = "") {
|
|
9
|
+
super(t), this.gamma = e(i([t.gpt.nEmbed]), !0, `${a}_gamma`, "float32");
|
|
10
10
|
}
|
|
11
11
|
get trainableWeights() {
|
|
12
12
|
return [this.gamma];
|
|
@@ -23,8 +23,8 @@ class d extends m {
|
|
|
23
23
|
apply(t) {
|
|
24
24
|
return r(() => {
|
|
25
25
|
this.startMemory();
|
|
26
|
-
const a = t
|
|
27
|
-
return this.endMemory("RMSNorm"),
|
|
26
|
+
const a = s(t, this.gamma);
|
|
27
|
+
return this.endMemory("RMSNorm"), a;
|
|
28
28
|
});
|
|
29
29
|
}
|
|
30
30
|
dispose() {
|
|
@@ -32,5 +32,5 @@ class d extends m {
|
|
|
32
32
|
}
|
|
33
33
|
}
|
|
34
34
|
export {
|
|
35
|
-
|
|
35
|
+
u as default
|
|
36
36
|
};
|
|
@@ -12,7 +12,7 @@ class W extends p {
|
|
|
12
12
|
_trainable = !0;
|
|
13
13
|
skipped = !1;
|
|
14
14
|
constructor(t, s) {
|
|
15
|
-
super(s), this.index = t, this.ln1 = new a(s,
|
|
15
|
+
super(s), this.index = t, this.ln1 = new a(s, `block_${this.index}_rms1`), this.attn = new h(this.index, s), this.ln2 = new a(s, `block_${this.index}_rms2`), this.mlp = new o(this.index, s);
|
|
16
16
|
}
|
|
17
17
|
get variables() {
|
|
18
18
|
return [
|
package/dist/main.js
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import { default as
|
|
2
|
-
import { default as
|
|
3
|
-
import { default as
|
|
4
|
-
import { default as
|
|
5
|
-
import { default as
|
|
6
|
-
import { default as
|
|
7
|
-
import { estimateMemoryUsage as
|
|
1
|
+
import { default as E } from "./NanoGPTModel.js";
|
|
2
|
+
import { default as G } from "./TeachableLLM.js";
|
|
3
|
+
import { default as R } from "./tokeniser/CharTokeniser.js";
|
|
4
|
+
import { default as q } from "./tokeniser/bpe.js";
|
|
5
|
+
import { default as A } from "./utilities/waitForModel.js";
|
|
6
|
+
import { default as I } from "./data/textLoader.js";
|
|
7
|
+
import { estimateMemoryUsage as K, estimateParameterCount as O, estimateResources as Q, estimateTrainingMemoryUsage as S, validateConfig as V } from "./utilities/parameters.js";
|
|
8
8
|
import "./index--6vO-cOz.js";
|
|
9
9
|
import "./ops/cpu/scatterSub.js";
|
|
10
10
|
import "./ops/webgl/scatterSub.js";
|
|
@@ -31,16 +31,19 @@ import "./ops/grads/matMulGelu.js";
|
|
|
31
31
|
import "./ops/cpu/gelu.js";
|
|
32
32
|
import "./ops/webgl/gelu.js";
|
|
33
33
|
import "./ops/grads/gelu.js";
|
|
34
|
+
import "./ops/cpu/normRMS.js";
|
|
35
|
+
import "./ops/webgl/normRMS.js";
|
|
36
|
+
import "./ops/grads/normRMS.js";
|
|
34
37
|
export {
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
38
|
+
q as BPETokeniser,
|
|
39
|
+
R as CharTokeniser,
|
|
40
|
+
E as NanoGPT,
|
|
41
|
+
G as TeachableLLM,
|
|
42
|
+
K as estimateMemoryUsage,
|
|
43
|
+
O as estimateParameterCount,
|
|
44
|
+
Q as estimateResources,
|
|
45
|
+
S as estimateTrainingMemoryUsage,
|
|
46
|
+
I as loadTextData,
|
|
47
|
+
V as validateConfig,
|
|
48
|
+
A as waitForModel
|
|
46
49
|
};
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import { r as o, t as d } from "../../index--6vO-cOz.js";
|
|
2
|
+
function i(t) {
|
|
3
|
+
const { inputs: e } = t, { x: n, gamma: s } = e, r = n, a = s;
|
|
4
|
+
return d(() => {
|
|
5
|
+
const u = r.square().mean(-1, !0).add(1e-8).rsqrt();
|
|
6
|
+
return r.mul(u).mul(a);
|
|
7
|
+
});
|
|
8
|
+
}
|
|
9
|
+
const f = {
|
|
10
|
+
kernelName: "RMSNorm",
|
|
11
|
+
backendName: "cpu",
|
|
12
|
+
kernelFunc: i
|
|
13
|
+
};
|
|
14
|
+
o(f);
|
|
15
|
+
const g = {
|
|
16
|
+
kernelName: "RMSNorm",
|
|
17
|
+
backendName: "tensorflow",
|
|
18
|
+
kernelFunc: i
|
|
19
|
+
};
|
|
20
|
+
o(g);
|
|
21
|
+
function N(t) {
|
|
22
|
+
const { dy: e, x: n, gamma: s } = t.inputs;
|
|
23
|
+
return d(() => {
|
|
24
|
+
const r = n.shape[n.shape.length - 1], a = n.square().mean(-1, !0), m = a.add(1e-8).rsqrt(), u = n.mul(m), l = e.mul(u).sum([0, 1]), c = e.mul(s), k = c.mul(n).sum(-1, !0).div(r);
|
|
25
|
+
return [c.mul(m).sub(n.mul(k).mul(m).div(a.add(1e-8))), l];
|
|
26
|
+
});
|
|
27
|
+
}
|
|
28
|
+
const S = {
|
|
29
|
+
kernelName: "RMSNormGrad",
|
|
30
|
+
backendName: "cpu",
|
|
31
|
+
kernelFunc: N
|
|
32
|
+
};
|
|
33
|
+
o(S);
|
|
34
|
+
const R = {
|
|
35
|
+
kernelName: "RMSNormGrad",
|
|
36
|
+
backendName: "tensorflow",
|
|
37
|
+
kernelFunc: N
|
|
38
|
+
};
|
|
39
|
+
o(R);
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import { g as t, e as g } from "../../index--6vO-cOz.js";
|
|
2
|
+
function s(r, a, n) {
|
|
3
|
+
return g().runKernel("RMSNormGrad", { dy: r, x: a, gamma: n });
|
|
4
|
+
}
|
|
5
|
+
const u = {
|
|
6
|
+
kernelName: "RMSNorm",
|
|
7
|
+
inputsToSave: ["x", "gamma"],
|
|
8
|
+
outputsToSave: [],
|
|
9
|
+
gradFunc: (r, a) => {
|
|
10
|
+
const [n, e] = a, [m, o] = s(r, n, e);
|
|
11
|
+
return {
|
|
12
|
+
x: () => m,
|
|
13
|
+
gamma: () => o
|
|
14
|
+
};
|
|
15
|
+
}
|
|
16
|
+
};
|
|
17
|
+
t(u);
|
|
18
|
+
export {
|
|
19
|
+
u as normRMSGradConfig
|
|
20
|
+
};
|
|
@@ -7,9 +7,10 @@ type BatchMatMulConfig = {
|
|
|
7
7
|
transposeA: boolean;
|
|
8
8
|
transposeB: boolean;
|
|
9
9
|
backend: MathBackendWebGL;
|
|
10
|
-
activationSnippet
|
|
10
|
+
activationSnippet?: string;
|
|
11
|
+
multiplier?: TensorInfo;
|
|
11
12
|
};
|
|
12
|
-
export declare function batchMatMulGeluImpl({ a, b, transposeA, transposeB, backend, activationSnippet, }: BatchMatMulConfig): TensorInfo;
|
|
13
|
+
export declare function batchMatMulGeluImpl({ a, b, transposeA, transposeB, backend, activationSnippet, multiplier, }: BatchMatMulConfig): TensorInfo;
|
|
13
14
|
export declare function batchMatMulKernel(args: {
|
|
14
15
|
inputs: {
|
|
15
16
|
x: TensorInfo;
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { r as
|
|
2
|
-
import { r as
|
|
1
|
+
import { r as C, t as R, e as I, n as G, O as L, j as F, Q as U } from "../../index--6vO-cOz.js";
|
|
2
|
+
import { r as S } from "../../Reshape-CiAY8ltP.js";
|
|
3
3
|
import { u as H } from "../../gpgpu_math-CUzjlO9A.js";
|
|
4
|
-
import { m as
|
|
4
|
+
import { m as B } from "../../mat_mul-BEHRPMh0.js";
|
|
5
5
|
/**
|
|
6
6
|
* @license
|
|
7
7
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -19,39 +19,39 @@ import { m as z } from "../../mat_mul-BEHRPMh0.js";
|
|
|
19
19
|
* =============================================================================
|
|
20
20
|
*/
|
|
21
21
|
class W {
|
|
22
|
-
constructor(e, s,
|
|
23
|
-
this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape =
|
|
24
|
-
const
|
|
25
|
-
let
|
|
26
|
-
r && (
|
|
22
|
+
constructor(e, s, n, a = !1, c = !1, o = !1, r = null, u = !1, l = !1) {
|
|
23
|
+
this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = n, this.enableShapeUniforms = H(this.outputShape.length);
|
|
24
|
+
const h = a ? e[1] : e[2], p = Math.ceil(h / 2), d = a ? "i * 2, rc.y" : "rc.y, i * 2", $ = c ? "rc.z, i * 2" : "i * 2, rc.z", x = a ? ["a.xxyy", "a.zzww"] : ["a.xxzz", "a.yyww"], m = c ? ["b.xzxz", "b.ywyw"] : ["b.xyxy", "b.zwzw"];
|
|
25
|
+
let i = "", b = "";
|
|
26
|
+
r && (u ? i = `vec4 activation(vec4 a) {
|
|
27
27
|
vec4 b = getPreluActivationWeightsAtOutCoords();
|
|
28
28
|
${r}
|
|
29
|
-
}` :
|
|
29
|
+
}` : l ? i = `vec4 activation(vec4 a) {
|
|
30
30
|
vec4 b = getLeakyreluAlphaAtOutCoords();
|
|
31
31
|
${r}
|
|
32
|
-
}` :
|
|
32
|
+
}` : i = `vec4 activation(vec4 x) {
|
|
33
33
|
${r}
|
|
34
|
-
}`,
|
|
35
|
-
const
|
|
36
|
-
o && this.variableNames.push("bias"),
|
|
37
|
-
let f = "rc.x",
|
|
38
|
-
e[0] < s[0] ? f = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (
|
|
39
|
-
${
|
|
34
|
+
}`, b = "result = activation(result);");
|
|
35
|
+
const M = o ? "result += getBiasAtOutCoords();" : "";
|
|
36
|
+
o && this.variableNames.push("bias"), u && this.variableNames.push("preluActivationWeights"), l && this.variableNames.push("leakyreluAlpha");
|
|
37
|
+
let f = "rc.x", v = "rc.x";
|
|
38
|
+
e[0] < s[0] ? f = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (v = `imod(rc.x, ${s[0]})`), this.userCode = `
|
|
39
|
+
${i}
|
|
40
40
|
// Don't use uniform for sharedDimensionPacked for performance.
|
|
41
|
-
const float sharedDimension = ${
|
|
41
|
+
const float sharedDimension = ${p}.0;
|
|
42
42
|
|
|
43
43
|
vec4 dot2x2ARowBCol(ivec3 rc) {
|
|
44
44
|
vec4 result = vec4(0);
|
|
45
45
|
int batchA = ${f};
|
|
46
|
-
int batchB = ${
|
|
47
|
-
for (int i = 0; i < ${
|
|
48
|
-
vec4 a = getMatrixA(batchA, ${
|
|
49
|
-
vec4 b = getMatrixB(batchB, ${
|
|
46
|
+
int batchB = ${v};
|
|
47
|
+
for (int i = 0; i < ${p}; i++) {
|
|
48
|
+
vec4 a = getMatrixA(batchA, ${d});
|
|
49
|
+
vec4 b = getMatrixB(batchB, ${$});
|
|
50
50
|
|
|
51
51
|
// These swizzled products need to be separately added.
|
|
52
52
|
// See: https://github.com/tensorflow/tfjs/issues/1735
|
|
53
|
-
result += (${
|
|
54
|
-
result += (${
|
|
53
|
+
result += (${x[0]} * ${m[0]});
|
|
54
|
+
result += (${x[1]} * ${m[1]});
|
|
55
55
|
}
|
|
56
56
|
return result;
|
|
57
57
|
}
|
|
@@ -60,69 +60,72 @@ class W {
|
|
|
60
60
|
ivec3 rc = getOutputCoords();
|
|
61
61
|
vec4 result = dot2x2ARowBCol(rc);
|
|
62
62
|
|
|
63
|
-
${
|
|
63
|
+
${M}
|
|
64
64
|
|
|
65
|
-
${
|
|
65
|
+
${b}
|
|
66
66
|
|
|
67
67
|
setOutput(result);
|
|
68
68
|
}
|
|
69
69
|
`;
|
|
70
70
|
}
|
|
71
71
|
}
|
|
72
|
-
const
|
|
72
|
+
const g = 0.7978845608028654, w = 0.044715, j = `
|
|
73
73
|
vec4 x3 = x * x * x;
|
|
74
74
|
vec4 inner = x + ${w} * x3;
|
|
75
|
-
inner = ${
|
|
75
|
+
inner = ${g} * inner;
|
|
76
76
|
inner = tanh(inner);
|
|
77
77
|
inner = 0.5 * (1.0 + inner);
|
|
78
78
|
vec4 result = x * inner;
|
|
79
79
|
return result;
|
|
80
80
|
`, q = `
|
|
81
|
-
vec4
|
|
82
|
-
vec4
|
|
83
|
-
vec4 u = ${
|
|
81
|
+
vec4 a2 = a * a;
|
|
82
|
+
vec4 a3 = a2 * a;
|
|
83
|
+
vec4 u = ${g} * (a + ${w} * a3);
|
|
84
84
|
vec4 t = tanh(u);
|
|
85
85
|
vec4 sech2 = 1.0 - t * t;
|
|
86
|
-
vec4 du_dx = ${
|
|
87
|
-
vec4 dgelu = 0.5 * (1.0 + t) + 0.5 *
|
|
88
|
-
return dgelu;
|
|
86
|
+
vec4 du_dx = ${g} * (1.0 + 3.0 * ${w} * a2);
|
|
87
|
+
vec4 dgelu = 0.5 * (1.0 + t) + 0.5 * a * sech2 * du_dx;
|
|
88
|
+
return dgelu * b;
|
|
89
89
|
`, se = 1e3;
|
|
90
|
-
function
|
|
90
|
+
function O({
|
|
91
91
|
a: t,
|
|
92
92
|
b: e,
|
|
93
93
|
transposeA: s,
|
|
94
|
-
transposeB:
|
|
95
|
-
backend:
|
|
96
|
-
activationSnippet: c
|
|
94
|
+
transposeB: n,
|
|
95
|
+
backend: a,
|
|
96
|
+
activationSnippet: c,
|
|
97
|
+
multiplier: o
|
|
97
98
|
}) {
|
|
98
|
-
const
|
|
99
|
+
const r = t.shape.length, u = e.shape.length, l = s ? t.shape[r - 2] : t.shape[r - 1], h = n ? e.shape[u - 1] : e.shape[u - 2], p = s ? t.shape[r - 1] : t.shape[r - 2], d = n ? e.shape[u - 2] : e.shape[u - 1], $ = t.shape.slice(0, -2), x = e.shape.slice(0, -2), m = G($), i = G(x), M = L(t.shape.slice(0, -2), e.shape.slice(0, -2)).concat([p, d]);
|
|
99
100
|
F(
|
|
100
|
-
|
|
101
|
-
() => `Error in matMul: inner shapes (${
|
|
101
|
+
l === h,
|
|
102
|
+
() => `Error in matMul: inner shapes (${l}) and (${h}) of Tensors with shapes ${t.shape} and ${e.shape} and transposeA=${s} and transposeB=${n} must match.`
|
|
102
103
|
);
|
|
103
|
-
const
|
|
104
|
-
$,
|
|
104
|
+
const f = s ? [m, l, p] : [m, p, l], v = n ? [i, d, h] : [i, h, d], A = S({ inputs: { x: t }, backend: a, attrs: { shape: f } }), y = S({ inputs: { x: e }, backend: a, attrs: { shape: v } }), D = [A, y], E = Math.max(m, i), N = c, T = U(t.dtype, e.dtype), _ = new W(
|
|
105
105
|
f,
|
|
106
|
-
|
|
106
|
+
v,
|
|
107
|
+
[E, p, d],
|
|
107
108
|
s,
|
|
108
|
-
|
|
109
|
-
!1,
|
|
110
|
-
O,
|
|
109
|
+
n,
|
|
111
110
|
!1,
|
|
111
|
+
N,
|
|
112
|
+
!!o,
|
|
112
113
|
!1
|
|
113
|
-
),
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
114
|
+
), k = [A, y];
|
|
115
|
+
o && k.push(o);
|
|
116
|
+
const z = a.runWebGLProgram(_, k, T), K = S({ inputs: { x: z }, backend: a, attrs: { shape: M } });
|
|
117
|
+
D.push(z);
|
|
118
|
+
for (const P of D)
|
|
119
|
+
a.disposeIntermediateTensorInfo(P);
|
|
120
|
+
return K;
|
|
118
121
|
}
|
|
119
122
|
function Q(t) {
|
|
120
|
-
const { inputs: e, backend: s } = t, { x:
|
|
121
|
-
if (
|
|
123
|
+
const { inputs: e, backend: s } = t, { x: n, kernel: a } = e;
|
|
124
|
+
if (n === void 0 || a === void 0)
|
|
122
125
|
throw new Error("BatchMatMul requires two input tensors.");
|
|
123
|
-
return
|
|
124
|
-
a,
|
|
125
|
-
b:
|
|
126
|
+
return O({
|
|
127
|
+
a: n,
|
|
128
|
+
b: a,
|
|
126
129
|
transposeA: !1,
|
|
127
130
|
transposeB: !1,
|
|
128
131
|
backend: s,
|
|
@@ -134,23 +137,22 @@ const J = {
|
|
|
134
137
|
backendName: "webgl",
|
|
135
138
|
kernelFunc: Q
|
|
136
139
|
};
|
|
137
|
-
|
|
140
|
+
C(J);
|
|
138
141
|
function V(t) {
|
|
139
|
-
const { dy: e, x: s, kernel:
|
|
140
|
-
return
|
|
141
|
-
const c =
|
|
142
|
-
|
|
142
|
+
const { dy: e, x: s, kernel: n } = t.inputs, a = t.backend;
|
|
143
|
+
return R(() => {
|
|
144
|
+
const c = I().makeTensorFromTensorInfo(
|
|
145
|
+
O({
|
|
143
146
|
a: s,
|
|
144
|
-
b:
|
|
147
|
+
b: n,
|
|
145
148
|
transposeA: !1,
|
|
146
149
|
transposeB: !1,
|
|
147
|
-
backend:
|
|
148
|
-
activationSnippet: q
|
|
150
|
+
backend: a,
|
|
151
|
+
activationSnippet: q,
|
|
152
|
+
multiplier: e
|
|
149
153
|
})
|
|
150
|
-
), o =
|
|
151
|
-
|
|
152
|
-
const r = z(o, a, !1, !0), i = z(s, o, !0, !1);
|
|
153
|
-
return [r, i];
|
|
154
|
+
), o = B(c, n, !1, !0), r = B(s, c, !0, !1);
|
|
155
|
+
return [o, r];
|
|
154
156
|
});
|
|
155
157
|
}
|
|
156
158
|
const X = {
|
|
@@ -158,9 +160,9 @@ const X = {
|
|
|
158
160
|
backendName: "webgl",
|
|
159
161
|
kernelFunc: V
|
|
160
162
|
};
|
|
161
|
-
|
|
163
|
+
C(X);
|
|
162
164
|
export {
|
|
163
165
|
se as MATMUL_SHARED_DIM_THRESHOLD,
|
|
164
|
-
|
|
166
|
+
O as batchMatMulGeluImpl,
|
|
165
167
|
Q as batchMatMulKernel
|
|
166
168
|
};
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import { r as c, e as h } from "../../index--6vO-cOz.js";
|
|
2
|
+
import { s as q } from "../../sum-DdkDf2MG.js";
|
|
3
|
+
class G {
|
|
4
|
+
variableNames = ["x", "meanSquare", "gamma"];
|
|
5
|
+
outputShape;
|
|
6
|
+
userCode;
|
|
7
|
+
constructor(e, a, o) {
|
|
8
|
+
this.outputShape = [e, a, o], this.userCode = `
|
|
9
|
+
void main() {
|
|
10
|
+
ivec3 coords = getOutputCoords();
|
|
11
|
+
float x = getXAtOutCoords();
|
|
12
|
+
float meanSquare = getMeanSquare(coords.x, coords.y, 0);
|
|
13
|
+
float gamma = getGammaAtOutCoords();
|
|
14
|
+
float invRms = inversesqrt(meanSquare + 1e-8);
|
|
15
|
+
float normalized = x * invRms;
|
|
16
|
+
float outVal = normalized * gamma;
|
|
17
|
+
setOutput(outVal);
|
|
18
|
+
}
|
|
19
|
+
`;
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
function v(t) {
|
|
23
|
+
const { x: e, gamma: a } = t.inputs, o = t.backend, r = e.shape[0], s = e.shape[1], n = e.shape[2], m = e.square().mean(-1, !0), u = new G(r, s, n);
|
|
24
|
+
return o.runWebGLProgram(u, [e, m, a], "float32");
|
|
25
|
+
}
|
|
26
|
+
const x = {
|
|
27
|
+
kernelName: "RMSNorm",
|
|
28
|
+
backendName: "webgl",
|
|
29
|
+
kernelFunc: v
|
|
30
|
+
};
|
|
31
|
+
c(x);
|
|
32
|
+
class y {
|
|
33
|
+
variableNames = ["x", "meanSquare", "dyGamma", "dyXMean"];
|
|
34
|
+
outputShape;
|
|
35
|
+
userCode;
|
|
36
|
+
constructor(e, a, o) {
|
|
37
|
+
this.outputShape = [e, a, o], this.userCode = `
|
|
38
|
+
void main() {
|
|
39
|
+
ivec3 coords = getOutputCoords();
|
|
40
|
+
float x = getXAtOutCoords();
|
|
41
|
+
float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
|
|
42
|
+
float dyGamma = getDyGammaAtOutCoords();
|
|
43
|
+
float dyXMean = getDyXMean(coords.x, coords.y, 0) / ${o}.0;
|
|
44
|
+
float invRms = inversesqrt(meanSquare);
|
|
45
|
+
float dx = dyGamma * invRms - x * dyXMean * invRms / meanSquare;
|
|
46
|
+
setOutput(dx);
|
|
47
|
+
}
|
|
48
|
+
`;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
class C {
|
|
52
|
+
variableNames = ["x", "meanSquare", "dy"];
|
|
53
|
+
outputShape;
|
|
54
|
+
userCode;
|
|
55
|
+
constructor(e, a, o) {
|
|
56
|
+
this.outputShape = [e, a, o], this.userCode = `
|
|
57
|
+
void main() {
|
|
58
|
+
ivec3 coords = getOutputCoords();
|
|
59
|
+
float x = getXAtOutCoords();
|
|
60
|
+
float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
|
|
61
|
+
float dy = getDyAtOutCoords();
|
|
62
|
+
float invRms = inversesqrt(meanSquare);
|
|
63
|
+
float dGamma = dy * (x * invRms);
|
|
64
|
+
setOutput(dGamma);
|
|
65
|
+
}
|
|
66
|
+
`;
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
function b(t) {
|
|
70
|
+
const { dy: e, x: a, gamma: o } = t.inputs, r = t.backend, s = a.shape[0], n = a.shape[1], m = a.shape[2], u = a.square().mean(-1, !0), d = e.mul(o), l = d.mul(a).sum(-1, !0), i = new y(s, n, m), g = r.runWebGLProgram(i, [a, u, d, l], "float32"), p = new C(s, n, m), S = r.runWebGLProgram(p, [a, u, e], "float32"), f = q(h().makeTensorFromTensorInfo(S), [0, 1]);
|
|
71
|
+
return [g, f];
|
|
72
|
+
}
|
|
73
|
+
const N = {
|
|
74
|
+
kernelName: "RMSNormGrad",
|
|
75
|
+
backendName: "webgl",
|
|
76
|
+
kernelFunc: b
|
|
77
|
+
};
|
|
78
|
+
c(N);
|