@genai-fi/nanogpt 0.4.4 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseLayer-BhrMN8JO.js +135 -0
- package/dist/Generator.js +44 -41
- package/dist/NanoGPTModel.d.ts +12 -16
- package/dist/NanoGPTModel.js +128 -138
- package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
- package/dist/TeachableLLM.js +8 -5
- package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
- package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
- package/dist/broadcast_to-CMlkG8NS.js +44 -0
- package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
- package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
- package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
- package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
- package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
- package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
- package/dist/layers/BaseLayer.d.ts +28 -4
- package/dist/layers/BaseLayer.js +3 -16
- package/dist/layers/CausalSelfAttention.d.ts +22 -24
- package/dist/layers/CausalSelfAttention.js +73 -127
- package/dist/layers/MLP.d.ts +8 -15
- package/dist/layers/MLP.js +43 -81
- package/dist/layers/RMSNorm.d.ts +5 -11
- package/dist/layers/RMSNorm.js +13 -29
- package/dist/layers/RoPECache.js +14 -12
- package/dist/layers/TiedEmbedding.d.ts +6 -16
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.d.ts +12 -16
- package/dist/layers/TransformerBlock.js +20 -41
- package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
- package/dist/main.js +22 -19
- package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
- package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
- package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
- package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
- package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
- package/dist/ops/appendCache.js +4 -4
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +14 -15
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.d.ts +1 -0
- package/dist/ops/cpu/matMulMul.js +17 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.d.ts +1 -0
- package/dist/ops/cpu/normRMS.js +39 -0
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +8 -8
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +13 -9
- package/dist/ops/grads/fusedSoftmax.js +12 -9
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.d.ts +2 -0
- package/dist/ops/grads/normRMS.js +20 -0
- package/dist/ops/grads/qkv.js +19 -9
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.d.ts +2 -0
- package/dist/ops/matMulMul.js +9 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.d.ts +2 -0
- package/dist/ops/normRMS.js +10 -0
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +13 -12
- package/dist/ops/webgl/fusedSoftmax.js +43 -40
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.d.ts +3 -2
- package/dist/ops/webgl/matMulGelu.js +77 -75
- package/dist/ops/webgl/matMulMul.d.ts +14 -0
- package/dist/ops/webgl/matMulMul.js +28 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.d.ts +1 -0
- package/dist/ops/webgl/normRMS.js +86 -0
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops-ObfXLHYQ.js +1269 -0
- package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
- package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
- package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
- package/dist/slice_util-D-kaD4ZV.js +49 -0
- package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
- package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
- package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
- package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
- package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
- package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
- package/dist/tfjs_backend-NucKez4s.js +1010 -0
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +44 -44
- package/dist/training/Evaluator.js +6 -6
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +7 -7
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +10 -10
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/save.js +10 -8
- package/dist/utilities/weights.js +2 -2
- package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
- package/package.json +1 -1
- package/dist/slice_util-BdhYwFY_.js +0 -90
- package/dist/tfjs_backend-DuKis_xG.js +0 -2271
- package/dist/variable-BJTZ3jOy.js +0 -23
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { r as
|
|
2
|
-
import { r as
|
|
3
|
-
import { u as H } from "../../gpgpu_math-
|
|
4
|
-
import { m as
|
|
1
|
+
import { r as C, t as R, e as I, p as G, N as L, k as F, O as U } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
import { r as S } from "../../Reshape-BE5rA4rT.js";
|
|
3
|
+
import { u as H } from "../../gpgpu_math-C0zyxKFi.js";
|
|
4
|
+
import { m as B } from "../../mat_mul-D0SifYfJ.js";
|
|
5
5
|
/**
|
|
6
6
|
* @license
|
|
7
7
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -19,39 +19,39 @@ import { m as z } from "../../mat_mul-BEHRPMh0.js";
|
|
|
19
19
|
* =============================================================================
|
|
20
20
|
*/
|
|
21
21
|
class W {
|
|
22
|
-
constructor(e, s,
|
|
23
|
-
this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape =
|
|
24
|
-
const
|
|
25
|
-
let
|
|
26
|
-
r && (
|
|
22
|
+
constructor(e, s, n, a = !1, c = !1, o = !1, r = null, u = !1, l = !1) {
|
|
23
|
+
this.variableNames = ["matrixA", "matrixB"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = n, this.enableShapeUniforms = H(this.outputShape.length);
|
|
24
|
+
const h = a ? e[1] : e[2], p = Math.ceil(h / 2), d = a ? "i * 2, rc.y" : "rc.y, i * 2", $ = c ? "rc.z, i * 2" : "i * 2, rc.z", x = a ? ["a.xxyy", "a.zzww"] : ["a.xxzz", "a.yyww"], m = c ? ["b.xzxz", "b.ywyw"] : ["b.xyxy", "b.zwzw"];
|
|
25
|
+
let i = "", b = "";
|
|
26
|
+
r && (u ? i = `vec4 activation(vec4 a) {
|
|
27
27
|
vec4 b = getPreluActivationWeightsAtOutCoords();
|
|
28
28
|
${r}
|
|
29
|
-
}` :
|
|
29
|
+
}` : l ? i = `vec4 activation(vec4 a) {
|
|
30
30
|
vec4 b = getLeakyreluAlphaAtOutCoords();
|
|
31
31
|
${r}
|
|
32
|
-
}` :
|
|
32
|
+
}` : i = `vec4 activation(vec4 x) {
|
|
33
33
|
${r}
|
|
34
|
-
}`,
|
|
35
|
-
const
|
|
36
|
-
o && this.variableNames.push("bias"),
|
|
37
|
-
let f = "rc.x",
|
|
38
|
-
e[0] < s[0] ? f = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (
|
|
39
|
-
${
|
|
34
|
+
}`, b = "result = activation(result);");
|
|
35
|
+
const M = o ? "result += getBiasAtOutCoords();" : "";
|
|
36
|
+
o && this.variableNames.push("bias"), u && this.variableNames.push("preluActivationWeights"), l && this.variableNames.push("leakyreluAlpha");
|
|
37
|
+
let f = "rc.x", v = "rc.x";
|
|
38
|
+
e[0] < s[0] ? f = `imod(rc.x, ${e[0]})` : s[0] < e[0] && (v = `imod(rc.x, ${s[0]})`), this.userCode = `
|
|
39
|
+
${i}
|
|
40
40
|
// Don't use uniform for sharedDimensionPacked for performance.
|
|
41
|
-
const float sharedDimension = ${
|
|
41
|
+
const float sharedDimension = ${p}.0;
|
|
42
42
|
|
|
43
43
|
vec4 dot2x2ARowBCol(ivec3 rc) {
|
|
44
44
|
vec4 result = vec4(0);
|
|
45
45
|
int batchA = ${f};
|
|
46
|
-
int batchB = ${
|
|
47
|
-
for (int i = 0; i < ${
|
|
48
|
-
vec4 a = getMatrixA(batchA, ${
|
|
49
|
-
vec4 b = getMatrixB(batchB, ${
|
|
46
|
+
int batchB = ${v};
|
|
47
|
+
for (int i = 0; i < ${p}; i++) {
|
|
48
|
+
vec4 a = getMatrixA(batchA, ${d});
|
|
49
|
+
vec4 b = getMatrixB(batchB, ${$});
|
|
50
50
|
|
|
51
51
|
// These swizzled products need to be separately added.
|
|
52
52
|
// See: https://github.com/tensorflow/tfjs/issues/1735
|
|
53
|
-
result += (${
|
|
54
|
-
result += (${
|
|
53
|
+
result += (${x[0]} * ${m[0]});
|
|
54
|
+
result += (${x[1]} * ${m[1]});
|
|
55
55
|
}
|
|
56
56
|
return result;
|
|
57
57
|
}
|
|
@@ -60,97 +60,99 @@ class W {
|
|
|
60
60
|
ivec3 rc = getOutputCoords();
|
|
61
61
|
vec4 result = dot2x2ARowBCol(rc);
|
|
62
62
|
|
|
63
|
-
${
|
|
63
|
+
${M}
|
|
64
64
|
|
|
65
|
-
${
|
|
65
|
+
${b}
|
|
66
66
|
|
|
67
67
|
setOutput(result);
|
|
68
68
|
}
|
|
69
69
|
`;
|
|
70
70
|
}
|
|
71
71
|
}
|
|
72
|
-
const
|
|
72
|
+
const g = 0.7978845608028654, w = 0.044715, j = `
|
|
73
73
|
vec4 x3 = x * x * x;
|
|
74
74
|
vec4 inner = x + ${w} * x3;
|
|
75
|
-
inner = ${
|
|
75
|
+
inner = ${g} * inner;
|
|
76
76
|
inner = tanh(inner);
|
|
77
77
|
inner = 0.5 * (1.0 + inner);
|
|
78
78
|
vec4 result = x * inner;
|
|
79
79
|
return result;
|
|
80
80
|
`, q = `
|
|
81
|
-
vec4
|
|
82
|
-
vec4
|
|
83
|
-
vec4 u = ${
|
|
81
|
+
vec4 a2 = a * a;
|
|
82
|
+
vec4 a3 = a2 * a;
|
|
83
|
+
vec4 u = ${g} * (a + ${w} * a3);
|
|
84
84
|
vec4 t = tanh(u);
|
|
85
85
|
vec4 sech2 = 1.0 - t * t;
|
|
86
|
-
vec4 du_dx = ${
|
|
87
|
-
vec4 dgelu = 0.5 * (1.0 + t) + 0.5 *
|
|
88
|
-
return dgelu;
|
|
86
|
+
vec4 du_dx = ${g} * (1.0 + 3.0 * ${w} * a2);
|
|
87
|
+
vec4 dgelu = 0.5 * (1.0 + t) + 0.5 * a * sech2 * du_dx;
|
|
88
|
+
return dgelu * b;
|
|
89
89
|
`, se = 1e3;
|
|
90
|
-
function
|
|
90
|
+
function O({
|
|
91
91
|
a: t,
|
|
92
92
|
b: e,
|
|
93
93
|
transposeA: s,
|
|
94
|
-
transposeB:
|
|
95
|
-
backend:
|
|
96
|
-
activationSnippet: c
|
|
94
|
+
transposeB: n,
|
|
95
|
+
backend: a,
|
|
96
|
+
activationSnippet: c,
|
|
97
|
+
multiplier: o
|
|
97
98
|
}) {
|
|
98
|
-
const
|
|
99
|
+
const r = t.shape.length, u = e.shape.length, l = s ? t.shape[r - 2] : t.shape[r - 1], h = n ? e.shape[u - 1] : e.shape[u - 2], p = s ? t.shape[r - 1] : t.shape[r - 2], d = n ? e.shape[u - 2] : e.shape[u - 1], $ = t.shape.slice(0, -2), x = e.shape.slice(0, -2), m = G($), i = G(x), M = L(t.shape.slice(0, -2), e.shape.slice(0, -2)).concat([p, d]);
|
|
99
100
|
F(
|
|
100
|
-
|
|
101
|
-
() => `Error in matMul: inner shapes (${
|
|
101
|
+
l === h,
|
|
102
|
+
() => `Error in matMul: inner shapes (${l}) and (${h}) of Tensors with shapes ${t.shape} and ${e.shape} and transposeA=${s} and transposeB=${n} must match.`
|
|
102
103
|
);
|
|
103
|
-
const
|
|
104
|
-
$,
|
|
104
|
+
const f = s ? [m, l, p] : [m, p, l], v = n ? [i, d, h] : [i, h, d], A = S({ inputs: { x: t }, backend: a, attrs: { shape: f } }), y = S({ inputs: { x: e }, backend: a, attrs: { shape: v } }), k = [A, y], N = Math.max(m, i), E = c, T = U(t.dtype, e.dtype), _ = new W(
|
|
105
105
|
f,
|
|
106
|
-
|
|
106
|
+
v,
|
|
107
|
+
[N, p, d],
|
|
107
108
|
s,
|
|
108
|
-
|
|
109
|
-
!1,
|
|
110
|
-
O,
|
|
109
|
+
n,
|
|
111
110
|
!1,
|
|
111
|
+
E,
|
|
112
|
+
!!o,
|
|
112
113
|
!1
|
|
113
|
-
),
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
114
|
+
), D = [A, y];
|
|
115
|
+
o && D.push(o);
|
|
116
|
+
const z = a.runWebGLProgram(_, D, T), K = S({ inputs: { x: z }, backend: a, attrs: { shape: M } });
|
|
117
|
+
k.push(z);
|
|
118
|
+
for (const P of k)
|
|
119
|
+
a.disposeIntermediateTensorInfo(P);
|
|
120
|
+
return K;
|
|
118
121
|
}
|
|
119
|
-
function
|
|
120
|
-
const { inputs: e, backend: s } = t, { x:
|
|
121
|
-
if (
|
|
122
|
+
function J(t) {
|
|
123
|
+
const { inputs: e, backend: s } = t, { x: n, kernel: a } = e;
|
|
124
|
+
if (n === void 0 || a === void 0)
|
|
122
125
|
throw new Error("BatchMatMul requires two input tensors.");
|
|
123
|
-
return
|
|
124
|
-
a,
|
|
125
|
-
b:
|
|
126
|
+
return O({
|
|
127
|
+
a: n,
|
|
128
|
+
b: a,
|
|
126
129
|
transposeA: !1,
|
|
127
130
|
transposeB: !1,
|
|
128
131
|
backend: s,
|
|
129
132
|
activationSnippet: j
|
|
130
133
|
});
|
|
131
134
|
}
|
|
132
|
-
const
|
|
135
|
+
const Q = {
|
|
133
136
|
kernelName: "MatMulGelu",
|
|
134
137
|
backendName: "webgl",
|
|
135
|
-
kernelFunc:
|
|
138
|
+
kernelFunc: J
|
|
136
139
|
};
|
|
137
|
-
|
|
140
|
+
C(Q);
|
|
138
141
|
function V(t) {
|
|
139
|
-
const { dy: e, x: s, kernel:
|
|
140
|
-
return
|
|
141
|
-
const c =
|
|
142
|
-
|
|
142
|
+
const { dy: e, x: s, kernel: n } = t.inputs, a = t.backend;
|
|
143
|
+
return R(() => {
|
|
144
|
+
const c = I().makeTensorFromTensorInfo(
|
|
145
|
+
O({
|
|
143
146
|
a: s,
|
|
144
|
-
b:
|
|
147
|
+
b: n,
|
|
145
148
|
transposeA: !1,
|
|
146
149
|
transposeB: !1,
|
|
147
|
-
backend:
|
|
148
|
-
activationSnippet: q
|
|
150
|
+
backend: a,
|
|
151
|
+
activationSnippet: q,
|
|
152
|
+
multiplier: e
|
|
149
153
|
})
|
|
150
|
-
), o =
|
|
151
|
-
|
|
152
|
-
const r = z(o, a, !1, !0), i = z(s, o, !0, !1);
|
|
153
|
-
return [r, i];
|
|
154
|
+
), o = B(c, n, !1, !0), r = B(s, c, !0, !1);
|
|
155
|
+
return [o, r];
|
|
154
156
|
});
|
|
155
157
|
}
|
|
156
158
|
const X = {
|
|
@@ -158,9 +160,9 @@ const X = {
|
|
|
158
160
|
backendName: "webgl",
|
|
159
161
|
kernelFunc: V
|
|
160
162
|
};
|
|
161
|
-
|
|
163
|
+
C(X);
|
|
162
164
|
export {
|
|
163
165
|
se as MATMUL_SHARED_DIM_THRESHOLD,
|
|
164
|
-
|
|
165
|
-
|
|
166
|
+
O as batchMatMulGeluImpl,
|
|
167
|
+
J as batchMatMulKernel
|
|
166
168
|
};
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import { TensorInfo } from '@tensorflow/tfjs-core';
|
|
2
|
+
import { MathBackendWebGL } from '@tensorflow/tfjs-backend-webgl';
|
|
3
|
+
export declare function batchMatMulKernel(args: {
|
|
4
|
+
inputs: {
|
|
5
|
+
x: TensorInfo;
|
|
6
|
+
kernel: TensorInfo;
|
|
7
|
+
y: TensorInfo;
|
|
8
|
+
};
|
|
9
|
+
attrs: {
|
|
10
|
+
transposeA: boolean;
|
|
11
|
+
transposeB: boolean;
|
|
12
|
+
};
|
|
13
|
+
backend: MathBackendWebGL;
|
|
14
|
+
}): TensorInfo;
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import { r as u } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
import { batchMatMulGeluImpl as c } from "./matMulGelu.js";
|
|
3
|
+
const M = `
|
|
4
|
+
return a * b;
|
|
5
|
+
`;
|
|
6
|
+
function p(r) {
|
|
7
|
+
const { inputs: n, backend: o, attrs: a } = r, { x: t, kernel: e, y: l } = n, { transposeA: i, transposeB: s } = a;
|
|
8
|
+
if (t === void 0 || e === void 0)
|
|
9
|
+
throw new Error("BatchMatMul requires two input tensors.");
|
|
10
|
+
return c({
|
|
11
|
+
a: t,
|
|
12
|
+
b: e,
|
|
13
|
+
transposeA: i,
|
|
14
|
+
transposeB: s,
|
|
15
|
+
backend: o,
|
|
16
|
+
activationSnippet: M,
|
|
17
|
+
multiplier: l
|
|
18
|
+
});
|
|
19
|
+
}
|
|
20
|
+
const m = {
|
|
21
|
+
kernelName: "MatMulMul",
|
|
22
|
+
backendName: "webgl",
|
|
23
|
+
kernelFunc: p
|
|
24
|
+
};
|
|
25
|
+
u(m);
|
|
26
|
+
export {
|
|
27
|
+
p as batchMatMulKernel
|
|
28
|
+
};
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import { r as p, e as G } from "../../index-iNhkcAEQ.js";
|
|
2
|
+
import { s as x } from "../../sum-B_92TaHD.js";
|
|
3
|
+
class y {
|
|
4
|
+
variableNames = ["x", "meanSquare", "gamma"];
|
|
5
|
+
outputShape;
|
|
6
|
+
userCode;
|
|
7
|
+
constructor(a, e, o) {
|
|
8
|
+
this.outputShape = [a, e, o], this.userCode = `
|
|
9
|
+
void main() {
|
|
10
|
+
ivec3 coords = getOutputCoords();
|
|
11
|
+
float x = getXAtOutCoords();
|
|
12
|
+
float meanSquare = getMeanSquare(coords.x, coords.y, 0);
|
|
13
|
+
float gamma = getGammaAtOutCoords();
|
|
14
|
+
float invRms = inversesqrt(meanSquare + 1e-8);
|
|
15
|
+
float normalized = x * invRms;
|
|
16
|
+
float outVal = normalized * gamma;
|
|
17
|
+
setOutput(outVal);
|
|
18
|
+
}
|
|
19
|
+
`;
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
function v(t) {
|
|
23
|
+
const { x: a, gamma: e } = t.inputs, o = t.backend, r = a.shape[0], n = a.shape[1], m = a.shape[2], u = a.square().mean(-1, !0), s = new y(r, n, m);
|
|
24
|
+
return o.runWebGLProgram(s, [a, u, e], "float32");
|
|
25
|
+
}
|
|
26
|
+
const C = {
|
|
27
|
+
kernelName: "RMSNorm",
|
|
28
|
+
backendName: "webgl",
|
|
29
|
+
kernelFunc: v
|
|
30
|
+
};
|
|
31
|
+
p(C);
|
|
32
|
+
class b {
|
|
33
|
+
variableNames = ["x", "meanSquare", "dyGamma", "dyXMean"];
|
|
34
|
+
outputShape;
|
|
35
|
+
userCode;
|
|
36
|
+
constructor(a, e, o) {
|
|
37
|
+
this.outputShape = [a, e, o], this.userCode = `
|
|
38
|
+
void main() {
|
|
39
|
+
ivec3 coords = getOutputCoords();
|
|
40
|
+
float x = getXAtOutCoords();
|
|
41
|
+
float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
|
|
42
|
+
float dyGamma = getDyGammaAtOutCoords();
|
|
43
|
+
float dyXMean = getDyXMean(coords.x, coords.y, 0) / ${o}.0;
|
|
44
|
+
float invRms = inversesqrt(meanSquare);
|
|
45
|
+
float dx = dyGamma * invRms - x * dyXMean * invRms / meanSquare;
|
|
46
|
+
setOutput(dx);
|
|
47
|
+
}
|
|
48
|
+
`;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
class N {
|
|
52
|
+
variableNames = ["x", "meanSquare", "dy"];
|
|
53
|
+
outputShape;
|
|
54
|
+
userCode;
|
|
55
|
+
constructor(a, e, o) {
|
|
56
|
+
this.outputShape = [a, e, o], this.userCode = `
|
|
57
|
+
void main() {
|
|
58
|
+
ivec3 coords = getOutputCoords();
|
|
59
|
+
float x = getXAtOutCoords();
|
|
60
|
+
float meanSquare = getMeanSquare(coords.x, coords.y, 0) + 1e-8;
|
|
61
|
+
float dy = getDyAtOutCoords();
|
|
62
|
+
float invRms = inversesqrt(meanSquare);
|
|
63
|
+
float dGamma = dy * (x * invRms);
|
|
64
|
+
setOutput(dGamma);
|
|
65
|
+
}
|
|
66
|
+
`;
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
function M(t) {
|
|
70
|
+
const { dy: a, x: e, gamma: o } = t.inputs, r = t.backend, n = e.shape[0], m = e.shape[1], u = e.shape[2], s = a.mul(o), c = s.mul(e), i = c.sum(-1, !0);
|
|
71
|
+
c.dispose();
|
|
72
|
+
const l = e.square(), d = l.mean(-1, !0);
|
|
73
|
+
l.dispose();
|
|
74
|
+
const f = new b(n, m, u), S = r.runWebGLProgram(f, [e, d, s, i], "float32");
|
|
75
|
+
s.dispose(), i.dispose();
|
|
76
|
+
const h = new N(n, m, u), g = r.runWebGLProgram(h, [e, d, a], "float32");
|
|
77
|
+
d.dispose();
|
|
78
|
+
const q = x(G().makeTensorFromTensorInfo(g), [0, 1]);
|
|
79
|
+
return r.disposeIntermediateTensorInfo(g), [S, q];
|
|
80
|
+
}
|
|
81
|
+
const k = {
|
|
82
|
+
kernelName: "RMSNormGrad",
|
|
83
|
+
backendName: "webgl",
|
|
84
|
+
kernelFunc: M
|
|
85
|
+
};
|
|
86
|
+
p(k);
|
package/dist/ops/webgl/qkv.js
CHANGED
package/dist/ops/webgl/rope.js
CHANGED